# PPO Training with Adaptive Opponent Sampling (win-rate prioritized)

基于 `ppo_frozen_opponents.ipynb` 的变体：对手池按实时胜率动态加权，优先挑战当前胜率较低的对手，其余流程保持一致。


## 0. 环境与依赖
- 需 Kaggle GPU 环境
- 依赖：`kaggle-environments`, `huggingface_hub`



In [None]:
!git clone https://github.com/mogoo7zn/Kaggle-ConnectX.git
%cd Kaggle-ConnectX


In [None]:
%pip install -q kaggle-environments huggingface_hub


In [None]:
import os
import sys
import random
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
import importlib

# 让项目包可被导入：优先常见 Kaggle 路径，否则尝试当前/父目录
candidates = [
    Path('/kaggle/working/Kaggle-ConnectX'),
    Path('/kaggle/working'),
    Path.cwd(),
    Path.cwd().parent,
]
repo_root = None
for c in candidates:
    if (c / 'agents').exists():
        repo_root = c
        break
if repo_root is None:
    repo_root = Path('.').resolve()

os.chdir(repo_root)
sys.path.insert(0, str(repo_root))
importlib.invalidate_caches()
print("Repo root:", repo_root)
print("CWD:", Path.cwd())

from agents.ppo.ppo_agent import PPOAgent
from agents.ppo.ppo_config import ppo_config
from agents.base.utils import get_valid_moves, get_negamax_move, encode_state, make_move, is_terminal
from agents.dqn.dqn_agent import DQNAgent

# 可选：AlphaZero 载入（需要与你队友的权重匹配）
try:
    from agents.alphazero.az_model import create_alphazero_model
except Exception:
    import traceback
    traceback.print_exc()
    create_alphazero_model = None

DEVICE = ppo_config.DEVICE
print("Using device:", DEVICE)



In [None]:
# 可在 notebook 内直接定义/替换模型
import torch.nn as nn
from agents.ppo import ppo_model

class ActorCriticCustom(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 128, 3, padding=1)
        self.conv2 = nn.Conv2d(128, 256, 3, padding=1)
        conv_out = ppo_config.ROWS * ppo_config.COLUMNS * 256
        self.fc = nn.Linear(conv_out, 512)
        self.policy = nn.Linear(512, ppo_config.COLUMNS)
        self.value = nn.Linear(512, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        logits = self.policy(x)
        value = self.value(x)
        return logits, value

# 覆盖原有工厂方法（如不需要可注释掉）
ppo_model.ActorCritic = ActorCriticCustom
ppo_model.make_model = lambda: ActorCriticCustom().to(ppo_config.DEVICE)

print("Using custom ActorCritic with larger capacity:", ActorCriticCustom())


In [None]:
from pathlib import Path
import requests

# 直接填入公开仓库的 resolve 链接
DQN_URL = "https://huggingface.co/mogoo7zn/Kaggle-ConnectX/resolve/main/DQN-base.pth"
AZ_URL  = "https://huggingface.co/mogoo7zn/Kaggle-ConnectX/resolve/main/alpha-zero-medium.pth"

ckpt_dir = Path('/kaggle/working/checkpoints')
ckpt_dir.mkdir(parents=True, exist_ok=True)

def download(url, out_path):
    if not url:
        print(f"[skip] empty url for {out_path.name}")
        return None
    print(f"Downloading {url} -> {out_path}")
    r = requests.get(url)
    r.raise_for_status()
    with open(out_path, 'wb') as f:
        f.write(r.content)
    return out_path

ckpt_dqn = download(DQN_URL, ckpt_dir/'dqn_frozen.pth')
ckpt_az  = download(AZ_URL,  ckpt_dir/'alphazero_frozen.pth')


In [None]:
frozen_dqn = None
if ckpt_dqn and ckpt_dqn.exists():
    frozen_dqn = DQNAgent()
    frozen_dqn.load_model(str(ckpt_dqn))
    print("Loaded frozen DQN.")
else:
    print("No DQN checkpoint provided.")

frozen_az = None
if ckpt_az and ckpt_az.exists() and create_alphazero_model:
    try:
        frozen_az = create_alphazero_model('full')  # 如权重是轻量版改为 'light'
        state = torch.load(ckpt_az, map_location=DEVICE)
        if isinstance(state, dict) and 'model_state_dict' in state:
            state = state['model_state_dict']
        frozen_az.load_state_dict(state, strict=False)
        frozen_az.eval()
        print("Loaded frozen AlphaZero model.")
    except Exception as e:
        print("Failed to load AlphaZero model:", e)
        frozen_az = None
else:
    print("No AlphaZero checkpoint provided or loader unavailable.")


In [None]:
from huggingface_hub import hf_hub_download

# 额外冻结 PPO 对手（来自个人仓库）
hf_ppo_sources = [
    ("ppo_mix", "tigerxin024/connectx-ppo-frozen", "ppo_frozen_mix.pth"),
    ("ppo_mix_v16", "tigerxin024/connectx-ppo-frozen", "ppo_frozen_mix_v1.6.pth"),
]

frozen_ppo_models = {}

for name, repo_id, filename in hf_ppo_sources:
    try:
        ckpt = hf_hub_download(repo_id=repo_id, filename=filename)
        model = PPOAgent()
        state = torch.load(ckpt, map_location=DEVICE)
        if isinstance(state, dict) and "state_dict" in state:
            state = state["state_dict"]
        model.model.load_state_dict(state, strict=False)
        model.model.eval()
        frozen_ppo_models[name] = model
        print(f"Loaded frozen PPO from HF: {name} ({repo_id}/{filename})")
    except Exception as e:
        print(f"Failed to load HF PPO {name}: {e}")


def make_frozen_ppo_policy(model: PPOAgent):
    def policy(board, mark):
        state = encode_state(board, mark)
        state_t = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            logits, _ = model.model(state_t)
            valid = get_valid_moves(board)
            mask = torch.full_like(logits, float('-inf'))
            mask[0, valid] = 0
            logits = logits + mask
            probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
        return int(np.random.choice(len(probs), p=probs))
    return policy



In [None]:
from types import SimpleNamespace
from kaggle_environments.envs.connectx.connectx import negamax_agent

# 定义固定策略对手
def random_policy(board, mark):
    moves = get_valid_moves(board)
    return random.choice(moves) if moves else 0

def negamax_simple_policy(board, mark, depth=4):
    return get_negamax_move(board, mark, depth=depth)

def negamax_kaggle_policy(board, mark, depth=4):
    obs = SimpleNamespace(board=board, mark=mark)
    cfg = SimpleNamespace(rows=ppo_config.ROWS, columns=ppo_config.COLUMNS, inarow=ppo_config.INAROW,
                          timeout=1, actTimeout=1, depth=depth)
    return int(negamax_agent(obs, cfg))

def frozen_dqn_policy(board, mark):
    assert frozen_dqn is not None
    return frozen_dqn.select_action(board, mark, epsilon=0.0)

def frozen_alphazero_policy(board, mark):
    assert frozen_az is not None
    state = encode_state(board, mark)
    state_t = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits, _ = frozen_az(state_t)
        probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
    valid = get_valid_moves(board)
    masked = np.full_like(probs, -1e9)
    masked[valid] = probs[valid]
    return int(masked.argmax())

# 对手表 + 统计，初始胜率视为 0.5（wins=1,games=2 防止除0）
opponent_registry = [
    ("negamax_simple", negamax_simple_policy),
    ("negamax_kaggle", negamax_kaggle_policy),
    ("random", random_policy),
]
if frozen_dqn:
    opponent_registry.append(("frozen_dqn", frozen_dqn_policy))
if frozen_az:
    opponent_registry.append(("frozen_az", frozen_alphazero_policy))
if frozen_ppo_models:
    for name, model in frozen_ppo_models.items():
        opponent_registry.append((f"hf_{name}", make_frozen_ppo_policy(model)))

opponent_stats = {name: {"wins": 1, "games": 2} for name, _ in opponent_registry}


def sample_opponent(stats):
    names = list(stats.keys())
    win_rates = np.array([v["wins"] / v["games"] for v in stats.values()], dtype=float)
    # 越难打（胜率低）权重越高，加上平滑项避免全 0
    difficulties = 1.0 - win_rates
    probs = difficulties + 1e-3
    probs = probs / probs.sum()
    idx = np.random.choice(len(names), p=probs)
    return names[idx], probs


def record_result(name, win_rate_batch):
    stats = opponent_stats[name]
    games_add = 1
    stats["games"] += games_add
    stats["wins"] += win_rate_batch * games_add


print("Opponents:", list(opponent_stats.keys()))


In [None]:
agent = PPOAgent()

TOTAL_UPDATES = 100
ROLLOUT_STEPS = 1024
LOG_INTERVAL = 10
EVAL_GAMES = 6

reward_log = []

# 简单评估函数：PPO 始终用 mark=1 先手

def eval_vs(opp_fn, games=EVAL_GAMES):
    wins = 0
    for g in range(games):
        board = [0] * (ppo_config.ROWS * ppo_config.COLUMNS)
        current = 1
        while True:
            if current == 1:
                action, _, _ = agent.select_action(board, current)
            else:
                action = opp_fn(board, current)
            board = make_move(board, action, current)
            done, winner = is_terminal(board)
            if done:
                if winner == 1:
                    wins += 1
                break
            current = 3 - current
    return wins / games


for update in range(1, TOTAL_UPDATES + 1):
    opp_name, probs = sample_opponent(opponent_stats)
    opp_fn = dict(opponent_registry)[opp_name]

    batch = agent.generate_rollout(opp_fn, ROLLOUT_STEPS)
    metrics = agent.update(batch)
    reward_log.append(batch.returns.mean().item())

    # 训练后立刻小样本评估该对手，更新胜率
    win_rate_batch = eval_vs(opp_fn, games=EVAL_GAMES)
    record_result(opp_name, win_rate_batch)

    if update % LOG_INTERVAL == 0:
        avg_ret = float(np.mean(reward_log[-LOG_INTERVAL:]))
        if isinstance(metrics, dict):
            loss_val = metrics.get("loss", metrics.get("total_loss", 0.0))
            policy = metrics.get("policy_loss", float("nan"))
            value = metrics.get("value_loss", float("nan"))
            ent = metrics.get("entropy", float("nan"))
            kl = metrics.get("approx_kl", float("nan"))
            clip = metrics.get("clip_frac", float("nan"))
        else:
            loss_val = float(metrics)
            policy = value = ent = kl = clip = float("nan")

        win_rates = {n: round(v["wins"] / v["games"], 3) for n, v in opponent_stats.items()}
        print(
            f"Update {update}/{TOTAL_UPDATES} "
            f"| opp {opp_name} "
            f"| loss {loss_val:.3f} "
            f"| avg_ret {avg_ret:.3f} "
            f"| win_rates {win_rates} "
            f"| probs {[round(p,3) for p in probs.tolist()]}"
        )

# 保存模型
ppo_path = Path('/kaggle/working/ppo_frozen_mix_adaptive.pth')
torch.save(agent.model.state_dict(), ppo_path)
print("Saved", ppo_path)



In [None]:
# 评估：训练始终先手（上方已满足）。当模型处于后手时，改用固定求和策略尽量逼平。
# 这里选用 negamax_kaggle 作为后手保平策略，可按需更换。
draw_policy = lambda board, mark: negamax_kaggle_policy(board, mark, depth=4)


def play_one(policy_first_fn, opp_fn, games=20, alt_first=True, draw_on_second=None):
    wins = 0
    draws = 0
    for g in range(games):
        board = [0]*(ppo_config.ROWS*ppo_config.COLUMNS)
        # 轮流先手：偶数局 PPO 先手(标记1)，奇数局对手先手(标记1)，PPO 用标记2
        ppo_first = (g % 2 == 0) if alt_first else True
        ppo_mark = 1 if ppo_first else 2
        current = 1
        while True:
            if current == ppo_mark:
                action = policy_first_fn(board, ppo_mark)
            else:
                if draw_on_second is not None and ppo_mark == 2:
                    action = draw_on_second(board, current)
                else:
                    action = opp_fn(board, current)
            board = make_move(board, action, current)
            done, winner = is_terminal(board)
            if done:
                if winner == ppo_mark:
                    wins += 1
                elif winner is None or winner == 0:
                    draws += 1
                break
            current = 3 - current
    return wins / games, draws / games


def ppo_policy(board, mark):
    state = encode_state(board, mark)
    state_t = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits, _ = agent.model(state_t)
        valid = get_valid_moves(board)
        mask = torch.full_like(logits, float('-inf'))
        mask[0, valid] = 0
        logits = logits + mask
        probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
    return int(np.random.choice(len(probs), p=probs))

# 纯模型（交替先后）
wr, dr = play_one(ppo_policy, random_policy, games=20)
print(f"PPO vs random | win {wr:.3f} | draw {dr:.3f}")
wr, dr = play_one(ppo_policy, lambda b,m: negamax_simple_policy(b,m,depth=4), games=20)
print(f"PPO vs negamax_simple | win {wr:.3f} | draw {dr:.3f}")
wr, dr = play_one(ppo_policy, lambda b,m: negamax_kaggle_policy(b,m,depth=4), games=20)
print(f"PPO vs negamax_kaggle | win {wr:.3f} | draw {dr:.3f}")
if frozen_dqn:
    wr, dr = play_one(ppo_policy, frozen_dqn_policy, games=20)
    print(f"PPO vs frozen DQN | win {wr:.3f} | draw {dr:.3f}")
if frozen_az:
    wr, dr = play_one(ppo_policy, frozen_alphazero_policy, games=20)
    print(f"PPO vs frozen AlphaZero | win {wr:.3f} | draw {dr:.3f}")

# 模型先手、后手用 draw_policy（目的：后手保平）
wr, dr = play_one(ppo_policy, random_policy, games=20, draw_on_second=draw_policy)
print(f"(1st model / 2nd draw_policy) vs random | win {wr:.3f} | draw {dr:.3f}")
wr, dr = play_one(ppo_policy, lambda b,m: negamax_simple_policy(b,m,depth=4), games=20, draw_on_second=draw_policy)
print(f"(1st model / 2nd draw_policy) vs negamax_simple | win {wr:.3f} | draw {dr:.3f}")
wr, dr = play_one(ppo_policy, lambda b,m: negamax_kaggle_policy(b,m,depth=4), games=20, draw_on_second=draw_policy)
print(f"(1st model / 2nd draw_policy) vs negamax_kaggle | win {wr:.3f} | draw {dr:.3f}")
if frozen_dqn:
    wr, dr = play_one(ppo_policy, frozen_dqn_policy, games=20, draw_on_second=draw_policy)
    print(f"(1st model / 2nd draw_policy) vs frozen DQN | win {wr:.3f} | draw {dr:.3f}")
if frozen_az:
    wr, dr = play_one(ppo_policy, frozen_alphazero_policy, games=20, draw_on_second=draw_policy)
    print(f"(1st model / 2nd draw_policy) vs frozen AlphaZero | win {wr:.3f} | draw {dr:.3f}")



In [None]:
from huggingface_hub import HfApi, create_repo

# 需先在上方运行 `notebook_login()` 完成认证
to_upload = ppo_path
hub_model_id = "your-username/connectx-ppo-adaptive"  # 替换为你的仓库 ID
create_repo(hub_model_id, exist_ok=True)

api = HfApi()
api.upload_file(
    path_or_fileobj=str(to_upload),
    path_in_repo=to_upload.name,
    repo_id=hub_model_id,
)
print("Uploaded .pth to HF repo:", hub_model_id)



In [None]:
from huggingface_hub import hf_hub_download

hf_repo_id = hub_model_id  # 或者填你上传的仓库
hf_ckpt = hf_hub_download(repo_id=hf_repo_id, filename=to_upload.name)
print("Downloaded from HF:", hf_ckpt)

