In [None]:
# ---- ablation keep list (paste from TensorBoard text) ----
KEEP_CSV = "dist_to_ball_norm,sin_to_ball,cos_to_ball,atk_team_1hot,def_team_1hot,att_glob_vx_norm,att_glob_vy_norm,def_glob_vx_norm,def_glob_vy_norm,gaussian_control,nearest_def_dist_norm," \
"boundary_dist_norm,centerline_dist_norm,goal_sin,goal_cos,goal_dist_norm"

KEEP_NAMES = [s.strip() for s in KEEP_CSV.split(",") if s.strip()]

# Static channel names MUST match PitchStaticChannels.forward() stacking order in your code
STATIC_NAMES = ["boundary_dist_norm", "centerline_dist_norm", "goal_sin", "goal_cos", "goal_dist_norm"]
STATIC_SET = set(STATIC_NAMES)

def build_keep_indices(mm_channels: list[str], keep_names: list[str]):
    dyn_name_to_idx = {n:i for i,n in enumerate(mm_channels)}
    dyn_keep = []
    static_keep = []
    unknown = []

    for n in keep_names:
        if n in STATIC_SET:
            static_keep.append(n)
        elif n in dyn_name_to_idx:
            dyn_keep.append(n)
        else:
            unknown.append(n)

    if unknown:
        raise ValueError(f"Keep list contains unknown channels not in manifest or static set: {unknown}")

    dyn_keep_idxs = [dyn_name_to_idx[n] for n in dyn_keep]
    static_keep_idxs = [STATIC_NAMES.index(n) for n in static_keep]
    return dyn_keep, dyn_keep_idxs, static_keep, static_keep_idxs


In [None]:
import os, json
import numpy as np
import torch
from collections import OrderedDict
os.chdir('..')

from utils.visualizer import SoccerVisualizer  # same import style as eval-viz.py
from utils.train_utils import * 
# --- memmap helpers (adapted from eval-viz.py) ---
class MemmapShard:
    def __init__(self, root_dir: str, x_name: str, t_name: str, n: int, C: int, H: int, W: int):
        self.n = int(n)
        self.C, self.H, self.W = int(C), int(H), int(W)
        self.X = np.memmap(os.path.join(root_dir, x_name), mode="r", dtype=np.float16,
                           shape=(self.n, self.C, self.H, self.W))
        self.T = np.memmap(os.path.join(root_dir, t_name), mode="r", dtype=np.float32,
                           shape=(self.n, 3))

class MemmapManifest:
    def __init__(self, root_dir: str, cache_size: int = 2):
        self.root_dir = root_dir
        self.cache_size = int(cache_size)

        with open(os.path.join(root_dir, "manifest.json"), "r") as f:
            man = json.load(f)
        assert man.get("format") == "memmap_v1"

        self.C = int(man["C"]); self.H = int(man["H"]); self.W = int(man["W"])
        self.channels = list(man.get("channels", []))
        self.shards = list(man["shards"])

        self.starts = []
        cur = 0
        for s in self.shards:
            self.starts.append(cur)
            cur += int(s["n"])
        self.total = cur

        self._cache = OrderedDict()

    def _open_shard(self, shard_id: int) -> MemmapShard:
        shard_id = int(shard_id)
        if shard_id in self._cache:
            self._cache.move_to_end(shard_id)
            return self._cache[shard_id]
        s = self.shards[shard_id]
        mm = MemmapShard(self.root_dir, s["x_path"], s["t_path"], int(s["n"]), self.C, self.H, self.W)
        self._cache[shard_id] = mm
        if len(self._cache) > self.cache_size:
            self._cache.popitem(last=False)
        return mm

    def locate(self, k: int):
        k = int(k)
        lo, hi = 0, len(self.starts) - 1
        while lo <= hi:
            mid = (lo + hi) // 2
            start = self.starts[mid]
            end = self.starts[mid + 1] if mid + 1 < len(self.starts) else self.total
            if start <= k < end:
                return mid, k - start
            if k < start: hi = mid - 1
            else: lo = mid + 1
        raise RuntimeError("locate failed")

    def load_by_shard_local(self, shard_id: int, local_i: int, swap_xy: bool = False):
        shard = self._open_shard(shard_id)
        x = torch.from_numpy(np.array(shard.X[local_i], copy=True)).float()  # (C,H,W)
        t = shard.T[local_i]
        dst_xy = torch.tensor(t[:2], dtype=torch.long)
        if swap_xy:
            dst_xy = dst_xy[[1, 0]]
        y = torch.tensor(float(t[2]), dtype=torch.float32)
        x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        y = torch.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
        return x, dst_xy, y

# ---- choose dataset + load manifest once ----
DATA_ROOT = "data/finaldata-3meter"  # has train/ and val/
SPLIT = "val"                                # "train" or "val"
mm = MemmapManifest(os.path.join(DATA_ROOT, SPLIT), cache_size=2)
print(f"[data] split={SPLIT} N={mm.total} C,H,W={mm.C},{mm.H},{mm.W}")


# ---- shard selection ----
# You can lock to a specific shard_id to avoid re-opening new ones:
SHARD_ID = None  # set to int to force a specific shard
if SHARD_ID is None:
    SHARD_ID = int(np.random.randint(0, len(mm.shards)))
shard_n = int(mm.shards[SHARD_ID]["n"])
print(f"[data] using shard_id={SHARD_ID} n={shard_n}")

def sample_from_loaded_shard(local_i=None, swap_xy=False):
    if local_i is None:
        local_i = int(np.random.randint(0, shard_n))
    x_chw, dst_xy, y = mm.load_by_shard_local(SHARD_ID, local_i, swap_xy=swap_xy)
    return local_i, x_chw, dst_xy, y

dyn_keep_names, dyn_keep_idxs, static_keep_names, static_keep_idxs = build_keep_indices(mm.channels, KEEP_NAMES)

print("[ablation] dyn keep:", dyn_keep_names)
print("[ablation] static keep:", static_keep_names)
print("[ablation] dyn_keep_idxs:", dyn_keep_idxs)
print("[ablation] static_keep_idxs:", static_keep_idxs)

In [None]:
import torch
from utils.static_maps import PitchStaticChannels, PitchDims

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- static channels (same as training) ---
static = PitchStaticChannels(dims=PitchDims(H=mm.H, W=mm.W)).to(DEVICE)

# infer full static count (from PitchStaticChannels)
with torch.no_grad():
    C_static_full = int(static.forward().shape[0])

# Effective channels after ablation slicing
C_dyn_eff = len(dyn_keep_idxs)
C_static_eff = len(static_keep_idxs)
IN_CHANNELS = C_dyn_eff + C_static_eff

print(f"[static] C_dyn_full={mm.C} C_static_full={C_static_full}")
print(f"[static] C_dyn_eff={C_dyn_eff} C_static_eff={C_static_eff} -> in_channels={IN_CHANNELS}")


# --- model (replace with your actual constructor) ---
from models.footballmap import PassMap
#model = BetterSoccerMap2Head(in_channels=IN_CHANNELS, base=64, blocks_per_stage=2, dropout=0.0).to(DEVICE).float()
model = PassMap(in_channels=IN_CHANNELS, base=64, blocks_per_stage=4).to(DEVICE).float()
#model =  PitchVisionNet(C_dyn+C_static, base=64, blocks_per_stage=3).to(DEVICE).float()
# --- checkpoint loading (adapt to your ckpt format) ---
CKPT_PATH = "overnight-training-runs/20251217-015059_PassMap_drop_ball_vel/best_ckpt.pt"
ckpt = torch.load(CKPT_PATH, map_location=DEVICE,weights_only= False)
# common patterns:
if isinstance(ckpt, dict) and "model_state" in ckpt:
    sd = ckpt["model_state"]
elif isinstance(ckpt, dict) and "model" in ckpt:
    sd = ckpt["model"]
elif isinstance(ckpt, dict) and "state_dict" in ckpt:
    sd = ckpt["state_dict"]
else:
    sd = ckpt

missing, unexpected = model.load_state_dict(sd, strict=False)
print(f"[ckpt] loaded {CKPT_PATH}")
if missing: print(f"[warn] missing keys: {len(missing)}")
if unexpected: print(f"[warn] unexpected keys: {len(unexpected)}")

print(ckpt['args'])


In [None]:
import numpy as np
import torch

def compute_twohead_maps(out: dict):
    dest = out["dest_logits"]  # (B,1,H,W) or (B,H,W)
    succ = out["succ_logits"]

    if dest.dim() == 4: dest = dest[:, 0]
    if succ.dim() == 4: succ = succ[:, 0]
    

    B, H, W = dest.shape
    dest_probs = torch.softmax(dest.view(B, -1), dim=1).view(B, H, W)
    succ_probs = torch.sigmoid(succ)
    comp_map = dest_probs * succ_probs
    return dest_probs[0], succ_probs[0], comp_map[0]

def infer_and_plot(local_i=None, swap_xy=False, coords_are_centers=False):
    local_i, x_chw, dst_xy, y = sample_from_loaded_shard(local_i=local_i, swap_xy=swap_xy)

    # append static channels (B,C,*,*) -> (B,C+Cstatic,*,*)
    X = x_chw.unsqueeze(0).to(DEVICE)
    X = static.concat_to(X, dim=1)
    X = torch.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

    with torch.no_grad():
        out = model(X)
        dest_map, succ_map, comp_map = compute_twohead_maps(out)

    # destination marker coords
    dst_x = float(dst_xy[0].item())
    dst_y = float(dst_xy[1].item())
    if not coords_are_centers:
        dst_x += 0.5
        dst_y += 0.5

    # start location from dist_to_ball argmin (channel 0 in your current schema)
    ball_dist = x_chw[0]
    flat_idx = torch.argmin(ball_dist)
    x_idx = (flat_idx // ball_dist.shape[1]).item()
    y_idx = (flat_idx %  ball_dist.shape[1]).item()
    bx, by = float(x_idx) + 0.5, float(y_idx) + 0.5

    # occupancy maps for overlay (channel 3/4 in your current schema)
    in_pos  = (x_chw[3] > 0).float()
    out_pos = (x_chw[4] > 0).float()

    vis = SoccerVisualizer(pitch_length=mm.H, pitch_width=mm.W, layout="x_rows")

    ok = "✓" if int(y.item()) == 1 else "✗"

    def _plot(heat_t: torch.Tensor, title: str):
        fig, ax, _ = vis.plot_state(
            in_possession=in_pos,
            out_possession=out_pos,
            heatmap=heat_t.detach().cpu(),
            cmap="Blues",
            heatmap_kwargs=dict(alpha=0.9),
            add_colorbar=True,
        )
        ax.scatter([bx], [by], c="black", s=30, marker="o", zorder=6, linewidths=0.5, label="Start")
        ax.scatter([dst_x], [dst_y], c="red",   s=30, marker="o", zorder=6, linewidths=0.5, label="End")
        ax.set_title(title)
        fig.tight_layout()
        fig.legend()

    _plot(dest_map, f"Destination P(dest=cell | s) | pass {ok} | local_i={local_i}")
    _plot(succ_map, f"Success P(complete | s, cell) | pass {ok} | local_i={local_i}")
    _plot(comp_map, f"Completed-pass surface P(dest & complete | s) | pass {ok} | local_i={local_i}")

    return local_i

# Run a random example from the currently loaded shard
#infer_and_plot(local_i=None)


In [None]:
import matplotlib.pyplot as plt
import torch

def infer_and_plot_three(local_i=None, swap_xy=False, coords_are_centers=False):
    local_i, x_chw, dst_xy, y = sample_from_loaded_shard(local_i=local_i, swap_xy=swap_xy)

    X = x_chw.unsqueeze(0).to(DEVICE)  # (1, C_dyn_full, H, W)

    # ---- slice dynamic channels ----
    X = X[:, dyn_keep_idxs]

    # ---- build + slice static channels ----
    st = static.expand_to_batch(X.size(0)).to(device=X.device, dtype=X.dtype)  # (1, C_static_full, H, W)
    st = st[:, static_keep_idxs]

    # ---- concat -> model input ----
    X = torch.cat([X, st], dim=1)
    X = torch.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

    with torch.no_grad():
        out = model(X)
        dest_map, succ_map, comp_map = compute_twohead_maps(out)  # expected (H,W) each

    # destination marker coords
    dst_x = float(dst_xy[0].item())
    dst_y = float(dst_xy[1].item())
    if not coords_are_centers:
        dst_x += 0.5
        dst_y += 0.5

    # approximate ball location from argmin of dist_to_ball (channel 0)
    ball_dist = x_chw[0]  # (H,W) in your (FL,FW) convention
    flat_idx = torch.argmin(ball_dist)
    bx = float((flat_idx // ball_dist.shape[1]).item()) + 0.5
    by = float((flat_idx %  ball_dist.shape[1]).item()) + 0.5

    # overlay maps (channels 3/4 in your current schema)
    in_pos  = (x_chw[3] > 0).float()
    out_pos = (x_chw[4] > 0).float()

    vis = SoccerVisualizer(pitch_length=mm.H, pitch_width=mm.W, layout="x_rows")

    maps = [
        dest_map,
        succ_map,
        comp_map,
    ]

    for heat in maps:
        fig, ax, artists = vis.plot_state(
            in_possession=in_pos,
            out_possession=out_pos,
            heatmap=heat.detach().cpu(),     # (FL,FW)
            add_colorbar=True,              # each figure gets its own
            colorbar_kwargs=dict(),         # tweak if you want (ticks, format, etc.)
            plain=True,                     # match your preference (optional)
        )

        # markers
        ax.scatter([bx], [by], c="black", s=25, marker="o", zorder=25)
        ax.scatter([dst_x], [dst_y], c="red",   s=25, marker="o", zorder=25)

        # absolutely no titles
        ax.set_title("")
        fig.suptitle("")

        plt.show()

    return local_i

# Example


In [None]:
import matplotlib.pyplot as plt
import torch

def infer_and_plot_triptych(local_i=None, swap_xy=False, coords_are_centers=False):
    local_i, x_chw, dst_xy, y = sample_from_loaded_shard(
        local_i=local_i, swap_xy=swap_xy
    )

    print(y)

    X = x_chw.unsqueeze(0).to(DEVICE)  # (1, C_dyn_full, H, W)

    # ---- slice dynamic channels ----
    X = X[:, dyn_keep_idxs]

    # ---- build + slice static channels ----
    st = static.expand_to_batch(X.size(0)).to(
        device=X.device, dtype=X.dtype
    )
    st = st[:, static_keep_idxs]

    # ---- concat -> model input ----
    X = torch.cat([X, st], dim=1)
    X = torch.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

    with torch.no_grad():
        out = model(X)
        dest_map, succ_map, comp_map = compute_twohead_maps(out)

    

    # ---- destination marker ----
    dst_x = float(dst_xy[0].item())
    dst_y = float(dst_xy[1].item())
    if not coords_are_centers:
        dst_x += 0.5
        dst_y += 0.5

    # ---- approximate ball location ----
    ball_dist = x_chw[0]
    flat_idx = torch.argmin(ball_dist)
    bx = float((flat_idx // ball_dist.shape[1]).item()) + 0.5
    by = float((flat_idx %  ball_dist.shape[1]).item()) + 0.5

    # ---- possession overlays ----
    in_pos  = (x_chw[3] > 0).float()
    out_pos = (x_chw[4] > 0).float()

    vis = SoccerVisualizer(
        pitch_length=mm.H,
        pitch_width=mm.W,
        layout="x_rows",
    )

    maps = [dest_map, succ_map, comp_map]

    # ---- single figure, 1x3 ----
    fig, axes = plt.subplots(
        1, 3, figsize=(18, 6), constrained_layout=True
    )

    for ax, heat in zip(axes, maps):
        vis.plot_state(
            in_possession=in_pos,
            out_possession=out_pos,
            heatmap=heat.detach().cpu(),
            ax=ax,
            add_colorbar=True,      # colorbar per subplot
            plain=True,
        )

        ax.scatter([bx], [by], c="black", s=25, zorder=25)
        ax.scatter([dst_x], [dst_y], c="red", s=25, zorder=25)

        # absolutely no titles
        ax.set_title("")

    fig.suptitle("")  # explicitly blank
    plt.show()

    return local_i


In [None]:
#infer_and_plot_triptych(3151)
infer_and_plot_triptych(6801)
infer_and_plot_triptych(481)
infer_and_plot_triptych(555)
infer_and_plot_triptych(1348)
infer_and_plot_triptych(3151)
infer_and_plot_triptych(4657)


Validation example 885 from shard 0, 762 shard 1, 6801 shard 0