In [None]:
from pathlib import Path
from scripts.build_rym_npz import pgn_to_npz  # adjust import if needed

PGN_PATH = Path("gte.pgn")      # <-- adjust path to your gte.pgn
NPZ_PATH = Path("gte_rym.npz")  # where you want the npz

pgn_to_npz(
    pgn_path=PGN_PATH,
    out_path=NPZ_PATH,
    max_games=None,   # or 1 if you know there is only one game and want to be explicit
    min_rating=400,
    max_rating=2400,
    num_bins=10,
)

NPZ_PATH

In [None]:
import numpy as np

npz_big = np.load("data/rym_2025_jan_apr_tc300+0_bin200_test_shard000.npz")
print("Keys:", npz_big.files)

for k in npz_big.files:
    arr = npz_big[k]
    print(f"{k:8s} shape={arr.shape}, dtype={arr.dtype}")

In [None]:
import numpy as np

npz = np.load(NPZ_PATH)
print("Keys:", npz.files)

for k in npz.files:
    arr = npz[k]
    print(f"{k:8s} shape={arr.shape}, dtype={arr.dtype}")

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd

from scripts.ply_features import NUM_PLANES, plane_label_for_index  # your code


def npz_to_debug_dataframe(npz_path: str | Path, max_plies: int | None = None) -> pd.DataFrame:
    """
    Debug view of an RYM NPZ.

    Each row = one ply.
    Columns:
      - 'i', 'game_id', 'ply_idx', 'y_bin', 'y_elo'
      - one column per plane, named by plane_label_for_index(idx),
        where each cell is a length-64 list (flattened 8×8 vector of 0/1).

    WARNING: only use this for small NPZs (e.g. a single game),
    or with max_plies, because this blows up fast.
    """
    npz_path = Path(npz_path)
    data = np.load(npz_path)

    X = data["X"]        # (N, NUM_PLANES, 8, 8)
    y_bin = data["y_bin"]
    y_elo = data["y_elo"]
    game_id = data["game_id"]
    ply_idx = data["ply_idx"]

    N, C, H, W = X.shape
    assert C == NUM_PLANES, f"Expected {NUM_PLANES} planes, got {C}"

    if max_plies is not None:
        N = min(N, max_plies)

    # Precompute plane labels
    plane_labels = [plane_label_for_index(p) for p in range(NUM_PLANES)]

    rows: list[dict] = []

    for i in range(N):
        planes = X[i]  # (NUM_PLANES, 8, 8)

        row: dict = {
            "i": int(i),
            "game_id": int(game_id[i]),
            "ply_idx": int(ply_idx[i]),
            "y_bin": int(y_bin[i]),
            "y_elo": float(y_elo[i]),
        }

        for p in range(NUM_PLANES):
            label = plane_labels[p]
            # Flatten 8×8 to 64-length vector of ints
            row[label] = planes[p].reshape(-1).astype(int).tolist()

        rows.append(row)

    df = pd.DataFrame(rows)
    return df

In [None]:
df_debug = npz_to_debug_dataframe(NPZ_PATH)  # gte.pgn has far fewer
df_debug.head()

In [None]:
from pathlib import Path
import chess
import chess.pgn

def build_ply_metadata_from_pgn(pgn_path: str | Path) -> dict[tuple[int, int], dict]:
    """
    Parse the PGN and return a mapping:
        (game_id, ply_idx) -> {
            'fen_pre',  # FEN before the move
            'fen_post', # FEN after the move
            'move_uci', # UCI string of the move
        }
    """
    pgn_path = Path(pgn_path)
    meta: dict[tuple[int, int], dict] = {}

    with pgn_path.open("r", encoding="utf-8") as f:
        game_idx = 0
        while True:
            game = chess.pgn.read_game(f)
            if game is None:
                break

            board = game.board()
            ply_idx = 0

            for move in game.mainline_moves():
                fen_pre = board.fen()
                board.push(move)
                fen_post = board.fen()

                meta[(game_idx, ply_idx)] = {
                    "fen_pre": fen_pre,
                    "fen_post": fen_post,
                    "move_uci": move.uci(),
                }

                ply_idx += 1

            game_idx += 1

    return meta

In [None]:
ply_meta = build_ply_metadata_from_pgn(PGN_PATH)
len(ply_meta)

In [None]:
import numpy as np
import chess

def plane_squares_from_npz_plane(plane_2d: np.ndarray) -> list[chess.Square]:
    """
    Convert a (8, 8) 0/1 plane from the NPZ into a list of chess.Square
    indices where the plane == 1.
    """
    assert plane_2d.shape == (8, 8)
    flat = plane_2d.reshape(-1)          # 64-length
    idxs = np.nonzero(flat)[0]          # indices 0..63 where value == 1
    return [chess.SQUARES[i] for i in idxs]

In [None]:
from IPython.display import display, HTML
import chess.svg
import numpy as np

from scripts.ply_features import NUM_PLANES, plane_label_for_index  # your code


def show_npz_planes_for_row(
    npz_path: str | Path,
    ply_meta: dict[tuple[int, int], dict],
    row_index: int,
    size: int = 120,
    show_labels: bool = True,
):
    """
    Visualize one NPZ row (a single ply):

    - Left: main board after the move, with last move highlighted.
    - Right: 8x8 grid of boards, one per plane, using the NPZ's X[row_index].

    npz_path : path to .npz file (e.g. gte_rym.npz)
    ply_meta : mapping (game_id, ply_idx) -> {fen_pre, fen_post, move_uci}
    row_index : index into X (0..N-1)
    """
    npz_path = Path(npz_path)
    data = np.load(npz_path)

    X = data["X"]          # (N, 64, 8, 8)
    game_id_arr = data["game_id"]
    ply_idx_arr = data["ply_idx"]

    N = X.shape[0]
    if not (0 <= row_index < N):
        raise IndexError(f"row_index {row_index} out of range [0, {N-1}]")

    planes = X[row_index]  # (64, 8, 8)
    game_id = int(game_id_arr[row_index])
    ply_idx = int(ply_idx_arr[row_index])

    meta_key = (game_id, ply_idx)
    if meta_key not in ply_meta:
        raise KeyError(f"No PGN metadata for (game_id={game_id}, ply_idx={ply_idx})")

    meta = ply_meta[meta_key]
    fen_pre = meta["fen_pre"]
    fen_post = meta["fen_post"]
    move_uci = meta["move_uci"]

    board_pre = chess.Board(fen_pre)
    board_post = chess.Board(fen_post)
    move = chess.Move.from_uci(move_uci)

    # Main board: position AFTER the move, with last move highlighted
    main_svg = chess.svg.board(board=board_post, lastmove=move, size=size * 2)

    # Build the 8x8 grid of plane boards
    cell_html_pieces = []

    for plane_idx in range(NUM_PLANES):
        plane = planes[plane_idx]  # (8, 8)
        sqs = plane_squares_from_npz_plane(plane)

        # Decide which board context to use
        if (42 <= plane_idx <= 47) or (58 <= plane_idx <= 63):
            ctx_board = board_post  # post-move aggregates & rule planes
        else:
            ctx_board = board_pre   # everything else is pre-move

        svg = chess.svg.board(board=ctx_board, squares=sqs, size=size)

        label = plane_label_for_index(plane_idx)

        if show_labels:
            cell_html = f"""
            <div style="border: 1px solid #aaa; padding: 2px; box-sizing: border-box;">
              <div style="font-size: 9px; text-align: center; margin-bottom: 2px;
                          white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">
                {plane_idx}: {label}
              </div>
              {svg}
            </div>
            """
        else:
            cell_html = f"""
            <div style="border: 1px solid #aaa; padding: 2px; box-sizing: border-box;">
              {svg}
            </div>
            """

        cell_html_pieces.append(cell_html)

    grid_html = f"""
    <div style="
        display: grid;
        grid-template-columns: repeat(8, {size + 12}px);
        grid-auto-rows: auto;
        gap: 4px;
    ">
      {''.join(cell_html_pieces)}
    </div>
    """

    # Wrap everything: main board on the left, planes grid on the right
    full_html = f"""
    <div style="display: flex; flex-direction: row; gap: 16px; align-items: flex-start;">
      <div>
        <div style="font-weight: bold; margin-bottom: 4px;">
          Game {game_id}, ply {ply_idx} — move {move_uci}
        </div>
        {main_svg}
      </div>
      <div>
        {grid_html}
      </div>
    </div>
    """

    display(HTML(full_html))

In [None]:
ply_meta = build_ply_metadata_from_pgn(PGN_PATH)

In [None]:
import numpy as np

data = np.load(NPZ_PATH)
N = data["X"].shape[0]

for i in range(N):
    show_npz_planes_for_row(NPZ_PATH, ply_meta, row_index=i, size=120, show_labels=False)
    # Maybe add a break or manual stepping in practice, this is a lot of SVGs to render
    if i < 5:
        continue
    break  # remove this to see the next ply


In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import torch

from scripts.ply_features import NUM_PLANES
from scripts.rym_models import get_model


In [None]:
def load_resnet_from_checkpoint(
    ckpt_path: str | Path,
    num_planes: int | None = None,
    num_bins: int | None = None,
    config_id: int | None = None,
    device: str | torch.device = "cuda",
) -> torch.nn.Module:
    """
    Load a ResNetRatingModel from a checkpoint that looks like:

        {
            "model_state": <state_dict>,
            "model_type": "resnet",
            "config_id": 0,
            "num_planes": 64,
            "num_bins": 10,
            ...
        }

    Also supports:
        {"model_state_dict": <state_dict>, ...}
    or a bare state_dict.
    """
    from scripts.rym_models import get_model  # ensure import here or at top

    device = torch.device(device)
    ckpt_path = Path(ckpt_path)
    state = torch.load(ckpt_path, map_location=device)

    # --- Case 1: your style with "model_state" + metadata ---
    if isinstance(state, dict) and "model_state" in state:
        ckpt_model_type = state.get("model_type", "resnet")
        ckpt_num_planes = state.get("num_planes", num_planes)
        ckpt_num_bins = state.get("num_bins", num_bins)
        ckpt_config_id = state.get("config_id", config_id if config_id is not None else 0)

        if ckpt_num_planes is None or ckpt_num_bins is None:
            raise ValueError("Checkpoint is missing num_planes/num_bins and none were provided.")

        # Build architecture exactly as in training
        model = get_model(
            model_type=ckpt_model_type,
            num_planes=ckpt_num_planes,
            num_bins=ckpt_num_bins,
            config_id=ckpt_config_id,
        ).to(device)

        model.load_state_dict(state["model_state"])
        model.eval()
        return model

    # --- Case 2: "model_state_dict" style ---
    if isinstance(state, dict) and "model_state_dict" in state:
        state_dict = state["model_state_dict"]
        if num_planes is None or num_bins is None:
            raise ValueError("Need num_planes and num_bins to rebuild model for this checkpoint.")

        model = get_model(
            model_type="resnet",
            num_planes=num_planes,
            num_bins=num_bins,
            config_id=config_id if config_id is not None else 0,
        ).to(device)

        model.load_state_dict(state_dict)
        model.eval()
        return model

    # --- Case 3: bare state_dict ---
    # Assume `state` is itself a state_dict
    if num_planes is None or num_bins is None:
        raise ValueError("Need num_planes and num_bins to rebuild model for bare state_dict.")

    model = get_model(
        model_type="resnet",
        num_planes=num_planes,
        num_bins=num_bins,
        config_id=config_id if config_id is not None else 0,
    ).to(device)

    model.load_state_dict(state)
    model.eval()
    return model


In [None]:
from torch.utils.data import DataLoader, TensorDataset
def predict_probs_for_npz(
    npz_path: str | Path,
    ckpt_path: str | Path,
    batch_size: int = 256,
    device: str | torch.device = None,
) -> pd.DataFrame:
    npz_path = Path(npz_path)
    data = np.load(npz_path)

    X = data["X"]            # (N, NUM_PLANES, 8, 8)
    y_bin = data["y_bin"]
    y_elo = data["y_elo"]
    game_id = data["game_id"]
    ply_idx = data["ply_idx"]

    num_bins_npz = int(data["num_bins"])

    N, C, H, W = X.shape
    assert C == NUM_PLANES, f"Expected {NUM_PLANES} planes, got {C}"

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    # Build and load model; loader will mostly use checkpoint metadata
    model = load_resnet_from_checkpoint(
        ckpt_path=ckpt_path,
        num_planes=NUM_PLANES,
        num_bins=num_bins_npz,
        config_id=0,
        device=device,
    )

    # Dataset + loader
    X_tensor = torch.from_numpy(X.astype(np.float32))
    ds = TensorDataset(X_tensor)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False)

    all_probs = []
    all_rating_preds = []

    with torch.no_grad():
        for (xb,) in loader:
            xb = xb.to(device)
            logits, rating_pred = model(xb)  # logits: (B, num_bins)
            probs = torch.softmax(logits, dim=-1)

            all_probs.append(probs.cpu().numpy())
            rating_pred_np = rating_pred.squeeze(-1).cpu().numpy()
            all_rating_preds.append(rating_pred_np)

    probs_full = np.vstack(all_probs)
    rating_pred_full = np.concatenate(all_rating_preds)

    assert probs_full.shape[0] == N

    num_bins = probs_full.shape[1]  # from model, just in case

    rows = []
    for i in range(N):
        row = {
            "i": int(i),
            "game_id": int(game_id[i]),
            "ply_idx": int(ply_idx[i]),
            "y_bin": int(y_bin[i]),
            "y_elo": float(y_elo[i]),
            "pred_bin": int(probs_full[i].argmax()),
            "pred_rating": float(rating_pred_full[i]),
        }
        for b in range(num_bins):
            row[f"prob_band_{b}"] = float(probs_full[i, b])
        rows.append(row)

    df = pd.DataFrame(rows)
    return df

In [None]:
NPZ_PATH = "gte_rym.npz"
CKPT_PATH = "models/rym_2017-04_baselines/rym_resnet_cfg0.pt"

df_probs = predict_probs_for_npz(NPZ_PATH, CKPT_PATH, batch_size=64)
df_probs.head()


In [None]:
OUT_CSV = "gte_rym_resnet_probs.csv"
df_probs.to_csv(OUT_CSV, index=False)
OUT_CSV

In [None]:
# Assuming you already have:
#   df_debug  = npz_to_debug_dataframe(NPZ_PATH)
#   df_probs  = predict_probs_for_npz(NPZ_PATH, CKPT_PATH)

df_merged = df_debug.merge(
    df_probs[["i", "game_id", "ply_idx", "pred_bin", "pred_rating"] +
             [c for c in df_probs.columns if c.startswith("prob_band_")]],
    on=["i", "game_id", "ply_idx"],
    how="left",
)


In [None]:
import numpy as np

NPZ_PATH = "gte_rym.npz"   # same one you used for predict_probs_for_npz

meta = np.load(NPZ_PATH)
min_rating = int(meta["min_rating"])
max_rating = int(meta["max_rating"])
num_bins = int(meta["num_bins"])

width = (max_rating - min_rating) / num_bins

print("Bands:")
for b in range(num_bins):
    lo = int(min_rating + b * width)
    hi = int(min_rating + (b + 1) * width)
    print(f"band {b}: {lo}-{hi}")

In [None]:
import pandas as pd

# 1) Identify the prob_band_* columns
prob_cols = [c for c in df_merged.columns if c.startswith("prob_band_")]
prob_cols = sorted(prob_cols, key=lambda c: int(c.split("_")[-1]))  # sort by band index

# 2) Build a rename map: prob_band_0 -> prob_400_600, etc.
rename_map = {}
for c in prob_cols:
    b = int(c.split("_")[-1])          # band index
    lo = int(min_rating + b * width)
    hi = int(min_rating + (b + 1) * width)
    new_name = f"prob_{lo}_{hi}"       # e.g. "prob_400_600"
    rename_map[c] = new_name

rename_map


In [None]:
# 3) Select the core columns + prob columns in order
base_cols = ["game_id", "ply_idx", "y_bin", "y_elo", "pred_bin", "pred_rating"]
df_out = df_merged[base_cols + prob_cols].copy()

# 4) Apply the friendlier band labels
df_out = df_out.rename(columns=rename_map)

# 5) Round all floating-point columns to 4 decimal places
float_cols = df_out.select_dtypes(include="float").columns
df_out[float_cols] = df_out[float_cols].round(4)

df_out

In [None]:
# %%
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np

# df_out, min_rating, max_rating, num_bins assumed to exist from previous cells

# Sort by game + ply; if multiple games are present, restrict to the first
df_plot = df_out.sort_values(["game_id", "ply_idx"]).copy()
first_game_id = df_plot["game_id"].iloc[0]
if df_plot["game_id"].nunique() > 1:
    print(f"Multiple games found, plotting only game_id={first_game_id}")
    df_plot = df_plot[df_plot["game_id"] == first_game_id]

plies       = df_plot["ply_idx"].to_numpy()
pred_rating = df_plot["pred_rating"].to_numpy()
pred_bin    = df_plot["pred_bin"].to_numpy().astype(int)

# All probability columns, e.g. prob_400_600, prob_600_800, ...
prob_cols = [c for c in df_plot.columns if c.startswith("prob_")]

# Sort bands by their lower Elo edge so index order matches bin index
prob_cols   = sorted(prob_cols, key=lambda c: int(c.split("_")[1]))
band_labels = [f"{c.split('_')[1]}–{c.split('_')[2]}" for c in prob_cols]

# probs shape: (num_bands, num_plies)
probs = df_plot[prob_cols].to_numpy().T
num_bands, num_plies = probs.shape

# Compute the stacked bottoms for each band and ply
bottoms = np.zeros_like(probs)
for b in range(num_bands):
    if b == 0:
        bottoms[b] = 0.0
    else:
        bottoms[b] = probs[:b].sum(axis=0)

fig, ax_prob = plt.subplots(figsize=(14, 6))

cmap   = plt.get_cmap("tab20")
colors = [cmap(i % cmap.N) for i in range(num_bands)]
band_handles = []

bar_width = 0.9  # a bit narrower than 1 ply so there are tiny gaps

# --- Stacked probability bars with alpha depending on predicted band ---
for b in range(num_bands):
    height = probs[b]          # (num_plies,)
    bottom = bottoms[b]        # (num_plies,)

    bars = ax_prob.bar(
        plies,
        height,
        bottom=bottom,
        width=bar_width,
        color=colors[b],
        edgecolor="none",
        align="center",
        label=band_labels[b],
    )

    # Per-ply alpha: brighter if this band is the predicted bin, dimmer otherwise
    for patch, pb in zip(bars.patches, pred_bin):
        if pb == b:
            patch.set_alpha(0.9)
        else:
            patch.set_alpha(0.25)

    # Keep a handle for legend
    band_handles.append(bars)

# Axes labels / limits for probabilities
ax_prob.set_xlabel("Ply index")
ax_prob.set_ylabel("Probability")
ax_prob.set_xlim(plies.min() - 0.5, plies.max() + 0.5)
ax_prob.set_ylim(0.0, 1.0)
ax_prob.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0))

# (2) Minor ticks every 2 plies; keep default major ticks/labels
ax_prob.set_xticks(np.arange(plies.min(), plies.max() + 1, 2), minor=True)

# --- Predicted Elo line on twin axis ---
ax_elo = ax_prob.twinx()
elo_line, = ax_elo.plot(
    plies,
    pred_rating,
    color="black",
    linewidth=1.5,
    label="Predicted Elo",
)
ax_elo.set_ylabel("Predicted Elo")
ax_elo.set_ylim(min_rating - 50, max_rating + 50)

# --- Legends ---
# Make room at bottom for the band legend
fig.subplots_adjust(bottom=0.25)

# (1) Bands legend at the bottom, mapping colors to Elo bands
legend_bands = ax_prob.legend(
    handles=[h for h in band_handles],
    labels=band_labels,
    loc="upper center",
    bbox_to_anchor=(0.5, -0.18),
    ncol=min(num_bands, 5),
    title="Rating bands (Elo)",
    frameon=False,
)
ax_prob.add_artist(legend_bands)

# Elo line legend in a corner
ax_elo.legend(loc="upper left")

plt.tight_layout()
plt.show()