# Train A Model

In [1]:
from __future__ import annotations

import json
import sys
from pathlib import Path


def _bootstrap_repo_root(start: Path | None = None) -> Path:
    here = (start or Path.cwd()).resolve()
    for candidate in (here, *here.parents):
        if (candidate / "time_to_explain").is_dir():
            return candidate
    raise RuntimeError(
        f"Could not locate the repository root from {here}. "
        "Set PROJECT_ROOT manually if your layout is unusual."
    )


PROJECT_ROOT = _bootstrap_repo_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from time_to_explain.utils.device import pick_device
CONFIG_PATH = PROJECT_ROOT / "configs" / "notebooks" / "global.json"
NOTEBOOK_CFG = json.loads(CONFIG_PATH.read_text(encoding="utf-8")) if CONFIG_PATH.exists() else {}
SEED = int(NOTEBOOK_CFG.get("seed", 42))
DEVICE = pick_device(NOTEBOOK_CFG.get("device", "auto"))
print(f"Notebook config: seed={SEED}, device={DEVICE}")

from time_to_explain.models.utils import (
    build_cmd,
    ensure_tempme_processed,
    ensure_workdir,
    export_trained_models,
    prepare_env,
    run_cmd,
)
from time_to_explain.utils.cli import (
    args_dict_to_list,
    normalize_datasets,
    resolve_path,
    slugify,
)


Notebook config: seed=42, device=mps


### Set Config

In [2]:
DATASET_NAME = "wikipedia"  # e.g. "reddit", "simulate_v1", ...
# Choose one of: "tgn", "tgat", "graphmixer", "dbgnn"
MODEL_TYPE = "tgn"


### Get Config

In [3]:
CONFIG_PATH = Path(f"configs/models/train_{MODEL_TYPE}_{DATASET_NAME}.json")

CONFIG_PATH = resolve_path(str(CONFIG_PATH), root=PROJECT_ROOT)

if CONFIG_PATH is None or not CONFIG_PATH.exists():
    print(f"[warn] Config not found: {CONFIG_PATH}. Using notebook defaults.")
    CONFIG = {}
else:
    CONFIG = json.loads(CONFIG_PATH.read_text(encoding="utf-8"))

# Prefer config override, otherwise keep the notebook selection above.
MODEL_TYPE = str(CONFIG.get("model_type", MODEL_TYPE)).upper()
DATASET_LIST = normalize_datasets(CONFIG.get("datasets", [DATASET_NAME]))
PYTHON_BIN = str(CONFIG.get("python_bin", "python"))
DRY_RUN = bool(CONFIG.get("dry_run", False))
CUDA_VISIBLE_DEVICES = CONFIG.get("cuda_visible_devices")

MODEL_SPECS = {str(k).upper(): v for k, v in (CONFIG.get("models") or {}).items()}

# Only enforce model-spec existence for the script-driven baselines.
if MODEL_TYPE in {"TGN", "TGAT", "GRAPHMIXER"} and MODEL_TYPE not in MODEL_SPECS:
    raise KeyError(
        f"Model spec for {MODEL_TYPE} missing in {CONFIG_PATH}. "
        f"Add it under CONFIG['models'] or switch MODEL_TYPE."
    )

RESOURCES_MODELS = resolve_path(CONFIG.get("resources_models_dir", "resources/models"), root=PROJECT_ROOT)
RUNS_ROOT = resolve_path(CONFIG.get("runs_root", "resources/models/runs"), root=PROJECT_ROOT)
RESOURCES_DATASETS = resolve_path(CONFIG.get("resources_datasets_dir", "resources/datasets/processed"), root=PROJECT_ROOT)

DEFAULT_WORKDIR = RUNS_ROOT / f"{slugify(MODEL_TYPE)}_{slugify(DATASET_LIST[0])}"

TGN_SPEC = MODEL_SPECS.get("TGN", {})
TGAT_SPEC = MODEL_SPECS.get("TGAT", {})
GRAPHMIXER_SPEC = MODEL_SPECS.get("GRAPHMIXER", {})
DBGNN_SPEC = MODEL_SPECS.get("DBGNN", {})  # optional (not required for in-notebook trainer)

def _maybe_resolve(path_like):
    return resolve_path(path_like, root=PROJECT_ROOT) if path_like else None

TGN_SCRIPT = _maybe_resolve(TGN_SPEC.get("script"))
TGAT_SCRIPT = _maybe_resolve(TGAT_SPEC.get("script"))
GRAPHMIXER_SCRIPT = _maybe_resolve(GRAPHMIXER_SPEC.get("script"))
GRAPHMIXER_PROCESSED_DIR = _maybe_resolve(GRAPHMIXER_SPEC.get("processed_dir"))
GRAPHMIXER_PARAMS_DIR = _maybe_resolve(GRAPHMIXER_SPEC.get("params_dir"))

DBGNN_SCRIPT = _maybe_resolve(DBGNN_SPEC.get("script"))  # optional

def get_tgn_args(dataset: str) -> list[str]:
    return args_dict_to_list(TGN_SPEC.get("args", {}), dataset)

def get_tgat_args(dataset: str) -> list[str]:
    return args_dict_to_list(TGAT_SPEC.get("args", {}), dataset)

def get_graphmixer_args(dataset: str) -> list[str]:
    return args_dict_to_list(GRAPHMIXER_SPEC.get("args", {}), dataset)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("MODEL_TYPE:", MODEL_TYPE)
print("DATASETS:", DATASET_LIST)
print("Artifacts root:", RUNS_ROOT)
print("Sample run directory:", DEFAULT_WORKDIR)
if CONFIG:
    print("Configuration loaded from:", CONFIG_PATH)
else:
    print("No config loaded (using notebook defaults).")


PROJECT_ROOT: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
MODEL_TYPE: TGN
DATASETS: ['wikipedia']
Artifacts root: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/runs
Sample run directory: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/runs/tgn_wikipedia
Configuration loaded from: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/configs/models/train_tgn_wikipedia.json


### DBGNN: train on Wikipedia (and other TGNN-style datasets)

This section adds an **in-notebook** trainer for the DBGNN model used in `01_train_dbgnn_and_node2vec.ipynb`.

It uses the TGNN processed files in `resources/datasets/processed/` (e.g. `ml_wikipedia.csv`, `ml_wikipedia_node.npy`).
By default it trains a **node classification** task on Wikipedia where the label is the node role:
- `0` = appears as a source node `u` (user)
- `1` = appears as a destination node `i` (page)

You can change the task/labels later if you want (see `DBGNN_DEFAULTS`).


In [4]:
# --- DBGNN utilities (used when MODEL_TYPE="DBGNN") ---

from pathlib import Path
from typing import Any

import numpy as np
import torch

# Make gnnbench (src/) importable (same trick as in 01_train_dbgnn_and_node2vec.ipynb).
SRC_ROOT = PROJECT_ROOT / "src"
if SRC_ROOT.exists() and str(SRC_ROOT) not in sys.path:
    sys.path.insert(0, str(SRC_ROOT))

# Optional dependency (used only if you toggle use_node2vec_features=True)
_HAS_NODE2VEC = True
try:
    from node2vec import Node2Vec  # noqa: F401
except Exception:
    _HAS_NODE2VEC = False

# torch_geometric is used for the Data container. We fall back to a lightweight object if missing.
try:
    from torch_geometric.data import Data  # type: ignore
except Exception:
    Data = None

DBGNN_DEFAULTS: dict[str, Any] = dict(
    # Task definition ---------------------------------------------------
    task="node_type",  # "node_type" (default) | "link_pred" (experimental stub)
    # Train/val/test split ---------------------------------------------
    num_test=0.30,      # fraction of labeled nodes
    num_val=0.10,       # fraction of labeled nodes (taken from remaining after test split)
    seed=SEED,
    # Model/training hyperparams ---------------------------------------
    epochs=400,
    lr=1e-3,
    p_dropout=0.4,
    hidden_dims=(16, 32, 16),
    # Features ----------------------------------------------------------
    use_node2vec_features=False,
    # If None, will be set to hidden_dims[0] (keeps dimensionalities consistent).
    n2v_dim=None,
    # Higher-order graph construction ----------------------------------
    # "identity" means: HO graph == FO graph, HO nodes == FO nodes (fast and robust).
    # You can replace this with a real HO construction later.
    ho_mode="identity",
)

def _read_tgnn_processed(dataset: str, processed_dir: Path):
    """Read TGNN-style processed files: ml_{dataset}.csv and ml_{dataset}_node.npy."""
    import pandas as pd

    csv_path = processed_dir / f"ml_{dataset}.csv"
    node_path = processed_dir / f"ml_{dataset}_node.npy"

    # If processed files are missing, try to generate them using the repo helper.
    if not csv_path.exists() or not node_path.exists():
        try:
            from time_to_explain.data.tgnn_setup import setup_tgnn_data
            print(f"[DBGNN] Processed files missing for '{dataset}'. Running setup_tgnn_data(...)")
            setup_tgnn_data(root=PROJECT_ROOT, only=[dataset], force=False, do_process=True)
        except Exception as e:
            print(f"[DBGNN] Auto-processing failed: {e}")

    if not csv_path.exists():
        raise FileNotFoundError(f"Missing TGNN CSV: {csv_path}")
    if not node_path.exists():
        raise FileNotFoundError(f"Missing TGNN node features: {node_path}")

    df = pd.read_csv(csv_path)

    # Standard TGNN/TGN column names are: u, i, ts, label, idx
    # If your CSV has no header, try to repair it.
    expected = {"u", "i", "ts"}
    if not expected.issubset(set(df.columns)):
        # Common case: unnamed columns 0..4
        if set(df.columns) >= {0, 1, 2}:
            df = df.rename(columns={0: "u", 1: "i", 2: "ts"})
        else:
            raise ValueError(f"Unexpected columns in {csv_path}: {list(df.columns)}")

    node_feat = np.load(node_path)
    return df, node_feat

def _infer_num_nodes(df, node_feat: np.ndarray) -> int:
    max_id = int(max(df["u"].max(), df["i"].max()))
    # node_feat might be larger than max_id+1 (some preprocessors include unused ids)
    return int(max(max_id + 1, node_feat.shape[0]))

def _make_node_type_labels(df, num_nodes: int) -> np.ndarray:
    """0 for nodes that appear in df['u'], 1 for nodes that appear in df['i'], -1 otherwise."""
    y = -np.ones(int(num_nodes), dtype=np.int64)
    u_nodes = set(int(x) for x in df["u"].to_numpy())
    i_nodes = set(int(x) for x in df["i"].to_numpy())

    # default bipartite assumption: u != i
    for n in u_nodes:
        if 0 <= n < num_nodes:
            y[n] = 0
    for n in i_nodes:
        if 0 <= n < num_nodes:
            if y[n] == 0:
                # If overlap happens (rare), mark as a 3rd class.
                y[n] = 2
            else:
                y[n] = 1
    return y

def _make_stratified_masks(y: np.ndarray, *, num_test: float, num_val: float, seed: int):
    """Create boolean train/val/test masks for labeled nodes (y>=0)."""
    from sklearn.model_selection import train_test_split

    y = np.asarray(y)
    idx = np.arange(len(y))
    labeled = (y >= 0)
    idx_lab = idx[labeled]
    y_lab = y[labeled]

    if idx_lab.size == 0:
        raise ValueError("No labeled nodes found (y < 0 everywhere).")

    idx_trainval, idx_test = train_test_split(
        idx_lab, test_size=float(num_test), random_state=int(seed), stratify=y_lab
    )
    # val fraction is defined w.r.t. labeled nodes, so convert to fraction of remaining
    val_frac_of_trainval = float(num_val) / float(1.0 - num_test)
    idx_train, idx_val = train_test_split(
        idx_trainval, test_size=val_frac_of_trainval, random_state=int(seed), stratify=y[idx_trainval]
    )

    train_mask = np.zeros(len(y), dtype=bool); train_mask[idx_train] = True
    val_mask = np.zeros(len(y), dtype=bool);   val_mask[idx_val] = True
    test_mask = np.zeros(len(y), dtype=bool);  test_mask[idx_test] = True
    return train_mask, val_mask, test_mask

def _build_fo_graph(df, *, num_nodes: int, undirected: bool = True, make_unique: bool = True):
    u = df["u"].to_numpy(dtype=np.int64)
    v = df["i"].to_numpy(dtype=np.int64)

    if undirected:
        src = np.concatenate([u, v], axis=0)
        dst = np.concatenate([v, u], axis=0)
    else:
        src, dst = u, v

    edge_index = np.stack([src, dst], axis=0)

    if make_unique:
        # Unique edges to reduce duplicates (Wikipedia has many repeated interactions)
        # Convert to structured array for fast unique.
        e = edge_index.T
        e = np.unique(e, axis=0)
        edge_index = e.T

    edge_index = torch.tensor(edge_index, dtype=torch.long)
    edge_weight = torch.ones(edge_index.size(1), dtype=torch.float32)
    return edge_index, edge_weight

def _build_identity_higher_order(
    *, num_nodes: int, x: torch.Tensor, edge_index: torch.Tensor, edge_weight: torch.Tensor
):
    """HO graph == FO graph; HO nodes == FO nodes; bipartite edges are identity."""
    x_h = x.clone()
    edge_index_ho = edge_index.clone()
    edge_weight_ho = edge_weight.clone()
    ids = torch.arange(int(num_nodes), dtype=torch.long)
    bipartite_edge_index = torch.stack([ids, ids], dim=0)  # [2, N]: [ho_id, fo_id]
    return x_h, edge_index_ho, edge_weight_ho, bipartite_edge_index

def _maybe_apply_node2vec_features(data, *, dim: int, seed: int):
    if not _HAS_NODE2VEC:
        raise RuntimeError(
            "node2vec is not installed. Install it (pip install node2vec) or set use_node2vec_features=False."
        )
    import networkx as nx

    # Build a NetworkX graph from FO edges
    ei = data.edge_index.detach().cpu().numpy()
    G = nx.Graph()
    G.add_edges_from(ei.T.tolist())

    # Learn embeddings (FO). We reuse the same for HO if ho_mode == identity.
    n2v = Node2Vec(
        G,
        dimensions=int(dim),
        walk_length=30,
        num_walks=10,
        p=1.0,
        q=1.0,
        workers=1,
        seed=int(seed),
    )
    w2v = n2v.fit(window=10, min_count=1, batch_words=128)

    emb = np.zeros((int(data.num_nodes), int(dim)), dtype=np.float32)
    for n in range(int(data.num_nodes)):
        emb[n] = w2v.wv[str(n)]
    x_new = torch.tensor(emb, device=data.x.device, dtype=torch.float32)

    data.x = x_new
    # If HO nodes == FO nodes, we can reuse embeddings.
    if hasattr(data, "x_h") and data.x_h is not None and data.x_h.size(0) == data.x.size(0):
        data.x_h = x_new.clone()

    return data

# --- Metrics helpers (copied from 01_train_dbgnn_and_node2vec.ipynb) ---
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

@torch.no_grad()
def _predict_labels(model, data):
    model.eval()
    logits = model(data)
    return logits.argmax(dim=1)

def _macro_scores(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
    labels = np.unique(y_true)
    return dict(
        accuracy=float(accuracy_score(y_true, y_pred)),
        precision_macro=float(precision_score(y_true, y_pred, average="macro", labels=labels, zero_division=0)),
        recall_macro=float(recall_score(y_true, y_pred, average="macro", labels=labels, zero_division=0)),
        f1_macro=float(f1_score(y_true, y_pred, average="macro", labels=labels, zero_division=0)),
    )

def _evaluate_macro_metrics(model, data):
    y = data.y.detach().cpu().numpy()
    pred = _predict_labels(model, data).detach().cpu().numpy()

    train_mask = getattr(data, "train_mask", None)
    test_mask = getattr(data, "test_mask", None)
    if train_mask is None or test_mask is None:
        raise ValueError("data.train_mask / data.test_mask are required for evaluation")

    train_mask = train_mask.detach().cpu().numpy().astype(bool)
    test_mask = test_mask.detach().cpu().numpy().astype(bool)

    labeled = (y >= 0)
    train_idx = train_mask & labeled
    test_idx = test_mask & labeled

    train_metrics = _macro_scores(y[train_idx], pred[train_idx])
    test_metrics = _macro_scores(y[test_idx], pred[test_idx])
    return train_metrics, test_metrics

def _evaluate_balanced_accuracy(model, data):
    train_metrics, test_metrics = _evaluate_macro_metrics(model, data)
    # Balanced accuracy == macro recall for multi-class
    return train_metrics["recall_macro"], test_metrics["recall_macro"]

def _build_dbgnn_model(data, *, device, hidden_dims, p_dropout):
    from pathpyG.nn.dbgnn import DBGNN

    # Use only labeled nodes to determine number of classes.
    y_lab = data.y[data.y >= 0]
    num_classes = int(y_lab.unique().numel()) if y_lab.numel() > 0 else int(data.y.unique().numel())

    num_features = (int(data.x.size(1)), int(data.x_h.size(1)))
    model = DBGNN(
        num_features=num_features,
        num_classes=int(num_classes),
        hidden_dims=list(hidden_dims),
        p_dropout=float(p_dropout),
    ).to(device)
    return model

def train_dbgnn_node_type(
    dataset: str,
    *,
    workdir: Path,
    processed_dir: Path,
    device: torch.device,
    cfg: dict[str, Any] | None = None,
):
    """Train DBGNN on Wikipedia (or another TGNN-style dataset) using node-type labels."""
    cfg = dict(DBGNN_DEFAULTS if cfg is None else cfg)

    df, node_feat_np = _read_tgnn_processed(dataset, processed_dir)
    num_nodes = _infer_num_nodes(df, node_feat_np)

    # Build labels + masks
    y_np = _make_node_type_labels(df, num_nodes)
    train_mask_np, val_mask_np, test_mask_np = _make_stratified_masks(
        y_np, num_test=float(cfg["num_test"]), num_val=float(cfg["num_val"]), seed=int(cfg["seed"])
    )

    # Features (TGNN provides node features)
    x = torch.tensor(node_feat_np[:num_nodes], dtype=torch.float32)

    # FO graph from interactions (static, deduplicated)
    edge_index, edge_weight = _build_fo_graph(df, num_nodes=num_nodes, undirected=True, make_unique=True)

    # HO graph (default: identity)
    ho_mode = str(cfg.get("ho_mode", "identity")).lower()
    if ho_mode != "identity":
        print(f"[DBGNN] ho_mode='{ho_mode}' not implemented in this notebook. Falling back to 'identity'.")
        ho_mode = "identity"

    x_h, edge_index_ho, edge_weight_ho, bip_ei = _build_identity_higher_order(
        num_nodes=num_nodes, x=x, edge_index=edge_index, edge_weight=edge_weight
    )

    # Assemble data object
    y = torch.tensor(y_np, dtype=torch.long)
    train_mask = torch.tensor(train_mask_np, dtype=torch.bool)
    val_mask = torch.tensor(val_mask_np, dtype=torch.bool)
    test_mask = torch.tensor(test_mask_np, dtype=torch.bool)

    if Data is not None:
        data = Data(
            x=x,
            y=y,
            edge_index=edge_index,
        )
        # Keep both common spellings for edge weights for maximum compatibility.
        data.edge_weights = edge_weight
        data.edge_weight = edge_weight

        # Extra attrs expected by DBGNN experiments
        data.x_h = x_h
        data.edge_index_higher_order = edge_index_ho
        data.edge_weights_higher_order = edge_weight_ho
        data.bipartite_edge_index = bip_ei
        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask
        data.num_nodes = int(num_nodes)
    else:
        from types import SimpleNamespace

        data = SimpleNamespace(
            x=x,
            y=y,
            edge_index=edge_index,
            edge_weights=edge_weight,
            edge_weight=edge_weight,
            x_h=x_h,
            edge_index_higher_order=edge_index_ho,
            edge_weights_higher_order=edge_weight_ho,
            bipartite_edge_index=bip_ei,
            train_mask=train_mask,
            val_mask=val_mask,
            test_mask=test_mask,
            num_nodes=int(num_nodes),
        )

    # Move tensors to device
    def _to_dev(v):
        return v.to(device) if torch.is_tensor(v) else v

    for attr in [
        "x",
        "y",
        "edge_index",
        "edge_weight",
        "x_h",
        "edge_index_higher_order",
        "edge_weights_higher_order",
        "bipartite_edge_index",
        "train_mask",
        "val_mask",
        "test_mask",
    ]:
        if hasattr(data, attr):
            setattr(data, attr, _to_dev(getattr(data, attr)))

    # Optional: overwrite features with Node2Vec embeddings
    if bool(cfg.get("use_node2vec_features", False)):
        dim = cfg.get("n2v_dim") or int(cfg["hidden_dims"][0])
        data = _maybe_apply_node2vec_features(data, dim=int(dim), seed=int(cfg["seed"]))

    # Build model
    model = _build_dbgnn_model(
        data, device=device, hidden_dims=cfg["hidden_dims"], p_dropout=cfg["p_dropout"]
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=float(cfg["lr"]))
    loss_fn = torch.nn.CrossEntropyLoss()

    losses = []
    train_ba_hist = []
    test_ba_hist = []

    for epoch in range(int(cfg["epochs"])):
        model.train()
        logits = model(data)
        loss = loss_fn(logits[data.train_mask], data.y[data.train_mask])

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        losses.append(float(loss.detach().cpu().item()))
        train_ba, test_ba = _evaluate_balanced_accuracy(model, data)
        train_ba_hist.append(float(train_ba))
        test_ba_hist.append(float(test_ba))

        if epoch % 20 == 0:
            print(
                f"[DBGNN:{dataset}] epoch={epoch:04d}  loss={loss.item():.4f}  "
                f"train_bal_acc={train_ba:.4f}  test_bal_acc={test_ba:.4f}"
            )

    # Save artifacts
    workdir = Path(workdir)
    workdir.mkdir(parents=True, exist_ok=True)

    ckpt_path = workdir / "dbgnn_node_type.pt"
    torch.save(
        {
            "model_state_dict": model.state_dict(),
            "cfg": cfg,
            "dataset": dataset,
        },
        ckpt_path,
    )

    metrics_path = workdir / "dbgnn_train_metrics.json"
    train_metrics, test_metrics = _evaluate_macro_metrics(model, data)
    import json as _json

    metrics_path.write_text(
        _json.dumps(
            {
                "train": train_metrics,
                "test": test_metrics,
                "losses": losses,
                "train_bal_acc": train_ba_hist,
                "test_bal_acc": test_ba_hist,
            },
            indent=2,
        ),
        encoding="utf-8",
    )

    print(f"[DBGNN:{dataset}] Saved checkpoint -> {ckpt_path}")
    print(f"[DBGNN:{dataset}] Saved metrics    -> {metrics_path}")

    # Also export under resources/models/<dataset>/dbgnn for consistency with other baselines.
    export_dir = (RESOURCES_MODELS / slugify(dataset) / "dbgnn")
    export_dir.mkdir(parents=True, exist_ok=True)
    exported_ckpt = export_dir / ckpt_path.name
    exported_ckpt.write_bytes(ckpt_path.read_bytes())
    print(f"[DBGNN:{dataset}] Exported checkpoint -> {exported_ckpt}")

    return model, data, {
        "ckpt": str(ckpt_path),
        "metrics": str(metrics_path),
        "exported_ckpt": str(exported_ckpt),
    }

print("DBGNN utilities ready.")

DBGNN utilities ready.


### Train Model

In [5]:
# --- Launch training for each dataset ---
env = prepare_env(project_root=PROJECT_ROOT, cuda_visible_devices=CUDA_VISIBLE_DEVICES)

for dataset in DATASET_LIST:
    workdir = ensure_workdir(RUNS_ROOT, MODEL_TYPE, dataset)
    print("=== Launching", MODEL_TYPE.upper(), "on", dataset, "===")
    print("Artifacts will be stored under:", workdir)

    if MODEL_TYPE.upper() == "TGN":
        cmd = build_cmd(PYTHON_BIN, TGN_SCRIPT, get_tgn_args(dataset))
        code = run_cmd(cmd, env=env, workdir=workdir, dry_run=DRY_RUN)
        if code == 0:
            export_trained_models(MODEL_TYPE, dataset, workdir, RESOURCES_MODELS)
            print("[TGN] Training completed (or started successfully) for", dataset)
        else:
            print("[TGN] Training failed for", dataset, "- see output above.")

    elif MODEL_TYPE.upper() == "TGAT":
        try:
            cmd = build_cmd(PYTHON_BIN, TGAT_SCRIPT, get_tgat_args(dataset))
        except FileNotFoundError as exc:
            raise FileNotFoundError(
                f"{exc} Update configs/models to point to your TGAT training script."
            )
        code = run_cmd(cmd, env=env, workdir=workdir, dry_run=DRY_RUN)
        if code == 0:
            export_trained_models(MODEL_TYPE, dataset, workdir, RESOURCES_MODELS)
            print("[TGAT] Training completed (or started successfully) for", dataset)
        else:
            print("[TGAT] Training failed for", dataset, "- see output above.")

    elif MODEL_TYPE.upper() == "GRAPHMIXER":
        if GRAPHMIXER_PROCESSED_DIR is None:
            raise ValueError("GRAPHMIXER processed_dir missing in config.")
        ensure_tempme_processed(
            dataset,
            processed_dir=GRAPHMIXER_PROCESSED_DIR,
            resources_datasets=RESOURCES_DATASETS,
        )
        if not GRAPHMIXER_SCRIPT or not GRAPHMIXER_SCRIPT.exists():
            raise FileNotFoundError(f"GraphMixer training script not found: {GRAPHMIXER_SCRIPT}")
        cmd = build_cmd(PYTHON_BIN, GRAPHMIXER_SCRIPT, get_graphmixer_args(dataset))
        code = run_cmd(cmd, env=env, workdir=GRAPHMIXER_SCRIPT.parent, dry_run=DRY_RUN)
        if code == 0:
            exported = []
            if GRAPHMIXER_PARAMS_DIR and GRAPHMIXER_PARAMS_DIR.exists():
                dest_dir = RESOURCES_MODELS / slugify(dataset) / "graphmixer"
                dest_dir.mkdir(parents=True, exist_ok=True)
                for src_file in sorted(GRAPHMIXER_PARAMS_DIR.glob(f"graphmixer_{slugify(dataset)}*.pt")):
                    dest_path = dest_dir / src_file.name
                    dest_path.write_bytes(src_file.read_bytes())
                    exported.append(dest_path)
                if not exported:
                    for src_file in sorted(GRAPHMIXER_PARAMS_DIR.glob("graphmixer_*.pt")):
                        dest_path = dest_dir / src_file.name
                        dest_path.write_bytes(src_file.read_bytes())
                        exported.append(dest_path)
            if exported:
                print("[GraphMixer] Exported:")
                for p in exported:
                    print(" -", p)
            else:
                print("[GraphMixer] No checkpoints found under", GRAPHMIXER_PARAMS_DIR)
            print("[GraphMixer] Training completed (or started successfully) for", dataset)
        else:
            print("[GraphMixer] Training failed for", dataset)

    elif MODEL_TYPE.upper() == "DBGNN":
        # In-notebook trainer (adapted from 01_train_dbgnn_and_node2vec.ipynb)
        if DRY_RUN:
            print("[DBGNN] DRY_RUN=True -> skip training.")
            continue

        # Allow overriding defaults from configs/models/*.json (optional).
        # Example:
        #   "models": { "DBGNN": { "cfg": { "epochs": 200, "hidden_dims": [16,32,16] } } }
        cfg = dict(DBGNN_DEFAULTS)
        try:
            cfg.update((MODEL_SPECS.get("DBGNN") or {}).get("cfg") or {})
        except Exception:
            pass

        _device = torch.device(DEVICE) if not isinstance(DEVICE, torch.device) else DEVICE
        _, _, artifacts = train_dbgnn_node_type(
            dataset,
            workdir=workdir,
            processed_dir=Path(RESOURCES_DATASETS),
            device=_device,
            cfg=cfg,
        )
        print("[DBGNN] Done. Artifacts:", artifacts)

    else:
        raise ValueError("MODEL_TYPE must be 'TGN' or 'TGAT' or 'GRAPHMIXER' or 'DBGNN'")


=== Launching TGN on wikipedia ===
Artifacts will be stored under: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/runs/tgn_wikipedia
$ (cwd=/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/runs/tgn_wikipedia) python /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/train_self_supervised.py --data wikipedia --use_memory --prefix tgn-attn --n_runs 10


INFO:root:Namespace(data='wikipedia', bs=200, prefix='tgn-attn', n_degree=10, n_head=2, n_epoch=50, n_layer=1, lr=0.0001, patience=5, n_runs=10, drop_out=0.1, gpu=0, node_dim=100, time_dim=100, backprop_every=1, use_memory=True, embedding_module='graph_attention', message_function='identity', memory_updater='gru', aggregator='last', memory_update_at_end=False, message_dim=100, memory_dim=172, different_new_nodes=False, uniform=False, randomize_features=False, use_destination_embedding_in_message=False, use_source_embedding_in_message=False, dyrep=False)
Traceback (most recent call last):
  File "/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/train_self_supervised.py", line 181, in <module>
    ) = get_data(
        ^^^^^^^^^
  File "/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/utils/data_processing.py", line 100, in get_data
    graph_df = pd.read_csv(graph_df_file)
               ^^

/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
[ERROR] process exited with code 1
[TGN] Training failed for wikipedia - see output above.


### TempME pretraining (run once, outside evaluation)
Pre-train the TempME base model + explainer here so evaluation notebooks can set `train_if_missing=False` and avoid training inside the explainer pipeline.


In [6]:
if MODEL_TYPE.upper() == "DBGNN":
    print("Skip TempME pretraining for DBGNN (only relevant for TGNN baselines).")
else:
    # --- TempME pretraining (base model + explainer) ---
    # Runs in a clean process to avoid memory pressure during evaluation.
    TEMP_ME_BASE_TYPE = MODEL_TYPE.lower()
    if TEMP_ME_BASE_TYPE not in {"tgn", "tgat", "graphmixer"}:
        TEMP_ME_BASE_TYPE = "tgn"

    TEMP_ME_ROOT = PROJECT_ROOT / "submodules" / "explainer" / "tempme"
    TEMP_ME_LEARN_BASE = TEMP_ME_ROOT / "learn_base.py"
    TEMP_ME_EXP_MAIN = TEMP_ME_ROOT / "temp_exp_main.py"

    TEMP_ME_CKPT_ROOT = PROJECT_ROOT / "resources" / "explainer" / "tempme"

    LEGACY_TEMP_ME_CKPT_ROOT = TEMP_ME_ROOT / "params"

    # Optional overrides to reduce memory usage. Leave empty to use TempME defaults.
    TEMP_ME_BASE_OVERRIDES = {
        # "bs": 128,
        # "n_epoch": 50,
    }
    TEMP_ME_EXP_OVERRIDES = {
        # "bs": 128,
        # "test_bs": 128,
        # "n_epoch": 80,
    }

    from time_to_explain.data.tgnn_setup import setup_tgnn_data
    import shutil

    REAL_TGNN_DATASETS = {"wikipedia", "reddit", "simulate_v1", "simulate_v2", "multihost"}

    def _ensure_tempme_inputs(dataset: str) -> None:
        missing = []
        for fname in (f"ml_{dataset}.csv", f"ml_{dataset}.npy", f"ml_{dataset}_node.npy"):
            if not (RESOURCES_DATASETS / fname).exists():
                missing.append(fname)
        if not missing:
            return
        print("[TempME] Missing processed files, regenerating:", ", ".join(missing))
        if dataset not in REAL_TGNN_DATASETS:
            raise ValueError(f"TempME inputs missing for '{dataset}'. Prepare the dataset or add it to REAL_TGNN_DATASETS.")
        setup_tgnn_data(root=PROJECT_ROOT, only=[dataset], force=False, do_process=True)

    def _copy_if_missing(src: Path, dst: Path) -> None:
        if not src.exists() or dst.exists():
            return
        dst.parent.mkdir(parents=True, exist_ok=True)
        try:
            dst.symlink_to(src)
        except Exception:
            shutil.copy2(src, dst)

    def _maybe_migrate_tempme_ckpts(base_type: str, dataset: str) -> None:
        legacy_base = LEGACY_TEMP_ME_CKPT_ROOT / "tgnn" / f"{base_type}_{dataset}.pt"
        legacy_expl = LEGACY_TEMP_ME_CKPT_ROOT / "explainer" / base_type / f"{dataset}.pt"
        new_base, new_expl = _tempme_ckpts(base_type, dataset)
        _copy_if_missing(legacy_base, new_base)
        _copy_if_missing(legacy_expl, new_expl)

    def _tempme_ckpts(base_type: str, dataset: str):
        base_ckpt = TEMP_ME_CKPT_ROOT / "params" / "tgnn" / f"{base_type}_{dataset}.pt"
        expl_ckpt = TEMP_ME_CKPT_ROOT / "params" / "explainer" / base_type / f"{dataset}.pt"
        return base_ckpt, expl_ckpt

    env = prepare_env(project_root=PROJECT_ROOT, cuda_visible_devices=CUDA_VISIBLE_DEVICES)

    for dataset in DATASET_LIST:
        _maybe_migrate_tempme_ckpts(TEMP_ME_BASE_TYPE, dataset)

        _ensure_tempme_inputs(dataset)

        base_ckpt, expl_ckpt = _tempme_ckpts(TEMP_ME_BASE_TYPE, dataset)

        if base_ckpt.exists():
            print("[TempME] Base checkpoint exists:", base_ckpt)
        else:
            base_args = {"base_type": TEMP_ME_BASE_TYPE, "data": dataset, **TEMP_ME_BASE_OVERRIDES}
            cmd = build_cmd(PYTHON_BIN, TEMP_ME_LEARN_BASE, args_dict_to_list(base_args, dataset))
            code = run_cmd(cmd, env=env, workdir=TEMP_ME_ROOT, dry_run=DRY_RUN)
            if code != 0:
                raise RuntimeError("TempME base training failed; see logs above.")

        if expl_ckpt.exists():
            print("[TempME] Explainer checkpoint exists:", expl_ckpt)
        else:
            expl_args = {"base_type": TEMP_ME_BASE_TYPE, "data": dataset, **TEMP_ME_EXP_OVERRIDES}
            cmd = build_cmd(PYTHON_BIN, TEMP_ME_EXP_MAIN, args_dict_to_list(expl_args, dataset))
            code = run_cmd(cmd, env=env, workdir=TEMP_ME_ROOT, dry_run=DRY_RUN)
            if code != 0:
                raise RuntimeError("TempME explainer training failed; see logs above.")

        print("[TempME] Ready for dataset:", dataset)


$ (cwd=/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/explainer/tempme) python /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/explainer/tempme/learn_base.py --base_type tgn --data wikipedia


Traceback (most recent call last):
  File "/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/explainer/tempme/learn_base.py", line 106, in <module>
    mask_node_set = set(random.sample(set(src_l[ts_l > val_time]).union(set(dst_l[ts_l > val_time])),
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/juliawenkmann/miniconda3/envs/ml/lib/python3.11/random.py", line 439, in sample
    raise TypeError("Population must be a sequence.  "
TypeError: Population must be a sequence.  For dicts or sets, use sorted(d).


[ERROR] process exited with code 1


RuntimeError: TempME base training failed; see logs above.

In [None]:
# --- Contact edge sanity check + visualization (stick_figure + sticky_hips + TGN) ---
if DRY_RUN:
    print("Skip contact-edge check (DRY_RUN=True).")
elif MODEL_TYPE.upper() != "TGN":
    print("Skip contact-edge check (MODEL_TYPE != TGN).")
else:
    import json
    import math
    import numpy as np
    import pandas as pd
    import torch
    import plotly.graph_objects as go

    from time_to_explain.data.synthetic_recipes.stick_figure import JOINTS_PER_PERSON
    from time_to_explain.visualization.utils import COLORS
    from submodules.models.tgn.model.tgn import TGN
    from submodules.models.tgn.tgn_utils.data_processing import get_data, compute_time_statistics
    from submodules.models.tgn.tgn_utils.utils import RandEdgeSampler, get_neighbor_finder

    def _find_checkpoint(models_root: Path, dataset_name: str, model_name: str) -> Path:
        model_name = model_name.lower()
        dataset_name = str(dataset_name)
        candidates = [
            models_root / dataset_name / model_name / f"{model_name}_{dataset_name}_best.pth",
            models_root / dataset_name / "checkpoints" / f"{model_name}_{dataset_name}_best.pth",
            models_root / "checkpoints" / f"{model_name}_{dataset_name}_best.pth",
        ]
        for cand in candidates:
            if cand.exists():
                return cand
        search_roots = [
            models_root / dataset_name / model_name,
            models_root / dataset_name,
            models_root / "checkpoints",
        ]
        for root in search_roots:
            if not root.exists():
                continue
            matches = sorted(root.rglob(f"{model_name}*{dataset_name}*.pth"))
            if not matches:
                matches = sorted(root.rglob("*.pth"))
            for match in matches:
                if "best" in match.name:
                    return match
            if matches:
                return matches[0]
        raise FileNotFoundError(
            f"Checkpoint not found under {models_root} for {model_name}_{dataset_name}."
        )

    def _build_tgn_args(train_args: dict) -> dict:
        return {
            "n_layers": int(train_args.get("n_layer", train_args.get("n_layers", 1))),
            "n_heads": int(train_args.get("n_head", train_args.get("n_heads", 2))),
            "dropout": float(train_args.get("drop_out", 0.1)),
            "use_memory": bool(train_args.get("use_memory", False)),
            "message_dimension": int(train_args.get("message_dim", 100)),
            "memory_dimension": int(train_args.get("memory_dim", 172)),
            "memory_update_at_start": not bool(train_args.get("memory_update_at_end", False)),
            "embedding_module_type": str(train_args.get("embedding_module", "graph_attention")),
            "message_function": str(train_args.get("message_function", "identity")),
            "aggregator_type": str(train_args.get("aggregator", "last")),
            "memory_updater_type": str(train_args.get("memory_updater", "gru")),
            "use_destination_embedding_in_message": bool(train_args.get("use_destination_embedding_in_message", False)),
            "use_source_embedding_in_message": bool(train_args.get("use_source_embedding_in_message", False)),
            "dyrep": bool(train_args.get("dyrep", False)),
        }

    def _load_processed_tables(dataset: str) -> tuple[pd.DataFrame, np.ndarray, dict]:
        flat_csv = RESOURCES_DATASETS / f"ml_{dataset}.csv"
        if flat_csv.exists():
            data_dir = RESOURCES_DATASETS
        else:
            data_dir = RESOURCES_DATASETS / dataset
        graph_df = pd.read_csv(data_dir / f"ml_{dataset}.csv")
        edge_features = np.load(data_dir / f"ml_{dataset}.npy")
        meta_path = data_dir / f"ml_{dataset}.json"
        meta = json.loads(meta_path.read_text(encoding="utf-8")) if meta_path.exists() else {}
        return graph_df, edge_features, meta


    def _plot_frame(
        frame_id: int,
        *,
        title: str,
        probs: dict | None = None,
        show_gt: bool = True,
        pred_query_only: bool = True,
    ) -> go.Figure:
        fig = go.Figure()
        mask = (clip_ids == clip_id) & (frame_idx == frame_id)
        indices = np.where(mask)[0]

        if show_gt and probs is not None:
            for idx in indices:
                if is_query[idx]:
                    continue
                coords = feat_map[idx]
                x0, y0, x1, y1 = [float(v) for v in coords[:4]]
                fig.add_trace(
                    go.Scatter(
                        x=[x0, x1],
                        y=[y0, y1],
                        mode="lines",
                        line=dict(color=COLORS["base"], width=1.2),
                        opacity=0.35,
                        hoverinfo="skip",
                        showlegend=False,
                    )
                )

        for idx in indices:
            coords = feat_map[idx]
            x0, y0, x1, y1 = [float(v) for v in coords[:4]]
            if probs is None:
                if is_query[idx]:
                    color = COLORS["accent"]
                    width = 3.2
                else:
                    color = COLORS["base"]
                    width = 2.0
            else:
                if pred_query_only and not is_query[idx]:
                    continue
                prob = probs.get(int(idx), 0.0)
                if prob < 0.5:
                    continue
                color = COLORS["user"]
                width = 2.8
            fig.add_trace(
                go.Scatter(
                    x=[x0, x1],
                    y=[y0, y1],
                    mode="lines",
                    line=dict(color=color, width=width),
                    hovertemplate=f"edge_idx={idx_vals[idx]}<extra></extra>",
                    showlegend=False,
                )
            )
        fig.update_layout(
            title=title,
            template="simple_white",
            xaxis=dict(visible=False),
            yaxis=dict(visible=False, scaleanchor="x", scaleratio=1),
            margin=dict(l=20, r=20, t=60, b=20),
        )
        return fig

    tgn_spec = MODEL_SPECS.get("TGN", {})
    train_args = dict(tgn_spec.get("args") or {})
    n_neighbors = int(train_args.get("n_degree", 10))
    batch_size = int(train_args.get("bs", 200))
    tgn_args = _build_tgn_args(train_args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Contact-edge check device:", device)

    for dataset in DATASET_LIST:
        if dataset not in {"stick_figure", "sticky_hips"}:
            print(f"Skip {dataset}: contact-edge check is configured for stick_figure/sticky_hips only.")
            continue

        try:
            ckpt_path = _find_checkpoint(RESOURCES_MODELS, dataset, "tgn")
        except FileNotFoundError as exc:
            print(f"Skip {dataset}: {exc}")
            continue
        print("Using checkpoint:", ckpt_path)

        node_features, edge_features, full_data, _train_data, _val_data, test_data, _nn_val, _nn_test = get_data(dataset)
        m_src, s_src, m_dst, s_dst = compute_time_statistics(
            full_data.sources, full_data.destinations, full_data.timestamps
        )
        full_ngh_finder = get_neighbor_finder(full_data, uniform=False)

        model = TGN(
            neighbor_finder=full_ngh_finder,
            node_features=node_features,
            edge_features=edge_features,
            device=device,
            n_neighbors=n_neighbors,
            mean_time_shift_src=m_src,
            std_time_shift_src=s_src,
            mean_time_shift_dst=m_dst,
            std_time_shift_dst=s_dst,
            **tgn_args,
        )
        state_dict = torch.load(ckpt_path, map_location="cpu")
        filtered_state = {
            k: v
            for k, v in state_dict.items()
            if not (k.startswith("memory.") or k.startswith("memory_updater.memory."))
        }
        _ = model.load_state_dict(filtered_state, strict=False)
        model = model.to(device).eval()
        if getattr(model, "use_memory", False) and getattr(model, "memory", None) is not None:
            model.memory.__init_memory__()

        graph_df, edge_feat_full, meta = _load_processed_tables(dataset)
        cfg_meta = meta.get("config") if isinstance(meta.get("config"), dict) else {}
        frames = int(cfg_meta.get("frames", 30))

        idx_col = "idx" if "idx" in graph_df.columns else ("e_idx" if "e_idx" in graph_df.columns else None)
        if idx_col is None:
            print("Paired contact check skipped: missing idx/e_idx.")
            continue
        idx_vals = graph_df[idx_col].astype(int).to_numpy()
        if edge_feat_full.ndim != 2 or edge_feat_full.shape[1] < 9:
            print("Paired contact check skipped: edge_features missing frame info.")
            continue
        if edge_feat_full.shape[0] > int(idx_vals.max()):
            feat_map = edge_feat_full[idx_vals]
        elif edge_feat_full.shape[0] == len(graph_df):
            feat_map = edge_feat_full
        else:
            print("Paired contact check skipped: edge_features length mismatch.")
            continue

        u_all = graph_df["u"].astype(int).to_numpy()
        v_all = graph_df["i"].astype(int).to_numpy()
        ts_all = graph_df["ts"].astype(float).to_numpy()
        frame_norm = feat_map[:, 8]
        is_query = feat_map[:, 7] >= 0.5

        node_min = int(min(u_all.min(), v_all.min()))
        node_base = 1 if node_min >= 1 else 0
        clip_ids = (u_all - node_base) // JOINTS_PER_PERSON
        num_clips = int(clip_ids.max() + 1)
        frame_idx = np.clip(np.rint(frame_norm * (frames - 1)), 0, frames - 1).astype(int)

        frame_time = np.full((num_clips, frames), np.nan)
        contact_present = np.zeros((num_clips, frames), dtype=bool)
        for i in range(len(graph_df)):
            c = int(clip_ids[i])
            f = int(frame_idx[i])
            if is_query[i]:
                contact_present[c, f] = True
            if np.isnan(frame_time[c, f]) or ts_all[i] < frame_time[c, f]:
                frame_time[c, f] = ts_all[i]

        idx_to_row = {int(v): i for i, v in enumerate(idx_vals.tolist())}
        pos_mask = np.asarray(test_data.labels) == 1
        pos_edge_idxs = np.asarray(test_data.edge_idxs)[pos_mask]

        pos_sources = []
        pos_destinations = []
        pos_times = []
        neg_sources = []
        neg_destinations = []
        neg_times = []

        def _find_non_contact_frame(c: int, f: int) -> int | None:
            for delta in range(1, frames):
                lo = f - delta
                hi = f + delta
                if lo >= 0 and not contact_present[c, lo] and not np.isnan(frame_time[c, lo]):
                    return lo
                if hi < frames and not contact_present[c, hi] and not np.isnan(frame_time[c, hi]):
                    return hi
            return None

        for e_idx in pos_edge_idxs:
            row_idx = idx_to_row.get(int(e_idx))
            if row_idx is None or not is_query[row_idx]:
                continue
            c = int(clip_ids[row_idx])
            f = int(frame_idx[row_idx])
            neg_f = _find_non_contact_frame(c, f)
            if neg_f is None:
                continue
            pos_sources.append(int(u_all[row_idx]))
            pos_destinations.append(int(v_all[row_idx]))
            pos_times.append(float(ts_all[row_idx]))
            neg_sources.append(int(u_all[row_idx]))
            neg_destinations.append(int(v_all[row_idx]))
            neg_times.append(float(frame_time[c, neg_f]))

        if not pos_sources:
            print("Paired contact check: no usable contact/non-contact pairs found.")
            continue

        pos_sources = np.asarray(pos_sources)
        pos_destinations = np.asarray(pos_destinations)
        pos_times = np.asarray(pos_times)
        neg_sources = np.asarray(neg_sources)
        neg_destinations = np.asarray(neg_destinations)
        neg_times = np.asarray(neg_times)

        sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0)

        prev_forbidden = getattr(model, "forbidden_memory_update", False)
        model.forbidden_memory_update = True
        if getattr(model, "use_memory", False) and getattr(model, "memory", None) is not None:
            model.memory.__init_memory__()

        def _batched_pos_probs(src: np.ndarray, dst: np.ndarray, ts: np.ndarray) -> np.ndarray:
            out = []
            n = len(src)
            for i in range(0, n, batch_size):
                s = src[i:i + batch_size]
                d = dst[i:i + batch_size]
                t = ts[i:i + batch_size]
                if len(s) == 0:
                    continue
                _, neg = sampler.sample(len(s))
                edge_idx_dummy = np.zeros(len(s), dtype=int)
                pos_prob, _ = model.compute_edge_probabilities(
                    s,
                    d,
                    neg,
                    t,
                    edge_idx_dummy,
                    n_neighbors=n_neighbors,
                )
                out.append(pos_prob.detach().cpu().numpy())
            return np.concatenate(out) if out else np.array([])

        pos_prob_pair = _batched_pos_probs(pos_sources, pos_destinations, pos_times)
        neg_prob_pair = _batched_pos_probs(neg_sources, neg_destinations, neg_times)

        model.forbidden_memory_update = prev_forbidden

        pos_hit = int(np.sum(pos_prob_pair >= 0.5))
        neg_reject = int(np.sum(neg_prob_pair < 0.5))
        total = int(pos_prob_pair.size)
        print(f"Paired contact hit-rate (pos >= 0.5): {100.0 * pos_hit / total:.2f}% ({pos_hit}/{total})")
        print(f"Paired non-contact reject (neg < 0.5): {100.0 * neg_reject / total:.2f}% ({neg_reject}/{total})")

        # Visualization: predictions at multiple frames
        row_idx = None
        for e_idx in pos_edge_idxs:
            idx = idx_to_row.get(int(e_idx))
            if idx is not None and is_query[idx]:
                row_idx = idx
                break
        if row_idx is None:
            print("Visualization skipped: no contact edge found in test split.")
            continue

        clip_id = int(clip_ids[row_idx])
        frame_contact = int(frame_idx[row_idx])
        pred_frames = 8
        start_frame = max(0, frame_contact - (pred_frames // 2))
        end_frame = min(frames - 1, start_frame + pred_frames - 1)
        start_frame = max(0, end_frame - (pred_frames - 1))
        display_frames = list(range(start_frame, end_frame + 1))

        prev_forbidden = getattr(model, "forbidden_memory_update", False)
        model.forbidden_memory_update = False
        if getattr(model, "use_memory", False) and getattr(model, "memory", None) is not None:
            model.memory.__init_memory__()

        frame_probs: dict[int, dict[int, float]] = {}
        for f in range(frames):
            indices_f = np.where((clip_ids == clip_id) & (frame_idx == f))[0]
            if len(indices_f) == 0:
                continue
            s = u_all[indices_f]
            d = v_all[indices_f]
            t = ts_all[indices_f]
            e = idx_vals[indices_f]
            _, neg = sampler.sample(len(s))
            pos_prob, _ = model.compute_edge_probabilities(
                s,
                d,
                neg,
                t,
                e,
                n_neighbors=n_neighbors,
            )
            if f in display_frames:
                probs = {
                    int(idx): float(p)
                    for idx, p in zip(indices_f, pos_prob.detach().cpu().numpy())
                }
                frame_probs[int(f)] = probs

        model.forbidden_memory_update = prev_forbidden

        if not frame_probs:
            print("Visualization skipped: no frames selected.")
            continue

        for f in display_frames:
            probs = frame_probs.get(int(f))
            if not probs:
                continue
            contact_flag = "contact" if contact_present[clip_id, f] else "no contact"
            fig_pred = _plot_frame(
                f,
                title=f"{dataset}: predictions at frame {f} ({contact_flag})",
                probs=probs,
                show_gt=True,
                pred_query_only=True,
            )
            fig_pred.show()

Contact-edge check device: cpu
Skip wikipedia: contact-edge check is configured for stick_figure/sticky_hips only.
