# Merging models

In [1]:
import os, sys, math, time, random, pathlib
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from DQN.dqn_agent import DQNAgent
from DQN.dqn_model import DQN
from C4.connect4_env import Connect4Env
from C4.fast_connect4_lookahead import Connect4Lookahead
from DQN.distill_helpers import *

print("All dependencies imported successfully.")
print("torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)

if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))
else:
    print("CUDA not available. Using CPU.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

All dependencies imported successfully.
torch version: 2.5.1
CUDA available: True
CUDA version: 11.8
GPU name: NVIDIA GeForce RTX 4090


In [2]:
begin_start_time = time.time()
time_str = time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(begin_start_time))
print(time_str)

2025-10-30 16-05-18


In [3]:
SEED = 666
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(SEED)

In [4]:
TEACHER_A_CKPT = "RND_3 DQN model.pt"     # path to teacher A
TEACHER_B_CKPT = "RND_1 DQN model.pt"     # path to teacher A

HYBRID =         "DIST_IV DQN model.pt"

In [5]:
ALPHA = 0.65          # A weight in blended target: Q* = α QA + (1-α) QB
Z_NORM = True         # per-state z-normalize teacher Q's before blending

# State collection
TOTAL_STATES = 150_000

SPLIT = {           
    "Random":       0.10,
    "Lookahead-1":  0.30,
    "Lookahead-2":  0.30,
    "Lookahead-3":  0.30,
}

# Training
LR = 2e-4 #1.0e-4
BATCH_SIZE = 1024 #1024
EPOCHS = 4
GRAD_CLIP = 5.0
WEIGHT_DECAY = 5e-6 #1e-5

In [6]:
teacherA = load_teacher(TEACHER_A_CKPT, device=device, epsilon=0.0, guard_prob=0.0)
teacherB = load_teacher(TEACHER_B_CKPT, device=device, epsilon=0.0, guard_prob=0.0)

In [7]:
begin_start_time = time.time()

In [8]:
X_states, mask_all, meta = collect_states_cached(
    TOTAL_STATES, SPLIT,
    cache_path="./state_cache",
    overwrite=False,                 # set True to rebuild
    shuffle=True,
    max_plies_per_game=None,
    enable_cache=True,
    phase_split=(0.35, 0.45, 0.20),
)

[cache] loaded 150,000 states from state_cache\states_150000_fb39ef0f14.pt


In [9]:
student = build_student(init_from="A", teacherA=teacherA, teacherB=teacherB, device=device)
distill_train(student, teacherA, teacherB, X_states, alpha=ALPHA, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR, device=device)

Epochs:   0%|          | 0/4 [00:00<?, ?it/s]

epoch 1 / 4:   0%|          | 0/147 [00:00<?, ?it/s]

[distill] epoch 1: avg_loss=0.08872, batches=147


epoch 2 / 4:   0%|          | 0/147 [00:00<?, ?it/s]

[distill] epoch 2: avg_loss=0.02189, batches=147


epoch 3 / 4:   0%|          | 0/147 [00:00<?, ?it/s]

[distill] epoch 3: avg_loss=0.01656, batches=147


epoch 4 / 4:   0%|          | 0/147 [00:00<?, ?it/s]

[distill] epoch 4: avg_loss=0.01426, batches=147
[distill] training complete.


In [10]:
san = sanity_eval(student, teacherA, teacherB, X_states,    alpha=ALPHA, z_norm=True,    samples=min(10_000, X_states.shape[0]),
    batch_size=BATCH_SIZE, device=device)

save_student(student, HYBRID)

sanity batches:   0%|          | 0/10 [00:00<?, ?it/s]

[sanity] mse=0.174126 huber=0.086249 agree=0.724 illegal_pref=0.0158 (N=10000)
[save] student weights -> DIST_IV DQN model.pt


In [11]:
total_end_time = time.time()
total_elapsed = (total_end_time - begin_start_time) / 3600
print(f"Retrain performed in {total_elapsed:.1f} hours")

Retrain performed in 0.0 hours
