# Teacher's Union
Checkpoint distillation

In [1]:
import math
import random
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PPO.actor_critic import ActorCritic
from C4.connect4_env import Connect4Env
from C4.fast_connect4_lookahead import Connect4Lookahead  
from PPO.ppo_hall_of_fame import PPOHallOfFame
from PPO.ppo_agent_eval import _logits_np, _argmax_legal_center_tiebreak
from tqdm.auto import tqdm
import pprint

# ⚙ device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)
MENTOR_DEPTH = 7
OUT_PATH = Path("DIST_XXXIX PPO model.pt")

BATCH_SIZE   = 1024
EPOCHS       = 64
LR           = 625e-7
TEMPERATURE  = 0.99  
EPISODES     = 1000

Using device: cuda


In [2]:
HOF_METASCORES = {
    "RND_71": 0.670,
    "RND_58a": 0.554,
    "RND_56c": 0.530,
    "DIST_XXXIV": 0.521,
    "RND_69": 0.517,
    "RND_77b": 0.517,
    "RND_74": 0.488,
    "DIAG_4": 0.476,
    "INT_5": 0.458,
    "RND_76": 0.456,
}


In [3]:
hof = PPOHallOfFame(device=DEVICE)

HOF_SPECS = {
    name: f"{name} PPO model.pt"
    for name in HOF_METASCORES.keys()
}

In [4]:
N_ACTIONS = 7  # Connect-4
POP_MEMBER_NAMES = list(HOF_SPECS.keys())
MENTOR = Connect4Lookahead()

In [5]:
# --- Build Hall-of-Fame and load individual teachers --------------------

hof = PPOHallOfFame(device=DEVICE)

for name, path in HOF_SPECS.items():
    hof.add_member(
        name=name,
        ckpt_path=path,
        metascore=HOF_METASCORES.get(name, 1.0),
    )


# Load all POP members as frozen ActorCritic policies
TEACHER_NAMES = POP_MEMBER_NAMES
TEACHERS = {name: hof.ensure_loaded(name) for name in TEACHER_NAMES}

print("Loaded teachers:", TEACHER_NAMES)

# Normalized meta-score weights (used for majority-vote distillation)
raw_weights = np.array([HOF_METASCORES.get(n, 1.0) for n in TEACHER_NAMES], dtype=np.float32)
raw_weights = np.maximum(raw_weights, 1e-6)
META_WEIGHTS = raw_weights / raw_weights.sum()
print("Meta weights:", {n: float(w) for n, w in zip(TEACHER_NAMES, META_WEIGHTS)})

Loaded teachers: ['RND_71', 'RND_58a', 'RND_56c', 'DIST_XXXIV', 'RND_69', 'RND_77b', 'RND_74', 'DIAG_4', 'INT_5', 'RND_76']
Meta weights: {'RND_71': 0.12916907668113708, 'RND_58a': 0.1068054735660553, 'RND_56c': 0.10217851400375366, 'DIST_XXXIV': 0.10044341534376144, 'RND_69': 0.09967225790023804, 'RND_77b': 0.09967225790023804, 'RND_74': 0.094081349670887, 'DIAG_4': 0.09176787734031677, 'INT_5': 0.08829766511917114, 'RND_76': 0.08791208267211914}


In [6]:
# build a score-weighted ensemble teacher (for online distill) ----
ensemble_teacher = hof.build_ensemble(names=POP_MEMBER_NAMES, use_metascore_weights=True)
ensemble_teacher.eval()

print("Ensemble teacher built with", len(POP_MEMBER_NAMES), "members.")

Ensemble teacher built with 10 members.


In [7]:
# --- Student policy to distil into --------------------------------------

student = ActorCritic().to(DEVICE)
student.train()  

print("Student ActorCritic params:", sum(p.numel() for p in student.parameters()))

Student ActorCritic params: 112120


In [8]:
def weighted_majority_vote(
    actions: List[int],
    weights: np.ndarray,
    n_actions: int,
) -> np.ndarray:
    """
    Score-weighted majority vote over teacher actions.

    actions: list of ints (one per teacher)
    weights: np.array of same length (meta-score weights)
    """
    votes = np.zeros(n_actions, dtype=np.float32)
    for a, w in zip(actions, weights):
        if 0 <= int(a) < n_actions:
            votes[int(a)] += float(w)

    total = votes.sum()
    return votes / total

def mentor_override(
    env: Connect4Env,
    target_probs: np.ndarray,
    legal_actions: List[int],
    depth: int = 3,
    max_strength: float = 0.35,
) -> np.ndarray:
    """
    Soft mentor:
      - NEVER hard-overrides POP.
      - Computes an effective lambda in [0, max_strength] based on:
          * POP entropy (uncertainty)
          * POP probability on mentor's move
      - Then nudges target_probs toward mentor's move via convex combination.

      new_p = (1 - λ) * POP + λ * one_hot(a_mentor)
    """
    if not legal_actions: return target_probs

    p = np.asarray(target_probs, dtype=np.float32)
    if p.shape[0] != N_ACTIONS: return target_probs

    # Ask mentor for best move for current mover
    try:
        a_mentor = MENTOR.n_step_lookahead(env.board, player=env.current_player, depth=depth,)
    except Exception: return p

    if a_mentor not in legal_actions: return p

    # --- POP confidence: high entropy -> low confidence, low entropy -> high confidence
    eps = 1e-8
    p_clipped = np.clip(p, eps, 1.0)
    entropy = -float(np.sum(p_clipped * np.log(p_clipped)))
    max_entropy = math.log(len(p))
    conf_pop = 1.0 - entropy / max_entropy  # 0 = very unsure, 1 = very sure

    # POP's belief in mentor move
    p_mentor_pop = float(p[a_mentor])

    # Mentor strongest when:
    #   - POP is uncertain (low conf_pop)
    #   - POP doesn't already like mentor move (low p_mentor_pop)
    lambda_eff = max_strength * (1.0 - conf_pop) * (1.0 - p_mentor_pop)

    if lambda_eff <= 0.0:
        return p

    # Convex combo: gently push mass toward mentor move
    one_hot = np.zeros_like(p)
    one_hot[a_mentor] = 1.0

    new_p = (1.0 - lambda_eff) * p + lambda_eff * one_hot
    new_p = np.clip(new_p, 1e-6, 1.0)
    new_p /= new_p.sum()
    return new_p


class DistillDataset(Dataset):
    """
    (state, target_probs) dataset for offline policy distillation.

    states:  (N, C, 6, 7) float32  — C is 4 or 6 in your setup
    targets: (N, N_ACTIONS) float32 — probability distribution over columns
    """
    def __init__(self, states: np.ndarray, targets: np.ndarray):
        assert states.ndim == 4 and states.shape[-2:] == (6, 7), \
            f"states shape should be (N, C, 6, 7), got {states.shape}"
        assert targets.ndim == 2 and targets.shape[1] == N_ACTIONS, \
            f"targets shape should be (N, {N_ACTIONS}), got {targets.shape}"

        self.states = states.astype(np.float32)
        self.targets = targets.astype(np.float32)

    def __len__(self) -> int:
        return self.states.shape[0]

    def __getitem__(self, idx: int):
        x = torch.from_numpy(self.states[idx])   # (C, 6, 7)
        y = torch.from_numpy(self.targets[idx])  # (7,)
        return x, y

In [9]:
# Generate distillation dataset from POP 

def generate_distill_data(
    n_episodes: int = 200,
    max_moves: int = 42,
    max_samples: int = 80_000,
    seed: int = 666,
    use_mentor: bool = True,
    mentor_depth: int = 3,
) -> DistillDataset:
    """
    Let the POP ensemble play games in Connect4Env and record:
      - s_t  := env.get_state(perspective=env.current_player)
      - y_t  := score-weighted majority vote over teacher greedy actions

    Progress:
      - Outer loop is wrapped in tqdm so you see episode progress.
    """

    rng = np.random.default_rng(seed)
    import random as _random
    _random.seed(seed)
    torch.manual_seed(seed)

    states: List[np.ndarray] = []
    targets: List[np.ndarray] = []

    total_samples = 0

    for ep in tqdm(range(n_episodes), desc="Generating POP rollouts"):
        env = Connect4Env()
        env.reset()
        done = False
        moves = 0

        while (not done) and moves < max_moves and total_samples < max_samples:
            legal = env.available_actions()
            if not legal:
                break

            # Mover-centric state; typically (4,6,7) in your setup
            s = env.get_state(perspective=env.current_player)
            s = np.asarray(s, dtype=np.float32)

            # Each teacher proposes a greedy action
            teacher_actions: List[int] = []
            for name in TEACHER_NAMES:
                policy = TEACHERS[name]
                logits = _logits_np(policy, s)              # (7,)
                a = _argmax_legal_center_tiebreak(logits, legal)
                teacher_actions.append(int(a))

            # Score-weighted majority vote
            target_probs = weighted_majority_vote(
                teacher_actions,
                META_WEIGHTS,
                n_actions=N_ACTIONS,
            )

            # Optional: mentor tweak / override
            if use_mentor:
                target_probs = mentor_override(
                    env=env,
                    target_probs=target_probs,
                    legal_actions=legal,
                    depth=MENTOR_DEPTH,
                )

            # Snapshot (state, label)
            states.append(s)
            targets.append(target_probs.astype(np.float32))
            total_samples += 1

            # Use POP majority action to drive env
            majority_action = int(np.argmax(target_probs))
            if majority_action not in legal:
                majority_action = int(rng.choice(legal))

            _, _, done = env.step(majority_action)
            moves += 1

        if total_samples >= max_samples:
            break

    states_arr = np.stack(states, axis=0)   # (N, C, 6, 7)
    targets_arr = np.stack(targets, axis=0) # (N, 7)

    print(f"Final distill dataset: states {states_arr.shape}, targets {targets_arr.shape}")
    return DistillDataset(states_arr, targets_arr)


# Kick off dataset generation
distill_dataset = generate_distill_data(
    n_episodes=EPISODES,
    max_moves=42,
    max_samples=100_000,
    seed=666,
    use_mentor=True,  
    mentor_depth=MENTOR_DEPTH,
)

len(distill_dataset)


Generating POP rollouts:   0%|          | 0/1000 [00:00<?, ?it/s]

Final distill dataset: states (28000, 4, 6, 7), targets (28000, 7)


28000

In [10]:
# Train student via POP distillation 
train_loader = DataLoader(distill_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True,)
optimizer = torch.optim.Adam(student.parameters(), lr=LR)


def distill_loss(
    student_logits: torch.Tensor,
    target_probs: torch.Tensor,
    temperature: float = 1.0,
) -> torch.Tensor:
    """
    KL(target || student) with optional temperature scaling on student logits.
    """
    if temperature != 1.0:
        student_logits = student_logits / temperature

    log_student = F.log_softmax(student_logits, dim=-1)
    loss = F.kl_div(log_student, target_probs, reduction="batchmean")
    return loss


for epoch in range(1, EPOCHS + 1):
    student.train()
    running_loss = 0.0
    n_batches = 0

    for x_batch, y_batch in tqdm(
        train_loader,
        desc=f"Epoch {epoch}/{EPOCHS}",
        leave=False,
    ):
        # x_batch: (B, C, 6, 7), y_batch: (B, 7)
        x_batch = x_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE)

        optimizer.zero_grad()
        logits, _ = student.forward(x_batch)
        loss = distill_loss(logits, y_batch, temperature=TEMPERATURE)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=0.5)
        optimizer.step()

        running_loss += float(loss.item())
        n_batches += 1

    avg_loss = running_loss / max(1, n_batches)
    print(f"[Epoch {epoch}/{EPOCHS}] distill_loss = {avg_loss:.4f}")


Epoch 1/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 1/64] distill_loss = 1.0915


Epoch 2/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 2/64] distill_loss = 0.9923


Epoch 3/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 3/64] distill_loss = 0.9473


Epoch 4/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 4/64] distill_loss = 0.9099


Epoch 5/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 5/64] distill_loss = 0.8029


Epoch 6/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 6/64] distill_loss = 0.6622


Epoch 7/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 7/64] distill_loss = 0.5644


Epoch 8/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 8/64] distill_loss = 0.4907


Epoch 9/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 9/64] distill_loss = 0.4268


Epoch 10/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 10/64] distill_loss = 0.3829


Epoch 11/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 11/64] distill_loss = 0.3525


Epoch 12/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 12/64] distill_loss = 0.3283


Epoch 13/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 13/64] distill_loss = 0.3033


Epoch 14/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 14/64] distill_loss = 0.2767


Epoch 15/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 15/64] distill_loss = 0.2525


Epoch 16/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 16/64] distill_loss = 0.2310


Epoch 17/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 17/64] distill_loss = 0.2165


Epoch 18/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 18/64] distill_loss = 0.2002


Epoch 19/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 19/64] distill_loss = 0.1890


Epoch 20/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 20/64] distill_loss = 0.1784


Epoch 21/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 21/64] distill_loss = 0.1685


Epoch 22/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 22/64] distill_loss = 0.1599


Epoch 23/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 23/64] distill_loss = 0.1518


Epoch 24/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 24/64] distill_loss = 0.1462


Epoch 25/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 25/64] distill_loss = 0.1402


Epoch 26/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 26/64] distill_loss = 0.1346


Epoch 27/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 27/64] distill_loss = 0.1296


Epoch 28/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 28/64] distill_loss = 0.1261


Epoch 29/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 29/64] distill_loss = 0.1214


Epoch 30/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 30/64] distill_loss = 0.1171


Epoch 31/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 31/64] distill_loss = 0.1132


Epoch 32/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 32/64] distill_loss = 0.1094


Epoch 33/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 33/64] distill_loss = 0.1077


Epoch 34/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 34/64] distill_loss = 0.1045


Epoch 35/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 35/64] distill_loss = 0.1012


Epoch 36/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 36/64] distill_loss = 0.0988


Epoch 37/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 37/64] distill_loss = 0.0954


Epoch 38/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 38/64] distill_loss = 0.0928


Epoch 39/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 39/64] distill_loss = 0.0908


Epoch 40/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 40/64] distill_loss = 0.0892


Epoch 41/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 41/64] distill_loss = 0.0872


Epoch 42/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 42/64] distill_loss = 0.0853


Epoch 43/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 43/64] distill_loss = 0.0840


Epoch 44/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 44/64] distill_loss = 0.0819


Epoch 45/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 45/64] distill_loss = 0.0807


Epoch 46/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 46/64] distill_loss = 0.0799


Epoch 47/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 47/64] distill_loss = 0.0786


Epoch 48/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 48/64] distill_loss = 0.0774


Epoch 49/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 49/64] distill_loss = 0.0761


Epoch 50/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 50/64] distill_loss = 0.0752


Epoch 51/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 51/64] distill_loss = 0.0746


Epoch 52/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 52/64] distill_loss = 0.0734


Epoch 53/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 53/64] distill_loss = 0.0721


Epoch 54/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 54/64] distill_loss = 0.0725


Epoch 55/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 55/64] distill_loss = 0.0707


Epoch 56/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 56/64] distill_loss = 0.0712


Epoch 57/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 57/64] distill_loss = 0.0692


Epoch 58/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 58/64] distill_loss = 0.0681


Epoch 59/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 59/64] distill_loss = 0.0670


Epoch 60/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 60/64] distill_loss = 0.0662


Epoch 61/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 61/64] distill_loss = 0.0656


Epoch 62/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 62/64] distill_loss = 0.0654


Epoch 63/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 63/64] distill_loss = 0.0676


Epoch 64/64:   0%|          | 0/27 [00:00<?, ?it/s]

[Epoch 64/64] distill_loss = 0.0635


In [11]:
# Save distilled model

from pathlib import Path
OUT_PATH.parent.mkdir(parents=True, exist_ok=True)

torch.save(
    {
        "model_state_dict": student.state_dict(),
        "meta": {
            "source": "POP_distillation",
            "teachers": TEACHER_NAMES,
            "teacher_weights": META_WEIGHTS.tolist(),
        },
    },
    OUT_PATH,
)

print("Saved distilled model to:", OUT_PATH)



Saved distilled model to: DIST_XXXIX PPO model.pt
