# Explain a Single Event

Load a pretrained temporal GNN, generate one explanation, and visualise both the local graph context and the explainer's subgraph.

In [None]:
# Environment setup
from pathlib import Path
import sys, importlib


def _add_notebooks_src_to_path():
    here = Path.cwd().resolve()
    for p in [here, *here.parents]:
        candidate = p / "notebooks" / "src"
        if candidate.is_dir():
            if str(candidate) not in sys.path:
                sys.path.insert(0, str(candidate))
            return candidate
    raise FileNotFoundError("Could not find 'notebooks/src' from current working directory.")

NOTEBOOKS_SRC = _add_notebooks_src_to_path()
print(f"Using helpers from: {NOTEBOOKS_SRC}")

from constants import (
    REPO_ROOT, RESOURCES_DIR, PROCESSED_DATA_DIR, MODELS_ROOT, ensure_repo_importable, load_notebook_config
)
ensure_repo_importable()
from device import pick_device
NOTEBOOK_CFG = load_notebook_config()
SEED = int(NOTEBOOK_CFG.get("seed", 42))
DEVICE = pick_device(NOTEBOOK_CFG.get("device", "auto"))
print(f"Notebook config: seed={SEED}, device={DEVICE}")


SUBMODULE_MODELS_DIR = REPO_ROOT / "submodules" / "models"
for sub in ("tgn", "tgat"):
    sub_path = SUBMODULE_MODELS_DIR / sub
    if sub_path.is_dir() and str(sub_path) not in sys.path:
        sys.path.insert(0, str(sub_path))

EXPLAINER_ROOT = REPO_ROOT / "submodules" / "explainer"
if EXPLAINER_ROOT.is_dir() and str(EXPLAINER_ROOT) not in sys.path:
    sys.path.insert(0, str(EXPLAINER_ROOT))


SyntaxError: invalid syntax (1284727354.py, line 21)

In [None]:
# Core imports and notebook knobs
import json
import torch
import numpy as np
import pandas as pd

from time_to_explain.data.legacy.tg_dataset import load_explain_idx
from time_to_explain.data.workflows import load_processed_dataset_safe
from time_to_explain.data.synthetic_recipes.stick_figure import (
    write_stick_figure_explain_index,
    JOINTS_PER_PERSON,
    J_T,
    J_LW,
    J_RW,
    J_LA,
    J_RA,
)
from time_to_explain.utils.graph import NeighborFinder
from submodules.models.tgat.module import TGAN
from submodules.models.tgn.model.tgn import TGN
from submodules.explainer.tgnnexplainer.tgnnexplainer.xgraph.models.ext.tgn.utils.data_processing import compute_time_statistics

from time_to_explain.adapters.subgraphx_adapter import SubgraphXTGAdapter, SubgraphXTGAdapterConfig
from time_to_explain.extractors.base_extractor import BaseExtractor
from time_to_explain.core.types import ExplanationContext

from IPython.display import display

from time_to_explain.visualization import (
    plot_bipartite_graph,
    plot_explain_timeline,
    plot_force_directed_graph,
    plot_ground_truth_subgraph,
    visualize_dataset,
)

DATASET_NAME = "sticky_hips"  #'stick_figure' #"triadic_closure"     # e.g. "wikipedia", "reddit", "simulate_v1"
MODEL_NAME = "tgn"             # "tgn" or "tgat"
EXPLAINER_NAME = "subgraphx_tg"  # config key under configs/explainer
EVENT_INDEX = None              # set an integer (1-based) to override the default pick
GRAPH_MAX_USERS = 40
GRAPH_MAX_ITEMS = 40
TIMELINE_WINDOW = 200           # events shown before/after the anchor in the timeline

SHOW = True
WRITE_FILES = False
SAVE_FORMAT = "pdf"
SAVE_DIR = Path("viz_out") / DATASET_NAME if WRITE_FILES else None

if "DEVICE" not in globals():
    DEVICE = pick_device(NOTEBOOK_CFG.get("device", "auto") if "NOTEBOOK_CFG" in globals() else "auto")
print(f"Using device: {DEVICE}")



EXPLAIN_IDX_CANDIDATES = [
    RESOURCES_DIR / "datasets" / "explain_index" / f"{DATASET_NAME}.csv",
    RESOURCES_DIR / "explainer" / "explain_index" / f"{DATASET_NAME}.csv",
]
EXPLAIN_IDX_CSV = next((p for p in EXPLAIN_IDX_CANDIDATES if p.exists()), EXPLAIN_IDX_CANDIDATES[0])
MODELS_ROOT = RESOURCES_DIR / "models"

def _find_checkpoint(models_root: Path, dataset_name: str, model_name: str) -> Path:
    model_name = model_name.lower()
    dataset_name = str(dataset_name)
    candidates = [
        models_root / dataset_name / model_name / f"{model_name}_{dataset_name}_best.pth",
        models_root / dataset_name / "checkpoints" / f"{model_name}_{dataset_name}_best.pth",
        models_root / "checkpoints" / f"{model_name}_{dataset_name}_best.pth",
    ]
    for cand in candidates:
        if cand.exists():
            return cand
    search_roots = [
        models_root / dataset_name / model_name,
        models_root / dataset_name,
        models_root / "checkpoints",
    ]
    for root in search_roots:
        if not root.exists():
            continue
        matches = sorted(root.rglob(f"{model_name}*{dataset_name}*.pth"))
        if not matches:
            matches = sorted(root.rglob("*.pth"))
        for match in matches:
            if "best" in match.name:
                return match
        if matches:
            return matches[0]
    raise FileNotFoundError(
        f"Checkpoint not found under {models_root} for {model_name}_{dataset_name}."
    )

CKPT_PATH = _find_checkpoint(MODELS_ROOT, DATASET_NAME, MODEL_NAME)
print(f"Using checkpoint: {CKPT_PATH}")




Using REPO_ROOT / ROOT_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
Using REPO_ROOT / ROOT_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
Using REPO_ROOT / ROOT_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
Using device: mps
Using checkpoint: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/sticky_hips/tgn/tgn-attn-sticky_hips-sticky_hips.pth


In [None]:
# Dataset overview (same style as prepare_datasets)
explain_preview = None
try:
    explain_df = pd.read_csv(EXPLAIN_IDX_CSV)
    explain_preview = explain_df["event_idx"].head(5).astype(int).tolist()
except Exception:
    explain_preview = None

if explain_preview:
    print(f"Explain preview: {explain_preview}")


Explain preview: [289043, 289100, 289157, 289214, 289271]


In [None]:
# Load dataset and backbone model
print(f"Loading temporal dataset: {DATASET_NAME}")
bundle = load_processed_dataset_safe(DATASET_NAME, verbose=True)
events = bundle["interactions"]
edge_feats = bundle.get("edge_features")
node_feats = bundle.get("node_features")
metadata = bundle.get("metadata") or {}
config_meta = metadata.get("config") if isinstance(metadata.get("config"), dict) else {}
def _infer_bipartite_from_events(df: pd.DataFrame) -> bool:
    if len(df) == 0:
        return False
    u_min, u_max = int(df["u"].min()), int(df["u"].max())
    i_min, i_max = int(df["i"].min()), int(df["i"].max())
    return i_min > u_max or u_min > i_max

is_bipartite = metadata.get("bipartite", config_meta.get("bipartite"))
if is_bipartite is None:
    is_bipartite = _infer_bipartite_from_events(events)
else:
    is_bipartite = bool(is_bipartite)
if DATASET_NAME in {"stick_figure", "sticky_hips"}:
    is_bipartite = False


def _build_neighbor_finder(df: pd.DataFrame) -> NeighborFinder:
    u = df["u"].to_numpy(dtype=int)
    v = df["i"].to_numpy(dtype=int)
    ts = df["ts"].to_numpy(dtype=float)
    if "e_idx" in df.columns:
        e_idx = df["e_idx"].to_numpy(dtype=int)
    elif "idx" in df.columns:
        e_idx = df["idx"].to_numpy(dtype=int)
    else:
        e_idx = np.arange(1, len(df) + 1, dtype=int)
    max_node = int(max(u.max(), v.max())) if len(df) else 0
    adj_list = [[] for _ in range(max_node + 1)]
    for src, dst, t, e in zip(u, v, ts, e_idx):
        adj_list[int(src)].append((int(dst), int(e), float(t)))
        adj_list[int(dst)].append((int(src), int(e), float(t)))
    return NeighborFinder(adj_list, uniform=False)
print(f"Loaded {len(events):,} interactions")

configs_dir = REPO_ROOT / "configs" / "models"
config_candidates = [
    configs_dir / f"infer_{MODEL_NAME.lower()}_{DATASET_NAME}.json",
    configs_dir / f"infer_{MODEL_NAME.lower()}.json",
]
model_config_path = next((p for p in config_candidates if p.exists()), None)
if model_config_path is None:
    raise FileNotFoundError(
        "Model config not found. Expected one of: "
        + ", ".join(str(p) for p in config_candidates)
    )
model_config = json.loads(model_config_path.read_text())
config_model = str(model_config.get("model", "")).lower()
if config_model and config_model != MODEL_NAME.lower():
    raise ValueError(
        f"Config model {config_model!r} does not match MODEL_NAME={MODEL_NAME!r}"
    )
model_args = dict(model_config.get("args") or {})
print(f"Using model config: {model_config_path}")

model_name = MODEL_NAME.lower()
if model_name == "tgat":
    if not is_bipartite:
        raise ValueError("TGAT expects bipartite datasets; set MODEL_NAME='tgn' for stick_figure.")
    ngh_finder = _build_neighbor_finder(events)
    model = TGAN(
        ngh_finder,
        node_feats,
        edge_feats,
        device=DEVICE,
        **model_args,
    )
elif model_name == "tgn":
    m_src, s_src, m_dst, s_dst = compute_time_statistics(events.u.values, events.i.values, events.ts.values)
    ngh_finder = _build_neighbor_finder(events)
    model = TGN(
        ngh_finder,
        node_feats,
        edge_feats,
        device=DEVICE,
        mean_time_shift_src=m_src,
        std_time_shift_src=s_src,
        mean_time_shift_dst=m_dst,
        std_time_shift_dst=s_dst,
        **model_args,
    )
else:
    raise NotImplementedError(f"Unsupported MODEL_NAME={MODEL_NAME}")

state_dict = torch.load(CKPT_PATH, map_location="cpu")
_ = model.load_state_dict(state_dict, strict=False)
model = model.to(DEVICE).eval()
default_layers = model_args.get("num_layers") or model_args.get("n_layers") or 2
default_neighbors = model_args.get("num_neighbors")
if default_neighbors is None:
    default_neighbors = model_args.get("n_neighbors")
if default_neighbors is None:
    default_neighbors = 20
model.num_layers = getattr(model, "num_layers", None) or default_layers
model.num_neighbors = getattr(model, "num_neighbors", None) or default_neighbors
print("Backbone ready on", DEVICE)


Loading temporal dataset: sticky_hips
Loaded flat processed files for 'sticky_hips'.
Loaded 339,560 interactions
Using model config: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/configs/models/infer_tgn.json
Backbone ready on mps


  state_dict = torch.load(CKPT_PATH, map_location="cpu")


In [None]:
# Build extractor and explainer
configs_dir = REPO_ROOT / "configs" / "explainer"
explainer_candidates = [
    configs_dir / f"{EXPLAINER_NAME.lower()}_{DATASET_NAME}.json",
    configs_dir / f"{EXPLAINER_NAME.lower()}.json",
]
explainer_config_path = next((p for p in explainer_candidates if p.exists()), None)
if explainer_config_path is None:
    raise FileNotFoundError(
        "Explainer config not found. Expected one of: "
        + ", ".join(str(p) for p in explainer_candidates)
    )
explainer_config = json.loads(explainer_config_path.read_text())
explainer_args = dict(explainer_config.get("args") or {})
if explainer_args.get("results_dir"):
    results_dir = Path(explainer_args["results_dir"])
    if not results_dir.is_absolute():
        results_dir = REPO_ROOT / results_dir
    explainer_args["results_dir"] = str(results_dir)
explainer_args.setdefault("model_name", model_name)
explainer_args.setdefault("dataset_name", DATASET_NAME)
explainer_args.setdefault("device", DEVICE)

CANDIDATE_LIMIT = int(explainer_args.get("threshold_num", 50))
MAX_COALITION_SIZE = int(explainer_args.get("min_atoms", 6))
explainer_args["threshold_num"] = CANDIDATE_LIMIT
explainer_args["min_atoms"] = MAX_COALITION_SIZE
print(f"Using explainer config: {explainer_config_path}")

extractor = BaseExtractor(
    model=model,
    events=events,
    threshold_num=CANDIDATE_LIMIT,
    keep_order="last-N-then-sort",
)

adapter_cfg = SubgraphXTGAdapterConfig(**explainer_args)

explainer = SubgraphXTGAdapter(adapter_cfg)
dataset_pack = {"events": events, "dataset_name": DATASET_NAME}
explainer.prepare(model=model, dataset=dataset_pack)
print(f"Explainer prepared with {EXPLAINER_NAME}")


Using explainer config: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/configs/explainer/subgraphx_tg.json
Explainer prepared with subgraphx_tg


In [None]:
# Select a single event and generate an explanation
all_event_indices = load_explain_idx(EXPLAIN_IDX_CSV)

if DATASET_NAME.lower() == "stick_figure":
    target_label = int(config_meta.get("target_label", 1)) if isinstance(config_meta, dict) else 1
    node_min = int(min(events["u"].min(), events["i"].min()))
    node_base = 1 if node_min >= 1 else 0

    def _is_contact_edge(idx_1based: int) -> bool:
        row = events.iloc[int(idx_1based) - 1]
        if "label" in row and int(row["label"]) != target_label:
            return False
        u = int(row["u"])
        v = int(row["i"])
        ju = (u - node_base) % JOINTS_PER_PERSON
        jv = (v - node_base) % JOINTS_PER_PERSON
        if DATASET_NAME.lower() == "sticky_hips":
            return {ju, jv} in ({J_RW, J_T}, {J_LW, J_T})
        return {ju, jv} in ({J_LW, J_RW}, {J_LA, J_RA})

    filtered = [idx for idx in all_event_indices if _is_contact_edge(idx)]
    if filtered:
        all_event_indices = filtered
    else:
        print("Warning: explain index has no stick-figure contact edges; falling back to label filter.")
        if "label" in events.columns:
            all_event_indices = (events.index[events["label"].astype(int) == target_label] + 1).tolist()

if EVENT_INDEX is None:
    target_event_idx = int(all_event_indices[0])
else:
    target_event_idx = int(EVENT_INDEX)
    if DATASET_NAME.lower() == "stick_figure":
        if "_is_contact_edge" in globals() and not _is_contact_edge(target_event_idx):
            raise ValueError("EVENT_INDEX must reference a stick-figure contact edge between wrists or ankles.")

if not (1 <= target_event_idx <= len(events)):
    raise ValueError(f"EVENT_INDEX {target_event_idx} must be within [1, {len(events)}]")

anchor = {"target_kind": "edge", "event_idx": target_event_idx}
subgraph = extractor.extract(dataset_pack, anchor, k_hop=model.num_layers, num_neighbors=model.num_neighbors)

context = ExplanationContext(
    run_id=f"{MODEL_NAME}_{DATASET_NAME}_single",
    target_kind="edge",
    target={"event_idx": target_event_idx},
    k_hop=model.num_layers,
    num_neighbors=model.num_neighbors,
    subgraph=subgraph,
)

result = explainer.explain(context)
coalition = result.extras.get("coalition_eidx", [])
print(f"Explained event #{target_event_idx} -> coalition size {len(coalition)}")

if metadata.get("ground_truth"):
    gt = metadata.get("ground_truth") or {}
    raw_rationales = gt.get("rationales") or {}
    target_idx0 = target_event_idx - 1
    if target_idx0 in raw_rationales:
        gt_support = raw_rationales.get(target_idx0, [])
    else:
        gt_support = raw_rationales.get(str(target_idx0), [])
    gt_support = [int(i) for i in gt_support if isinstance(i, (int, np.integer))]
    expl_indices = [idx - 1 for idx in coalition]
    overlap = sorted(set(gt_support) & set(expl_indices))
    if gt_support:
        recall = len(overlap) / len(gt_support)
        precision = len(overlap) / len(expl_indices) if expl_indices else 0.0
        iou = len(overlap) / len(set(gt_support).union(expl_indices)) if expl_indices else 0.0
        print(
            f"GT overlap: {len(overlap)}/{len(gt_support)} support edges matched "
            f"(recall {recall:.1%}, precision {precision:.1%}, IoU {iou:.1%})"
        )
    else:
        print("GT overlap: no ground-truth support edges available for this event.")
else:
    print("GT overlap: ground-truth metadata missing.")




552 events to explain

explain 0-th: 289043
The nodes in graph is 24


mcts simulating: 100%|██████████| 30/30 [00:01<00:00, 27.97it/s, states=71]

mcts recorder saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/candidate_scores/tgn_sticky_hips_289043_mcts_recorder_pg_false_th50.csv
Explained event #289043 -> coalition size 3
GT overlap: 2/2 support edges matched (recall 100.0%, precision 66.7%, IoU 66.7%)





In [None]:
# Prepare dataframes for visualisation
anchor_row = events.iloc[target_event_idx - 1].copy()
anchor_row.name = target_event_idx
anchor_user = int(anchor_row["u"])
anchor_item = int(anchor_row["i"])
anchor_edge = (anchor_user, anchor_item)
print("Event to be explained:")
display(anchor_row.to_frame().T)

# Ensure the global snapshot includes the anchor interaction
max_global_edges = GRAPH_MAX_USERS * GRAPH_MAX_ITEMS
head_df = events.head(max_global_edges)
if target_event_idx - 1 not in head_df.index:
    global_view_df = pd.concat([head_df, anchor_row.to_frame().T], ignore_index=False).sort_index()
else:
    global_view_df = head_df.copy()

global_view_df = global_view_df.drop_duplicates().copy()

# Candidate and explanation subsets
candidate_eidx = result.extras.get("candidate_eidx", []) or []
candidate_mask = [idx for idx in candidate_eidx if 1 <= idx <= len(events)]
if candidate_mask:
    candidate_df = events.iloc[[idx - 1 for idx in candidate_mask]].copy()
else:
    candidate_df = pd.DataFrame(columns=events.columns)

if not candidate_df.empty:
    global_view_df = pd.concat([global_view_df, candidate_df], ignore_index=False).drop_duplicates().copy()

coalition = [int(idx) for idx in coalition if 1 <= int(idx) <= len(events)]
if coalition:
    explanation_df = events.iloc[[idx - 1 for idx in coalition]].copy()
    global_view_df = pd.concat([global_view_df, explanation_df], ignore_index=False).drop_duplicates().copy()
else:
    explanation_df = pd.DataFrame(columns=events.columns)

coalition_edges = list({(int(row.u), int(row.i)) for _, row in explanation_df.iterrows()})

print(f"Candidate events considered ({len(candidate_df)})")
display(candidate_df.head())
print(f"Explanation edges returned ({len(explanation_df)})")
display(explanation_df.head())


Event to be explained:


Unnamed: 0,u,i,ts,label,idx,e_idx
289043,9542.0,9535.0,289042.0,1.0,289043.0,289043.0


Candidate events considered (7)


Unnamed: 0,u,i,ts,label,idx,e_idx
289028,9535,9536,289028.0,0,289029,289029
289029,9535,9537,289029.0,0,289030,289030
289030,9535,9538,289030.0,0,289031,289031
289034,9538,9540,289034.0,0,289035,289035
289035,9540,9542,289035.0,0,289036,289036


Explanation edges returned (3)


Unnamed: 0,u,i,ts,label,idx,e_idx
289028,9535,9536,289028.0,0,289029,289029
289034,9538,9540,289034.0,0,289035,289035
289035,9540,9542,289035.0,0,289036,289036


In [None]:
# Visualise dataset, context, and explanation
edges_global = [anchor_edge]
print(f"Dataset overview (first {len(global_view_df)} interactions; explained event marked)")
if is_bipartite:
    plot_bipartite_graph(
        global_view_df,
        max_users=GRAPH_MAX_USERS,
        max_items=GRAPH_MAX_ITEMS,
        highlight_users=[anchor_user],
        highlight_items=[anchor_item],
        highlight_edges=edges_global,
    )
else:
    print("Skipping force-directed overview for non-bipartite datasets.")

if not candidate_df.empty:
    print(f"Candidate subgraph supplied to SubgraphX (|C|={len(candidate_df)})")
    if is_bipartite:
        plot_bipartite_graph(
            candidate_df,
            max_users=GRAPH_MAX_USERS,
            max_items=GRAPH_MAX_ITEMS,
            highlight_users=[anchor_user],
            highlight_items=[anchor_item],
            highlight_edges=[anchor_edge],
        )
    else:
        print("Skipping force-directed candidate view for non-bipartite datasets.")
else:
    print("No candidate edges available to plot.")

if not explanation_df.empty:
    print(f"SubgraphX explanation coalition (|S|={len(explanation_df)}) with explained event marked")
    edges_to_explain = list(dict.fromkeys(coalition_edges or [anchor_edge]))
    if is_bipartite:
        plot_bipartite_graph(
            explanation_df,
            max_users=GRAPH_MAX_USERS,
            max_items=GRAPH_MAX_ITEMS,
            highlight_users=[anchor_user],
            highlight_items=[anchor_item],
            highlight_edges=edges_to_explain,
        )
    else:
        print("Skipping force-directed explanation view; see ground-truth overlay below.")
else:
    print("Explainer returned an empty coalition.")

if metadata.get("ground_truth"):
    print("Ground-truth neighborhood with explanation edges")
    explanation_indices = [idx - 1 for idx in coalition] if coalition else []
    try:
        plot_ground_truth_subgraph(
            events,
            event_idx=target_event_idx - 1,
            metadata=metadata,
            event_window=min(60, TIMELINE_WINDOW),
            max_context_edges=40,
            explanation_indices=explanation_indices,
        )
    except ValueError as exc:
        print(f"Skipping ground-truth plot: {exc}")
else:
    print("No ground-truth metadata available to overlay explanation.")

if metadata.get("ground_truth") and not explanation_df.empty:
    print("Ground-truth neighborhood + explainer edges (from explanation_df)")
    explanation_indices_from_df = []
    idx_col = "idx" if "idx" in events.columns else ("e_idx" if "e_idx" in events.columns else None)
    if idx_col and idx_col in explanation_df.columns:
        idx_to_row = {int(v): i for i, v in enumerate(events[idx_col].astype(int).tolist())}
        explanation_indices_from_df = [
            idx_to_row[int(v)]
            for v in explanation_df[idx_col].astype(int).tolist()
            if int(v) in idx_to_row
        ]
    if not explanation_indices_from_df:
        explanation_indices_from_df = [
            int(i) for i in explanation_df.index if 0 <= int(i) < len(events)
        ]
    explanation_indices_from_df = list(dict.fromkeys(explanation_indices_from_df))
    if explanation_indices_from_df:
        try:
            plot_ground_truth_subgraph(
                events,
                event_idx=target_event_idx - 1,
                metadata=metadata,
                event_window=min(60, TIMELINE_WINDOW),
                max_context_edges=40,
                explanation_indices=explanation_indices_from_df,
            )
        except ValueError as exc:
            print(f"Skipping GT+explainer plot: {exc}")
    else:
        print("No explanation edges could be mapped to events for GT overlay.")

highlight_indices = list(dict.fromkeys([target_event_idx, *coalition]))
print("Temporal context around the explained event")
plot_explain_timeline(
    events,
    event_indices=highlight_indices,
    window=TIMELINE_WINDOW,
    max_base_points=20_000,
)



Dataset overview (first 1608 interactions; explained event marked)
Skipping force-directed overview for non-bipartite datasets.
Candidate subgraph supplied to SubgraphX (|C|=7)
Skipping force-directed candidate view for non-bipartite datasets.
SubgraphX explanation coalition (|S|=3) with explained event marked
Skipping force-directed explanation view; see ground-truth overlay below.
Ground-truth neighborhood with explanation edges


Ground-truth neighborhood + explainer edges (from explanation_df)


Temporal context around the explained event
