# PPO Training v2 - Diverse Opponent Pool

> 本 Notebook 为 v2 版本，在不修改原始 `ppo_frozen_opponents.ipynb` 的前提下，增加“多样化对手池”机制，提升 PPO 的泛化与鲁棒性。


## 0. 环境准备
- 克隆仓库，进入根目录
- 安装必要依赖（kaggle 环境通常自带 torch，如缺少可按需安装）


In [None]:
!git clone https://github.com/mogoo7zn/Kaggle-ConnectX.git
%cd Kaggle-ConnectX
!pip install -q kaggle-environments


## 1. 导入与路径设置


In [None]:
import os, sys, random, importlib
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F

# 确定仓库根
candidates = [Path('/kaggle/working/Kaggle-ConnectX'), Path.cwd(), Path.cwd().parent]
repo_root = next((c for c in candidates if (c/'agents').exists()), Path('.').resolve())
os.chdir(repo_root)
sys.path.insert(0, str(repo_root))
importlib.invalidate_caches()
print('Repo root:', repo_root)

from agents.ppo.ppo_agent import PPOAgent
from agents.ppo.ppo_config import ppo_config
from agents.base.utils import get_valid_moves, get_negamax_move, encode_state, make_move, is_terminal
from agents.dqn.dqn_agent import DQNAgent

# AlphaZero loader
try:
    from agents.alphazero.az_model import create_alphazero_model
except Exception:
    create_alphazero_model = None
    print('AlphaZero loader unavailable')

DEVICE = ppo_config.DEVICE
print('Using device:', DEVICE)


## 2. 加载冻结对手（DQN / AlphaZero，可选）
支持从已下载的 checkpoint 加载。


In [None]:
from pathlib import Path

ckpt_dir = Path('/kaggle/working/checkpoints')
ckpt_dir.mkdir(exist_ok=True, parents=True)

# 在此填入已下载权重路径（若已在工作目录，可直接指定文件名）
CKPT_DQN = ckpt_dir / 'dqn_frozen.pth'
CKPT_AZ  = ckpt_dir / 'alphazero_frozen.pth'

frozen_dqn = None
if CKPT_DQN.exists():
    frozen_dqn = DQNAgent()
    frozen_dqn.load_model(str(CKPT_DQN))
    print('Loaded frozen DQN:', CKPT_DQN)
else:
    print('No DQN checkpoint found:', CKPT_DQN)

frozen_az = None
if CKPT_AZ.exists() and create_alphazero_model:
    try:
        frozen_az = create_alphazero_model('full')  # 如权重为轻量版改 'light'
        state = torch.load(CKPT_AZ, map_location=DEVICE)
        if isinstance(state, dict) and 'model_state_dict' in state:
            state = state['model_state_dict']
        frozen_az.load_state_dict(state, strict=False)
        frozen_az.eval()
        print('Loaded frozen AlphaZero:', CKPT_AZ)
    except Exception as e:
        print('Failed to load AlphaZero:', e)
        frozen_az = None
else:
    print('No AlphaZero checkpoint found or loader unavailable:', CKPT_AZ)



## 3. 对手定义（模块化）
包含随机、negamax、冻结 DQN、冻结 AZ、当前 PPO、自身历史版本、以及 PPO 自博弈低概率采样。


In [None]:
def random_policy(board, mark):
    moves = get_valid_moves(board)
    return random.choice(moves) if moves else 0

def negamax_policy(board, mark, depth=4):
    return get_negamax_move(board, mark, depth=depth)

def frozen_dqn_policy(board, mark):
    assert frozen_dqn is not None, "frozen_dqn not loaded"
    return frozen_dqn.select_action(board, mark, epsilon=0.0)

def frozen_alphazero_policy(board, mark):
    assert frozen_az is not None, "frozen_az not loaded"
    state = encode_state(board, mark)
    state_t = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits, _ = frozen_az(state_t)
        probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
    valid = get_valid_moves(board)
    masked = np.full_like(probs, -1e9)
    masked[valid] = probs[valid]
    return int(masked.argmax())

# 当前 PPO 策略（训练中的 agent，会更新）
def ppo_policy_live(agent):
    def _policy(board, mark):
        state = encode_state(board, mark)
        state_t = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            logits, _ = agent.model(state_t)
            valid = get_valid_moves(board)
            mask = torch.full_like(logits, float('-inf'))
            mask[0, valid] = 0
            logits = logits + mask
            probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
        return int(probs.argmax())
    return _policy

# 历史 PPO 检查点策略（冻结）
def ppo_policy_from_state_dict(state_dict):
    from agents.ppo.ppo_model import make_model
    m = make_model()
    m.load_state_dict(state_dict)
    m.eval()
    def _policy(board, mark):
        state = encode_state(board, mark)
        state_t = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            logits, _ = m(state_t)
            valid = get_valid_moves(board)
            mask = torch.full_like(logits, float('-inf'))
            mask[0, valid] = 0
            logits = logits + mask
            probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
        return int(probs.argmax())
    return _policy



## 4. 对手池组件（OpponentPool）
- 支持按概率采样对手
- 支持定期加入 PPO 历史 checkpoint（冻结）
- 仅 PPO 主体更新，其他对手冻结


In [None]:
class OpponentPool:
    def __init__(self, base_prob=None, snapshot_limit=5):
        # base_prob 定义每类对手的采样概率
        self.base_prob = base_prob or {
            'random': 0.10,
            'negamax': 0.20,
            'frozen_dqn': 0.20 if frozen_dqn else 0.0,
            'frozen_az': 0.20 if frozen_az else 0.0,
            'ppo_self': 0.20,      # 当前 PPO（自博弈）
            'ppo_history': 0.10,   # 历史快照
        }
        self.snapshots = []  # 存储 state_dict
        self.snapshot_limit = snapshot_limit
        self._rebuild()

    def _rebuild(self):
        # 构建可用对手列表
        self.entries = []
        def add(name, fn):
            p = self.base_prob.get(name, 0.0)
            if p > 0:
                self.entries.append((p, name, fn))
        add('random', random_policy)
        add('negamax', lambda b,m: negamax_policy(b,m,depth=4))
        if frozen_dqn:
            add('frozen_dqn', frozen_dqn_policy)
        if frozen_az:
            add('frozen_az', frozen_alphazero_policy)
        add('ppo_self', ppo_policy_live(agent))
        # 历史快照
        for i, sd in enumerate(self.snapshots):
            add(f'ppo_hist_{i}', ppo_policy_from_state_dict(sd))
        # 归一化权重
        total = sum(p for p,_,_ in self.entries)
        self.entries = [(p/total, n, f) for p,n,f in self.entries if total > 0]

    def add_snapshot(self, state_dict):
        self.snapshots.append(state_dict)
        if len(self.snapshots) > self.snapshot_limit:
            self.snapshots.pop(0)
        self._rebuild()

    def sample(self):
        if not self.entries:
            return random_policy
        probs = [p for p,_,_ in self.entries]
        choices = list(range(len(self.entries)))
        idx = np.random.choice(choices, p=probs)
        return self.entries[idx][2], self.entries[idx][1]

# 初始化对手池
def build_pool():
    pool = OpponentPool()
    print('Opponent entries:')
    for p, n, _ in pool.entries:
        print(f'  {n}: {p:.2f}')
    return pool



## 5. 训练循环（加入对手池采样 + 历史快照）
- 每个 update 随机从对手池按概率采样
- 每隔 SNAPSHOT_INTERVAL 将当前 PPO 参数加入历史快照（冻结）
- 只有 PPO 主体更新，其他对手冻结


In [None]:
agent = PPOAgent()
opponent_pool = None

TOTAL_UPDATES = 200    # 可调大
ROLLOUT_STEPS = 512    # 可调大
LOG_INTERVAL = 20
SNAPSHOT_INTERVAL = 50  # 每隔多少次更新保存一次 PPO 快照

reward_log = []

for update in range(1, TOTAL_UPDATES + 1):
    if opponent_pool is None:
        opponent_pool = build_pool()
    opp_fn, opp_name = opponent_pool.sample()
    batch = agent.generate_rollout(opp_fn, ROLLOUT_STEPS)
    loss = agent.update(batch)
    reward_log.append(batch.returns.mean().item())

    # 定期保存 PPO 快照加入对手池
    if update % SNAPSHOT_INTERVAL == 0:
        snapshot_sd = agent.model.state_dict()
        opponent_pool.add_snapshot(snapshot_sd)
        print(f"[Snapshot] added at update {update}, snapshots={len(opponent_pool.snapshots)}")

    if update % LOG_INTERVAL == 0:
        avg_ret = float(np.mean(reward_log[-LOG_INTERVAL:]))
        print(f"Update {update}/{TOTAL_UPDATES} | opp={opp_name} | loss {loss:.3f} | avg_ret {avg_ret:.3f}")

# 保存模型
ppo_path = Path('/kaggle/working/ppo_pool_v2.pth')
torch.save(agent.model.state_dict(), ppo_path)
print("Saved", ppo_path)


## 6. 简单评估（交替先手 + 贪心）
- 交替先后手，避免先手偏置
- 贪心落子，减少随机波动


In [None]:
def play_one(policy_fn, opp_fn, games=20):
    wins = 0
    for g in range(games):
        board = [0]*(ppo_config.ROWS*ppo_config.COLUMNS)
        ppo_mark = 1 if g % 2 == 0 else 2
        opp_mark = 3 - ppo_mark
        current = 1
        while True:
            if current == ppo_mark:
                action = policy_fn(board, current)
            else:
                action = opp_fn(board, current)
            board = make_move(board, action, current)
            done, winner = is_terminal(board)
            if done:
                if winner == ppo_mark:
                    wins += 1
                break
            current = 3 - current
    return wins / games

def ppo_policy_greedy(board, mark):
    state = encode_state(board, mark)
    state_t = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits, _ = agent.model(state_t)
        valid = get_valid_moves(board)
        mask = torch.full_like(logits, float('-inf'))
        mask[0, valid] = 0
        logits = logits + mask
        probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
    return int(probs.argmax())

print("PPO vs random:", play_one(ppo_policy_greedy, random_policy, games=20))
print("PPO vs negamax:", play_one(ppo_policy_greedy, lambda b,m: negamax_policy(b,m,depth=4), games=20))
if frozen_dqn:
    print("PPO vs frozen DQN:", play_one(ppo_policy_greedy, frozen_dqn_policy, games=20))
if frozen_az:
    print("PPO vs frozen AlphaZero:", play_one(ppo_policy_greedy, frozen_alphazero_policy, games=20))


## 7. 打包提交（与 v1 相同，只是文件名不同）
- 输出：`/kaggle/working/ppo_pool_v2.pth` + 生成 `agent.py` + `ppo_submission_v2.zip`
- 仍可按原流程下载 zip 上传竞赛


In [None]:
from textwrap import dedent
import zipfile
import shutil

submit_dir = Path('/kaggle/working/ppo_submit_v2')
submit_dir.mkdir(exist_ok=True, parents=True)
model_out = submit_dir / 'ppo_pool_v2.pth'
shutil.copy(ppo_path, model_out)

agent_code = dedent(f"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

ROWS, COLS, INAROW = 6, 7, 4
DEVICE = torch.device('cpu')

class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc = nn.Linear(ROWS * COLS * 128, 256)
        self.policy = nn.Linear(256, COLS)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        logits = self.policy(x)
        return logits

def encode_state(board, mark):
    board_2d = np.array(board).reshape(ROWS, COLS)
    opp = 3 - mark
    player = (board_2d == mark).astype(np.float32)
    opponent = (board_2d == opp).astype(np.float32)
    valid = np.zeros((ROWS, COLS), dtype=np.float32)
    for c in range(COLS):
        if board_2d[0, c] == 0:
            valid[:, c] = 1.0
    state = np.stack([player, opponent, valid], axis=0)
    return state

def get_valid_moves(board):
    return [c for c in range(COLS) if board[c] == 0]

def agent(obs, config):
    global _model
    if '_model' not in globals():
        _model = ActorCritic().to(DEVICE)
        sd = torch.load('ppo_pool_v2.pth', map_location=DEVICE)
        _model.load_state_dict(sd)
        _model.eval()
    board = obs.board
    mark = obs.mark
    valid = get_valid_moves(board)
    if not valid:
        return 0
    state = torch.from_numpy(encode_state(board, mark)).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = _model(state)
        mask = torch.full_like(logits, float('-inf'))
        mask[0, valid] = 0
        logits = logits + mask
        probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
    return int(probs.argmax())
""")

with open(submit_dir / 'agent.py', 'w') as f:
    f.write(agent_code)

zip_path = Path('/kaggle/working/ppo_submission_v2.zip')
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
    zf.write(submit_dir / 'agent.py', arcname='agent.py')
    zf.write(submit_dir / 'ppo_pool_v2.pth', arcname='ppo_pool_v2.pth')

print('Submission zip created:', zip_path)
!ls -lh /kaggle/working/ppo_submission_v2.zip
