# UNION (single-channel) â€” policy distillation to CNet192

This notebook distills an **ensemble of CNet192 teachers**  into a **single CNet192** student;



In [1]:
from __future__ import annotations

from pathlib import Path
from typing import Dict, List, Sequence, Tuple
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset
import random
from tqdm.auto import tqdm

from C4.connect4_env import Connect4Env
from C4.CNet192 import CNet192, load_cnet192, save_cnet192
from PPO.ppo_hall_of_fame import PPOHallOfFame, _state_to_1ch_tensor, _argmax_legal_center_tiebreak

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

N_ACTIONS = 7

# --- global seeds ---
SEED = 666
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE.type == "cuda":
    torch.cuda.manual_seed_all(SEED)


Using device: cuda


In [2]:
EPISODES = 7500
MAX_SAMPLES = 300_000
BATCH_SIZE = 512 
LR = 3e-4 
WD = 1e-5
NUM_EPOCHS = 20
GRAD_CLIP = 1.0

In [3]:
HOF_METASCORES = {
    "PPO_Models/MIX_12a.pt": 0.586667,
    #"PPO_Models/MIX_12b.pt": 0.507500,
}


In [4]:
# Derive teacher names/weights from HOF_METASCORES and load policies

TEACHER_CKPT_PATHS = list(HOF_METASCORES.keys())
TEACHER_NAMES = [Path(p).stem for p in TEACHER_CKPT_PATHS]

# Convenience maps: name -> path, name -> metascore
HOF_PATHS_BY_NAME  = {Path(p).stem: p for p in HOF_METASCORES.keys()}
HOF_SCORES_BY_NAME = {Path(p).stem: score for p, score in HOF_METASCORES.items()}

# Normalized weights in TEACHER_NAMES order
TEACHER_WEIGHTS = np.array(
    [HOF_SCORES_BY_NAME[name] for name in TEACHER_NAMES],
    dtype=np.float32,
)
TEACHER_WEIGHTS /= TEACHER_WEIGHTS.sum()

print("Teachers:", TEACHER_NAMES)
print("Weights:", dict(zip(TEACHER_NAMES, TEACHER_WEIGHTS)))

# Instantiate Hall of Fame and register members
hof = PPOHallOfFame(device=DEVICE)

for name in TEACHER_NAMES:
    ckpt_path = HOF_PATHS_BY_NAME[name]
    if not Path(ckpt_path).exists():
        raise FileNotFoundError(f"HOF checkpoint not found: {ckpt_path}")
    hof.add_member(
        name=name,
        ckpt_path=str(ckpt_path),
        metascore=HOF_SCORES_BY_NAME[name],
    )

# Materialize wrapped, frozen policies
TEACHERS: Dict[str, nn.Module] = {}
for name in TEACHER_NAMES:
    pol = hof.get_policy(name)
    pol.eval()
    TEACHERS[name] = pol
    print(f"Loaded teacher {name} from {hof.get_member(name).ckpt_path}")


Teachers: ['MIX_12a']
Weights: {'MIX_12a': np.float32(1.0)}
Loaded teacher MIX_12a from PPO_Models/MIX_12a.pt


In [5]:
# Distillation dataset + logits-mixture helpers

NEG_INF = -1e9  # masking value for illegal actions


class DistillDataset(Dataset):
    """
    states: (N,1,6,7) float32
    targets: (N,7) float32  (teacher mixture probs)
    legal_mask: (N,7) bool  (True for legal moves)
    """
    def __init__(self, states: np.ndarray, targets: np.ndarray, legal_mask: np.ndarray):
        assert states.ndim == 4 and states.shape[1:] == (1, 6, 7), f"states must be (N,1,6,7), got {states.shape}"
        assert targets.ndim == 2 and targets.shape[1] == 7, f"targets must be (N,7), got {targets.shape}"
        assert legal_mask.ndim == 2 and legal_mask.shape[1] == 7, f"legal_mask must be (N,7), got {legal_mask.shape}"

        self.states = torch.from_numpy(states.astype(np.float32, copy=False))
        self.targets = torch.from_numpy(targets.astype(np.float32, copy=False))
        self.legal_mask = torch.from_numpy(legal_mask.astype(np.bool_, copy=False))

    def __len__(self):
        return int(self.states.shape[0])

    def __getitem__(self, i):
        return self.states[i], self.targets[i], self.legal_mask[i]


def _compute_legal_mask_from_state_1ch(states_1ch: np.ndarray) -> np.ndarray:
    """
    states_1ch: (N,1,6,7) mover POV scalar board (board * player)
    Legal iff top cell in that column is zero => states[:,0,0,c] == 0
    """
    top_row = states_1ch[:, 0, 0, :]  # (N,7)
    return (top_row == 0.0)

def _softmax_masked(logits: np.ndarray, legal_mask: np.ndarray) -> np.ndarray:
    z = logits.astype(np.float32, copy=True)
    z[~legal_mask] = NEG_INF
    m = float(np.max(z[legal_mask])) if legal_mask.any() else 0.0
    e = np.zeros_like(z, dtype=np.float32)
    e[legal_mask] = np.exp(z[legal_mask] - m)
    s = float(e.sum())
    if s <= 0.0:
        # fallback uniform over legal
        p = np.zeros(7, dtype=np.float32)
        if legal_mask.any():
            p[legal_mask] = 1.0 / float(legal_mask.sum())
        return p
    return (e / s).astype(np.float32)

def ensemble_mixture_probs(state_1ch: np.ndarray, legal_actions: List[int]) -> np.ndarray:
    legal_mask = np.zeros(7, dtype=np.bool_)
    legal_mask[np.asarray(legal_actions, dtype=np.int64)] = True

    # mixture over teachers: sum_i w_i * softmax_i(masked)
    mix = np.zeros(7, dtype=np.float32)
    for name, w in zip(TEACHER_NAMES, TEACHER_WEIGHTS):
        teacher = TEACHERS[name]
        logits = _teacher_logits_np(teacher, state_1ch)  # (7,)
        p = _softmax_masked(logits, legal_mask)          # (7,)
        mix += float(w) * p

    # renormalize (just in case)
    mix[~legal_mask] = 0.0
    s = float(mix.sum())
    if s <= 0.0:
        if legal_mask.any():
            mix[legal_mask] = 1.0 / float(legal_mask.sum())
            return mix
        return np.zeros(7, dtype=np.float32)
    return (mix / s).astype(np.float32)

In [6]:
@torch.inference_mode()
def _teacher_logits_np(teacher: nn.Module, state_1ch: np.ndarray) -> np.ndarray:
    # state_1ch: (1,6,7) float32
    x = torch.from_numpy(state_1ch[None]).to(DEVICE)  # (1,1,6,7)
    out = teacher(x)
    logits = out[0] if isinstance(out, (tuple, list)) else out
    return logits[0].detach().float().cpu().numpy()   # (7,)


In [7]:
def generate_distill_data(
    n_episodes: int = 200,
    max_moves: int = 42,
    max_samples: int = 80_000,
    seed: int = 666,
) -> DistillDataset:
    rng = np.random.default_rng(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if DEVICE.type == "cuda":
        torch.cuda.manual_seed_all(seed)

    states: List[np.ndarray] = []
    targets: List[np.ndarray] = []
    legal_masks: List[np.ndarray] = []

    total_samples = 0

    pbar = tqdm(range(n_episodes), desc="Generating distill rollouts", leave=True)
    for _ in pbar:
        env = Connect4Env()
        env.reset()
        done = False
        moves = 0

        while (not done) and moves < max_moves and total_samples < max_samples:
            legal = env.available_actions()
            if not legal:
                break

            # env already returns mover-centric POV (1,6,7)
            s = env.get_state(perspective=env.current_player).astype(np.float32, copy=False)

            lm = np.zeros(7, dtype=np.bool_)
            lm[np.asarray(legal, dtype=np.int64)] = True

            target_probs = ensemble_mixture_probs(state_1ch=s, legal_actions=legal)

            states.append(s)                  # (1,6,7)
            targets.append(target_probs)      # (7,)
            legal_masks.append(lm)            # (7,)
            total_samples += 1

            # drive env with teacher-mixture argmax (legal-safe)
            a = int(target_probs.argmax())
            if a not in legal:
                a = int(rng.choice(legal))
            _, _, done = env.step(a)
            moves += 1

        pbar.set_postfix(samples=total_samples)

        if total_samples >= max_samples:
            break

    states_arr = np.stack(states, axis=0)        # (N,1,6,7)
    targets_arr = np.stack(targets, axis=0)      # (N,7)
    legal_arr = np.stack(legal_masks, axis=0)    # (N,7)

    print(f"Final distill dataset: states {states_arr.shape}, targets {targets_arr.shape}, legal {legal_arr.shape}")
    return DistillDataset(states_arr, targets_arr, legal_arr)

In [8]:
# Kick off dataset generation 
distill_dataset = generate_distill_data(
    n_episodes=EPISODES,
    max_moves=42,
    max_samples=MAX_SAMPLES,
    seed=SEED,
)

len(distill_dataset)

Generating distill rollouts:   0%|          | 0/7500 [00:00<?, ?it/s]

Final distill dataset: states (127500, 1, 6, 7), targets (127500, 7), legal (127500, 7)


127500

In [9]:
# ---- assume you already have these from dataset generation:
# states_arr: (N,1,6,7)
# targets_arr: (N,7)
assert "states_arr" in globals() and "targets_arr" in globals(), "Run dataset generation cell first (states_arr/targets_arr)."

legal_mask_arr = _compute_legal_mask_from_state_1ch(states_arr)

distill_dataset = DistillDataset(states_arr, targets_arr, legal_mask_arr)

# ---- split + loaders
VAL_FRAC = 0.10
N = len(distill_dataset)
n_val = int(round(N * VAL_FRAC))
n_train = N - n_val

g = torch.Generator().manual_seed(SEED)
train_ds, val_ds = random_split(distill_dataset, [n_train, n_val], generator=g)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=(DEVICE.type == "cuda"),
    drop_last=False,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=(DEVICE.type == "cuda"),
    drop_last=False,
)

print(f"Loaders ready: train={len(train_ds)}  val={len(val_ds)}  batch={BATCH_SIZE}")

AssertionError: Run dataset generation cell first (states_arr/targets_arr).

In [None]:
def run_epoch(loader, train: bool = True, desc: str | None = None) -> float:
    student.train(train)

    total_loss = 0.0
    total_items = 0

    pbar = tqdm(loader, desc=desc or ("Train" if train else "Val"), leave=False)

    for states, targets, legal_mask in pbar:
        states = states.to(DEVICE, non_blocking=True)            # (B,1,6,7)
        targets = targets.to(DEVICE, non_blocking=True)          # (B,7)
        legal_mask = legal_mask.to(DEVICE, non_blocking=True)    # (B,7) bool

        if train:
            optimizer.zero_grad(set_to_none=True)

        out = student(states)
        logits = out[0] if isinstance(out, (tuple, list)) else out  # (B,7)

        # CRITICAL: mask illegal actions before log_softmax
        logits = logits.masked_fill(~legal_mask, -1e9)

        log_probs = F.log_softmax(logits, dim=-1)
        loss = criterion(log_probs, targets)  # KL(student || teacher-mixture)

        if train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student.parameters(), GRAD_CLIP)
            optimizer.step()

        bs = int(states.size(0))
        total_loss += float(loss.item()) * bs
        total_items += bs

        pbar.set_postfix(loss=f"{(total_loss / max(1,total_items)):.4f}")

    return total_loss / max(1, total_items)

## Training loop

In [None]:
def run_epoch(loader, train: bool = True) -> float:
    student.train(train)
    total_loss = 0.0
    total_items = 0

    it = tqdm(loader, desc=("Train" if train else "Val"), leave=False)
    for states, targets, legal_mask in it:
        states = states.to(DEVICE, non_blocking=True)            # (B,1,6,7)
        targets = targets.to(DEVICE, non_blocking=True)          # (B,7)
        legal_mask = legal_mask.to(DEVICE, non_blocking=True)    # (B,7) bool

        if train:
            optimizer.zero_grad(set_to_none=True)

        out = student(states)
        logits = out[0] if isinstance(out, (tuple, list)) else out  # (B,7)

        # CRITICAL: prevent illegal-action leakage
        logits = logits.masked_fill(~legal_mask, -1e9)

        log_probs = F.log_softmax(logits, dim=-1)
        loss = criterion(log_probs, targets)

        if train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student.parameters(), GRAD_CLIP)
            optimizer.step()

        bs = int(states.size(0))
        total_loss += float(loss.item()) * bs
        total_items += bs
        it.set_postfix(loss=f"{(total_loss/max(1,total_items)):.4f}")

    return total_loss / max(1, total_items)

In [None]:
# ---------- main training loop ----------

meta_base = {
    "source": "UNION_PPO_distill",
    "teachers": list(HOF_METASCORES.keys()),
    "metascores": HOF_METASCORES,
    "seed": SEED,
    "num_samples": len(distill_dataset),
}

best_val = math.inf
best_path = Path("CNet192_UNION_PPO_best.pt")

# loss histories for plotting later
train_loss_history = []
val_loss_history   = []

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss = run_epoch(train_loader, train=True,  desc=f"Train {epoch}/{NUM_EPOCHS}")
    val_loss   = run_epoch(val_loader,   train=False, desc=f"Val   {epoch}/{NUM_EPOCHS}")

    train_loss_history.append(train_loss)
    val_loss_history.append(val_loss)

    print(f"[Epoch {epoch:03d}]  train_loss={train_loss:.6f}   val_loss={val_loss:.6f}")

    if val_loss < best_val:
        best_val = val_loss
        meta = dict(meta_base)
        meta["tag"] = "best_val"
        meta["epoch"] = epoch
        meta["best_val_loss"] = float(best_val)
        # NOTE: save_cnet192 expects path as first positional arg
        save_cnet192(str(best_path), student, meta)
        print(f"  -> New best model saved to: {best_path}")

# also save final model (last epoch)
final_path = Path("CNet192_UNION_PPO_last.pt")
final_meta = dict(meta_base)
final_meta["tag"] = "final"
final_meta["epoch"] = NUM_EPOCHS
final_meta["best_val_loss"] = float(best_val)
save_cnet192(str(final_path), student, final_meta)

print("Training complete.")
print(" Best model:", best_path)
print(" Final model:", final_path)

In [None]:
epochs = np.arange(1, len(train_loss_history) + 1)

plt.figure(figsize=(7, 4))
plt.plot(epochs, train_loss_history, marker="o", label="Train KL")
plt.plot(epochs, val_loss_history, marker="o", label="Val KL")

plt.xlabel("Epoch")
plt.ylabel("KL(student || teacher)")
plt.title("CNet192 UNION distillation losses")
plt.grid(True, alpha=0.35)
plt.legend()
plt.tight_layout()
plt.show()