# Train A Model

In [1]:
from __future__ import annotations

import json
import sys
from pathlib import Path


def _bootstrap_repo_root(start: Path | None = None) -> Path:
    here = (start or Path.cwd()).resolve()
    for candidate in (here, *here.parents):
        if (candidate / "time_to_explain").is_dir():
            return candidate
    raise RuntimeError(
        f"Could not locate the repository root from {here}. "
        "Set PROJECT_ROOT manually if your layout is unusual."
    )


PROJECT_ROOT = _bootstrap_repo_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from time_to_explain.models.device import pick_device
CONFIG_PATH = PROJECT_ROOT / "configs" / "notebooks" / "global.json"
NOTEBOOK_CFG = json.loads(CONFIG_PATH.read_text(encoding="utf-8")) if CONFIG_PATH.exists() else {}
SEED = int(NOTEBOOK_CFG.get("seed", 42))
DEVICE = pick_device(NOTEBOOK_CFG.get("device", "auto"))
print(f"Notebook config: seed={SEED}, device={DEVICE}")

from time_to_explain.models.utils import (
    build_cmd,
    ensure_tempme_processed,
    ensure_workdir,
    export_trained_models,
    prepare_env,
    run_cmd,
)
from time_to_explain.utils.cli import (
    args_dict_to_list,
    normalize_datasets,
    resolve_path,
    slugify,
)


Notebook config: seed=42, device=mps


### Set Config

In [2]:
DATASET_NAME = "stick_figure" #'triadic_closure' #"stick_figure"  # change to "nicolaus", etc.
MODEL_TYPE = "tgn"

### Get Config

In [3]:
CONFIG_PATH = Path(f"configs/models/train_{MODEL_TYPE}_{DATASET_NAME}.json")

CONFIG_PATH = resolve_path(str(CONFIG_PATH), root=PROJECT_ROOT)
if CONFIG_PATH is None or not CONFIG_PATH.exists():
    raise FileNotFoundError(f"Config not found: {CONFIG_PATH}")
CONFIG = json.loads(CONFIG_PATH.read_text(encoding="utf-8"))

MODEL_TYPE = str(CONFIG.get("model_type", "TGN")).upper()
DATASET_LIST = normalize_datasets(CONFIG.get("datasets", []))
PYTHON_BIN = str(CONFIG.get("python_bin", "python"))
DRY_RUN = bool(CONFIG.get("dry_run", False))
CUDA_VISIBLE_DEVICES = CONFIG.get("cuda_visible_devices")

MODEL_SPECS = {str(k).upper(): v for k, v in (CONFIG.get("models") or {}).items()}
if MODEL_TYPE not in MODEL_SPECS:
    raise KeyError(f"Model spec for {MODEL_TYPE} missing in {CONFIG_PATH}")

RESOURCES_MODELS = resolve_path(CONFIG.get("resources_models_dir", "resources/models"), root=PROJECT_ROOT)
RUNS_ROOT = resolve_path(CONFIG.get("runs_root", "resources/models/runs"), root=PROJECT_ROOT)
RESOURCES_DATASETS = resolve_path(CONFIG.get("resources_datasets_dir", "resources/datasets/processed"), root=PROJECT_ROOT)

DEFAULT_WORKDIR = RUNS_ROOT / f"{slugify(MODEL_TYPE)}_{slugify(DATASET_LIST[0])}"

TGN_SPEC = MODEL_SPECS.get("TGN", {})
TGAT_SPEC = MODEL_SPECS.get("TGAT", {})
GRAPHMIXER_SPEC = MODEL_SPECS.get("GRAPHMIXER", {})

TGN_SCRIPT = resolve_path(TGN_SPEC.get("script"), root=PROJECT_ROOT)
TGAT_SCRIPT = resolve_path(TGAT_SPEC.get("script"), root=PROJECT_ROOT)
GRAPHMIXER_SCRIPT = resolve_path(GRAPHMIXER_SPEC.get("script"), root=PROJECT_ROOT)
GRAPHMIXER_PROCESSED_DIR = resolve_path(GRAPHMIXER_SPEC.get("processed_dir"), root=PROJECT_ROOT)
GRAPHMIXER_PARAMS_DIR = resolve_path(GRAPHMIXER_SPEC.get("params_dir"), root=PROJECT_ROOT)

def get_tgn_args(dataset: str) -> list[str]:
    return args_dict_to_list(TGN_SPEC.get("args", {}), dataset)

def get_tgat_args(dataset: str) -> list[str]:
    return args_dict_to_list(TGAT_SPEC.get("args", {}), dataset)

def get_graphmixer_args(dataset: str) -> list[str]:
    return args_dict_to_list(GRAPHMIXER_SPEC.get("args", {}), dataset)

print("Configuration loaded from:", CONFIG_PATH)
print("PROJECT_ROOT:", PROJECT_ROOT)
print("MODEL_TYPE:", MODEL_TYPE)
print("DATASETS:", DATASET_LIST)
print("Artifacts root:", RUNS_ROOT)
print("Sample run directory:", DEFAULT_WORKDIR)


Configuration loaded from: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/configs/models/train_tgn_stick_figure.json
PROJECT_ROOT: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
MODEL_TYPE: TGN
DATASETS: ['stick_figure']
Artifacts root: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/runs
Sample run directory: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/runs/tgn_stick_figure


### Train Model

In [4]:
# --- Launch training for each dataset ---
env = prepare_env(project_root=PROJECT_ROOT, cuda_visible_devices=CUDA_VISIBLE_DEVICES)

for dataset in DATASET_LIST:
    workdir = ensure_workdir(RUNS_ROOT, MODEL_TYPE, dataset)
    print("=== Launching", MODEL_TYPE.upper(), "on", dataset, "===")
    print("Artifacts will be stored under:", workdir)

    if MODEL_TYPE.upper() == "TGN":
        cmd = build_cmd(PYTHON_BIN, TGN_SCRIPT, get_tgn_args(dataset))
        code = run_cmd(cmd, env=env, workdir=workdir, dry_run=DRY_RUN)
        if code == 0:
            export_trained_models(MODEL_TYPE, dataset, workdir, RESOURCES_MODELS)
            print("[TGN] Training completed (or started successfully) for", dataset)
        else:
            print("[TGN] Training failed for", dataset, "- see output above.")

    elif MODEL_TYPE.upper() == "TGAT":
        try:
            cmd = build_cmd(PYTHON_BIN, TGAT_SCRIPT, get_tgat_args(dataset))
        except FileNotFoundError as exc:
            raise FileNotFoundError(
                f"{exc} Update configs/models to point to your TGAT training script."
            )
        code = run_cmd(cmd, env=env, workdir=workdir, dry_run=DRY_RUN)
        if code == 0:
            export_trained_models(MODEL_TYPE, dataset, workdir, RESOURCES_MODELS)
            print("[TGAT] Training completed (or started successfully) for", dataset)
        else:
            print("[TGAT] Training failed for", dataset, "- see output above.")

    elif MODEL_TYPE.upper() == "GRAPHMIXER":
        if GRAPHMIXER_PROCESSED_DIR is None:
            raise ValueError("GRAPHMIXER processed_dir missing in config.")
        ensure_tempme_processed(
            dataset,
            processed_dir=GRAPHMIXER_PROCESSED_DIR,
            resources_datasets=RESOURCES_DATASETS,
        )
        if not GRAPHMIXER_SCRIPT or not GRAPHMIXER_SCRIPT.exists():
            raise FileNotFoundError(f"GraphMixer training script not found: {GRAPHMIXER_SCRIPT}")
        cmd = build_cmd(PYTHON_BIN, GRAPHMIXER_SCRIPT, get_graphmixer_args(dataset))
        code = run_cmd(cmd, env=env, workdir=GRAPHMIXER_SCRIPT.parent, dry_run=DRY_RUN)
        if code == 0:
            exported = []
            if GRAPHMIXER_PARAMS_DIR and GRAPHMIXER_PARAMS_DIR.exists():
                dest_dir = RESOURCES_MODELS / slugify(dataset) / "graphmixer"
                dest_dir.mkdir(parents=True, exist_ok=True)
                for src_file in sorted(GRAPHMIXER_PARAMS_DIR.glob(f"graphmixer_{slugify(dataset)}*.pt")):
                    dest_path = dest_dir / src_file.name
                    dest_path.write_bytes(src_file.read_bytes())
                    exported.append(dest_path)
                if not exported:
                    for src_file in sorted(GRAPHMIXER_PARAMS_DIR.glob("graphmixer_*.pt")):
                        dest_path = dest_dir / src_file.name
                        dest_path.write_bytes(src_file.read_bytes())
                        exported.append(dest_path)
            if exported:
                print("[GraphMixer] Exported:")
                for p in exported:
                    print(" -", p)
            else:
                print("[GraphMixer] No checkpoints found under", GRAPHMIXER_PARAMS_DIR)
            print("[GraphMixer] Training completed (or started successfully) for", dataset)
        else:
            print("[GraphMixer] Training failed for", dataset)

    else:
        raise ValueError("MODEL_TYPE must be 'TGN' or 'TGAT' or 'GRAPHMIXER'")


=== Launching TGN on stick_figure ===
Artifacts will be stored under: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/runs/tgn_stick_figure
$ (cwd=/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/runs/tgn_stick_figure) python /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/train_self_supervised.py --data stick_figure --use_memory --prefix tgn-attn-stick_figure --n_runs 2 --n_epoch 10 --patience 3 --memory_dim 8


INFO:root:Namespace(data='stick_figure', bs=200, prefix='tgn-attn-stick_figure', n_degree=10, n_head=2, n_epoch=10, n_layer=1, lr=0.0001, patience=3, n_runs=2, drop_out=0.1, gpu=0, node_dim=100, time_dim=100, backprop_every=1, use_memory=True, embedding_module='graph_attention', message_function='identity', memory_updater='gru', aggregator='last', memory_update_at_end=False, message_dim=100, memory_dim=8, different_new_nodes=False, uniform=False, randomize_features=False, use_destination_embedding_in_message=False, use_source_embedding_in_message=False, dyrep=False)


/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
The dataset has 342400 interactions, involving 11200 different nodes
The training dataset has 239680 interactions, involving 7840 different nodes
The validation dataset has 51360 interactions, involving 1680 different nodes
The test dataset has 51360 interactions, involving 1680 different nodes
The new node validation dataset has 51360 interactions, involving 1680 different nodes
The new node test dataset has 51360 interactions, involving 1680 different nodes
1120 nodes were used for the inductive testing, i.e. are never seen during training


INFO:root:num of training instances: 239680
INFO:root:num of batches per epoch: 1199
INFO:root:start 0 epoch
INFO:root:epoch: 0 took 178.68s
INFO:root:Epoch mean loss: 0.8515856625128528
INFO:root:val auc: 0.9960248054474707, new node val auc: 0.9960248054474707
INFO:root:val ap: 0.9954147237773058, new node val ap: 0.9954147237773058
INFO:root:val acc: 0.9648808365758755, new node val acc: 0.9648808365758755
INFO:root:start 1 epoch
INFO:root:epoch: 1 took 149.10s
INFO:root:Epoch mean loss: 0.22953668798555352
INFO:root:val auc: 0.9955454766536965, new node val auc: 0.9955454766536965
INFO:root:val ap: 0.9951435179269146, new node val ap: 0.9951435179269146
INFO:root:val acc: 0.972534046692607, new node val acc: 0.972534046692607
INFO:root:start 2 epoch
INFO:root:epoch: 2 took 145.70s
INFO:root:Epoch mean loss: 0.2202950767167627
INFO:root:val auc: 0.994548686770428, new node val auc: 0.994548686770428
INFO:root:val ap: 0.9940648184028066, new node val ap: 0.9940648184028066
INFO:root:

Copied trained model(s) to:
 - /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/stick_figure/tgn/tgn-attn-stick_figure-stick_figure_5.pth
[TGN] Training completed (or started successfully) for stick_figure


In [5]:
# --- Contact edge sanity check + visualization (stick_figure + sticky_hips + TGN) ---
if DRY_RUN:
    print("Skip contact-edge check (DRY_RUN=True).")
elif MODEL_TYPE.upper() != "TGN":
    print("Skip contact-edge check (MODEL_TYPE != TGN).")
else:
    import json
    import math
    import numpy as np
    import pandas as pd
    import torch
    import plotly.graph_objects as go

    from time_to_explain.data.synthetic_recipes.stick_figure import JOINTS_PER_PERSON
    from time_to_explain.visualization.utils import COLORS
    from submodules.models.tgn.model.tgn import TGN
    from submodules.models.tgn.tgn_utils.data_processing import get_data, compute_time_statistics
    from submodules.models.tgn.tgn_utils.utils import RandEdgeSampler, get_neighbor_finder

    def _find_checkpoint(models_root: Path, dataset_name: str, model_name: str) -> Path:
        model_name = model_name.lower()
        dataset_name = str(dataset_name)
        candidates = [
            models_root / dataset_name / model_name / f"{model_name}_{dataset_name}_best.pth",
            models_root / dataset_name / "checkpoints" / f"{model_name}_{dataset_name}_best.pth",
            models_root / "checkpoints" / f"{model_name}_{dataset_name}_best.pth",
        ]
        for cand in candidates:
            if cand.exists():
                return cand
        search_roots = [
            models_root / dataset_name / model_name,
            models_root / dataset_name,
            models_root / "checkpoints",
        ]
        for root in search_roots:
            if not root.exists():
                continue
            matches = sorted(root.rglob(f"{model_name}*{dataset_name}*.pth"))
            if not matches:
                matches = sorted(root.rglob("*.pth"))
            for match in matches:
                if "best" in match.name:
                    return match
            if matches:
                return matches[0]
        raise FileNotFoundError(
            f"Checkpoint not found under {models_root} for {model_name}_{dataset_name}."
        )

    def _build_tgn_args(train_args: dict) -> dict:
        return {
            "n_layers": int(train_args.get("n_layer", train_args.get("n_layers", 1))),
            "n_heads": int(train_args.get("n_head", train_args.get("n_heads", 2))),
            "dropout": float(train_args.get("drop_out", 0.1)),
            "use_memory": bool(train_args.get("use_memory", False)),
            "message_dimension": int(train_args.get("message_dim", 100)),
            "memory_dimension": int(train_args.get("memory_dim", 172)),
            "memory_update_at_start": not bool(train_args.get("memory_update_at_end", False)),
            "embedding_module_type": str(train_args.get("embedding_module", "graph_attention")),
            "message_function": str(train_args.get("message_function", "identity")),
            "aggregator_type": str(train_args.get("aggregator", "last")),
            "memory_updater_type": str(train_args.get("memory_updater", "gru")),
            "use_destination_embedding_in_message": bool(train_args.get("use_destination_embedding_in_message", False)),
            "use_source_embedding_in_message": bool(train_args.get("use_source_embedding_in_message", False)),
            "dyrep": bool(train_args.get("dyrep", False)),
        }

    def _load_processed_tables(dataset: str) -> tuple[pd.DataFrame, np.ndarray, dict]:
        flat_csv = RESOURCES_DATASETS / f"ml_{dataset}.csv"
        if flat_csv.exists():
            data_dir = RESOURCES_DATASETS
        else:
            data_dir = RESOURCES_DATASETS / dataset
        graph_df = pd.read_csv(data_dir / f"ml_{dataset}.csv")
        edge_features = np.load(data_dir / f"ml_{dataset}.npy")
        meta_path = data_dir / f"ml_{dataset}.json"
        meta = json.loads(meta_path.read_text(encoding="utf-8")) if meta_path.exists() else {}
        return graph_df, edge_features, meta


    def _plot_frame(
        frame_id: int,
        *,
        title: str,
        probs: dict | None = None,
        show_gt: bool = True,
        pred_query_only: bool = True,
    ) -> go.Figure:
        fig = go.Figure()
        mask = (clip_ids == clip_id) & (frame_idx == frame_id)
        indices = np.where(mask)[0]

        if show_gt and probs is not None:
            for idx in indices:
                if is_query[idx]:
                    continue
                coords = feat_map[idx]
                x0, y0, x1, y1 = [float(v) for v in coords[:4]]
                fig.add_trace(
                    go.Scatter(
                        x=[x0, x1],
                        y=[y0, y1],
                        mode="lines",
                        line=dict(color=COLORS["base"], width=1.2),
                        opacity=0.35,
                        hoverinfo="skip",
                        showlegend=False,
                    )
                )

        for idx in indices:
            coords = feat_map[idx]
            x0, y0, x1, y1 = [float(v) for v in coords[:4]]
            if probs is None:
                if is_query[idx]:
                    color = COLORS["accent"]
                    width = 3.2
                else:
                    color = COLORS["base"]
                    width = 2.0
            else:
                if pred_query_only and not is_query[idx]:
                    continue
                prob = probs.get(int(idx), 0.0)
                if prob < 0.5:
                    continue
                color = COLORS["user"]
                width = 2.8
            fig.add_trace(
                go.Scatter(
                    x=[x0, x1],
                    y=[y0, y1],
                    mode="lines",
                    line=dict(color=color, width=width),
                    hovertemplate=f"edge_idx={idx_vals[idx]}<extra></extra>",
                    showlegend=False,
                )
            )
        fig.update_layout(
            title=title,
            template="simple_white",
            xaxis=dict(visible=False),
            yaxis=dict(visible=False, scaleanchor="x", scaleratio=1),
            margin=dict(l=20, r=20, t=60, b=20),
        )
        return fig

    tgn_spec = MODEL_SPECS.get("TGN", {})
    train_args = dict(tgn_spec.get("args") or {})
    n_neighbors = int(train_args.get("n_degree", 10))
    batch_size = int(train_args.get("bs", 200))
    tgn_args = _build_tgn_args(train_args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Contact-edge check device:", device)

    for dataset in DATASET_LIST:
        if dataset not in {"stick_figure", "sticky_hips"}:
            print(f"Skip {dataset}: contact-edge check is configured for stick_figure/sticky_hips only.")
            continue

        try:
            ckpt_path = _find_checkpoint(RESOURCES_MODELS, dataset, "tgn")
        except FileNotFoundError as exc:
            print(f"Skip {dataset}: {exc}")
            continue
        print("Using checkpoint:", ckpt_path)

        node_features, edge_features, full_data, _train_data, _val_data, test_data, _nn_val, _nn_test = get_data(dataset)
        m_src, s_src, m_dst, s_dst = compute_time_statistics(
            full_data.sources, full_data.destinations, full_data.timestamps
        )
        full_ngh_finder = get_neighbor_finder(full_data, uniform=False)

        model = TGN(
            neighbor_finder=full_ngh_finder,
            node_features=node_features,
            edge_features=edge_features,
            device=device,
            n_neighbors=n_neighbors,
            mean_time_shift_src=m_src,
            std_time_shift_src=s_src,
            mean_time_shift_dst=m_dst,
            std_time_shift_dst=s_dst,
            **tgn_args,
        )
        state_dict = torch.load(ckpt_path, map_location="cpu")
        filtered_state = {
            k: v
            for k, v in state_dict.items()
            if not (k.startswith("memory.") or k.startswith("memory_updater.memory."))
        }
        _ = model.load_state_dict(filtered_state, strict=False)
        model = model.to(device).eval()
        if getattr(model, "use_memory", False) and getattr(model, "memory", None) is not None:
            model.memory.__init_memory__()

        graph_df, edge_feat_full, meta = _load_processed_tables(dataset)
        cfg_meta = meta.get("config") if isinstance(meta.get("config"), dict) else {}
        frames = int(cfg_meta.get("frames", 30))

        idx_col = "idx" if "idx" in graph_df.columns else ("e_idx" if "e_idx" in graph_df.columns else None)
        if idx_col is None:
            print("Paired contact check skipped: missing idx/e_idx.")
            continue
        idx_vals = graph_df[idx_col].astype(int).to_numpy()
        if edge_feat_full.ndim != 2 or edge_feat_full.shape[1] < 9:
            print("Paired contact check skipped: edge_features missing frame info.")
            continue
        if edge_feat_full.shape[0] > int(idx_vals.max()):
            feat_map = edge_feat_full[idx_vals]
        elif edge_feat_full.shape[0] == len(graph_df):
            feat_map = edge_feat_full
        else:
            print("Paired contact check skipped: edge_features length mismatch.")
            continue

        u_all = graph_df["u"].astype(int).to_numpy()
        v_all = graph_df["i"].astype(int).to_numpy()
        ts_all = graph_df["ts"].astype(float).to_numpy()
        frame_norm = feat_map[:, 8]
        is_query = feat_map[:, 7] >= 0.5

        node_min = int(min(u_all.min(), v_all.min()))
        node_base = 1 if node_min >= 1 else 0
        clip_ids = (u_all - node_base) // JOINTS_PER_PERSON
        num_clips = int(clip_ids.max() + 1)
        frame_idx = np.clip(np.rint(frame_norm * (frames - 1)), 0, frames - 1).astype(int)

        frame_time = np.full((num_clips, frames), np.nan)
        contact_present = np.zeros((num_clips, frames), dtype=bool)
        for i in range(len(graph_df)):
            c = int(clip_ids[i])
            f = int(frame_idx[i])
            if is_query[i]:
                contact_present[c, f] = True
            if np.isnan(frame_time[c, f]) or ts_all[i] < frame_time[c, f]:
                frame_time[c, f] = ts_all[i]

        idx_to_row = {int(v): i for i, v in enumerate(idx_vals.tolist())}
        pos_mask = np.asarray(test_data.labels) == 1
        pos_edge_idxs = np.asarray(test_data.edge_idxs)[pos_mask]

        pos_sources = []
        pos_destinations = []
        pos_times = []
        neg_sources = []
        neg_destinations = []
        neg_times = []

        def _find_non_contact_frame(c: int, f: int) -> int | None:
            for delta in range(1, frames):
                lo = f - delta
                hi = f + delta
                if lo >= 0 and not contact_present[c, lo] and not np.isnan(frame_time[c, lo]):
                    return lo
                if hi < frames and not contact_present[c, hi] and not np.isnan(frame_time[c, hi]):
                    return hi
            return None

        for e_idx in pos_edge_idxs:
            row_idx = idx_to_row.get(int(e_idx))
            if row_idx is None or not is_query[row_idx]:
                continue
            c = int(clip_ids[row_idx])
            f = int(frame_idx[row_idx])
            neg_f = _find_non_contact_frame(c, f)
            if neg_f is None:
                continue
            pos_sources.append(int(u_all[row_idx]))
            pos_destinations.append(int(v_all[row_idx]))
            pos_times.append(float(ts_all[row_idx]))
            neg_sources.append(int(u_all[row_idx]))
            neg_destinations.append(int(v_all[row_idx]))
            neg_times.append(float(frame_time[c, neg_f]))

        if not pos_sources:
            print("Paired contact check: no usable contact/non-contact pairs found.")
            continue

        pos_sources = np.asarray(pos_sources)
        pos_destinations = np.asarray(pos_destinations)
        pos_times = np.asarray(pos_times)
        neg_sources = np.asarray(neg_sources)
        neg_destinations = np.asarray(neg_destinations)
        neg_times = np.asarray(neg_times)

        sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0)

        prev_forbidden = getattr(model, "forbidden_memory_update", False)
        model.forbidden_memory_update = True
        if getattr(model, "use_memory", False) and getattr(model, "memory", None) is not None:
            model.memory.__init_memory__()

        def _batched_pos_probs(src: np.ndarray, dst: np.ndarray, ts: np.ndarray) -> np.ndarray:
            out = []
            n = len(src)
            for i in range(0, n, batch_size):
                s = src[i:i + batch_size]
                d = dst[i:i + batch_size]
                t = ts[i:i + batch_size]
                if len(s) == 0:
                    continue
                _, neg = sampler.sample(len(s))
                edge_idx_dummy = np.zeros(len(s), dtype=int)
                pos_prob, _ = model.compute_edge_probabilities(
                    s,
                    d,
                    neg,
                    t,
                    edge_idx_dummy,
                    n_neighbors=n_neighbors,
                )
                out.append(pos_prob.detach().cpu().numpy())
            return np.concatenate(out) if out else np.array([])

        pos_prob_pair = _batched_pos_probs(pos_sources, pos_destinations, pos_times)
        neg_prob_pair = _batched_pos_probs(neg_sources, neg_destinations, neg_times)

        model.forbidden_memory_update = prev_forbidden

        pos_hit = int(np.sum(pos_prob_pair >= 0.5))
        neg_reject = int(np.sum(neg_prob_pair < 0.5))
        total = int(pos_prob_pair.size)
        print(f"Paired contact hit-rate (pos >= 0.5): {100.0 * pos_hit / total:.2f}% ({pos_hit}/{total})")
        print(f"Paired non-contact reject (neg < 0.5): {100.0 * neg_reject / total:.2f}% ({neg_reject}/{total})")

        # Visualization: predictions at multiple frames
        row_idx = None
        for e_idx in pos_edge_idxs:
            idx = idx_to_row.get(int(e_idx))
            if idx is not None and is_query[idx]:
                row_idx = idx
                break
        if row_idx is None:
            print("Visualization skipped: no contact edge found in test split.")
            continue

        clip_id = int(clip_ids[row_idx])
        frame_contact = int(frame_idx[row_idx])
        pred_frames = 8
        start_frame = max(0, frame_contact - (pred_frames // 2))
        end_frame = min(frames - 1, start_frame + pred_frames - 1)
        start_frame = max(0, end_frame - (pred_frames - 1))
        display_frames = list(range(start_frame, end_frame + 1))

        prev_forbidden = getattr(model, "forbidden_memory_update", False)
        model.forbidden_memory_update = False
        if getattr(model, "use_memory", False) and getattr(model, "memory", None) is not None:
            model.memory.__init_memory__()

        frame_probs: dict[int, dict[int, float]] = {}
        for f in range(frames):
            indices_f = np.where((clip_ids == clip_id) & (frame_idx == f))[0]
            if len(indices_f) == 0:
                continue
            s = u_all[indices_f]
            d = v_all[indices_f]
            t = ts_all[indices_f]
            e = idx_vals[indices_f]
            _, neg = sampler.sample(len(s))
            pos_prob, _ = model.compute_edge_probabilities(
                s,
                d,
                neg,
                t,
                e,
                n_neighbors=n_neighbors,
            )
            if f in display_frames:
                probs = {
                    int(idx): float(p)
                    for idx, p in zip(indices_f, pos_prob.detach().cpu().numpy())
                }
                frame_probs[int(f)] = probs

        model.forbidden_memory_update = prev_forbidden

        if not frame_probs:
            print("Visualization skipped: no frames selected.")
            continue

        for f in display_frames:
            probs = frame_probs.get(int(f))
            if not probs:
                continue
            contact_flag = "contact" if contact_present[clip_id, f] else "no contact"
            fig_pred = _plot_frame(
                f,
                title=f"{dataset}: predictions at frame {f} ({contact_flag})",
                probs=probs,
                show_gt=True,
                pred_query_only=True,
            )
            fig_pred.show()

Contact-edge check device: cpu
Using checkpoint: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/stick_figure/checkpoints/tgn_stick_figure_best.pth
/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
The dataset has 342400 interactions, involving 11200 different nodes
The training dataset has 239680 interactions, involving 7840 different nodes
The validation dataset has 51360 interactions, involving 1680 different nodes
The test dataset has 51360 interactions, involving 1680 different nodes
The new node validation dataset has 51360 interactions, involving 1680 different nodes
The new node test dataset has 51360 interactions, involving 1680 different nodes
1120 nodes were used for the inductive testing, i.e. are never seen during training


  state_dict = torch.load(ckpt_path, map_location="cpu")


Paired contact hit-rate (pos >= 0.5): 84.27% (809/960)
Paired non-contact reject (neg < 0.5): 6.98% (67/960)


  int(idx): float(p)
