# PPO Training with Frozen DQN / AlphaZero Opponents

在 Kaggle GPU 上运行：下载队友训练好的 DQN / AlphaZero 权重，作为冻结对手加入对手池，与 PPO 训练对战，最后保存/上传权重到 Hugging Face。


## 克隆代码仓库（云端首次运行需要）


In [None]:
!git clone https://github.com/mogoo7zn/Kaggle-ConnectX.git
%cd Kaggle-ConnectX


## 0. 环境与依赖
- 需 Kaggle GPU 环境
- 安装最小依赖：`kaggle-environments`（对战模拟）、`huggingface_hub`（可选上传）


In [None]:
!pip install -q kaggle-environments huggingface_hub


In [None]:
import os
import sys
import random
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
import importlib

# 让项目包可被导入：优先常见 Kaggle 路径，否则尝试当前/父目录
candidates = [
    Path('/kaggle/working/Kaggle-ConnectX'),  # git clone 默认位置
    Path('/kaggle/working'),                  # 若 notebook 已位于仓库内，这里无 agents 则会跳过
    Path.cwd(),
    Path.cwd().parent,
]
repo_root = None
for c in candidates:
    if (c / 'agents').exists():
        repo_root = c
        break
if repo_root is None:
    repo_root = Path('.').resolve()

os.chdir(repo_root)
sys.path.insert(0, str(repo_root))
importlib.invalidate_caches()
print("Repo root:", repo_root)
print("CWD:", Path.cwd())
print("agents exists:", (repo_root / 'agents').exists())
print("mcts exists:", (repo_root / 'agents/alphazero/mcts.py').exists())

from agents.ppo.ppo_agent import PPOAgent
from agents.ppo.ppo_config import ppo_config
from agents.base.utils import get_valid_moves, get_negamax_move, encode_state, make_move, is_terminal
from agents.dqn.dqn_agent import DQNAgent

# 可选：AlphaZero 载入（需要与你队友的权重匹配）
try:
    from agents.alphazero.az_model import create_alphazero_model
except Exception:
    import traceback
    traceback.print_exc()
    create_alphazero_model = None

DEVICE = ppo_config.DEVICE
print("Using device:", DEVICE)


In [None]:
# 可在 notebook 内直接定义/替换模型，避免每次修改后重新 clone
# 调整下方通道/隐藏层即可放大模型规模
import types
import torch.nn as nn
import torch.nn.functional as F
from agents.ppo import ppo_model

class ActorCriticCustom(nn.Module):
    def __init__(self):
        super().__init__()
        # 自行调大通道/隐藏层尺寸
        self.conv1 = nn.Conv2d(3, 128, 3, padding=1)
        self.conv2 = nn.Conv2d(128, 256, 3, padding=1)
        conv_out = ppo_config.ROWS * ppo_config.COLUMNS * 256
        self.fc = nn.Linear(conv_out, 512)
        self.policy = nn.Linear(512, ppo_config.COLUMNS)
        self.value = nn.Linear(512, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        logits = self.policy(x)
        value = self.value(x)
        return logits, value

# 覆盖原有工厂方法
ppo_model.ActorCritic = ActorCriticCustom
ppo_model.make_model = lambda: ActorCriticCustom().to(ppo_config.DEVICE)

print("Using custom ActorCritic with larger capacity:", ActorCriticCustom())


## 1. 下载队友的已训练模型
填写你队友提供的下载链接（例如 Hugging Face / Kaggle Dataset）。


In [None]:
from pathlib import Path
import requests

# 直接填入公开仓库的 resolve 链接
DQN_URL = "https://huggingface.co/mogoo7zn/Kaggle-ConnectX/resolve/main/DQN-base.pth"
AZ_URL  = "https://huggingface.co/mogoo7zn/Kaggle-ConnectX/resolve/main/alpha-zero-medium.pth"
# 如果你想用 high 版，就把上面这一行改成 alpha-zero-high.pth

ckpt_dir = Path('/kaggle/working/checkpoints')
ckpt_dir.mkdir(parents=True, exist_ok=True)

def download(url, out_path):
    if not url:
        print(f"[skip] empty url for {out_path.name}")
        return None
    print(f"Downloading {url} -> {out_path}")
    r = requests.get(url)
    r.raise_for_status()
    with open(out_path, 'wb') as f:
        f.write(r.content)
    return out_path

ckpt_dqn = download(DQN_URL, ckpt_dir/'dqn_frozen.pth')
ckpt_az  = download(AZ_URL,  ckpt_dir/'alphazero_frozen.pth')

## 2. 加载冻结对手（DQN / AlphaZero，可选）
如果未提供权重，则跳过对应对手。


In [None]:
frozen_dqn = None
if ckpt_dqn and ckpt_dqn.exists():
    frozen_dqn = DQNAgent()
    frozen_dqn.load_model(str(ckpt_dqn))
    print("Loaded frozen DQN.")
else:
    print("No DQN checkpoint provided.")

frozen_az = None
if ckpt_az and ckpt_az.exists() and create_alphazero_model:
    try:
        frozen_az = create_alphazero_model('full')  # 如权重是轻量版改为 'light'
        state = torch.load(ckpt_az, map_location=DEVICE)
        if isinstance(state, dict) and 'model_state_dict' in state:
            state = state['model_state_dict']
        frozen_az.load_state_dict(state, strict=False)
        frozen_az.eval()
        print("Loaded frozen AlphaZero model.")
    except Exception as e:
        print("Failed to load AlphaZero model:", e)
        frozen_az = None
else:
    print("No AlphaZero checkpoint provided or loader unavailable.")


## 3. 定义对手池（随机 / negamax / 冻结DQN / 冻结AlphaZero）


In [None]:
from types import SimpleNamespace
from kaggle_environments.envs.connectx.connectx import negamax_agent

def random_policy(board, mark):
    moves = get_valid_moves(board)
    return random.choice(moves) if moves else 0

def negamax_simple_policy(board, mark, depth=4):
    # 原简化版：先手赢/堵 + 中心优先
    return get_negamax_move(board, mark, depth=depth)

def negamax_kaggle_policy(board, mark, depth=4):
    # kaggle-environments 自带 negamax 搜索；深度通过配置传入（若实现不支持，则忽略）
    obs = SimpleNamespace(board=board, mark=mark)
    cfg = SimpleNamespace(rows=ppo_config.ROWS, columns=ppo_config.COLUMNS, inarow=ppo_config.INAROW,
                          timeout=1, actTimeout=1, depth=depth)
    return int(negamax_agent(obs, cfg))

def frozen_dqn_policy(board, mark):
    assert frozen_dqn is not None, "frozen_dqn not loaded; otherwise fallback would be random"
    return frozen_dqn.select_action(board, mark, epsilon=0.0)

def frozen_alphazero_policy(board, mark):
    assert frozen_az is not None, "frozen_az not loaded; otherwise fallback would be random"
    # AlphaZero uses MCTS normally; here greedily pick best policy head
    state = encode_state(board, mark)
    state_t = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits, _ = frozen_az(state_t)
        probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
    valid = get_valid_moves(board)
    masked = np.full_like(probs, -1e9)
    masked[valid] = probs[valid]
    return int(masked.argmax())

# 目标：negamax 类对手合计 50% 抽样概率
opponent_candidates = [
    ("negamax_simple", negamax_simple_policy, 0.25),
    ("negamax_kaggle", negamax_kaggle_policy, 0.25),
    ("random", random_policy, 0.25),
]
if frozen_dqn:
    opponent_candidates.append(("frozen_dqn", frozen_dqn_policy, 0.125))
if frozen_az:
    opponent_candidates.append(("frozen_az", frozen_alphazero_policy, 0.125))

# 归一化权重，确保总和为 1
opponent_fns = [fn for _, fn, _ in opponent_candidates]
opponent_weights = np.array([w for _, _, w in opponent_candidates], dtype=float)
opponent_weights = opponent_weights / opponent_weights.sum()

opponent_pool = [name for name, _, _ in opponent_candidates]
print("Opponent pool:", opponent_pool)
print("Weights:", opponent_weights)


## 4. 训练循环（小规模示例）
- 从对手池随机采样对手
- 收集 rollout（玩家1视角）
- PPO 更新


In [None]:
agent = PPOAgent()

TOTAL_UPDATES = 100   # 可调大以获得更强效果
ROLLOUT_STEPS = 512   # 可调大
LOG_INTERVAL = 10

reward_log = []

for update in range(1, TOTAL_UPDATES + 1):
    opp_fn = random.choices(opponent_fns, weights=opponent_weights, k=1)[0]
    batch = agent.generate_rollout(opp_fn, ROLLOUT_STEPS)
    metrics = agent.update(batch)
    reward_log.append(batch.returns.mean().item())

    if update % LOG_INTERVAL == 0:
        avg_ret = float(np.mean(reward_log[-LOG_INTERVAL:]))
        # agent.update may return a float (old API) or a metrics dict; normalize for logging
        if isinstance(metrics, dict):
            loss_val = metrics.get("loss", metrics.get("total_loss", 0.0))
            policy = metrics.get("policy_loss", float("nan"))
            value = metrics.get("value_loss", float("nan"))
            ent = metrics.get("entropy", float("nan"))
            kl = metrics.get("approx_kl", float("nan"))
            clip = metrics.get("clip_frac", float("nan"))
        else:
            loss_val = float(metrics)
            policy = value = ent = kl = clip = float("nan")

        print(
            f"Update {update}/{TOTAL_UPDATES} "
            f"| loss {loss_val:.3f} "
            f"| policy {policy:.3f} "
            f"| value {value:.3f} "
            f"| ent {ent:.3f} "
            f"| kl {kl:.4f} "
            f"| clip {clip:.3f} "
            f"| avg_ret {avg_ret:.3f}"
        )

# 保存模型
ppo_path = Path('/kaggle/working/ppo_frozen_mix.pth')
torch.save(agent.model.state_dict(), ppo_path)
print("Saved", ppo_path)


## 5. 简单评估（可选）
对少量局数做 smoke test，防止训练发散。


In [None]:
def play_one(policy_fn, opp_fn, games=20):
    wins = 0
    for g in range(games):
        board = [0]*(ppo_config.ROWS*ppo_config.COLUMNS)
        # 轮流先手：偶数局 PPO 先手(标记1)，奇数局对手先手(标记1)，PPO 用标记2
        ppo_first = (g % 2 == 0)
        ppo_mark = 1 if ppo_first else 2
        current = 1
        while True:
            if current == ppo_mark:
                action = policy_fn(board, ppo_mark)
            else:
                action = opp_fn(board, current)
            board = make_move(board, action, current)
            done, winner = is_terminal(board)
            if done:
                if winner == ppo_mark:
                    wins += 1
                break
            current = 3 - current
    return wins / games

def ppo_policy(board, mark):
    state = encode_state(board, mark)
    state_t = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits, _ = agent.model(state_t)
        valid = get_valid_moves(board)
        mask = torch.full_like(logits, float('-inf'))
        mask[0, valid] = 0
        logits = logits + mask
        probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
    return int(np.random.choice(len(probs), p=probs))

print("PPO vs random:", play_one(ppo_policy, random_policy, games=20))
print("PPO vs negamax_simple:", play_one(ppo_policy, lambda b,m: negamax_simple_policy(b,m,depth=4), games=20))
print("PPO vs negamax_kaggle:", play_one(ppo_policy, lambda b,m: negamax_kaggle_policy(b,m,depth=4), games=20))
if frozen_dqn:
    print("PPO vs frozen DQN:", play_one(ppo_policy, frozen_dqn_policy, games=20))
if frozen_az:
    print("PPO vs frozen AlphaZero:", play_one(ppo_policy, frozen_alphazero_policy, games=20))


## 7. 打包提交到 Kaggle（生成 agent.py + 权重 zip）
- 使用训练好的 `ppo_frozen_mix.pth`
- 生成最小可提交的 `agent.py`，一起打包成 zip 上传到竞赛


In [None]:
# 已不再生成 Kaggle 竞赛提交包，直接上传 Hugging Face 模型权重
print("Skip Kaggle submission packaging; use HF upload cell below.")


## 6. 上传到 Hugging Face（可选）
需提供环境变量 `HUGGINGFACE_HUB_TOKEN`。


In [None]:
from huggingface_hub import HfApi, create_repo

# 需先在上方运行 `notebook_login()` 完成认证
to_upload = ppo_path  # 直接上传训练得到的 .pth 权重
hub_model_id = "your-username/connectx-ppo-frozen"  # 替换为你的仓库 ID
create_repo(hub_model_id, exist_ok=True)

api = HfApi()
api.upload_file(
    path_or_fileobj=str(to_upload),
    path_in_repo="ppo_frozen_mix.pth",
    repo_id=hub_model_id,
)
print("Uploaded .pth to HF repo:", hub_model_id)


## 8. 从 Hugging Face 拉取刚上传的权重并做快速验证


In [None]:
from huggingface_hub import hf_hub_download

# 替换为你实际上传的 HF 仓库 ID
hf_repo_id = "your-username/connectx-ppo-frozen"

# 下载权重到本地缓存
hf_ckpt = hf_hub_download(repo_id=hf_repo_id, filename="ppo_frozen_mix.pth")
print("Downloaded from HF:", hf_ckpt)

# 加载权重到新建的 PPOAgent 并做烟雾测试
agent_hf = PPOAgent()
agent_hf.model.load_state_dict(torch.load(hf_ckpt, map_location=DEVICE))
agent_hf.model.eval()


def ppo_policy_hf(board, mark):
    state = encode_state(board, mark)
    state_t = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits, _ = agent_hf.model(state_t)
        valid = get_valid_moves(board)
        mask = torch.full_like(logits, float('-inf'))
        mask[0, valid] = 0
        logits = logits + mask
        probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
    return int(np.random.choice(len(probs), p=probs))

print("HF model vs random:", play_one(ppo_policy_hf, random_policy, games=10))
print("HF model vs negamax_simple:", play_one(ppo_policy_hf, lambda b,m: negamax_simple_policy(b,m,depth=4), games=10))
