# Notebook 03 — Task 2 (Playing by Move Prediction)

This notebook is my implementation of **Task 2**.

The PDF describes Task 2 for **medium difficulty**, but I run the same actor/critic setup for:

- Easy (22×22, 50 mines)
- Medium (22×22, 80 mines)
- Hard (22×22, 100 mines)

## What my critic predicts
For a given visible board state and a candidate click (row, col), my critic predicts a **survival score** in **[0, 1]**.

I define survival for a full playthrough as:

- Let **T** = total number of clicks taken until the game reaches 100% progress (all safe cells opened), i.e. `WON` if 0 mines were triggered or `DONE` if >=1 mine was triggered.
- Let **F** = the step index (1-indexed) of the **first** mine trigger.
  - If no mine ever triggers, I set **F = T**.
- Then `survival` = F / T.

So:
- survival = 1.0 means the first mine (if any) happened only at the end (and in the perfect case: no mines were triggered at all).
- survival = 0.6 means the first mine happened at step 60 of a 100-click completion.

When I report `avg_survival`, it is the **mean of this per-game survival score** over the evaluation episodes.

## What I optimize in the actor
- If LogicBot has any provably-safe moves, I use the network to **choose the best safe move** (this is where the model can pick a higher-survival safe click than the LogicBot’s default choice).
- Only when LogicBot is forced to guess do I use the network to choose among guess candidates.

For any candidate (row, col), I score it with:
- score = predicted_survival - mine_penalty * P(mine)

## Metrics I report
- `perfect_win_rate`: fraction of games that finish with 100% progress **and** `mines_triggered == 0`.
- `avg_survival`: average of \(F/T\) described above.

Assumption: I already unzipped my project so the repo lives at `/content/repo/`.


In [None]:
# Colab installs
# NOTE: I avoid re-installing torch in Colab because it can create checkpoint loading issues
# if the runtime's torch version changes mid-session.
%pip install -q numpy tqdm matplotlib


In [None]:
# Repo root
import sys
from pathlib import Path

repo_root = Path('/content/repo')

# Sometimes zip extraction creates one extra top-level folder; if so, I step into it.
if not ((repo_root / 'minesweeper').exists() and (repo_root / 'models').exists()):
    kids = [p for p in repo_root.iterdir() if p.is_dir()]
    if len(kids) == 1:
        repo_root = kids[0]

if not ((repo_root / 'minesweeper').exists() and (repo_root / 'models').exists()):
    raise FileNotFoundError(f'Bad repo_root: {repo_root}')

sys.path.insert(0, str(repo_root))
print('Repo root:', repo_root)

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU:', torch.cuda.get_device_name(0))


In [None]:
# Step 1 — Task 2 data + policy utilities
#
# The Task 2 instructions describe two phases:
# 1) Learn to predict a survival score in [0,1] for a candidate click.
#    I define survival as (step of first mine trigger) / (total steps to reach 100% progress).
# 2) Use that network as an Actor on forced-guess states (when LogicBot has no safe move).
#    Then train a Critic on the Actor's own guesses and bootstrap.

import json
import random
import time
from pathlib import Path

import numpy as np
import torch

from minesweeper.game import GameState, MinesweeperGame
from minesweeper.logic_bot import LogicBot
from models.task1.encoding import visible_to_int8
from models.task2.dataset import _clone_game_fast
from models.task2.policy import (
    actor_choose_click_value_map as _actor_choose_click_value_map,
    allowed_coords_from_logic as _allowed_coords_from_logic,
    logic_infer_sets as _logic_infer_sets,
)
from models.task2.value_map_model import BoardValuePredictor, BoardValuePredictorConfig

# I keep evaluation deterministic and use exploration only during collection.
EPS_COLLECT = 0.05
TOPK_COLLECT = 5
EPS_EVAL = 0.0
TOPK_EVAL = 1

# I default to scoring ALL candidates ("every possible move"), but keep this knob for speed.
MAX_CANDIDATES = None  # set to e.g. 128 if you need more speed

# Optional stability: use LogicBot inference as a mask.
USE_LOGIC_MASK = True

# Difficulty-aware choice: on easy, I keep the actor "guess-only" (never override LogicBot safe moves).
# On medium/hard, I let the model rank safe moves too.
USE_MODEL_ON_SAFE_MOVES = True

# I combine value + mine-prob for action selection: score = value - mine_penalty * P(mine).
# I keep mine_penalty a bit stronger on hard to reduce catastrophic guess failures.
MINE_PENALTY = 4.0

DIFFICULTIES = {
    'easy': {'height': 22, 'width': 22, 'num_mines': 50},
    'medium': {'height': 22, 'width': 22, 'num_mines': 80},
    'hard': {'height': 22, 'width': 22, 'num_mines': 100},
}

DATA_DIR = Path(repo_root) / 'models' / 'task2' / 'datasets'
DATA_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR = Path(repo_root) / 'models' / 'task2' / 'checkpoints'
CKPT_DIR.mkdir(parents=True, exist_ok=True)

# Note: I keep LogicBot inference + actor move-selection in `models/task2/policy.py`.
# (Imported above as `_logic_infer_sets`, `_allowed_coords_from_logic`, and `_actor_choose_click_value_map`.)

@torch.no_grad()
def actor_choose_click(*, model: BoardValuePredictor, game: MinesweeperGame, bot: LogicBot, seed: int, epsilon: float, top_k: int):
    return _actor_choose_click_value_map(
        model=model,
        game=game,
        bot=bot,
        device=device,
        seed=int(seed),
        mine_penalty=float(MINE_PENALTY),
        epsilon=float(epsilon),
        top_k=int(top_k),
        use_logic_mask=bool(USE_LOGIC_MASK),
        use_model_on_safe_moves=bool(USE_MODEL_ON_SAFE_MOVES),
    )


def collect_dataset_for_policy(
    *,
    diff: dict,
    rollout_policy: str,
    model: BoardValuePredictor | None,
    num_games: int,
    seed: int,
    max_steps: int = 512,
    states_per_game: int = 24,
    record_prob: float = 0.25,
    actions_per_state: int = 2,
    target_samples: int | None = None,
    max_games: int | None = None,
) -> dict:
    """Collect sparse (s,a)->(survival_ratio, mine_clicked) supervision.

    Forced-guess samples can be rare, so for actor rounds I sometimes want to collect until I
    hit a target sample count.

    - If target_samples is None: I collect for exactly num_games episodes.
    - If target_samples is set: I keep playing until I reach that many samples, up to max_games.
      (If max_games is None, I use a conservative safety cap.)
    """
    rng = random.Random(int(seed))

    h = int(diff['height'])
    w = int(diff['width'])
    m = int(diff['num_mines'])

    xs = []
    ars = []
    y_survival = []
    y_mine = []
    ep_ids = []

    t0 = time.perf_counter()
    last_print = t0
    tgt = None
    try:
        tgt = int(target_samples) if target_samples is not None else None
    except Exception:
        tgt = None

    # If I'm targeting a sample count, I may need more than num_games episodes.
    # I keep a cap so collection can't run forever.
    if tgt is None:
        games_cap = int(num_games)
    else:
        if max_games is None:
            games_cap = max(int(num_games), int(num_games) * 10)
        else:
            games_cap = max(int(num_games), int(max_games))

    print(
        f"[dataset] start | mines={m} games<={int(games_cap)} target_samples={tgt} max_steps={int(max_steps)} "
        f"states/game~{int(states_per_game)} record_prob={float(record_prob):.3f} actions/state={int(actions_per_state)} "
        f"policy={str(rollout_policy)} eps_collect={float(EPS_COLLECT):.3f} topk_collect={int(TOPK_COLLECT)} mask={bool(USE_LOGIC_MASK)}"
    )

    def _rollout_step(g: MinesweeperGame, bot: LogicBot, *, policy: str, model: BoardValuePredictor | None, seed_local: int, step_i: int):
        kind = str(policy).strip().lower()
        if kind == 'logic':
            _logic_infer_sets(bot)
            a = bot.select_action()
            return None if a is None else (int(a[0]), int(a[1]))
        if model is None:
            raise ValueError('rollout_policy!=logic needs a model')
        return actor_choose_click(model=model, game=g, bot=bot, seed=int(seed_local) + int(step_i), epsilon=float(EPS_COLLECT), top_k=int(TOPK_COLLECT))

    def _rollout_to_end(
        g2: MinesweeperGame,
        *,
        policy: str,
        model2: BoardValuePredictor | None,
        seed_local: int,
        max_extra_steps: int,
        steps0: int = 0,
        first_mine_at0: int | None = None,
    ):
        """Roll forward until DONE/WON or step cap, even if mines are triggered.

        This returns metrics over the *entire* trajectory length, including any already-taken
        prefix of length steps0.
        """
        bot2 = LogicBot(g2, seed=int(seed_local))
        steps2 = 0
        first_mine_at = int(first_mine_at0) if first_mine_at0 is not None else None

        while steps2 < int(max_extra_steps):
            gs = g2.get_game_state()
            if gs == GameState.PROG:
                pass
            else:
                # Continue-after-mine mode keeps state LOST while still allowing more clicks.
                if not (bool(getattr(g2, 'allow_mine_triggers', False)) and gs == GameState.LOST):
                    break

            a2 = _rollout_step(g2, bot2, policy=policy, model=model2, seed_local=seed_local, step_i=steps2)
            if a2 is None:
                break

            prev_mines = int(getattr(g2, 'mines_triggered', 0) or 0)
            steps2 += 1
            _ = g2.player_clicks(int(a2[0]), int(a2[1]), set())
            cur_mines = int(getattr(g2, 'mines_triggered', 0) or 0)
            if first_mine_at is None and prev_mines == 0 and cur_mines > 0:
                first_mine_at = int(steps0) + int(steps2)

        total_steps = int(steps0) + int(steps2)
        if total_steps <= 0:
            return (0, 0.0)

        fm = int(first_mine_at) if first_mine_at is not None else int(total_steps)
        survival = float(fm) / float(total_steps)
        return (total_steps, float(survival))

    states_cap = max(1, int(states_per_game))
    rec_p = float(record_prob)

    for ep in range(int(games_cap)):
        if tgt is not None and int(len(xs)) >= int(tgt):
            break
        now = time.perf_counter()
        if (ep == 0) or ((now - last_print) > 6.0):
            sps = float(len(xs) / max(1e-6, (now - t0)))
            print(f"[dataset] ep {ep}/{int(games_cap)} | samples={len(xs)} ({sps:.1f} samples/s)")
            last_print = now

        game_seed = rng.randint(0, 2**31 - 1)
        g = MinesweeperGame(height=h, width=w, num_mines=m, seed=int(game_seed))
        # I continue after mines trigger so I can measure "when the first mine happens" vs full completion.
        setattr(g, 'allow_mine_triggers', True)

        first = (rng.randrange(h), rng.randrange(w))
        g.player_clicks(int(first[0]), int(first[1]), set())

        bot = LogicBot(g, seed=int(game_seed))

        steps = 0
        recorded = 0
        while steps < int(max_steps):
            gs = g.get_game_state()
            if gs == GameState.PROG:
                pass
            else:
                if not (bool(getattr(g, 'allow_mine_triggers', False)) and gs == GameState.LOST):
                    break

            visible = g.get_visible_board()

            safe_coords = None
            guess_coords = None
            if bool(USE_LOGIC_MASK):
                safe_coords, guess_coords = _allowed_coords_from_logic(bot, g)
            forced_guess = (safe_coords is None) or (int(getattr(safe_coords, 'shape', [0])[0]) == 0)

            a0 = _rollout_step(g, bot, policy=rollout_policy, model=model, seed_local=game_seed, step_i=steps)
            if a0 is None:
                break

            is_logic_phase = (str(rollout_policy).strip().lower() == 'logic')
            mines0 = (int(getattr(g, 'mines_triggered', 0) or 0) == 0)

            # I always record forced-guess states (they're rare + highest-signal).
            # For the logic phase, I subsample states with record_prob.
            if bool(is_logic_phase):
                do_record = (recorded < states_cap) and mines0 and (rng.random() < rec_p)
            else:
                do_record = (recorded < states_cap) and mines0 and bool(forced_guess)
            if do_record:
                x_int8 = visible_to_int8(visible)

                # Candidate set: chosen action + a few alternatives.
                # - In the logic phase, I consider all unrevealed cells so the model can learn to
                #   prefer better safe moves than the LogicBot's default safe move.
                # - In actor rounds, I focus supervision on guess candidates.
                if bool(is_logic_phase):
                    unrevealed = [(rr, cc) for rr in range(h) for cc in range(w) if visible[rr][cc] == 'E']
                else:
                    if guess_coords is not None and int(getattr(guess_coords, 'shape', [0])[0]) > 0:
                        unrevealed = [(int(rr), int(cc)) for (rr, cc) in guess_coords.tolist()]
                    else:
                        unrevealed = [(rr, cc) for rr in range(h) for cc in range(w) if visible[rr][cc] == 'E']

                candidates = [a0]
                pool = [rc for rc in unrevealed if rc != a0]
                if len(pool) > 0 and int(actions_per_state) > 1:
                    kk = min(len(pool), int(actions_per_state) - 1)
                    candidates.extend(rng.sample(pool, k=kk))

                for (rr, cc) in candidates:
                    if tgt is not None and int(len(xs)) >= int(tgt):
                        break
                    g2 = _clone_game_fast(g)
                    setattr(g2, 'allow_mine_triggers', True)

                    prev_m = int(getattr(g2, 'mines_triggered', 0) or 0)
                    res0 = g2.player_clicks(int(rr), int(cc), set())
                    cur_m = int(getattr(g2, 'mines_triggered', 0) or 0)

                    # Mine label is for the CLICK itself.
                    ym = 1.0 if (cur_m > prev_m) else 0.0

                    # Survival ratio is defined by when the FIRST mine happens (relative to completion).
                    # This candidate click is step 1 of the trajectory we label.
                    first_mine_at0 = 1 if (prev_m == 0 and cur_m > 0) else None

                    steps_total, surv = _rollout_to_end(
                        g2,
                        policy=rollout_policy,
                        model2=model,
                        seed_local=int(game_seed) + 999,
                        max_extra_steps=int(max_steps) - int(steps) - 1,
                        steps0=1,
                        first_mine_at0=first_mine_at0,
                    )

                    ys = float(surv)

                    xs.append(x_int8)
                    ars.append((int(rr), int(cc)))
                    y_survival.append(float(ys))
                    y_mine.append(float(ym))
                    ep_ids.append(int(ep))

                recorded += 1

            steps += 1
            _ = g.player_clicks(int(a0[0]), int(a0[1]), set())

    return {
        'x_visible': np.stack(xs).astype(np.int8) if xs else np.zeros((0, h, w), dtype=np.int8),
        'action_rc': np.asarray(ars, dtype=np.int16) if ars else np.zeros((0, 2), dtype=np.int16),
        'y_survival': np.asarray(y_survival, dtype=np.float32) if y_survival else np.zeros((0,), dtype=np.float32),
        'y_mine': np.asarray(y_mine, dtype=np.float32) if y_mine else np.zeros((0,), dtype=np.float32),
        'episode_id': np.asarray(ep_ids, dtype=np.int32) if ep_ids else np.zeros((0,), dtype=np.int32),
    }


def save_dataset_npz(path: Path, data: dict, meta: dict):
    path.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(path, **data, meta_json=json.dumps(meta))


def load_dataset_npz(path: Path) -> dict:
    z = np.load(path, allow_pickle=False)
    out = {k: z[k] for k in z.files if k != 'meta_json'}
    out['meta'] = json.loads(str(z['meta_json'].tolist()))
    return out


In [None]:
# Step 2 — Train a Task 2 value-map model (critic)
#
# This model outputs:
# - value_map[s]   : predicted survival score in [0,1] if we click each cell
# - mine_logit_map : predicted P(mine | click)
#
# I train on sparse (state, action) supervision by gathering the predicted values/logits at the
# labeled action coordinates.

from dataclasses import asdict

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from models.metrics import regression_metrics
from models.task2.value_map_model import BoardValuePredictor, BoardValuePredictorConfig


class _Task2SparseDataset(Dataset):
    def __init__(self, d: dict, indices: np.ndarray):
        self.x = d['x_visible'][indices]
        self.a = d['action_rc'][indices]
        self.y_surv = d['y_survival'][indices]
        self.y_mine = d.get('y_mine', np.zeros_like(self.y_surv))[indices]

    def __len__(self):
        return int(self.x.shape[0])

    def __getitem__(self, idx: int):
        return {
            'x': torch.from_numpy(self.x[idx]).to(torch.int64),
            'a': torch.from_numpy(self.a[idx]).to(torch.int64),
            'y_survival': torch.tensor(float(self.y_surv[idx]), dtype=torch.float32),
            'y_mine': torch.tensor(float(self.y_mine[idx]), dtype=torch.float32),
        }


def _split_by_episode(episode_id: np.ndarray, *, val_frac: float, seed: int) -> tuple[np.ndarray, np.ndarray]:
    episode_id = np.asarray(episode_id).astype(np.int64)
    uniq = np.unique(episode_id)
    rng = np.random.default_rng(int(seed))
    rng.shuffle(uniq)

    n_val_eps = max(1, int(round(len(uniq) * float(val_frac))))
    val_eps = set(int(x) for x in uniq[:n_val_eps])

    idx = np.arange(len(episode_id), dtype=np.int64)
    val_idx = idx[np.array([int(e) in val_eps for e in episode_id], dtype=bool)]
    train_idx = idx[np.array([int(e) not in val_eps for e in episode_id], dtype=bool)]
    return train_idx, val_idx


def train_value_map_model(
    *,
    data: dict,
    cfg: BoardValuePredictorConfig,
    epochs: int = 20,
    batch_size: int = 64,
    lr: float = 3e-4,
    weight_decay: float = 1e-2,
    val_frac: float = 0.2,
    seed: int = 0,
    patience: int = 3,
    mine_loss_weight: float = 1.0,
):
    use_cuda = (device.type == 'cuda')

    n = int(data['x_visible'].shape[0])
    n_eps = int(len(np.unique(np.asarray(data['episode_id']).astype(np.int64))))
    print(
        f"[train] samples={n} episodes={n_eps} epochs={int(epochs)} batch_size={int(batch_size)} lr={float(lr)} "
        f"mine_w={float(mine_loss_weight)} patience={int(patience)}"
    )

    if use_cuda:
        try:
            torch.backends.cuda.matmul.fp32_precision = 'tf32'
            torch.backends.cudnn.conv.fp32_precision = 'tf32'
        except Exception:
            pass

    train_idx, val_idx = _split_by_episode(data['episode_id'], val_frac=val_frac, seed=seed)

    bs = int(batch_size)
    if use_cuda:
        try:
            if 'A100' in torch.cuda.get_device_name(0):
                bs = max(bs, 128)
        except Exception:
            pass

    num_workers = 4 if use_cuda else 0
    dl_common = dict(num_workers=int(num_workers), pin_memory=bool(use_cuda))
    if int(num_workers) > 0:
        dl_common['persistent_workers'] = True

    train_loader = DataLoader(_Task2SparseDataset(data, train_idx), batch_size=int(bs), shuffle=True, **dl_common)
    val_loader = DataLoader(_Task2SparseDataset(data, val_idx), batch_size=int(bs), shuffle=False, **dl_common)

    print(
        f"[train] split | train_samples={len(train_idx)} val_samples={len(val_idx)} "
        f"bs={int(bs)} num_workers={int(num_workers)} amp={bool(use_cuda)}"
    )

    model = BoardValuePredictor(cfg).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=float(lr), weight_decay=float(weight_decay))

    use_amp = bool(use_cuda)
    scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

    best_rmse = float('inf')
    best_state = None
    best_epoch = 0
    bad_epochs = 0

    def _gather(map_hw: torch.Tensor, a_rc: torch.Tensor) -> torch.Tensor:
        # map_hw: (B,H,W), a_rc: (B,2)
        r = a_rc[:, 0].clamp(0, map_hw.shape[1] - 1)
        c = a_rc[:, 1].clamp(0, map_hw.shape[2] - 1)
        return map_hw[torch.arange(map_hw.shape[0], device=map_hw.device), r, c]

    for epoch in range(1, int(epochs) + 1):
        model.train()
        tr = 0.0
        tr_n = 0

        for batch_i, batch in enumerate(train_loader):
            x = batch['x'].to(device, non_blocking=use_cuda)
            a = batch['a'].to(device, non_blocking=use_cuda)
            y_surv = batch['y_survival'].to(device, non_blocking=use_cuda)
            y_mine = batch['y_mine'].to(device, non_blocking=use_cuda)

            opt.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type='cuda', enabled=use_amp):
                value_map, mine_logit_map = model(x)
                pred_surv = torch.sigmoid(_gather(value_map, a))
                pred_mine_logit = _gather(mine_logit_map, a)

                # Survival regression (target in [0,1])
                loss_steps = F.smooth_l1_loss(pred_surv, y_surv)

                # Mine prediction (auxiliary)
                # Mine clicks are rare but high-impact, so I upweight them.
                w_m = 1.0 + 8.0 * y_mine
                loss_mine = (F.binary_cross_entropy_with_logits(pred_mine_logit, y_mine, reduction='none') * w_m).mean()

                loss = loss_steps + float(mine_loss_weight) * loss_mine

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            tr += float(loss.item())
            tr_n += 1

            if (batch_i % 250) == 0:
                print(f"[train] epoch {epoch}/{epochs} batch {batch_i} | loss {float(loss.item()):.4f}")

        # Validation
        model.eval()
        preds = []
        ys = []
        mine_pred = []
        mine_true = []
        with torch.no_grad():
            for batch in val_loader:
                x = batch['x'].to(device, non_blocking=use_cuda)
                a = batch['a'].to(device, non_blocking=use_cuda)
                y_surv = batch['y_survival'].to(device, non_blocking=use_cuda)
                y_mine = batch['y_mine'].to(device, non_blocking=use_cuda)

                with torch.amp.autocast(device_type='cuda', enabled=use_amp):
                    value_map, mine_logit_map = model(x)
                    pred_surv = torch.sigmoid(_gather(value_map, a))
                    pred_mine_logit = _gather(mine_logit_map, a)

                preds.append(pred_surv.detach().float().cpu())
                ys.append(y_surv.detach().float().cpu())
                mine_pred.append(torch.sigmoid(pred_mine_logit.detach().float().cpu()))
                mine_true.append(y_mine.detach().float().cpu())

        val_metrics = regression_metrics(torch.cat(preds, dim=0), torch.cat(ys, dim=0))
        mp = torch.cat(mine_pred, dim=0)
        mt = torch.cat(mine_true, dim=0)
        mine_acc = float(((mp >= 0.5).float() == (mt >= 0.5).float()).float().mean().item()) if int(mt.numel()) else 0.0

        cur_rmse = float(val_metrics.get('rmse', 0.0) or 0.0)
        print(
            f"epoch {epoch}/{epochs} | "
            f"train loss {tr/max(1,tr_n):.4f} | "
            f"val rmse {cur_rmse:.4f} mae {val_metrics['mae']:.4f} corr {val_metrics['corr']:.3f} | "
            f"mine_acc {mine_acc:.3f}"
        )

        if cur_rmse < best_rmse:
            best_rmse = cur_rmse
            best_epoch = int(epoch)
            bad_epochs = 0
            best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
        else:
            bad_epochs += 1
            if int(patience) > 0 and bad_epochs >= int(patience):
                print(f"[train] early stop at epoch {epoch} (best epoch={best_epoch} rmse={best_rmse:.4f})")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
        print(f"[train] restored best epoch={best_epoch} rmse={best_rmse:.4f}")

    return model


In [None]:
# Step 3 — Task 2: Logic model -> Actor -> Critic bootstrapping (easy / medium / hard)
#
# This follows the instructions provided in the PDF:
# - First learn to predict how long the LogicBot survives after choosing a click.
# - Then use that model as an Actor: score every move (value map) and take a high-value move.
# - Then train a Critic on the Actor's own decisions and bootstrap.

from dataclasses import asdict

import numpy as np
import torch


@torch.no_grad()
def eval_policy(*, diff: dict, policy: str, model: BoardValuePredictor | None, n_games: int = 80, seed0: int = 0, max_steps: int = 512) -> dict:
    rng = random.Random(int(seed0))
    stats = []

    h = int(diff['height'])
    w = int(diff['width'])
    m = int(diff['num_mines'])

    for _ in range(int(n_games)):
        seed = rng.randint(0, 2**31 - 1)
        g = MinesweeperGame(height=h, width=w, num_mines=m, seed=int(seed))
        setattr(g, 'allow_mine_triggers', True)

        first = (rng.randrange(h), rng.randrange(w))
        g.player_clicks(int(first[0]), int(first[1]), set())

        bot = LogicBot(g, seed=int(seed))

        steps = 0
        first_mine_at = None
        while steps < int(max_steps):
            gs = g.get_game_state()
            if gs == GameState.PROG:
                pass
            else:
                if not (bool(getattr(g, 'allow_mine_triggers', False)) and gs == GameState.LOST):
                    break

            kind = str(policy).strip().lower()
            if kind == 'logic':
                _logic_infer_sets(bot)
                a = bot.select_action()
                a = None if a is None else (int(a[0]), int(a[1]))
            else:
                if model is None:
                    raise ValueError('policy!=logic needs a model')
                a = actor_choose_click(
                    model=model,
                    game=g,
                    bot=bot,
                    seed=int(seed) + steps,
                    epsilon=float(EPS_EVAL),
                    top_k=int(TOPK_EVAL),
                )

            if a is None:
                break

            prev_m = int(getattr(g, 'mines_triggered', 0) or 0)
            steps += 1
            _ = g.player_clicks(int(a[0]), int(a[1]), set())
            cur_m = int(getattr(g, 'mines_triggered', 0) or 0)
            if first_mine_at is None and prev_m == 0 and cur_m > 0:
                first_mine_at = int(steps)

        s = g.get_statistics()
        s['steps'] = int(steps)
        s['first_mine_at'] = int(first_mine_at) if first_mine_at is not None else int(steps)
        # Survival ratio = (step of first mine) / (total steps to clear all safe cells)
        s['survival'] = float(s['first_mine_at']) / float(max(1, int(steps)))
        stats.append(s)

    return {
        'n': len(stats),
        'perfect_win_rate': float(np.mean([bool(s.get('game_won')) and float(s.get('mines_triggered', 0) or 0) == 0.0 for s in stats])),
        'avg_survival': float(np.mean([float(s.get('survival', 0) or 0) for s in stats])),
        'avg_cells_opened': float(np.mean([float(s.get('cells_opened', 0) or 0) for s in stats])),
        'avg_mines_triggered': float(np.mean([float(s.get('mines_triggered', 0) or 0) for s in stats])),
    }


# I keep old datasets/checkpoints if they already exist (reruns are fast + reproducible).
OVERWRITE = False

# Easy has very few forced-guess states, so extra bootstrap rounds can end up training on tiny datasets.
NUM_ROUNDS = 3  # default
NUM_ROUNDS_BY_DIFF = {'easy': 1, 'medium': 3, 'hard': 3}

# I scale data and training thresholds by difficulty so bootstrapping has enough signal.
GAMES_PER_ROUND_DEFAULT = 40
GAMES_PER_ROUND_BY_DIFF = {'easy': 40, 'medium': 80, 'hard': 120}

MIN_SAMPLES_TO_TRAIN_DEFAULT = 200
MIN_SAMPLES_TO_TRAIN_BY_DIFF = {'easy': 200, 'medium': 400, 'hard': 700}

EVAL_GAMES = 80

# Dataset size: samples ~ games * states/game * actions/state.
# I increase states/game for medium/hard to capture more decision points.
STATES_PER_GAME_BY_DIFF = {'easy': 18, 'medium': 28, 'hard': 32}
RECORD_PROB_BY_DIFF = {'easy': 0.25, 'medium': 0.30, 'hard': 0.35}
ACTIONS_PER_STATE_BY_DIFF = {'easy': 2, 'medium': 2, 'hard': 2}

# Actor rounds: I target a minimum number of forced-guess samples so the critic can actually learn
# a better guessing policy than random.
TARGET_ACTOR_SAMPLES_BY_DIFF = {'easy': 0, 'medium': 1500, 'hard': 2500}

# New cache tag so my reruns don’t silently reuse old actor-round datasets/checkpoints.
# (I changed data-collection behavior to actually hit target_samples, and I also make easy "guess-only".)
CACHE_TAG = 'v10_target_samples_easy_guess_only'

all_results = {}

# I run Task 2 for all difficulties, but I only *optimize* bootstrapping behavior on medium/hard.
# Easy and hard are noisy to "optimize" via safe-move ranking, so I keep the actor guess-only there.
# Medium is the main Task 2 target, so I let the model rank safe moves there.
USE_MODEL_ON_SAFE_MOVES_BY_DIFF = {'easy': False, 'medium': True, 'hard': False}

for diff_name, diff in DIFFICULTIES.items():
    # Set the per-difficulty actor behavior.
    USE_MODEL_ON_SAFE_MOVES = bool(USE_MODEL_ON_SAFE_MOVES_BY_DIFF.get(diff_name, True))

    # Slightly stronger mine penalty on hard (no extra training cost; just shifts selection).
    if str(diff_name) == 'hard':
        MINE_PENALTY = 6.0
    else:
        MINE_PENALTY = 4.0

    print('\n------------------------------')
    print('Difficulty:', diff_name, diff)
    print('Cache tag:', CACHE_TAG)
    print('Actor ranks safe moves with model:', USE_MODEL_ON_SAFE_MOVES)

    results = []

    games_per_round = int(GAMES_PER_ROUND_BY_DIFF.get(diff_name, int(GAMES_PER_ROUND_DEFAULT)))
    min_samples_to_train = int(MIN_SAMPLES_TO_TRAIN_BY_DIFF.get(diff_name, int(MIN_SAMPLES_TO_TRAIN_DEFAULT)))

    # ROUND 0: train "logic survivability" model
    ds0 = DATA_DIR / f'task2_{diff_name}_{CACHE_TAG}_round0_logic.npz'
    if ds0.exists() and not bool(OVERWRITE):
        data0 = load_dataset_npz(ds0)
        meta0 = data0['meta']
        print(f"Loaded: {ds0} samples={data0['x_visible'].shape[0]}")
    else:
        data0 = collect_dataset_for_policy(
            diff=diff,
            rollout_policy='logic',
            model=None,
            num_games=int(games_per_round),
            seed=1000,
            states_per_game=int(STATES_PER_GAME_BY_DIFF.get(diff_name, 24)),
            record_prob=float(RECORD_PROB_BY_DIFF.get(diff_name, 0.25)),
            actions_per_state=int(ACTIONS_PER_STATE_BY_DIFF.get(diff_name, 3)),
        )
        meta0 = {'task': 'task2_logic_survival', 'difficulty': diff_name, 'difficulty_cfg': diff, 'round': 0}
        save_dataset_npz(ds0, data0, meta0)
        print(f"Saved: {ds0} samples={data0['x_visible'].shape[0]}")

    ck0 = CKPT_DIR / f'task2_{diff_name}_{CACHE_TAG}_round0_logic.pt'
    if ck0.exists() and not bool(OVERWRITE):
        print(f"[bootstrap] Reusing checkpoint: {ck0}")
        p = torch.load(ck0, map_location=device)
        cfg = BoardValuePredictorConfig(**p['model_cfg'])
        actor_best = BoardValuePredictor(cfg).to(device)
        actor_best.load_state_dict(p['state_dict'])
        actor_best.eval()
    else:
        cfg = BoardValuePredictorConfig(height=int(diff['height']), width=int(diff['width']))
        actor_best = train_value_map_model(data=data0, cfg=cfg, epochs=30, batch_size=64, lr=3e-4, weight_decay=1e-2, val_frac=0.2, seed=0, patience=4, mine_loss_weight=1.0)
        torch.save({'task': 'task2_logic_survival', 'difficulty': diff_name, 'round': 0, 'model_cfg': asdict(cfg), 'state_dict': actor_best.state_dict()}, ck0)
        print(f"Saved logic-model checkpoint: {ck0}")

    base_stats = eval_policy(diff=diff, policy='logic', model=None, n_games=EVAL_GAMES, seed0=123)
    actor0_stats = eval_policy(diff=diff, policy='actor', model=actor_best, n_games=EVAL_GAMES, seed0=5000)
    print('LogicBot baseline:', base_stats)
    print('Actor (from logic model) stats:', actor0_stats)

    def _score(stats: dict) -> tuple[float, float]:
        # Primary: perfect wins (aligns with winning the game with 0 mines).
        # Secondary: avg_survival (how late the first mine happens relative to completion).
        return (float(stats.get('perfect_win_rate', 0.0) or 0.0), float(stats.get('avg_survival', 0.0) or 0.0))

    best_score = _score(actor0_stats)
    best_round = 0
    best_payload = {'task': 'task2_bootstrap', 'difficulty': diff_name, 'difficulty_cfg': diff, 'round': 0, 'model_cfg': asdict(cfg), 'state_dict': {k: v.detach().cpu() for k, v in actor_best.state_dict().items()}}

    results.append({'difficulty': diff_name, 'round': 0, 'kind': 'logic_model_actor', **actor0_stats})

    # ROUNDS 1+: critic bootstrapping on actor behavior
    num_rounds = int(NUM_ROUNDS_BY_DIFF.get(diff_name, int(NUM_ROUNDS)))
    for r in range(1, int(num_rounds)):
        ds_path = DATA_DIR / f'task2_{diff_name}_{CACHE_TAG}_round{r}_actor.npz'
        if ds_path.exists() and not bool(OVERWRITE):
            data = load_dataset_npz(ds_path)
            meta = data['meta']
            print(f"Loaded: {ds_path} samples={data['x_visible'].shape[0]}")
        else:
            tgt = int(TARGET_ACTOR_SAMPLES_BY_DIFF.get(diff_name, 0) or 0) or None
            # If I’m targeting forced-guess samples, I allow more games so I can actually hit the target.
            # Otherwise I’m guaranteed to be “data-starved” and bootstrapping becomes noisy.
            max_games = None if tgt is None else int(games_per_round) * 10

            data = collect_dataset_for_policy(
                diff=diff,
                rollout_policy='actor',
                model=actor_best,
                num_games=int(games_per_round),
                max_games=max_games,
                seed=2000 + r,
                states_per_game=int(STATES_PER_GAME_BY_DIFF.get(diff_name, 24)),
                record_prob=float(RECORD_PROB_BY_DIFF.get(diff_name, 0.25)),
                actions_per_state=int(ACTIONS_PER_STATE_BY_DIFF.get(diff_name, 3)),
                target_samples=tgt,
            )
            meta = {'task': 'task2_actor_survival', 'difficulty': diff_name, 'difficulty_cfg': diff, 'round': int(r)}
            save_dataset_npz(ds_path, data, meta)
            print(f"Saved: {ds_path} samples={data['x_visible'].shape[0]}")

        if int(data['x_visible'].shape[0]) < int(min_samples_to_train):
            print(f"[bootstrap] Skipping round {r}: only {int(data['x_visible'].shape[0])} samples (<{int(min_samples_to_train)})")
            continue

        ckpt_path = CKPT_DIR / f'task2_{diff_name}_{CACHE_TAG}_round{r}_critic.pt'
        if ckpt_path.exists() and not bool(OVERWRITE):
            print(f"[bootstrap] Reusing checkpoint: {ckpt_path}")
            p = torch.load(ckpt_path, map_location=device)
            ccfg = BoardValuePredictorConfig(**p['model_cfg'])
            critic = BoardValuePredictor(ccfg).to(device)
            critic.load_state_dict(p['state_dict'])
            critic.eval()
        else:
            ccfg = BoardValuePredictorConfig(height=int(diff['height']), width=int(diff['width']))
            critic = train_value_map_model(data=data, cfg=ccfg, epochs=30, batch_size=64, lr=3e-4, weight_decay=1e-2, val_frac=0.2, seed=0, patience=4, mine_loss_weight=1.0)
            torch.save({'task': 'task2_actor_critic', 'difficulty': diff_name, 'round': int(r), 'model_cfg': asdict(ccfg), 'state_dict': critic.state_dict()}, ckpt_path)
            print(f"Saved critic checkpoint: {ckpt_path}")

        actor_stats = eval_policy(diff=diff, policy='actor', model=critic, n_games=EVAL_GAMES, seed0=5000 + r)
        print(f"Round {r} actor stats:", actor_stats)
        results.append({'difficulty': diff_name, 'round': int(r), 'kind': 'critic_actor', **actor_stats})

        cur_score = _score(actor_stats)
        if cur_score > best_score:
            best_score = cur_score
            best_round = int(r)
            actor_best = critic
            best_payload = {'task': 'task2_bootstrap', 'difficulty': diff_name, 'difficulty_cfg': diff, 'round': int(r), 'model_cfg': asdict(ccfg), 'state_dict': {k: v.detach().cpu() for k, v in critic.state_dict().items()}}

    final_path = CKPT_DIR / f'task2_{diff_name}.pt'
    if final_path.exists() and not bool(OVERWRITE):
        print(f"Final already exists: {final_path} (not overwriting; set OVERWRITE=True to regenerate)")
    else:
        torch.save(best_payload, final_path)
        print(f"Wrote final: {final_path} (from round {best_round})")

    all_results[diff_name] = results

all_results
