In [1]:
# --- imports / setup ---
import os
import itertools
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from C4.connect4_env import Connect4Env
from C4.fast_connect4_lookahead import Connect4Lookahead
import torch
from C4.CNet192 import CNet192, load_cnet192

pd.set_option("display.max_columns", 200)


# ENV reward hybrid sweep (practical)

This notebook evaluates **environment reward knob variants** by actually playing games in `Connect4Env`.

Two policy options:
- **greedy**: picks the move with the highest *immediate* env reward (fast sanity check).
- **ppo**: uses your PPO/CNet checkpoint (optional probe for reward alignment).

Metrics per opponent:
- `PPG_*` = points per game (win=1, draw=0.5, loss=0).
- `RET_*` = sum of env rewards **only on A's turns**.
- `NETRET_*` = A-centric net return = (A's return) - (B's return).
- `GAP_*` / `GAPDEF_*` = difference of average `NETRET` between A's wins vs A's losses (0 if you have only wins or only losses).
- `PLIES_*` = average game length.


In [2]:
# --- base knobs (single source) ---
# List knobs you want to show in the table. Add/remove as needed.
KNOB_KEYS = [
    "VERT_BLOCK_MUL_3"
]

BASE_KNOBS = {k: getattr(Connect4Env, k) for k in KNOB_KEYS if hasattr(Connect4Env, k)}
missing = [k for k in KNOB_KEYS if k not in BASE_KNOBS]
if missing: print("WARNING: Connect4Env missing keys:", missing)

BASE_KNOBS


{'VERT_BLOCK_MUL_3': 1.25}

In [3]:
# --- define sweep grid ---
# Only include keys you want to vary. Everything else uses BASE_KNOBS.
# Example: sweep one knob
GRID = {
    "VERT_BLOCK_MUL_3": [1.25, 1.5, 2, 3],
}

def make_variants(base: dict, grid: dict):
    keys = list(grid.keys())
    vals = [list(grid[k]) for k in keys]
    variants = []
    for i, combo in enumerate(itertools.product(*vals)):
        vid = f"V{i:03d}"
        cfg = dict(base)
        for k, v in zip(keys, combo):
            cfg[k] = v
        variants.append((vid, cfg))
    return variants

variants = make_variants(BASE_KNOBS, GRID)

# handy: show a compact variant map for the knobs you are actually sweeping
_variant_map = []
for vid, cfg in variants:
    row = {"variant": vid}
    for k in GRID.keys():
        row[k] = cfg[k]
    _variant_map.append(row)

df_variant_map = pd.DataFrame(_variant_map)
df_variant_map


Unnamed: 0,variant,VERT_BLOCK_MUL_3
0,V000,1.25
1,V001,1.5
2,V002,2.0
3,V003,3.0


In [4]:
# --- helpers: per-variant env instantiation (avoid global/class bleed) ---

CENTER_ORDER = [3, 4, 2, 5, 1, 6, 0]

def make_env_with_knobs(cfg: dict) -> Connect4Env:
    """Create a fresh env and apply knobs as instance attributes."""
    env = Connect4Env()

    # Apply all knobs as instance attrs (instance overrides class defaults cleanly).
    for k, v in cfg.items():
        setattr(env, k, v)

    # If WIN_REWARD changed but MAX_REWARD wasn't explicitly overridden, keep the tie.
    if ("WIN_REWARD" in cfg) and ("MAX_REWARD" not in cfg):
        env.MAX_REWARD = float(env.WIN_REWARD) * 0.35

    # If CENTER_WEIGHTS were changed by the user, keep _CENTER_WEIGHTS_ARR in sync.
    # (Also harmless if unchanged.)
    cw = getattr(env, "CENTER_WEIGHTS", [1.0] * 7)
    env._CENTER_WEIGHTS_ARR = np.asarray(cw, dtype=np.float32)

    env.reset()
    return env


# --- policies ---

def random_policy(env: Connect4Env, rng: np.random.Generator) -> int:
    legal = env.available_actions()
    return int(rng.choice(legal)) if legal else 0

class LookaheadPolicy:
    def __init__(self, depth: int):
        self.depth = int(depth)
        self.la = Connect4Lookahead()

    def act(self, env: Connect4Env) -> int:
        return int(self.la.n_step_lookahead(env.board, env.current_player, depth=self.depth))


def _snapshot_env(env: Connect4Env):
    return (
        int(env._pos1), int(env._pos2), int(env._mask),
        env._heights.copy(),
        env.board.copy(),
        int(env.current_player),
        bool(env.done),
        env.winner,
        int(env.ply),
    )

def _restore_env(env: Connect4Env, snap):
    pos1, pos2, mask, heights, board, cur, done, winner, ply = snap
    env._pos1 = np.uint64(pos1)
    env._pos2 = np.uint64(pos2)
    env._mask = np.uint64(mask)
    env._heights = heights.copy()
    env.board = board.copy()
    env.current_player = int(cur)
    env.done = bool(done)
    env.winner = winner
    env.ply = int(ply)

def greedy_reward_policy(env: Connect4Env, rng: np.random.Generator) -> int:
    """Pick action with max immediate env reward (ties: center order)."""
    legal = env.available_actions()
    if not legal:
        return 0

    best_r = -1e100
    best = []
    snap0 = _snapshot_env(env)

    for c in legal:
        _restore_env(env, snap0)
        _, r, _ = env.step(int(c))
        if r > best_r + 1e-12:
            best_r = float(r)
            best = [int(c)]
        elif abs(float(r) - best_r) <= 1e-12:
            best.append(int(c))

    _restore_env(env, snap0)

    if len(best) == 1:
        return best[0]

    # tie-break by center preference order, then smallest index
    for c in CENTER_ORDER:
        if c in best:
            return int(c)
    return int(min(best))


In [5]:
# --- game loop + evaluation ---

def play_one_game(env: Connect4Env, policyA, policyB, A_mark: int = 1, seed: int = 0, max_plies: int = 200):
    """Play one game. Returns dict with outcome + returns."""
    rng = np.random.default_rng(seed)

    env.reset()
    env.current_player = int(A_mark)  # who starts

    # returns from mover POV per player mark
    ret = {1: 0.0, -1: 0.0}
    plies = 0

    while (not env.done) and plies < max_plies:
        mover = int(env.current_player)
        if mover == A_mark:
            a = policyA(env, rng) if callable(policyA) else policyA.act(env)
        else:
            a = policyB(env, rng) if callable(policyB) else policyB.act(env)

        _, r, done = env.step(int(a))
        ret[mover] += float(r)
        plies += 1
        if done:
            break

    # winner is from env.current_player at terminal time? env sets winner to mark of player who made last move.
    # In env, winner == 1 or -1 or 0(draw)
    winner = env.winner if env.done else 0

    # points for A: win=1, draw=0.5, loss=0
    if winner == A_mark:
        ppg = 1.0
    elif winner == 0:
        ppg = 0.5
    else:
        ppg = 0.0

    # A-only return: only count A's mover rewards
    ret_A = float(ret[A_mark])

    # Net return: A - opponent (both from their own POV)
    net_ret_A = float(ret[A_mark] - ret[-A_mark])

    return dict(ppg=ppg, ret=ret_A, net_ret=net_ret_A, plies=float(plies), winner=int(winner))


def eval_matchup(cfg: dict, policy_kind: str, opponent, N: int = 50, seed: int = 1234):
    """Evaluate one cfg against one opponent. Alternates starting player."""
    env = make_env_with_knobs(cfg)

    rng = np.random.default_rng(seed)
    seeds = rng.integers(0, 2**31-1, size=N, dtype=np.int64)

    # choose policy A
    if policy_kind == "greedy":
        polA = greedy_reward_policy
    elif policy_kind == "random":
        polA = random_policy
    else:
        raise ValueError(f"Unknown policy_kind={policy_kind!r} (expected 'greedy' or 'random' here)")

    polB = opponent

    rows = []
    for i in range(N):
        # alternate starts
        A_mark = 1 if (i % 2 == 0) else -1
        out = play_one_game(env, polA, polB, A_mark=A_mark, seed=int(seeds[i]))
        rows.append(out)

    df = pd.DataFrame(rows)

    ppg = float(df["ppg"].mean())
    ret = float(df["ret"].mean())
    netret = float(df["net_ret"].mean())
    plies = float(df["plies"].mean())

    # GAP: difference between avg netret in A-wins vs A-losses (0 if not both exist)
    net_w = df.loc[df["ppg"] == 1.0, "net_ret"].to_numpy()
    net_l = df.loc[df["ppg"] == 0.0, "net_ret"].to_numpy()
    if (len(net_w) > 0) and (len(net_l) > 0):
        gap = float(np.mean(net_w) - np.mean(net_l))
        gapdef = 1
    else:
        gap = 0.0
        gapdef = 0

    return dict(PPG=ppg, RET=ret, NETRET=netret, GAP=gap, GAPDEF=gapdef, PLIES=plies)


In [6]:
# --- run sweep ---

# Configure opponents here.
# NOTE: for max performance, use LookaheadPolicy (bitboard/numba).
opponents = [
    ("Random", random_policy),
    ("L1", LookaheadPolicy(1)),
    ("L3", LookaheadPolicy(3)),
    ("L5", LookaheadPolicy(5)),
    ("L7", LookaheadPolicy(7)),
]

N_GAMES = 100
SEED = 666

def run_sweep(policy_kind: str = "greedy") -> pd.DataFrame:
    rows = []
    for vid, cfg in tqdm(variants, desc=f"sweep[{policy_kind}]"):
        row = {"variant": vid}
        # include full knob vector so the table is self-explanatory
        for k in BASE_KNOBS.keys():
            row[k] = cfg.get(k, BASE_KNOBS[k])

        for name, polB in opponents:
            stats = eval_matchup(cfg, policy_kind=policy_kind, opponent=polB, N=N_GAMES, seed=SEED)
            for kk, vv in stats.items():
                row[f"{kk}_{name}"] = vv

        ppg_cols = [f"PPG_{n}" for n, _ in opponents]
        net_cols = [f"NETRET_{n}" for n, _ in opponents]
        gap_cols = [f"GAP_{n}" for n, _ in opponents]

        row["meta_ppg"] = float(np.mean([row[c] for c in ppg_cols])) if ppg_cols else float("nan")
        row["meta_netret"] = float(np.mean([row[c] for c in net_cols])) if net_cols else float("nan")
        row["meta_gap"] = float(np.mean([row[c] for c in gap_cols])) if gap_cols else float("nan")

        rows.append(row)

    df = pd.DataFrame(rows)
    df = df.sort_values(["meta_ppg", "meta_netret", "meta_gap"], ascending=[False, False, False]).reset_index(drop=True)
    return df


In [7]:
# --- run: fast screen (reward-greedy) ---
df_greedy = run_sweep(policy_kind="greedy")
df_greedy.head(20)


sweep[greedy]:   0%|          | 0/4 [00:00<?, ?it/s]

Unnamed: 0,variant,VERT_BLOCK_MUL_3,PPG_Random,RET_Random,NETRET_Random,GAP_Random,GAPDEF_Random,PLIES_Random,PPG_L1,RET_L1,NETRET_L1,GAP_L1,GAPDEF_L1,PLIES_L1,PPG_L3,RET_L3,NETRET_L3,GAP_L3,GAPDEF_L3,PLIES_L3,PPG_L5,RET_L5,NETRET_L5,GAP_L5,GAPDEF_L5,PLIES_L5,PPG_L7,RET_L7,NETRET_L7,GAP_L7,GAPDEF_L7,PLIES_L7,meta_ppg,meta_netret,meta_gap
0,V003,3.0,1.0,12108.455,13451.8835,0.0,0,9.4,0.5,6814.1,2545.925,22999.95,1,15.5,0.75,7914.75,7664.525,0.0,0,40.5,0.0,1702.275,-9042.775,0.0,0,34.0,0.0,1341.175,-8985.025,0.0,0,20.0,0.45,1126.9067,4599.99
1,V002,2.0,1.0,12107.855,13453.4835,0.0,0,9.4,0.5,6794.1,2525.925,23039.95,1,15.5,0.75,7864.75,7624.525,0.0,0,40.5,0.0,1702.275,-9042.775,0.0,0,34.0,0.0,1331.175,-8995.025,0.0,0,20.0,0.45,1113.2267,4607.99
2,V001,1.5,1.0,12107.555,13454.2835,0.0,0,9.4,0.5,6784.1,2515.925,23059.95,1,15.5,0.75,7839.75,7604.525,0.0,0,40.5,0.0,1702.275,-9042.775,0.0,0,34.0,0.0,1326.175,-9000.025,0.0,0,20.0,0.45,1106.3867,4611.99
3,V000,1.25,1.0,12107.405,13454.6835,0.0,0,9.4,0.5,6779.1,2510.925,23069.95,1,15.5,0.75,7827.25,7594.525,0.0,0,40.5,0.0,1702.275,-9042.775,0.0,0,34.0,0.0,1323.675,-9002.525,0.0,0,20.0,0.45,1102.9667,4613.99


In [8]:
PPO_CKPT_PATH = "PPO_Models/MIX_42b.pt"  # e.g. "PPO_Models/MIX_9.pt"


In [9]:
# --- PPO policy wrapper (uses your CNet192 loader) ---



def load_ppo_model_cnet192(ckpt_path: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model, _ = load_cnet192(ckpt_path, device=device, strict=True)
    model.to(device)
    model.eval()
    return model, device

class PpoPolicy:
    def __init__(self, ckpt_path: str):
        self.model, self.device = load_ppo_model_cnet192(ckpt_path)

    @torch.no_grad()
    def act(self, env: Connect4Env) -> int:
        # env state is (1,6,7) mover-centric; CNet expects (B,C,6,7)
        x = torch.from_numpy(env.get_state()).float().to(self.device)  # (1,6,7)
        x = x.unsqueeze(0)  # (1,1,6,7)

        logits, _ = self.model(x)
        logits = logits.squeeze(0).detach().cpu().numpy().astype(np.float64, copy=False)

        legal = env.available_actions()
        if not legal:
            return 0

        best = max(logits[c] for c in legal)
        tied = [c for c in legal if abs(logits[c] - best) <= 1e-12]

        if len(tied) == 1:
            return int(tied[0])

        for c in CENTER_ORDER:
            if c in tied:
                return int(c)
        return int(min(tied))

def run_sweep_ppo() -> pd.DataFrame:
    if PPO_CKPT_PATH is None:
        raise ValueError("Set PPO_CKPT_PATH first.")
    polA = PpoPolicy(PPO_CKPT_PATH)

    # temporarily swap eval_matchup's policy selection
    def eval_matchup_ppo(cfg: dict, opponent, N: int = 50, seed: int = 1234):
        env = make_env_with_knobs(cfg)
        rng = np.random.default_rng(seed)
        seeds = rng.integers(0, 2**31-1, size=N, dtype=np.int64)

        rows = []
        for i in range(N):
            A_mark = 1 if (i % 2 == 0) else -1
            out = play_one_game(env, polA, opponent, A_mark=A_mark, seed=int(seeds[i]))
            rows.append(out)

        df = pd.DataFrame(rows)
        ppg = float(df["ppg"].mean())
        ret = float(df["ret"].mean())
        netret = float(df["net_ret"].mean())
        plies = float(df["plies"].mean())

        net_w = df.loc[df["ppg"] == 1.0, "net_ret"].to_numpy()
        net_l = df.loc[df["ppg"] == 0.0, "net_ret"].to_numpy()
        if (len(net_w) > 0) and (len(net_l) > 0):
            gap = float(np.mean(net_w) - np.mean(net_l))
            gapdef = 1
        else:
            gap = 0.0
            gapdef = 0

        return dict(PPG=ppg, RET=ret, NETRET=netret, GAP=gap, GAPDEF=gapdef, PLIES=plies)

    rows = []
    for vid, cfg in tqdm(variants, desc="sweep[PPO]"):
        row = {"variant": vid}
        for k in BASE_KNOBS.keys():
            row[k] = cfg.get(k, BASE_KNOBS[k])

        for name, polB in opponents:
            stats = eval_matchup_ppo(cfg, opponent=polB, N=N_GAMES, seed=SEED)
            for kk, vv in stats.items():
                row[f"{kk}_{name}"] = vv

        ppg_cols = [f"PPG_{n}" for n, _ in opponents]
        net_cols = [f"NETRET_{n}" for n, _ in opponents]
        gap_cols = [f"GAP_{n}" for n, _ in opponents]

        row["meta_ppg"] = float(np.mean([row[c] for c in ppg_cols])) if ppg_cols else float("nan")
        row["meta_netret"] = float(np.mean([row[c] for c in net_cols])) if net_cols else float("nan")
        row["meta_gap"] = float(np.mean([row[c] for c in gap_cols])) if gap_cols else float("nan")

        rows.append(row)

    df = pd.DataFrame(rows)
    df = df.sort_values(["meta_ppg", "meta_netret", "meta_gap"], ascending=[False, False, False]).reset_index(drop=True)
    return df

# Example:
df_ppo = run_sweep_ppo()
df_ppo.head(20)


sweep[PPO]:   0%|          | 0/4 [00:00<?, ?it/s]

Unnamed: 0,variant,VERT_BLOCK_MUL_3,PPG_Random,RET_Random,NETRET_Random,GAP_Random,GAPDEF_Random,PLIES_Random,PPG_L1,RET_L1,NETRET_L1,GAP_L1,GAPDEF_L1,PLIES_L1,PPG_L3,RET_L3,NETRET_L3,GAP_L3,GAPDEF_L3,PLIES_L3,PPG_L5,RET_L5,NETRET_L5,GAP_L5,GAPDEF_L5,PLIES_L5,PPG_L7,RET_L7,NETRET_L7,GAP_L7,GAPDEF_L7,PLIES_L7,meta_ppg,meta_netret,meta_gap
0,V000,1.25,1.0,12078.3425,13521.8115,0.0,0,9.64,0.5,6712.9,2561.8,22968.2,1,14.5,0.0,1290.075,-8979.875,0.0,0,20.0,0.25,2157.65,-3067.375,0.0,0,37.0,0.0,1416.65,-8959.475,0.0,0,24.0,0.35,-984.6227,4593.64
1,V001,1.5,1.0,12078.4425,13521.7115,0.0,0,9.64,0.5,6715.4,2561.8,22968.2,1,14.5,0.0,1290.075,-8982.375,0.0,0,20.0,0.25,2160.15,-3067.375,0.0,0,37.0,0.0,1416.65,-8959.475,0.0,0,24.0,0.35,-985.1427,4593.64
2,V002,2.0,1.0,12078.6425,13521.5115,0.0,0,9.64,0.5,6720.4,2561.8,22968.2,1,14.5,0.0,1290.075,-8987.375,0.0,0,20.0,0.25,2165.15,-3067.375,0.0,0,37.0,0.0,1416.65,-8959.475,0.0,0,24.0,0.35,-986.1827,4593.64
3,V003,3.0,1.0,12079.0425,13521.1115,0.0,0,9.64,0.5,6730.4,2561.8,22968.2,1,14.5,0.0,1290.075,-8997.375,0.0,0,20.0,0.25,2175.15,-3067.375,0.0,0,37.0,0.0,1416.65,-8959.475,0.0,0,24.0,0.35,-988.2627,4593.64
