# UNION (single-channel) — policy distillation to CNet192

This notebook distills an **ensemble of CNet192 teachers**  into a **single CNet192** student;



In [1]:
# --- Imports / setup ---
import os, time, math, random
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional
from PPO.ppo_hall_of_fame import PPOHallOfFame
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from C4.CNet192 import CNet192, load_cnet192, save_cnet192
from C4.fast_connect4_lookahead import Connect4Lookahead
from C4.connect4_env import Connect4Env

SEED = 666
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = (DEVICE.type == "cuda")
print("DEVICE:", DEVICE, " AMP:", USE_AMP)


DEVICE: cuda  AMP: True


In [2]:
EPISODES     = 1000

In [3]:
HOF_METASCORES = {
    "PPO_Models/MIX_9.pt": 0.758167,
    "PPO_Models/MIX_1.pt": 0.517667,
    "PPO_Models/MIX_8b.pt": 0.517500,
    "PPO_Models/MIX_7.pt": 0.515833,
    "PPO_Models/MIX_5.pt": 0.443333,
    "PPO_Models/MIX_8a.pt": 0.381833,
    "PPO_Models/MIX_6.pt": 0.365667,
}

In [4]:
# --- Inputs: teacher ensemble weights (meta-scores) ---

# Distillation data config
N_DISTILL_POSITIONS = 120_000          # base samples before symmetry augmentation
MAX_RANDOM_PLIES = 18        # random plies from empty board (position diversity knob)
ADD_HFLIP = True             # symmetry augmentation (recommended)

# Teacher voting config
N_ACTIONS = 7
TEACHER_TEMP = 1.0           # temperature used for argmax voting (kept for compatibility)
# If you ever want soft targets instead of vote targets, you can add it, but UNION uses votes.

# Mentor (lookahead) override config (kept small for speed)
USE_MENTOR = True
MENTOR_DEPTH = 9
MENTOR_PROB = 0.99           # apply mentor on ~6% of generated samples
MENTOR_COEF = 0.40           # blend: (1-COEF)*vote + COEF*mentor_onehot
MENTOR_MIN_CONF = 0.70       # only mentor if vote-top probability < this

# Student training config
LR = 1.0e-4
WEIGHT_DECAY = 5e-5
EPOCHS = 64
BATCH_SIZE = 512
NUM_WORKERS = 0              # set >0 if you like (Windows often prefers 0)
GRAD_CLIP = 1.0
TEMPERATURE = 1.0           # logits / temp before softmax in loss

# Output
OUT_DIR = Path("PPO_Models")
OUT_TAG = "UNION_3"

# Output filenames (configured here, not hidden later)
OUT_FILE = f"{OUT_TAG}.pt"
OUT_PATH = OUT_DIR / OUT_FILE
OUT_LATEST_PATH = OUT_DIR / f"{OUT_TAG}_LATEST.pt"


In [5]:
# --- Load teachers (CNet192 checkpoints) ---------------------------------
TEACHER_PATHS = list(HOF_METASCORES.keys())
TEACHER_NAMES = [Path(p).stem for p in TEACHER_PATHS]
HOF_SPECS = dict(zip(TEACHER_NAMES, TEACHER_PATHS))

hof = PPOHallOfFame(device=DEVICE)

for name, path in HOF_SPECS.items():
    hof.add_member(
        name=name,
        ckpt_path=path,
        metascore=float(HOF_METASCORES[path]),
    )

# Load all POP members as frozen policies (wrapped)
TEACHERS = {name: hof.ensure_loaded(name) for name in tqdm(TEACHER_NAMES, desc="Loading teachers")}
print("Loaded teachers:", TEACHER_NAMES)

# Normalized meta-score weights (used for majority-vote distillation)
raw_weights = np.array([HOF_METASCORES[HOF_SPECS[n]] for n in TEACHER_NAMES], dtype=np.float32)
raw_weights = np.maximum(raw_weights, 1e-9)
META_WEIGHTS = raw_weights / raw_weights.sum()
print("Meta weights:", {n: float(w) for n, w in zip(TEACHER_NAMES, META_WEIGHTS)})

Loading teachers:   0%|          | 0/7 [00:00<?, ?it/s]

Loaded teachers: ['MIX_9', 'MIX_1', 'MIX_8b', 'MIX_7', 'MIX_5', 'MIX_8a', 'MIX_6']
Meta weights: {'MIX_9': 0.21661914885044098, 'MIX_1': 0.14790485799312592, 'MIX_8b': 0.14785714447498322, 'MIX_7': 0.14738085865974426, 'MIX_5': 0.1266665756702423, 'MIX_8a': 0.10909514129161835, 'MIX_6': 0.10447628051042557}


In [6]:
# Mentor (Numba/CPU)
MENTOR = Connect4Lookahead() if USE_MENTOR else None

# --- Student init ---------------------------------------------------------
student = CNet192(in_channels=1, use_mid_3x3=True).to(DEVICE)
student.train()


CNet192(
  (conv1): Conv2d(1, 192, kernel_size=(4, 4), stride=(1, 1))
  (conv_mid): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(192, 192, kernel_size=(2, 2), stride=(1, 1))
  (fc): Linear(in_features=1152, out_features=192, bias=True)
  (policy_fc): Linear(in_features=192, out_features=192, bias=True)
  (policy_out): Linear(in_features=192, out_features=7, bias=True)
  (value_fc): Linear(in_features=192, out_features=192, bias=True)
  (value_out): Linear(in_features=192, out_features=1, bias=True)
)

In [7]:
def weighted_majority_vote(
    actions: List[int],
    weights: np.ndarray,
    n_actions: int,
) -> np.ndarray:
    """
    Score-weighted majority vote over teacher actions.

    actions: list of ints (one per teacher)
    weights: np.array of same length (meta-score weights)
    """
    votes = np.zeros(n_actions, dtype=np.float32)
    for a, w in zip(actions, weights):
        if 0 <= int(a) < n_actions:
            votes[int(a)] += float(w)

    total = votes.sum()
    return votes / total

def mentor_override(
    env: Connect4Env,
    target_probs: np.ndarray,
    legal_actions: List[int],
    depth: int = 3,
    max_strength: float = 0.35,
) -> np.ndarray:
    """
    Soft mentor:
      - NEVER hard-overrides POP.
      - Computes an effective lambda in [0, max_strength] based on:
          * POP entropy (uncertainty)
          * POP probability on mentor's move
      - Then nudges target_probs toward mentor's move via convex combination.

      new_p = (1 - λ) * POP + λ * one_hot(a_mentor)
    """
    if not legal_actions: return target_probs

    p = np.asarray(target_probs, dtype=np.float32)
    if p.shape[0] != N_ACTIONS: return target_probs

    # Ask mentor for best move for current mover
    try:
        a_mentor = MENTOR.n_step_lookahead(env.board, player=env.current_player, depth=depth,)
    except Exception: return p

    if a_mentor not in legal_actions: return p

    # --- POP confidence: high entropy -> low confidence, low entropy -> high confidence
    eps = 1e-8
    p_clipped = np.clip(p, eps, 1.0)
    entropy = -float(np.sum(p_clipped * np.log(p_clipped)))
    max_entropy = math.log(len(p))
    conf_pop = 1.0 - entropy / max_entropy  # 0 = very unsure, 1 = very sure

    # POP's belief in mentor move
    p_mentor_pop = float(p[a_mentor])

    # Mentor strongest when:
    #   - POP is uncertain (low conf_pop)
    #   - POP doesn't already like mentor move (low p_mentor_pop)
    lambda_eff = max_strength * (1.0 - conf_pop) * (1.0 - p_mentor_pop)

    if lambda_eff <= 0.0:
        return p

    # Convex combo: gently push mass toward mentor move
    one_hot = np.zeros_like(p)
    one_hot[a_mentor] = 1.0

    new_p = (1.0 - lambda_eff) * p + lambda_eff * one_hot
    new_p = np.clip(new_p, 1e-6, 1.0)
    new_p /= new_p.sum()
    return new_p


class DistillDataset(Dataset):
    """
    (state, target_probs) dataset for offline policy distillation.

    states:  (N, C, 6, 7) float32  — C is 4 or 6 in your setup
    targets: (N, N_ACTIONS) float32 — probability distribution over columns
    """
    def __init__(self, states: np.ndarray, targets: np.ndarray):
        assert states.ndim == 4 and states.shape[-2:] == (6, 7), \
            f"states shape should be (N, C, 6, 7), got {states.shape}"
        assert targets.ndim == 2 and targets.shape[1] == N_ACTIONS, \
            f"targets shape should be (N, {N_ACTIONS}), got {targets.shape}"

        self.states = states.astype(np.float32)
        self.targets = targets.astype(np.float32)

    def __len__(self) -> int:
        return self.states.shape[0]

    def __getitem__(self, idx: int):
        x = torch.from_numpy(self.states[idx])   # (C, 6, 7)
        y = torch.from_numpy(self.targets[idx])  # (7,)
        return x, y

In [8]:
# --- Teacher forward helpers (HOF-wrapped CNet192 / ensemble) ------------

def _module_device(m: torch.nn.Module) -> torch.device:
    try:
        return next(m.parameters()).device
    except StopIteration:
        return DEVICE  # fallback

def _state_to_1ch_batch(state: np.ndarray) -> np.ndarray:
    """
    Normalize numpy state to (B,1,6,7) float32.
    Accepts: (6,7), (1,6,7), (2,6,7), (B,6,7), (B,1,6,7), (B,2,6,7)
    """
    s = np.asarray(state)

    if s.ndim == 2:  # (6,7)
        return s.astype(np.float32)[None, None, :, :]

    if s.ndim == 3:
        # (1,6,7) or (2,6,7) or (C,6,7) OR (B,6,7)
        if s.shape[-2:] != (6, 7):
            raise ValueError(f"Bad state shape: {s.shape}")

        C0 = int(s.shape[0])

        # channel-first single sample
        if C0 in (1, 2, 4, 6) and s.shape[1:] == (6, 7):
            if C0 == 1:
                return s.astype(np.float32)[None, :, :, :]  # (1,1,6,7)
            # use first two planes as (me - opp) -> POV scalar
            scalar = (s[0] - s[1]).astype(np.float32)
            return scalar[None, None, :, :]

        # otherwise treat as batch: (B,6,7)
        return s.astype(np.float32)[:, None, :, :]

    if s.ndim == 4:
        # (B,C,6,7)
        if s.shape[-2:] != (6, 7):
            raise ValueError(f"Bad state shape: {s.shape}")
        C = int(s.shape[1])
        if C == 1:
            return s.astype(np.float32)
        # (me,opp,...) -> POV scalar
        scalar = (s[:, 0] - s[:, 1]).astype(np.float32)
        return scalar[:, None, :, :]

    raise ValueError(f"Unsupported state shape: {s.shape}")

@torch.inference_mode()
def _logits_np(policy: torch.nn.Module, s_1ch: np.ndarray) -> np.ndarray:
    """
    policy: HOF-wrapped CNet192 policy (or ensemble_teacher)
    s_1ch : state from env.get_state(), typically (1,6,7) float32 POV
    returns: (7,) numpy logits
    """
    dev = _module_device(policy)

    x_np = np.ascontiguousarray(_state_to_1ch_batch(s_1ch), dtype=np.float32)  # (1,1,6,7)
    x = torch.from_numpy(x_np).to(dev)

    out = policy.forward(x)
    if isinstance(out, (tuple, list)):
        logits = out[0]
    else:
        logits = out

    return logits[0].detach().float().cpu().numpy()

def _argmax_legal_center_tiebreak(
    logits,
    legal: List[int],
    center: int = 3,
    tol: float = 1e-12,
) -> int:
    """Deterministic argmax over legal actions; ties prefer center, then smaller index."""
    if not legal:
        raise ValueError("No legal actions")

    # accept torch or numpy
    if isinstance(logits, torch.Tensor):
        vals = logits.detach().float().cpu().numpy().astype(np.float64, copy=False)
    else:
        vals = np.asarray(logits, dtype=np.float64)

    best = max(vals[c] for c in legal)

    tied = [c for c in legal if abs(vals[c] - best) <= tol]
    if len(tied) == 1:
        return int(tied[0])

    return int(min(tied, key=lambda c: (abs(c - center), c)))



In [9]:
# Generate distillation dataset from POP 

def generate_distill_data(
    n_episodes: int = 200,
    max_moves: int = 42,
    max_samples: int = 80_000,
    seed: int = 666,
    use_mentor: bool = True,
    mentor_depth: int = 3,
) -> DistillDataset:
    """
    Let the POP ensemble play games in Connect4Env and record:
      - s_t  := env.get_state(perspective=env.current_player)
      - y_t  := score-weighted majority vote over teacher greedy actions

    Progress:
      - Outer loop is wrapped in tqdm so you see episode progress.
    """

    rng = np.random.default_rng(seed)
    import random as _random
    _random.seed(seed)
    torch.manual_seed(seed)

    states: List[np.ndarray] = []
    targets: List[np.ndarray] = []

    total_samples = 0

    for ep in tqdm(range(n_episodes), desc="Generating POP rollouts"):
        env = Connect4Env()
        env.reset()
        done = False
        moves = 0

        while (not done) and moves < max_moves and total_samples < max_samples:
            legal = env.available_actions()
            if not legal:
                break

            # Mover-centric state; typically (4,6,7) in your setup
            s = env.get_state(perspective=env.current_player)
            s = np.asarray(s, dtype=np.float32)

            # Each teacher proposes a greedy action
            teacher_actions: List[int] = []
            for name in TEACHER_NAMES:
                policy = TEACHERS[name]
                logits = _logits_np(policy, s)              # (7,)
                a = _argmax_legal_center_tiebreak(logits, legal)
                teacher_actions.append(int(a))

            # Score-weighted majority vote
            target_probs = weighted_majority_vote(
                teacher_actions,
                META_WEIGHTS,
                n_actions=N_ACTIONS,
            )

            # Optional: mentor tweak / override
            if use_mentor:
                target_probs = mentor_override(
                    env=env,
                    target_probs=target_probs,
                    legal_actions=legal,
                    depth=MENTOR_DEPTH,
                )

            # Snapshot (state, label)
            states.append(s)
            targets.append(target_probs.astype(np.float32))
            total_samples += 1

            # Use POP majority action to drive env
            majority_action = int(np.argmax(target_probs))
            if majority_action not in legal:
                majority_action = int(rng.choice(legal))

            _, _, done = env.step(majority_action)
            moves += 1

        if total_samples >= max_samples:
            break

    states_arr = np.stack(states, axis=0)   # (N, C, 6, 7)
    targets_arr = np.stack(targets, axis=0) # (N, 7)

    print(f"Final distill dataset: states {states_arr.shape}, targets {targets_arr.shape}")
    return DistillDataset(states_arr, targets_arr)


# Kick off dataset generation
distill_dataset = generate_distill_data(
    n_episodes=EPISODES,
    max_moves=42,
    max_samples=100_000,
    seed=666,
    use_mentor=True,  
    mentor_depth=MENTOR_DEPTH,
)

len(distill_dataset)


Generating POP rollouts:   0%|          | 0/1000 [00:00<?, ?it/s]

Final distill dataset: states (31000, 1, 6, 7), targets (31000, 7)


31000

In [10]:
# Train student via POP distillation 
train_loader = DataLoader(distill_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True,)
optimizer = torch.optim.Adam(student.parameters(), lr=LR)


def distill_loss(
    student_logits: torch.Tensor,
    target_probs: torch.Tensor,
    temperature: float = 1.0,
) -> torch.Tensor:
    """
    KL(target || student) with optional temperature scaling on student logits.
    """
    if temperature != 1.0:
        student_logits = student_logits / temperature

    log_student = F.log_softmax(student_logits, dim=-1)
    loss = F.kl_div(log_student, target_probs, reduction="batchmean")
    return loss


for epoch in range(1, EPOCHS + 1):
    student.train()
    running_loss = 0.0
    n_batches = 0

    for x_batch, y_batch in tqdm(
        train_loader,
        desc=f"Epoch {epoch}/{EPOCHS}",
        leave=False,
    ):
        # x_batch: (B, C, 6, 7), y_batch: (B, 7)
        x_batch = x_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE)

        optimizer.zero_grad()
        logits, _ = student.forward(x_batch)
        loss = distill_loss(logits, y_batch, temperature=TEMPERATURE)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=0.5)
        optimizer.step()

        running_loss += float(loss.item())
        n_batches += 1

    avg_loss = running_loss / max(1, n_batches)
    print(f"[Epoch {epoch}/{EPOCHS}] distill_loss = {avg_loss:.4f}")


Epoch 1/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 1/64] distill_loss = 1.5827


Epoch 2/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 2/64] distill_loss = 0.9403


Epoch 3/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 3/64] distill_loss = 0.5117


Epoch 4/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 4/64] distill_loss = 0.2367


Epoch 5/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 5/64] distill_loss = 0.0754


Epoch 6/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 6/64] distill_loss = 0.0222


Epoch 7/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 7/64] distill_loss = 0.0148


Epoch 8/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 8/64] distill_loss = 0.0121


Epoch 9/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 9/64] distill_loss = 0.0070


Epoch 10/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 10/64] distill_loss = 0.0071


Epoch 11/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 11/64] distill_loss = 0.0057


Epoch 12/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 12/64] distill_loss = 0.0032


Epoch 13/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 13/64] distill_loss = 0.0027


Epoch 14/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 14/64] distill_loss = 0.0024


Epoch 15/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 15/64] distill_loss = 0.0016


Epoch 16/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 16/64] distill_loss = 0.0026


Epoch 17/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 17/64] distill_loss = 0.0018


Epoch 18/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 18/64] distill_loss = 0.0022


Epoch 19/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 19/64] distill_loss = 0.0016


Epoch 20/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 20/64] distill_loss = 0.0034


Epoch 21/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 21/64] distill_loss = 0.0047


Epoch 22/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 22/64] distill_loss = 0.0037


Epoch 23/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 23/64] distill_loss = 0.0018


Epoch 24/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 24/64] distill_loss = 0.0008


Epoch 25/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 25/64] distill_loss = 0.0012


Epoch 26/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 26/64] distill_loss = 0.0005


Epoch 27/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 27/64] distill_loss = 0.0008


Epoch 28/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 28/64] distill_loss = 0.0010


Epoch 29/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 29/64] distill_loss = 0.0018


Epoch 30/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 30/64] distill_loss = 0.0014


Epoch 31/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 31/64] distill_loss = 0.0005


Epoch 32/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 32/64] distill_loss = 0.0015


Epoch 33/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 33/64] distill_loss = 0.0013


Epoch 34/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 34/64] distill_loss = 0.0012


Epoch 35/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 35/64] distill_loss = 0.0016


Epoch 36/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 36/64] distill_loss = 0.0006


Epoch 37/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 37/64] distill_loss = 0.0008


Epoch 38/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 38/64] distill_loss = 0.0013


Epoch 39/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 39/64] distill_loss = 0.0012


Epoch 40/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 40/64] distill_loss = 0.0010


Epoch 41/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 41/64] distill_loss = 0.0013


Epoch 42/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 42/64] distill_loss = 0.0011


Epoch 43/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 43/64] distill_loss = 0.0006


Epoch 44/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 44/64] distill_loss = 0.0011


Epoch 45/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 45/64] distill_loss = 0.0006


Epoch 46/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 46/64] distill_loss = 0.0002


Epoch 47/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 47/64] distill_loss = 0.0003


Epoch 48/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 48/64] distill_loss = 0.0008


Epoch 49/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 49/64] distill_loss = 0.0006


Epoch 50/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 50/64] distill_loss = 0.0008


Epoch 51/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 51/64] distill_loss = 0.0012


Epoch 52/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 52/64] distill_loss = 0.0013


Epoch 53/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 53/64] distill_loss = 0.0011


Epoch 54/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 54/64] distill_loss = 0.0012


Epoch 55/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 55/64] distill_loss = 0.0006


Epoch 56/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 56/64] distill_loss = 0.0002


Epoch 57/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 57/64] distill_loss = 0.0001


Epoch 58/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 58/64] distill_loss = 0.0001


Epoch 59/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 59/64] distill_loss = 0.0001


Epoch 60/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 60/64] distill_loss = 0.0003


Epoch 61/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 61/64] distill_loss = 0.0009


Epoch 62/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 62/64] distill_loss = 0.0022


Epoch 63/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 63/64] distill_loss = 0.0017


Epoch 64/64:   0%|          | 0/60 [00:00<?, ?it/s]

[Epoch 64/64] distill_loss = 0.0010


In [11]:
# --- Save distilled student (CNet192 checkpoint) ---

meta = {
    "tag": OUT_TAG,
    "distill_cfg": {
        "N_DISTILL_POSITIONS": int(N_DISTILL_POSITIONS),
        "USE_MENTOR": bool(USE_MENTOR),
        "MENTOR_DEPTH": int(MENTOR_DEPTH),
        "MENTOR_PROB": float(MENTOR_PROB),
        "MENTOR_COEF": float(MENTOR_COEF),
        "MENTOR_MIN_CONF": float(MENTOR_MIN_CONF),
        "LR": float(LR),
        "WEIGHT_DECAY": float(WEIGHT_DECAY),
        "EPOCHS": int(EPOCHS),
        "BATCH_SIZE": int(BATCH_SIZE),
    },
}

save_cnet192(
    path=OUT_PATH,
    model=student,
    cfg={"input_channels": 1, "use_mid_3x3": bool(getattr(student, "use_mid_3x3", True))},
    **meta,
)

print("Saved:", OUT_PATH)


Saved: PPO_Models\UNION_3.pt
