# Merging models

In [1]:
import os, sys, math, time, random, pathlib
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from DQN.dqn_agent import DQNAgent
from DQN.dqn_model import DQN
from C4.connect4_env import Connect4Env
from C4.connect4_lookahead import Connect4Lookahead
from DQN.distill_helpers import *

print("All dependencies imported successfully.")
print("torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)

if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))
else:
    print("CUDA not available. Using CPU.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

All dependencies imported successfully.
torch version: 2.5.1
CUDA available: True
CUDA version: 11.8
GPU name: NVIDIA GeForce RTX 4090


In [2]:
begin_start_time = time.time()
time_str = time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(begin_start_time))
print(time_str)

2025-10-11 08-54-21


In [3]:
SEED = 666
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(SEED)

In [4]:
TEACHER_A_CKPT = "LA1 II DQN model.pt"     # path to teacher A
TEACHER_B_CKPT = "LA2 I DQN model.pt"      # path to teacher B
HYBRID =         "DISTILLED MIXED III DQN model.pt"

### ALPHA = 0.85          # A weight in blended target: Q* = α QA + (1-α) QB
Z_NORM = True         # per-state z-normalize teacher Q's before blending

# State collection
TOTAL_STATES = 80_000

SPLIT = {           
    "Random":       0.15,
    "Lookahead-1":  0.30,
    "Lookahead-2":  0.15,
    "Lookahead-3":  0.40,
}

# Training
LR = 2e-4 #1.0e-4
BATCH_SIZE = 1024 #512
EPOCHS = 4
GRAD_CLIP = 5.0
WEIGHT_DECAY = 5e-6 #1e-5

In [6]:
teacherA = load_teacher(TEACHER_A_CKPT, device=device, epsilon=0.0, guard_prob=0.0)
teacherB = load_teacher(TEACHER_B_CKPT, device=device, epsilon=0.0, guard_prob=0.0)

[load] missing=0, unexpected=0
[load] matched params: 14/14
[load] missing=0, unexpected=0
[load] matched params: 14/14


In [7]:
begin_start_time = time.time()

In [8]:
X_states = collect_states(TOTAL_STATES, SPLIT, device=None, shuffle=True, max_plies_per_game=None, enable_cache=True)

[collect] target_counts={'Random': 12000, 'Lookahead-1': 24000, 'Lookahead-2': 12000, 'Lookahead-3': 32000} (sum=80000)


Buckets:   0%|          | 0/4 [00:00<?, ?it/s]

Random states:   0%|          | 0/12000 [00:00<?, ?it/s]

Lookahead-1 states:   0%|          | 0/24000 [00:00<?, ?it/s]

Lookahead-2 states:   0%|          | 0/12000 [00:00<?, ?it/s]

Lookahead-3 states:   0%|          | 0/32000 [00:00<?, ?it/s]

[collect] Collected 80,000 states.  [cache=81]


In [9]:
student = build_student(init_from="A", teacherA=teacherA, teacherB=teacherB, device=device)
distill_train(student, teacherA, teacherB, X_states, alpha=ALPHA, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR, device=device)

Epochs:   0%|          | 0/4 [00:00<?, ?it/s]

epoch 1 / 4:   0%|          | 0/79 [00:00<?, ?it/s]

[distill] epoch 1: avg_loss=0.50198, batches=79


epoch 2 / 4:   0%|          | 0/79 [00:00<?, ?it/s]

[distill] epoch 2: avg_loss=0.05088, batches=79


epoch 3 / 4:   0%|          | 0/79 [00:00<?, ?it/s]

[distill] epoch 3: avg_loss=0.03083, batches=79


epoch 4 / 4:   0%|          | 0/79 [00:00<?, ?it/s]

[distill] epoch 4: avg_loss=0.02370, batches=79
[distill] training complete.


In [10]:
san = sanity_eval(
    student, teacherA, teacherB, X_states,
    alpha=ALPHA, z_norm=True,
    samples=min(10_000, X_states.shape[0]),
    batch_size=BATCH_SIZE, device=device
)

save_student(student, HYBRID)

sanity batches:   0%|          | 0/10 [00:00<?, ?it/s]

[sanity] mse=0.194215 huber=0.088547 agree=0.694 illegal_pref=0.0060 (N=10000)
[save] student weights -> DISTILLED MIXED III DQN model.pt


In [11]:
total_end_time = time.time()
total_elapsed = (total_end_time - begin_start_time) / 3600
print(f"Retrain performed in {total_elapsed:.1f} hours")

Retrain performed in 0.3 hours
