# Notebook 05 — Results + Data Analysis (Tasks 1–3)

I use this notebook to regenerate the figures for my writeups. I save everything to `docs/figures/`.

Workflow note: I ran Notebooks 02/03/04 in Colab, then downloaded the repo + artifacts so I could run this notebook locally.

Sections:
- Task 1: training curves + gameplay vs LogicBot
- Task 2: bootstrapping results by round
- Task 3: thinking-time plots (reloaded from Notebook 04 exports if present) + heatmaps (generated here from the v1 checkpoints in this repo)


In [1]:
# I install plotting deps (I keep it explicit).
%pip install -q numpy matplotlib



You should consider upgrading via the '/Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/.venv/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
# I locate the repo root and set up `docs/figures/`.
# I usually run Notebooks 02/03/04 in Colab, then download the repo/artifacts and run this notebook locally.
import sys
from pathlib import Path


def _find_repo_root() -> Path:
    # Colab default (if I’m running inside Colab)
    p = Path('/content/repo')
    if p.exists() and (p / 'minesweeper').exists() and (p / 'models').exists():
        kids = [k for k in p.iterdir() if k.is_dir()]
        if len(kids) == 1 and (kids[0] / 'minesweeper').exists() and (kids[0] / 'models').exists():
            return kids[0]
        return p

    # Local run: walk upward until I see minesweeper/ + models/
    for q in [Path.cwd(), *Path.cwd().parents]:
        if (q / 'minesweeper').exists() and (q / 'models').exists():
            return q

    raise FileNotFoundError('I could not find repo root (expected minesweeper/ + models/)')


repo_root = _find_repo_root()
sys.path.insert(0, str(repo_root))
print('Repo root:', repo_root)

fig_dir = Path(repo_root) / 'docs' / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)
print('Figure dir:', fig_dir)



Repo root: /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission
Figure dir: /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures


In [3]:
# Task 1 — I plot training curves
# I pasted these metrics straight from my training log so the plots are reproducible.

import numpy as np

metrics = {
    'easy': {
        'train_loss': [0.2876,0.1761,0.1202,0.0898,0.0695,0.0549,0.0440,0.0363,0.0305,0.0261,0.0226,0.0199,0.0177,0.0160,0.0147],
        'val_loss':   [0.2217,0.1457,0.1059,0.0794,0.0652,0.0508,0.0419,0.0354,0.0319,0.0277,0.0252,0.0216,0.0210,0.0202,0.0188],
        'train_f1':   [0.901,0.942,0.960,0.969,0.975,0.980,0.984,0.986,0.989,0.990,0.991,0.992,0.993,0.994,0.994],
        'val_f1':     [0.926,0.953,0.964,0.973,0.978,0.982,0.985,0.987,0.989,0.991,0.991,0.992,0.993,0.993,0.994],
    },
    'medium': {
        'train_loss': [0.4502,0.3280,0.2398,0.1865,0.1553,0.1349,0.1202,0.1090,0.1002,0.0930,0.0871,0.0821,0.0779,0.0740,0.0707],
        'val_loss':   [0.3844,0.2797,0.2091,0.1659,0.1409,0.1229,0.1091,0.1013,0.0923,0.0859,0.0786,0.0756,0.0694,0.0674,0.0632],
        'train_f1':   [0.805,0.863,0.900,0.922,0.936,0.944,0.950,0.955,0.959,0.962,0.964,0.966,0.968,0.969,0.971],
        'val_f1':     [0.838,0.883,0.914,0.931,0.942,0.950,0.955,0.960,0.963,0.965,0.968,0.970,0.971,0.972,0.974],
    },
    'hard': {
        'train_loss': [0.4370,0.3352,0.2879,0.2578,0.2371,0.2220,0.2104,0.2014,0.1938,0.1875,0.1821,0.1773,0.1732,0.1694,0.1661],
        'val_loss':   [0.3728,0.3066,0.2680,0.2416,0.2209,0.2074,0.1950,0.1895,0.1817,0.1735,0.1683,0.1661,0.1603,0.1593,0.1507],
        'train_f1':   [0.791,0.849,0.872,0.887,0.897,0.904,0.909,0.913,0.917,0.920,0.922,0.924,0.926,0.928,0.929],
        'val_f1':     [0.828,0.863,0.883,0.896,0.904,0.912,0.916,0.919,0.922,0.927,0.929,0.930,0.932,0.933,0.937],
    },
}

for name, d in metrics.items():
    n = len(d['train_loss'])
    assert all(len(v) == n for v in d.values())
    print(name, 'epochs:', n, 'final val f1:', d['val_f1'][-1])



easy epochs: 15 final val f1: 0.994
medium epochs: 15 final val f1: 0.974
hard epochs: 15 final val f1: 0.937


In [4]:
# I plot and save curves (per difficulty)
import matplotlib.pyplot as plt


def plot_curves(name: str, d: dict, out_path: Path):
    epochs = np.arange(1, len(d['train_loss']) + 1)

    fig, ax = plt.subplots(1, 2, figsize=(12, 4))

    ax[0].plot(epochs, d['train_loss'], label='train')
    ax[0].plot(epochs, d['val_loss'], label='val')
    ax[0].set_title(f'{name}: loss vs epoch')
    ax[0].set_xlabel('epoch')
    ax[0].set_ylabel('masked BCE loss')
    ax[0].grid(True, alpha=0.3)
    ax[0].legend()

    ax[1].plot(epochs, d['train_f1'], label='train')
    ax[1].plot(epochs, d['val_f1'], label='val')
    ax[1].set_title(f'{name}: mine-F1 vs epoch')
    ax[1].set_xlabel('epoch')
    ax[1].set_ylabel('mine F1 (masked)')
    ax[1].set_ylim(0.0, 1.0)
    ax[1].grid(True, alpha=0.3)
    ax[1].legend()

    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)


for name, d in metrics.items():
    out = fig_dir / f'task1_{name}_curves.png'
    plot_curves(name, d, out)
    print('wrote', out)



wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task1_easy_curves.png
wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task1_medium_curves.png
wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task1_hard_curves.png


## Task 1 — What I take away

Most gains happen early; after that, improvements taper. Train/val stay close, so I don’t see obvious overfitting. Easy saturates; medium/hard are harder (more mines = more ambiguity).

## Task 1 — Gameplay vs LogicBot (same boards)

I run both bots on the same mine layouts (same first click, then I clone the state).

I report:
- clear rate
- perfect win rate
- avg clicks
- avg mines triggered

I also compute 95% bootstrap confidence intervals.

---

## Task 2 — Bootstrapping plots

I plot:
- perfect win rate by round
- avg_survival by round
- avg_mines_triggered by round
- dataset_samples by round


In [5]:
# Task 1 — I evaluate gameplay: LogicBot vs my NN bot
import random

import torch

# I run Notebook 05 locally, so I default to CPU.
cpu = torch.device('cpu')

from minesweeper.game import GameState, MinesweeperGame
from minesweeper.logic_bot import LogicBot
from models.task1.model import MinePredictor, MinePredictorConfig
from models.task1.policy import select_safest_unrevealed
from models.task2.dataset import _clone_game_fast


TASK1_CKPT_DIR = Path(repo_root) / 'models' / 'task1' / 'checkpoints'

TASK1_DIFFICULTIES = {
    'easy': {'height': 22, 'width': 22, 'num_mines': 50},
    'medium': {'height': 22, 'width': 22, 'num_mines': 80},
    'hard': {'height': 22, 'width': 22, 'num_mines': 100},
}


def load_task1_model(diff_name: str) -> MinePredictor:
    ckpt_path = TASK1_CKPT_DIR / f'task1_{diff_name}.pt'
    if not ckpt_path.exists():
        raise FileNotFoundError(f'Missing checkpoint: {ckpt_path}')

    ckpt = torch.load(ckpt_path, map_location=cpu)
    cfg = MinePredictorConfig(**(ckpt.get('model_cfg') or {}))
    model = MinePredictor(cfg).to(cpu)
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    return model


def _should_continue(g: MinesweeperGame) -> bool:
    gs = g.get_game_state()
    if gs == GameState.PROG:
        return True
    allow = bool(getattr(g, 'allow_mine_triggers', False))
    return bool(allow and gs == GameState.LOST)


def run_logic_bot(game: MinesweeperGame, *, seed: int, max_clicks: int) -> dict:
    bot = LogicBot(game, seed=int(seed))

    clicks = 0
    first_mine_at = None

    while clicks < int(max_clicks) and _should_continue(game):
        prev_m = int(getattr(game, 'mines_triggered', 0) or 0)
        result, action = bot.play_step()

        # LogicBot can emit flag actions; flags aren't clicks and don't change the board.
        if isinstance(action, dict) and action.get('type') == 'flag':
            continue

        if action is not None:
            clicks += 1
            cur_m = int(getattr(game, 'mines_triggered', 0) or 0)
            if first_mine_at is None and prev_m == 0 and cur_m > 0:
                first_mine_at = int(clicks)

        if result in {'Win', 'Done'}:
            break
        if result == 'Lost' and not bool(getattr(game, 'allow_mine_triggers', False)):
            break

    s = game.get_statistics()
    if first_mine_at is None:
        first_mine_at = int(clicks)

    mines = int(s.get('mines_triggered', 0) or 0)
    won = bool(s.get('game_won'))

    return {
        'clicks': int(clicks),
        'first_mine_at': int(first_mine_at),
        'mines_triggered': int(mines),
        'clear': int(won),
        'perfect': int(won and mines == 0),
    }


@torch.no_grad()
def run_nn_bot(game: MinesweeperGame, *, model: MinePredictor, temperature: float, max_clicks: int) -> dict:
    clicks = 0
    first_mine_at = None

    while clicks < int(max_clicks) and _should_continue(game):
        a = select_safest_unrevealed(model, game.get_visible_board(), device=cpu, temperature=float(temperature))
        if a is None:
            break

        prev_m = int(getattr(game, 'mines_triggered', 0) or 0)
        clicks += 1
        _ = game.player_clicks(int(a[0]), int(a[1]), set())
        cur_m = int(getattr(game, 'mines_triggered', 0) or 0)
        if first_mine_at is None and prev_m == 0 and cur_m > 0:
            first_mine_at = int(clicks)

    s = game.get_statistics()
    if first_mine_at is None:
        first_mine_at = int(clicks)

    mines = int(s.get('mines_triggered', 0) or 0)
    won = bool(s.get('game_won'))

    return {
        'clicks': int(clicks),
        'first_mine_at': int(first_mine_at),
        'mines_triggered': int(mines),
        'clear': int(won),
        'perfect': int(won and mines == 0),
    }



In [6]:
# Task 1 — I run the comparison on random boards (same board per episode for both bots)
import numpy as np


def bootstrap_ci(x: np.ndarray, *, n_boot: int = 2000, alpha: float = 0.05, seed: int = 0) -> tuple[float, float, float]:
    x = np.asarray(x, dtype=np.float64)
    if x.size == 0:
        return (float('nan'), float('nan'), float('nan'))

    rng = np.random.default_rng(int(seed))
    n = int(x.size)
    means = []
    for _ in range(int(n_boot)):
        idx = rng.integers(0, n, size=n)
        means.append(float(np.mean(x[idx])))
    means = np.asarray(means, dtype=np.float64)

    lo = float(np.quantile(means, float(alpha) / 2.0))
    hi = float(np.quantile(means, 1.0 - float(alpha) / 2.0))
    return (float(np.mean(x)), lo, hi)


def eval_task1_vs_logic(*, diff_name: str, n_games: int = 200, seed0: int = 0, temperature: float = 1.0) -> dict:
    diff = TASK1_DIFFICULTIES[diff_name]
    h = int(diff['height'])
    w = int(diff['width'])
    m = int(diff['num_mines'])

    model = load_task1_model(diff_name)

    rng = np.random.default_rng(int(seed0))
    seeds = rng.integers(0, 2**31 - 1, size=int(n_games), dtype=np.int64)
    first_rs = rng.integers(0, h, size=int(n_games), dtype=np.int64)
    first_cs = rng.integers(0, w, size=int(n_games), dtype=np.int64)

    out = {
        'logic': {k: [] for k in ['clear', 'perfect', 'clicks', 'first_mine_at', 'mines_triggered']},
        'nn': {k: [] for k in ['clear', 'perfect', 'clicks', 'first_mine_at', 'mines_triggered']},
    }

    max_clicks = int(h) * int(w) + 50

    for i in range(int(n_games)):
        s = int(seeds[i])
        r0 = int(first_rs[i])
        c0 = int(first_cs[i])

        base = MinesweeperGame(height=h, width=w, num_mines=m, seed=int(s))
        setattr(base, 'allow_mine_triggers', True)

        # I initialize actual_board (mines are placed after the first click).
        _ = base.player_clicks(int(r0), int(c0), set())

        g_logic = _clone_game_fast(base)
        g_nn = _clone_game_fast(base)
        setattr(g_logic, 'allow_mine_triggers', True)
        setattr(g_nn, 'allow_mine_triggers', True)

        a = run_logic_bot(g_logic, seed=int(s) + 1337, max_clicks=max_clicks)
        b = run_nn_bot(g_nn, model=model, temperature=float(temperature), max_clicks=max_clicks)

        for k in out['logic'].keys():
            out['logic'][k].append(a[k])
            out['nn'][k].append(b[k])

    # I summarize with bootstrap CIs (mean + 95% interval).
    summary = {}
    for bot_name in ['logic', 'nn']:
        summary[bot_name] = {}
        for k in out[bot_name].keys():
            arr = np.asarray(out[bot_name][k], dtype=np.float64)
            mu, lo, hi = bootstrap_ci(arr, seed=int(seed0) + 77)
            summary[bot_name][k] = {'mean': mu, 'ci_lo': lo, 'ci_hi': hi}

    return {
        'diff': diff,
        'n': int(n_games),
        'raw': out,
        'summary': summary,
    }


N_GAMES = 200  # bump this up for tigheter CIs
TEMP = 1.0     # >1 is more exploratory; <1 is more greedy

results_task1_play = {}
for dn in ['easy', 'medium', 'hard']:
    print('running', dn)
    results_task1_play[dn] = eval_task1_vs_logic(diff_name=dn, n_games=int(N_GAMES), seed0=42, temperature=float(TEMP))

# I print a quick summary
for dn in ['easy', 'medium', 'hard']:
    s = results_task1_play[dn]['summary']
    print('\n', dn)
    print('  logic perfect', s['logic']['perfect']['mean'], 'clear', s['logic']['clear']['mean'], 'mines', s['logic']['mines_triggered']['mean'])
    print('  nn    perfect', s['nn']['perfect']['mean'], 'clear', s['nn']['clear']['mean'], 'mines', s['nn']['mines_triggered']['mean'])



running easy




running medium
running hard

 easy
  logic perfect 0.695 clear 1.0 mines 0.42
  nn    perfect 0.4 clear 1.0 mines 1.31

 medium
  logic perfect 0.28 clear 1.0 mines 1.77
  nn    perfect 0.02 clear 1.0 mines 5.04

 hard
  logic perfect 0.005 clear 1.0 mines 5.815
  nn    perfect 0.01 clear 1.0 mines 6.105


In [7]:
# I plot LogicBot vs Task 1 NN bot (bootstrap 95% CIs)

def _mean_ci(dn: str, bot: str, key: str):
    s = results_task1_play[dn]['summary'][bot][key]
    mu = float(s['mean'])
    lo = float(s['ci_lo'])
    hi = float(s['ci_hi'])
    return mu, (mu - lo), (hi - mu)


def plot_task1_comparison(out_path: Path):
    diffs = ['easy', 'medium', 'hard']
    bots = ['logic', 'nn']

    metrics_to_plot = [
        ('perfect', 'perfect win rate (clear with 0 mines)'),
        ('clear', 'clear rate (finish even if mines triggered)'),
        ('clicks', 'avg clicks'),
        ('mines_triggered', 'avg mines triggered'),
    ]

    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    axes = axes.reshape(-1)

    width = 0.35
    x = np.arange(len(diffs))

    for ax_i, (key, title) in enumerate(metrics_to_plot):
        ax = axes[ax_i]
        for bi, bot in enumerate(bots):
            means = []
            yerr_lo = []
            yerr_hi = []
            for dn in diffs:
                mu, elo, ehi = _mean_ci(dn, bot, key)
                means.append(mu)
                yerr_lo.append(elo)
                yerr_hi.append(ehi)

            ax.bar(x + (bi - 0.5) * width, means, width=width, label=bot)
            ax.errorbar(
                x + (bi - 0.5) * width,
                means,
                yerr=[yerr_lo, yerr_hi],
                fmt='none',
                ecolor='black',
                capsize=3,
                linewidth=1,
            )

        ax.set_title(title)
        ax.set_xticks(x)
        ax.set_xticklabels(diffs)
        ax.grid(True, axis='y', alpha=0.3)
        if key in {'perfect', 'clear'}:
            ax.set_ylim(0.0, 1.0)
        ax.legend()

    fig.suptitle('Task 1: LogicBot vs NN bot (same random boards; 95% bootstrap CI)')
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    fig.savefig(out_path, dpi=200)
    plt.close(fig)
    print('wrote', out_path)


plot_task1_comparison(fig_dir / 'task1_vs_logic_summary.png')



wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task1_vs_logic_summary.png


In [8]:
# Task 2 — I load my bootstrapping results dict (copied from my notebook output)
# I hardcode it here so the plots are reproducible.

results_task2 = {
    'easy': [
        {
            'difficulty': 'easy',
            'round': 0,
            'kind': 'logic_model_actor',
            'n': 80,
            'perfect_win_rate': 0.95,
            'avg_survival': 0.9952145202425132,
            'avg_cells_opened': 434.0,
            'avg_mines_triggered': 0.05,
            'dataset_samples': 1284,
        }
    ],
    'medium': [
        {
            'difficulty': 'medium',
            'round': 0,
            'kind': 'logic_model_actor',
            'n': 80,
            'perfect_win_rate': 0.325,
            'avg_survival': 0.7453056383788143,
            'avg_cells_opened': 404.0,
            'avg_mines_triggered': 1.25,
            'dataset_samples': 3702,
        },
        {
            'difficulty': 'medium',
            'round': 1,
            'kind': 'critic_actor',
            'n': 80,
            'perfect_win_rate': 0.4125,
            'avg_survival': 0.8049724640593279,
            'avg_cells_opened': 404.0,
            'avg_mines_triggered': 1.1375,
            'dataset_samples': 1500,
        },
        {
            'difficulty': 'medium',
            'round': 2,
            'kind': 'critic_actor',
            'n': 80,
            'perfect_win_rate': 0.3875,
            'avg_survival': 0.8278645651605153,
            'avg_cells_opened': 404.0,
            'avg_mines_triggered': 1.0875,
            'dataset_samples': 1500,
        },
    ],
    'hard': [
        {
            'difficulty': 'hard',
            'round': 0,
            'kind': 'logic_model_actor',
            'n': 80,
            'perfect_win_rate': 0.0625,
            'avg_survival': 0.4453307522010162,
            'avg_cells_opened': 384.0,
            'avg_mines_triggered': 3.35,
            'dataset_samples': 4592,
        },
        {
            'difficulty': 'hard',
            'round': 1,
            'kind': 'critic_actor',
            'n': 80,
            'perfect_win_rate': 0.0375,
            'avg_survival': 0.4277399656135703,
            'avg_cells_opened': 384.0,
            'avg_mines_triggered': 4.4625,
            'dataset_samples': 2500,
        },
        {
            'difficulty': 'hard',
            'round': 2,
            'kind': 'critic_actor',
            'n': 80,
            'perfect_win_rate': 0.05,
            'avg_survival': 0.40203666203711547,
            'avg_cells_opened': 384.0,
            'avg_mines_triggered': 4.925,
            'dataset_samples': 2500,
        },
    ],
}

for d, rows in results_task2.items():
    rows = sorted(rows, key=lambda r: int(r['round']))
    print(d, [(r['round'], r['perfect_win_rate'], r['avg_survival'], r['dataset_samples']) for r in rows])



easy [(0, 0.95, 0.9952145202425132, 1284)]
medium [(0, 0.325, 0.7453056383788143, 3702), (1, 0.4125, 0.8049724640593279, 1500), (2, 0.3875, 0.8278645651605153, 1500)]
hard [(0, 0.0625, 0.4453307522010162, 4592), (1, 0.0375, 0.4277399656135703, 2500), (2, 0.05, 0.40203666203711547, 2500)]


In [9]:
# Task 2 — I plot the bootstrapping results (saved into docs/figures)
import matplotlib.pyplot as plt


def plot_task2_metric(metric: str, title: str, out_name: str, *, ylim=None):
    diffs = ['easy', 'medium', 'hard']
    rounds = sorted({int(r['round']) for d in diffs for r in results_task2[d]})

    x = np.arange(len(rounds))
    width = 0.25

    fig, ax = plt.subplots(figsize=(10, 4))

    for i, d in enumerate(diffs):
        rmap = {int(r['round']): r for r in results_task2[d]}
        vals = [float(rmap[rr][metric]) if rr in rmap else np.nan for rr in rounds]
        ax.bar(x + (i - 1) * width, vals, width=width, label=d)

    ax.set_title(title)
    ax.set_xticks(x)
    ax.set_xticklabels([f'round {r}' for r in rounds])
    ax.grid(True, axis='y', alpha=0.3)
    if ylim is not None:
        ax.set_ylim(*ylim)
    ax.legend()

    out_path = fig_dir / out_name
    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)
    print('wrote', out_path)


plot_task2_metric('perfect_win_rate', 'Task 2: perfect win rate by round', 'task2_perfect_win_rate.png', ylim=(0.0, 1.0))
plot_task2_metric('avg_survival', 'Task 2: avg survival by round', 'task2_avg_survival.png', ylim=(0.0, 1.0))
plot_task2_metric('avg_mines_triggered', 'Task 2: avg mines triggered by round', 'task2_avg_mines_triggered.png')
plot_task2_metric('dataset_samples', 'Task 2: dataset sizes by round', 'task2_dataset_sizes.png')



wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task2_perfect_win_rate.png
wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task2_avg_survival.png
wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task2_avg_mines_triggered.png
wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task2_dataset_sizes.png


---

## Task 3 — Thinking longer

I’m checking three things:
- loss vs thinking steps
- win rate vs thinking steps
- heatmap evolution on the same board

v2 training snapshot (from my Colab logs I pasted earlier):
- easy: val loss 0.0176, val F1 0.995 (25 epochs)
- medium: val loss 0.0468, val F1 0.982 (35 epochs)
- hard: val loss 0.1233, val F1 0.949 (40 epochs)

Important note: I couldn’t find / recover my v2 checkpoints locally (Colab disconnected when I fell asleep overnight), so in this repo I only have the **v1** Task 3 checkpoints. The heatmaps below are therefore generated from **v1** only.

Notebook 04 also exports Task 3 eval outputs into `docs/figures/` (JSON + PNG). Below I reload and replot those if they exist.


## Task 3 — Training log transcript (from my Colab run)

I pasted this directly from my Colab output. I’m keeping it here so the loss-curve plots below are reproducible.



In [13]:
# Task 3 — Loss curves (parsed from my training log transcript)
# I use the per-epoch lines I pasted above to plot train/val loss vs epoch, just like Task 1.

import re
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

TASK3_TRAIN_LOG = r"""
Training Task 3 model for easy...
[easy] train_task3: samples=73058 steps_train=4 epochs=25 batch_size=64 seed=0
[easy] epoch 1/25 | train loss 0.3141 acc 0.930 prec 0.955 rec 0.834 f1 0.890 | val loss 0.2545 acc 0.944 prec 0.963 rec 0.869 f1 0.914
[easy] epoch 2/25 | train loss 0.2009 acc 0.957 prec 0.978 rec 0.894 f1 0.934 | val loss 0.1741 acc 0.964 prec 0.982 rec 0.910 f1 0.945
[easy] epoch 3/25 | train loss 0.1449 acc 0.969 prec 0.982 rec 0.926 f1 0.953 | val loss 0.1366 acc 0.972 prec 0.986 rec 0.929 f1 0.957
[easy] epoch 4/25 | train loss 0.1154 acc 0.975 prec 0.985 rec 0.941 f1 0.962 | val loss 0.1153 acc 0.976 prec 0.985 rec 0.942 f1 0.963
[easy] epoch 5/25 | train loss 0.0962 acc 0.979 prec 0.986 rec 0.952 f1 0.968 | val loss 0.0991 acc 0.979 prec 0.984 rec 0.952 f1 0.968
[easy] epoch 6/25 | train loss 0.0816 acc 0.982 prec 0.986 rec 0.960 f1 0.973 | val loss 0.0851 acc 0.981 prec 0.987 rec 0.957 f1 0.972
[easy] epoch 7/25 | train loss 0.0690 acc 0.984 prec 0.986 rec 0.967 f1 0.976 | val loss 0.0718 acc 0.984 prec 0.989 rec 0.964 f1 0.976
[easy] epoch 8/25 | train loss 0.0578 acc 0.986 prec 0.987 rec 0.973 f1 0.980 | val loss 0.0619 acc 0.987 prec 0.991 rec 0.969 f1 0.980
[easy] epoch 9/25 | train loss 0.0483 acc 0.988 prec 0.987 rec 0.978 f1 0.983 | val loss 0.0507 acc 0.989 prec 0.990 rec 0.977 f1 0.983
[easy] epoch 10/25 | train loss 0.0408 acc 0.990 prec 0.988 rec 0.983 f1 0.985 | val loss 0.0434 acc 0.990 prec 0.990 rec 0.981 f1 0.986
[easy] epoch 11/25 | train loss 0.0345 acc 0.991 prec 0.989 rec 0.986 f1 0.987 | val loss 0.0368 acc 0.992 prec 0.989 rec 0.986 f1 0.988
[easy] epoch 12/25 | train loss 0.0292 acc 0.993 prec 0.990 rec 0.989 f1 0.989 | val loss 0.0353 acc 0.992 prec 0.990 rec 0.987 f1 0.988
[easy] epoch 13/25 | train loss 0.0255 acc 0.994 prec 0.991 rec 0.991 f1 0.991 | val loss 0.0309 acc 0.993 prec 0.991 rec 0.989 f1 0.990
[easy] epoch 14/25 | train loss 0.0223 acc 0.994 prec 0.992 rec 0.992 f1 0.992 | val loss 0.0267 acc 0.994 prec 0.992 rec 0.991 f1 0.991
[easy] epoch 15/25 | train loss 0.0197 acc 0.995 prec 0.992 rec 0.993 f1 0.993 | val loss 0.0267 acc 0.994 prec 0.992 rec 0.991 f1 0.991
[easy] epoch 16/25 | train loss 0.0177 acc 0.996 prec 0.993 rec 0.994 f1 0.993 | val loss 0.0247 acc 0.994 prec 0.991 rec 0.992 f1 0.992
[easy] epoch 17/25 | train loss 0.0162 acc 0.996 prec 0.994 rec 0.995 f1 0.994 | val loss 0.0229 acc 0.995 prec 0.994 rec 0.993 f1 0.993
[easy] epoch 18/25 | train loss 0.0147 acc 0.996 prec 0.994 rec 0.995 f1 0.995 | val loss 0.0235 acc 0.996 prec 0.995 rec 0.991 f1 0.993
[easy] epoch 19/25 | train loss 0.0135 acc 0.997 prec 0.994 rec 0.996 f1 0.995 | val loss 0.0210 acc 0.996 prec 0.993 rec 0.994 f1 0.993
[easy] epoch 20/25 | train loss 0.0126 acc 0.997 prec 0.995 rec 0.996 f1 0.995 | val loss 0.0197 acc 0.996 prec 0.994 rec 0.994 f1 0.994
[easy] epoch 21/25 | train loss 0.0118 acc 0.997 prec 0.995 rec 0.996 f1 0.996 | val loss 0.0192 acc 0.996 prec 0.994 rec 0.994 f1 0.994
[easy] epoch 22/25 | train loss 0.0109 acc 0.997 prec 0.995 rec 0.997 f1 0.996 | val loss 0.0196 acc 0.996 prec 0.995 rec 0.994 f1 0.995
[easy] epoch 23/25 | train loss 0.0105 acc 0.997 prec 0.996 rec 0.997 f1 0.996 | val loss 0.0185 acc 0.997 prec 0.995 rec 0.995 f1 0.995
[easy] epoch 24/25 | train loss 0.0098 acc 0.998 prec 0.996 rec 0.997 f1 0.996 | val loss 0.0176 acc 0.996 prec 0.994 rec 0.996 f1 0.995
[easy] epoch 25/25 | train loss 0.0091 acc 0.998 prec 0.996 rec 0.997 f1 0.997 | val loss 0.0176 acc 0.997 prec 0.995 rec 0.996 f1 0.995

Training Task 3 model for medium...
[medium] train_task3: samples=316187 steps_train=4 epochs=35 batch_size=64 seed=0
[medium] epoch 1/35 | train loss 0.4451 acc 0.883 prec 0.943 rec 0.706 f1 0.807 | val loss 0.3785 acc 0.899 prec 0.934 rec 0.762 f1 0.839
[medium] epoch 2/35 | train loss 0.3379 acc 0.910 prec 0.943 rec 0.790 f1 0.860 | val loss 0.3067 acc 0.918 prec 0.941 rec 0.814 f1 0.873
[medium] epoch 3/35 | train loss 0.2819 acc 0.924 prec 0.941 rec 0.833 f1 0.884 | val loss 0.2582 acc 0.929 prec 0.936 rec 0.854 f1 0.893
[medium] epoch 4/35 | train loss 0.2339 acc 0.935 prec 0.936 rec 0.873 f1 0.903 | val loss 0.2082 acc 0.943 prec 0.948 rec 0.886 f1 0.916
[medium] epoch 5/35 | train loss 0.1932 acc 0.946 prec 0.938 rec 0.904 f1 0.921 | val loss 0.1714 acc 0.951 prec 0.937 rec 0.919 f1 0.928
[medium] epoch 6/35 | train loss 0.1639 acc 0.954 prec 0.944 rec 0.924 f1 0.934 | val loss 0.1487 acc 0.958 prec 0.946 rec 0.931 f1 0.939
[medium] epoch 7/35 | train loss 0.1442 acc 0.960 prec 0.949 rec 0.935 f1 0.942 | val loss 0.1326 acc 0.961 prec 0.944 rec 0.944 f1 0.944
[medium] epoch 8/35 | train loss 0.1300 acc 0.964 prec 0.953 rec 0.943 f1 0.948 | val loss 0.1194 acc 0.966 prec 0.952 rec 0.949 f1 0.950
[medium] epoch 9/35 | train loss 0.1192 acc 0.967 prec 0.956 rec 0.949 f1 0.953 | val loss 0.1090 acc 0.970 prec 0.962 rec 0.950 f1 0.956
[medium] epoch 10/35 | train loss 0.1107 acc 0.970 prec 0.959 rec 0.954 f1 0.956 | val loss 0.1022 acc 0.971 prec 0.961 rec 0.956 f1 0.958
[medium] epoch 11/35 | train loss 0.1037 acc 0.972 prec 0.961 rec 0.958 f1 0.959 | val loss 0.0962 acc 0.972 prec 0.958 rec 0.962 f1 0.960
[medium] epoch 12/35 | train loss 0.0979 acc 0.974 prec 0.963 rec 0.961 f1 0.962 | val loss 0.0903 acc 0.975 prec 0.966 rec 0.961 f1 0.964
[medium] epoch 13/35 | train loss 0.0928 acc 0.975 prec 0.965 rec 0.963 f1 0.964 | val loss 0.0859 acc 0.976 prec 0.969 rec 0.962 f1 0.966
[medium] epoch 14/35 | train loss 0.0884 acc 0.976 prec 0.966 rec 0.965 f1 0.966 | val loss 0.0814 acc 0.977 prec 0.970 rec 0.965 f1 0.967
[medium] epoch 15/35 | train loss 0.0846 acc 0.977 prec 0.968 rec 0.967 f1 0.967 | val loss 0.0786 acc 0.978 prec 0.971 rec 0.966 f1 0.969
[medium] epoch 16/35 | train loss 0.0812 acc 0.978 prec 0.969 rec 0.969 f1 0.969 | val loss 0.0766 acc 0.978 prec 0.967 rec 0.970 f1 0.968
[medium] epoch 17/35 | train loss 0.0783 acc 0.979 prec 0.970 rec 0.970 f1 0.970 | val loss 0.0719 acc 0.980 prec 0.970 rec 0.971 f1 0.971
[medium] epoch 18/35 | train loss 0.0756 acc 0.980 prec 0.971 rec 0.971 f1 0.971 | val loss 0.0698 acc 0.981 prec 0.974 rec 0.970 f1 0.972
[medium] epoch 19/35 | train loss 0.0731 acc 0.981 prec 0.972 rec 0.972 f1 0.972 | val loss 0.0676 acc 0.982 prec 0.976 rec 0.971 f1 0.973
[medium] epoch 20/35 | train loss 0.0709 acc 0.981 prec 0.973 rec 0.974 f1 0.973 | val loss 0.0664 acc 0.982 prec 0.977 rec 0.971 f1 0.974
[medium] epoch 21/35 | train loss 0.0688 acc 0.982 prec 0.973 rec 0.975 f1 0.974 | val loss 0.0660 acc 0.982 prec 0.978 rec 0.971 f1 0.974
[medium] epoch 22/35 | train loss 0.0669 acc 0.982 prec 0.974 rec 0.975 f1 0.975 | val loss 0.0624 acc 0.982 prec 0.975 rec 0.975 f1 0.975
[medium] epoch 23/35 | train loss 0.0651 acc 0.983 prec 0.975 rec 0.976 f1 0.975 | val loss 0.0609 acc 0.983 prec 0.974 rec 0.976 f1 0.975
[medium] epoch 24/35 | train loss 0.0635 acc 0.983 prec 0.975 rec 0.977 f1 0.976 | val loss 0.0590 acc 0.984 prec 0.977 rec 0.976 f1 0.977
[medium] epoch 25/35 | train loss 0.0619 acc 0.984 prec 0.976 rec 0.978 f1 0.977 | val loss 0.0573 acc 0.984 prec 0.977 rec 0.977 f1 0.977
[medium] epoch 26/35 | train loss 0.0605 acc 0.984 prec 0.976 rec 0.978 f1 0.977 | val loss 0.0556 acc 0.984 prec 0.976 rec 0.979 f1 0.978
[medium] epoch 27/35 | train loss 0.0592 acc 0.985 prec 0.977 rec 0.979 f1 0.978 | val loss 0.0555 acc 0.984 prec 0.975 rec 0.979 f1 0.977
[medium] epoch 28/35 | train loss 0.0578 acc 0.985 prec 0.977 rec 0.980 f1 0.978 | val loss 0.0535 acc 0.985 prec 0.978 rec 0.979 f1 0.979
[medium] epoch 29/35 | train loss 0.0567 acc 0.985 prec 0.978 rec 0.980 f1 0.979 | val loss 0.0522 acc 0.986 prec 0.980 rec 0.979 f1 0.979
[medium] epoch 30/35 | train loss 0.0555 acc 0.986 prec 0.978 rec 0.980 f1 0.979 | val loss 0.0511 acc 0.986 prec 0.979 rec 0.980 f1 0.979
[medium] epoch 31/35 | train loss 0.0545 acc 0.986 prec 0.979 rec 0.981 f1 0.980 | val loss 0.0510 acc 0.986 prec 0.981 rec 0.979 f1 0.980
[medium] epoch 32/35 | train loss 0.0535 acc 0.986 prec 0.979 rec 0.981 f1 0.980 | val loss 0.0484 acc 0.986 prec 0.977 rec 0.983 f1 0.980
[medium] epoch 33/35 | train loss 0.0525 acc 0.987 prec 0.979 rec 0.982 f1 0.981 | val loss 0.0489 acc 0.986 prec 0.977 rec 0.983 f1 0.980
[medium] epoch 34/35 | train loss 0.0516 acc 0.987 prec 0.980 rec 0.982 f1 0.981 | val loss 0.0473 acc 0.987 prec 0.980 rec 0.982 f1 0.981
[medium] epoch 35/35 | train loss 0.0508 acc 0.987 prec 0.980 rec 0.983 f1 0.981 | val loss 0.0468 acc 0.987 prec 0.982 rec 0.981 f1 0.982

Training Task 3 model for hard...
[hard] train_task3: samples=621444 steps_train=4 epochs=40 batch_size=64 seed=0
[hard] epoch 1/40 | train loss 0.4389 acc 0.859 prec 0.937 rec 0.683 f1 0.790 | val loss 0.3796 acc 0.880 prec 0.937 rec 0.738 f1 0.826
[hard] epoch 2/40 | train loss 0.3441 acc 0.890 prec 0.937 rec 0.768 f1 0.844 | val loss 0.3224 acc 0.895 prec 0.924 rec 0.795 f1 0.855
[hard] epoch 3/40 | train loss 0.2990 acc 0.904 prec 0.933 rec 0.811 f1 0.868 | val loss 0.2849 acc 0.909 prec 0.935 rec 0.823 f1 0.875
[hard] epoch 4/40 | train loss 0.2670 acc 0.915 prec 0.931 rec 0.841 f1 0.884 | val loss 0.2521 acc 0.920 prec 0.939 rec 0.848 f1 0.891
[hard] epoch 5/40 | train loss 0.2438 acc 0.922 prec 0.932 rec 0.862 f1 0.896 | val loss 0.2355 acc 0.921 prec 0.913 rec 0.880 f1 0.896
[hard] epoch 6/40 | train loss 0.2268 acc 0.928 prec 0.934 rec 0.876 f1 0.904 | val loss 0.2140 acc 0.931 prec 0.937 rec 0.881 f1 0.908
[hard] epoch 7/40 | train loss 0.2141 acc 0.932 prec 0.936 rec 0.885 f1 0.910 | val loss 0.2044 acc 0.934 prec 0.936 rec 0.890 f1 0.913
[hard] epoch 8/40 | train loss 0.2043 acc 0.936 prec 0.938 rec 0.892 f1 0.915 | val loss 0.1933 acc 0.938 prec 0.943 rec 0.894 f1 0.918
[hard] epoch 9/40 | train loss 0.1963 acc 0.938 prec 0.940 rec 0.898 f1 0.918 | val loss 0.1909 acc 0.940 prec 0.944 rec 0.897 f1 0.920
[hard] epoch 10/40 | train loss 0.1899 acc 0.940 prec 0.941 rec 0.902 f1 0.921 | val loss 0.1847 acc 0.940 prec 0.937 rec 0.905 f1 0.921
[hard] epoch 11/40 | train loss 0.1844 acc 0.942 prec 0.943 rec 0.906 f1 0.924 | val loss 0.1799 acc 0.942 prec 0.940 rec 0.908 f1 0.924
[hard] epoch 12/40 | train loss 0.1795 acc 0.944 prec 0.944 rec 0.909 f1 0.926 | val loss 0.1732 acc 0.945 prec 0.948 rec 0.908 f1 0.927
[hard] epoch 13/40 | train loss 0.1753 acc 0.945 prec 0.945 rec 0.912 f1 0.928 | val loss 0.1778 acc 0.944 prec 0.948 rec 0.905 f1 0.926
[hard] epoch 14/40 | train loss 0.1716 acc 0.947 prec 0.946 rec 0.914 f1 0.930 | val loss 0.1658 acc 0.947 prec 0.948 rec 0.914 f1 0.930
[hard] epoch 15/40 | train loss 0.1683 acc 0.948 prec 0.947 rec 0.917 f1 0.931 | val loss 0.1599 acc 0.949 prec 0.948 rec 0.918 f1 0.933
[hard] epoch 16/40 | train loss 0.1653 acc 0.949 prec 0.947 rec 0.919 f1 0.933 | val loss 0.1561 acc 0.950 prec 0.951 rec 0.919 f1 0.935
[hard] epoch 17/40 | train loss 0.1625 acc 0.950 prec 0.948 rec 0.920 f1 0.934 | val loss 0.1588 acc 0.949 prec 0.947 rec 0.920 f1 0.933
[hard] epoch 18/40 | train loss 0.1602 acc 0.950 prec 0.949 rec 0.922 f1 0.935 | val loss 0.1517 acc 0.952 prec 0.954 rec 0.921 f1 0.937
[hard] epoch 19/40 | train loss 0.1577 acc 0.951 prec 0.949 rec 0.923 f1 0.936 | val loss 0.1513 acc 0.952 prec 0.952 rec 0.922 f1 0.937
[hard] epoch 20/40 | train loss 0.1558 acc 0.952 prec 0.950 rec 0.924 f1 0.937 | val loss 0.1480 acc 0.953 prec 0.955 rec 0.923 f1 0.938
[hard] epoch 21/40 | train loss 0.1536 acc 0.953 prec 0.951 rec 0.926 f1 0.938 | val loss 0.1523 acc 0.953 prec 0.965 rec 0.913 f1 0.938
[hard] epoch 22/40 | train loss 0.1519 acc 0.953 prec 0.951 rec 0.927 f1 0.939 | val loss 0.1424 acc 0.955 prec 0.959 rec 0.924 f1 0.941
[hard] epoch 23/40 | train loss 0.1500 acc 0.954 prec 0.952 rec 0.928 f1 0.940 | val loss 0.1410 acc 0.955 prec 0.955 rec 0.928 f1 0.941
[hard] epoch 24/40 | train loss 0.1484 acc 0.954 prec 0.952 rec 0.929 f1 0.940 | val loss 0.1412 acc 0.955 prec 0.957 rec 0.927 f1 0.941
[hard] epoch 25/40 | train loss 0.1469 acc 0.955 prec 0.952 rec 0.930 f1 0.941 | val loss 0.1410 acc 0.955 prec 0.954 rec 0.929 f1 0.941
[hard] epoch 26/40 | train loss 0.1456 acc 0.955 prec 0.953 rec 0.931 f1 0.942 | val loss 0.1368 acc 0.956 prec 0.954 rec 0.932 f1 0.943
[hard] epoch 27/40 | train loss 0.1441 acc 0.956 prec 0.953 rec 0.932 f1 0.942 | val loss 0.1385 acc 0.956 prec 0.957 rec 0.928 f1 0.943
[hard] epoch 28/40 | train loss 0.1429 acc 0.956 prec 0.954 rec 0.932 f1 0.943 | val loss 0.1350 acc 0.957 prec 0.956 rec 0.932 f1 0.944
[hard] epoch 29/40 | train loss 0.1417 acc 0.957 prec 0.954 rec 0.933 f1 0.943 | val loss 0.1351 acc 0.957 prec 0.955 rec 0.933 f1 0.944
[hard] epoch 30/40 | train loss 0.1405 acc 0.957 prec 0.954 rec 0.934 f1 0.944 | val loss 0.1337 acc 0.958 prec 0.961 rec 0.930 f1 0.945
[hard] epoch 31/40 | train loss 0.1394 acc 0.957 prec 0.955 rec 0.934 f1 0.944 | val loss 0.1320 acc 0.958 prec 0.958 rec 0.933 f1 0.945
[hard] epoch 32/40 | train loss 0.1384 acc 0.958 prec 0.955 rec 0.935 f1 0.945 | val loss 0.1294 acc 0.959 prec 0.958 rec 0.935 f1 0.946
[hard] epoch 33/40 | train loss 0.1373 acc 0.958 prec 0.955 rec 0.936 f1 0.945 | val loss 0.1299 acc 0.959 prec 0.958 rec 0.934 f1 0.946
[hard] epoch 34/40 | train loss 0.1365 acc 0.958 prec 0.955 rec 0.936 f1 0.946 | val loss 0.1277 acc 0.959 prec 0.957 rec 0.937 f1 0.947
[hard] epoch 35/40 | train loss 0.1355 acc 0.959 prec 0.956 rec 0.937 f1 0.946 | val loss 0.1277 acc 0.960 prec 0.960 rec 0.935 f1 0.947
[hard] epoch 36/40 | train loss 0.1346 acc 0.959 prec 0.956 rec 0.937 f1 0.946 | val loss 0.1292 acc 0.959 prec 0.958 rec 0.936 f1 0.946
[hard] epoch 37/40 | train loss 0.1338 acc 0.959 prec 0.956 rec 0.938 f1 0.947 | val loss 0.1270 acc 0.959 prec 0.952 rec 0.941 f1 0.946
[hard] epoch 38/40 | train loss 0.1331 acc 0.959 prec 0.956 rec 0.938 f1 0.947 | val loss 0.1284 acc 0.959 prec 0.955 rec 0.938 f1 0.946
[hard] epoch 39/40 | train loss 0.1322 acc 0.960 prec 0.957 rec 0.939 f1 0.948 | val loss 0.1229 acc 0.962 prec 0.964 rec 0.936 f1 0.950
[hard] epoch 40/40 | train loss 0.1315 acc 0.960 prec 0.957 rec 0.939 f1 0.948 | val loss 0.1233 acc 0.961 prec 0.961 rec 0.938 f1 0.949
""".strip()


def parse_task3_epoch_losses(log_text: str):
    # I parse per-epoch train/val loss for each difficulty.
    pat = re.compile(
        r"^(?:\[(?P<diff>[^\]]+)\]\s*)?epoch\s+(?P<epoch>\d+)\/(?P<epochs>\d+)\s*\|\s*"
        r"train\s+loss\s+(?P<tr>\d+(?:\.\d+)?)\b.*?\|\s*"
        r"val\s+loss\s+(?P<va>\d+(?:\.\d+)?)\b",
        flags=re.IGNORECASE | re.MULTILINE,
    )

    out = {}
    for m in pat.finditer(log_text or ''):
        diff = (m.group('diff') or '').strip().lower()
        if diff in {'e', 'ez'}:
            diff = 'easy'
        if diff in {'m', 'med'}:
            diff = 'medium'
        if diff in {'h'}:
            diff = 'hard'
        if diff not in {'easy', 'medium', 'hard'}:
            diff = diff or 'unknown'

        d = out.setdefault(diff, {'epoch': [], 'train_loss': [], 'val_loss': []})
        d['epoch'].append(int(m.group('epoch')))
        d['train_loss'].append(float(m.group('tr')))
        d['val_loss'].append(float(m.group('va')))

    for diff, d in out.items():
        order = np.argsort(np.asarray(d['epoch'], dtype=np.int64))
        for k in ['epoch', 'train_loss', 'val_loss']:
            arr = np.asarray(d[k])
            d[k] = [arr[i].item() for i in order]

    return out


def plot_task3_loss_curves_from_logs(log_text: str, *, out_path: Path) -> dict:
    parsed = parse_task3_epoch_losses(log_text)
    if not parsed:
        raise RuntimeError('No per-epoch loss lines found in TASK3_TRAIN_LOG.')

    diffs = [d for d in ['easy', 'medium', 'hard'] if d in parsed]
    if not diffs:
        diffs = list(parsed.keys())

    fig, axes = plt.subplots(1, len(diffs), figsize=(5 * len(diffs), 4), squeeze=False)
    axes = axes.reshape(-1)

    for ax, dn in zip(axes, diffs):
        d = parsed[dn]
        ax.plot(d['epoch'], d['train_loss'], marker='o', label='train')
        ax.plot(d['epoch'], d['val_loss'], marker='o', label='val')
        ax.set_title(dn)
        ax.set_xlabel('epoch')
        ax.set_ylabel('masked BCE loss')
        ax.grid(True, alpha=0.3)
        ax.legend()

    fig.suptitle('Task 3: train/val loss vs epoch (parsed from my training logs)')
    fig.tight_layout(rect=[0, 0, 1, 0.90])
    fig.savefig(out_path, dpi=200)
    plt.close(fig)
    print('wrote', out_path)

    return parsed


parsed = plot_task3_loss_curves_from_logs(TASK3_TRAIN_LOG, out_path=fig_dir / 'task3_loss_curves_from_logs.png')

# I print a quick summary like I did for Task 1.
for dn in ['easy', 'medium', 'hard']:
    if dn not in parsed:
        continue
    d = parsed[dn]
    print(dn, 'epochs:', len(d['epoch']), 'final train loss:', d['train_loss'][-1], 'final val loss:', d['val_loss'][-1])

# I also save one small per-difficulty loss curve.
for dn in ['easy', 'medium', 'hard']:
    if dn not in parsed:
        continue
    d = parsed[dn]
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(d['epoch'], d['train_loss'], marker='o', label='train')
    ax.plot(d['epoch'], d['val_loss'], marker='o', label='val')
    ax.set_title(f'Task 3 ({dn}): loss vs epoch')
    ax.set_xlabel('epoch')
    ax.set_ylabel('masked BCE loss')
    ax.grid(True, alpha=0.3)
    ax.legend()
    out = fig_dir / f'task3_{dn}_loss_curve.png'
    fig.tight_layout()
    fig.savefig(out, dpi=200)
    plt.close(fig)
    print('wrote', out)



wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task3_loss_curves_from_logs.png
easy epochs: 25 final train loss: 0.0091 final val loss: 0.0176
medium epochs: 35 final train loss: 0.0508 final val loss: 0.0468
hard epochs: 40 final train loss: 0.1315 final val loss: 0.1233
wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task3_easy_loss_curve.png
wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task3_medium_loss_curve.png
wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task3_hard_loss_curve.png


In [14]:
# Task 3 — Heatmap evolution (generated here from v1 checkpoints)
# I regenerate this figure directly from the v1 `.pt` checkpoints that exist in this repo.

import numpy as np
import matplotlib.pyplot as plt

try:
    import torch
except Exception as e:
    raise RuntimeError('This section requires PyTorch. Install requirements and restart the kernel.') from e

from pathlib import Path

from minesweeper.game import MinesweeperGame
from minesweeper.logic_bot import LogicBot
from models.task1.encoding import visible_to_int8
from models.task3.model import ThinkingMinePredictor, ThinkingMinePredictorConfig


def _make_partially_revealed_board(*, height: int, width: int, num_mines: int, seed: int, click_steps: int = 25):
    # I generate one deterministic partially-revealed board so step-by-step heatmaps are comparable.
    game = MinesweeperGame(height=int(height), width=int(width), num_mines=int(num_mines), seed=int(seed))
    setattr(game, 'allow_mine_triggers', True)

    # First click initializes the hidden board and usually opens a region.
    r0 = int(height // 2)
    c0 = int(width // 2)
    _ = game.player_clicks(r0, c0, set())

    bot = LogicBot(game, seed=int(seed) + 1337)

    clicks = 0
    for _ in range(int(click_steps) * 3):
        result, action = bot.play_step()

        # Count only click actions (flags don’t change the visible board).
        if isinstance(action, dict) and action.get('type') != 'flag' and 'pos' in action:
            clicks += 1

        if result in {'Win', 'Done'}:
            break
        if clicks >= int(click_steps):
            break

    return game.get_visible_board()


def _load_task3_v1(diff_name: str, *, device: torch.device):
    ckpt_path = Path(repo_root) / 'models' / 'task3' / 'checkpoints' / f'task3_{diff_name}_v1_baseline_15ep_loss_heatmap.pt'
    if not ckpt_path.exists():
        raise FileNotFoundError(f'Missing v1 Task 3 checkpoint: {ckpt_path}')

    ckpt = torch.load(ckpt_path, map_location=device)
    mcfg = ckpt.get('model_cfg') or {}
    cfg = ThinkingMinePredictorConfig(**mcfg)
    model = ThinkingMinePredictor(cfg).to(device)
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    return model, cfg, ckpt_path


TASK3_PRESETS = {
    'easy': {'height': 22, 'width': 22, 'num_mines': 50},
    'medium': {'height': 22, 'width': 22, 'num_mines': 80},
    'hard': {'height': 22, 'width': 22, 'num_mines': 100},
}

# I confirm v2 checkpoints are not present (this is why I stick to v1 here).
for dn in ['easy', 'medium', 'hard']:
    v2_any = list((Path(repo_root) / 'models' / 'task3' / 'checkpoints').glob(f'task3_{dn}_v2*.pt'))
    if not v2_any:
        print(f'[task3/{dn}] v2 checkpoint not found locally → generating heatmaps from v1 only')


def save_task3_v1_heatmap_grid(diff_name: str, *, steps: int = 8, seed: int = 0):
    preset = TASK3_PRESETS[diff_name]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model, cfg, ckpt_path = _load_task3_v1(diff_name, device=device)

    # One fixed partially revealed board.
    visible = _make_partially_revealed_board(
        height=int(preset['height']),
        width=int(preset['width']),
        num_mines=int(preset['num_mines']),
        seed=int(seed),
        click_steps=25,
    )

    x = visible_to_int8(visible)  # (H,W) in [-1..9]
    xt = torch.from_numpy(x).to(device=device).to(torch.int64).unsqueeze(0)

    _, per_step = model(xt, steps=int(steps), return_all=True)

    fig, axes = plt.subplots(2, 4, figsize=(14, 7))
    axes = axes.reshape(-1)
    for i in range(8):
        probs = torch.sigmoid(per_step[i]).squeeze(0).detach().cpu().numpy()
        ax = axes[i]
        im = ax.imshow(probs, vmin=0.0, vmax=1.0, cmap='viridis')
        ax.set_title(f'step {i+1}')
        ax.axis('off')

    fig.colorbar(im, ax=axes.tolist(), fraction=0.02, pad=0.02)
    fig.suptitle(f'Task 3 ({diff_name}, v1): heatmaps as I let the model think longer\nckpt={ckpt_path.name}')
    fig.tight_layout(rect=[0, 0, 1, 0.92])

    out = fig_dir / f'task3_{diff_name}_v1_heatmaps_steps_1_to_8.png'
    fig.savefig(out, dpi=200)
    plt.close(fig)
    print('wrote', out)


for dn in ['easy', 'medium', 'hard']:
    save_task3_v1_heatmap_grid(dn, steps=8, seed=0)



[task3/easy] v2 checkpoint not found locally → generating heatmaps from v1 only
[task3/medium] v2 checkpoint not found locally → generating heatmaps from v1 only
[task3/hard] v2 checkpoint not found locally → generating heatmaps from v1 only


  fig.tight_layout(rect=[0, 0, 1, 0.92])


wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task3_easy_v1_heatmaps_steps_1_to_8.png
wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task3_medium_v1_heatmaps_steps_1_to_8.png
wrote /Users/rachitasaini/Desktop/Rutgers/Fall 2026/Deep Learning 01-198-462/minesweeper-dl-submission/docs/figures/task3_hard_v1_heatmaps_steps_1_to_8.png
