In [1]:
import os, json
import numpy as np
import torch
from collections import OrderedDict
os.chdir('..')

from utils.visualizer import SoccerVisualizer  # same import style as eval-viz.py
from utils.train_utils import * 
# --- memmap helpers (adapted from eval-viz.py) ---
class MemmapShard:
    def __init__(self, root_dir: str, x_name: str, t_name: str, n: int, C: int, H: int, W: int):
        self.n = int(n)
        self.C, self.H, self.W = int(C), int(H), int(W)
        self.X = np.memmap(os.path.join(root_dir, x_name), mode="r", dtype=np.float16,
                           shape=(self.n, self.C, self.H, self.W))
        self.T = np.memmap(os.path.join(root_dir, t_name), mode="r", dtype=np.float32,
                           shape=(self.n, 3))

class MemmapManifest:
    def __init__(self, root_dir: str, cache_size: int = 2):
        self.root_dir = root_dir
        self.cache_size = int(cache_size)

        with open(os.path.join(root_dir, "manifest.json"), "r") as f:
            man = json.load(f)
        assert man.get("format") == "memmap_v1"

        self.C = int(man["C"]); self.H = int(man["H"]); self.W = int(man["W"])
        self.channels = list(man.get("channels", []))
        self.shards = list(man["shards"])

        self.starts = []
        cur = 0
        for s in self.shards:
            self.starts.append(cur)
            cur += int(s["n"])
        self.total = cur

        self._cache = OrderedDict()

    def _open_shard(self, shard_id: int) -> MemmapShard:
        shard_id = int(shard_id)
        if shard_id in self._cache:
            self._cache.move_to_end(shard_id)
            return self._cache[shard_id]
        s = self.shards[shard_id]
        mm = MemmapShard(self.root_dir, s["x_path"], s["t_path"], int(s["n"]), self.C, self.H, self.W)
        self._cache[shard_id] = mm
        if len(self._cache) > self.cache_size:
            self._cache.popitem(last=False)
        return mm

    def locate(self, k: int):
        k = int(k)
        lo, hi = 0, len(self.starts) - 1
        while lo <= hi:
            mid = (lo + hi) // 2
            start = self.starts[mid]
            end = self.starts[mid + 1] if mid + 1 < len(self.starts) else self.total
            if start <= k < end:
                return mid, k - start
            if k < start: hi = mid - 1
            else: lo = mid + 1
        raise RuntimeError("locate failed")

    def load_by_shard_local(self, shard_id: int, local_i: int, swap_xy: bool = False):
        shard = self._open_shard(shard_id)
        x = torch.from_numpy(np.array(shard.X[local_i], copy=True)).float()  # (C,H,W)
        t = shard.T[local_i]
        dst_xy = torch.tensor(t[:2], dtype=torch.long)
        if swap_xy:
            dst_xy = dst_xy[[1, 0]]
        y = torch.tensor(float(t[2]), dtype=torch.float32)
        x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        y = torch.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
        return x, dst_xy, y

# ---- choose dataset + load manifest once ----
DATA_ROOT = "data/finaldata-3meter"  # has train/ and val/
SPLIT = "val"                                # "train" or "val"
mm = MemmapManifest(os.path.join(DATA_ROOT, SPLIT), cache_size=2)
print(f"[data] split={SPLIT} N={mm.total} C,H,W={mm.C},{mm.H},{mm.W}")


# ---- shard selection ----
# You can lock to a specific shard_id to avoid re-opening new ones:
SHARD_ID = None  # set to int to force a specific shard
if SHARD_ID is None:
    SHARD_ID = int(np.random.randint(0, len(mm.shards)))
shard_n = int(mm.shards[SHARD_ID]["n"])
print(f"[data] using shard_id={SHARD_ID} n={shard_n}")

def sample_from_loaded_shard(local_i=None, swap_xy=False):
    if local_i is None:
        local_i = int(np.random.randint(0, shard_n))
    x_chw, dst_xy, y = mm.load_by_shard_local(SHARD_ID, local_i, swap_xy=swap_xy)
    return local_i, x_chw, dst_xy, y



[data] split=val N=7873 C,H,W=13,105,68
[data] using shard_id=0 n=7873


# Helper Functions

## Load Shard and Definitions

In [None]:
# X: (C, H, W) tensor for one frame
# channel indices based on your manifest ordering
DIST, SINB, COSB, POS, DEF, SING, COSG, DISTG, BVX, BVY, AVX, AVY, DVX, DVY = range(14)

feats, targs = load_first_shard()

feats.shape



## Regular State Plotting

In [None]:
EXAMPLE_IDX = np.random.randint(0,20000)

X, dst_xy, y = get_example(feats, targs, EXAMPLE_IDX)


viz = SoccerVisualizer()

ball_dist = X[DIST]
in_pos = X[POS]          # (105, 68) one-hot
out_pos = X[DEF]
ball_vx = X[BVX]
ball_vy = X[BVY]
att_vx = X[AVX]
att_vy = X[AVY]
def_vx = X[DVX]
def_vy = X[DVY]

fig, ax, artists = viz.plot_state(
    in_possession=in_pos,
    out_possession=out_pos,
)
dst_x = float(dst_xy[0].item())
dst_y = float(dst_xy[1].item())

flat_idx = torch.argmin(ball_dist)
yy = (flat_idx // ball_dist.shape[1]).item()
x = (flat_idx %  ball_dist.shape[1]).item()
bx, by = float(yy), float(x)

if y == 1:
    ax.scatter([bx], [by], c="black", s=15, marker="o", zorder=5, linewidths=0.5,label = 'Start Ball Location')
    ax.scatter([dst_x], [dst_y], c="cyan", s=15, marker="o", zorder=5, linewidths=0.5,label = 'End Ball Location')

else:
    ax.scatter([bx], [by], c="black", s=15, marker="o", zorder=5, linewidths=0.5,label = 'Start Ball Location')
    ax.scatter([dst_x], [dst_y], c="red", s=15, marker="o", zorder=5, linewidths=0.5,label = 'End Ball Location')

legend = ax.legend(
    loc="upper center",
    frameon=True,
    facecolor="#aabb97",
    edgecolor="lightgray",
    fontsize=5,
    labelspacing=0.8,
    borderpad=0.8,
    handletextpad=0.8,
)

# round the legend box corners
legend.get_frame().set_boxstyle("round,pad=0.4")
legend.get_frame().set_alpha(0.95)

print(y)

## Velocity Stuff

In [None]:



# 1) Just the ball velocity field:
fig, ax, q = viz.plot_velocity_quiver(def_vx, def_vy, color="black", step=16)

# 2) Full state + attacking glob velocity:
fig, ax, q = viz.plot_velocity_on_state(
    in_possession=in_pos,
    out_possession=out_pos,
    vx_map=att_vx,
    vy_map=att_vy,
    quiver_kwargs={"color": "red", "step": 16, "alpha": 0.9, "label": "Att glob vel"},
)
dst_x = float(dst_xy[0].item())
dst_y = float(dst_xy[1].item())

flat_idx = torch.argmin(ball_dist)
yy = (flat_idx // ball_dist.shape[1]).item()
x = (flat_idx %  ball_dist.shape[1]).item()
bx, by = float(yy), float(x)

if y ==1:
    ax.scatter([bx], [by], c="black", s=15, marker="o", zorder=5, linewidths=0.5,label = 'Start Ball Location')
    ax.scatter([dst_x], [dst_y], c="cyan", s=15, marker="o", zorder=5, linewidths=0.5,label = 'End Ball Location')

else:
    ax.scatter([bx], [by], c="black", s=15, marker="o", zorder=5, linewidths=0.5,label = 'Start Ball Location')
    ax.scatter([dst_x], [dst_y], c="red", s=15, marker="o", zorder=5, linewidths=0.5,label = 'End Ball Location')

legend = ax.legend(
    loc="upper center",
    frameon=True,
    facecolor="#aabb97",
    edgecolor="lightgray",
    fontsize=5,
    labelspacing=0.8,
    borderpad=0.8,
    handletextpad=0.8,
)

# round the legend box corners
legend.get_frame().set_boxstyle("round,pad=0.4")
legend.get_frame().set_alpha(0.95)



## Random Velo Sample

In [None]:
EXAMPLE_IDX = np.random.randint(0,20000)

X, dst_xy, y = get_example(feats, targs, EXAMPLE_IDX)


viz = SoccerVisualizer()

ball_dist = X[DIST]
in_pos = X[POS]          # (105, 68) one-hot
out_pos = X[DEF]
ball_vx = X[BVX]
ball_vy = X[BVY]
att_vx = X[AVX]
att_vy = X[AVY]
def_vx = X[DVX]
def_vy = X[DVY]


# 1) Just the ball velocity field:
fig, ax, q = viz.plot_velocity_quiver(def_vx, def_vy, color="black", step=16)

# 2) Full state + attacking glob velocity:
fig, ax, q = viz.plot_velocity_on_state(
    in_possession=in_pos,
    out_possession=out_pos,
    vx_map=att_vx,
    vy_map=att_vy,
    quiver_kwargs={"color": "red", "step": 16, "alpha": 0.9, "label": "Att glob vel"},
)
dst_x = float(dst_xy[0].item())
dst_y = float(dst_xy[1].item())

flat_idx = torch.argmin(ball_dist)
yy = (flat_idx // ball_dist.shape[1]).item()
x = (flat_idx %  ball_dist.shape[1]).item()
bx, by = float(yy), float(x)

if y ==1:
    ax.scatter([bx], [by], c="black", s=15, marker="o", zorder=5, linewidths=0.5,label = 'Start Ball Location')
    ax.scatter([dst_x], [dst_y], c="cyan", s=15, marker="o", zorder=5, linewidths=0.5,label = 'End Ball Location')

else:
    ax.scatter([bx], [by], c="black", s=15, marker="o", zorder=5, linewidths=0.5,label = 'Start Ball Location')
    ax.scatter([dst_x], [dst_y], c="red", s=15, marker="o", zorder=5, linewidths=0.5,label = 'End Ball Location')

legend = ax.legend(
    loc="upper center",
    frameon=True,
    facecolor="#aabb97",
    edgecolor="lightgray",
    fontsize=5,
    labelspacing=0.8,
    borderpad=0.8,
    handletextpad=0.8,
)

# round the legend box corners
legend.get_frame().set_boxstyle("round,pad=0.4")
legend.get_frame().set_alpha(0.95)

