# Explain a Single Event

Load a pretrained temporal GNN, generate one explanation, and visualise both the local graph context and the explainer's subgraph.

In [1]:
# Environment setup
from pathlib import Path
import sys, importlib


def _add_notebooks_src_to_path():
    here = Path.cwd().resolve()
    for p in [here, *here.parents]:
        candidate = p / "notebooks" / "src"
        if candidate.is_dir():
            if str(candidate) not in sys.path:
                sys.path.insert(0, str(candidate))
            return candidate
    raise FileNotFoundError("Could not find 'notebooks/src' from current working directory.")

NOTEBOOKS_SRC = _add_notebooks_src_to_path()
print(f"Using helpers from: {NOTEBOOKS_SRC}")

from constants import (
    REPO_ROOT, RESOURCES_DIR, PROCESSED_DATA_DIR, MODELS_ROOT, ensure_repo_importable
)
ensure_repo_importable()
from device import pick_device

SUBMODULE_MODELS_DIR = REPO_ROOT / "submodules" / "models"
for sub in ("tgn", "tgat"):
    sub_path = SUBMODULE_MODELS_DIR / sub
    if sub_path.is_dir() and str(sub_path) not in sys.path:
        sys.path.insert(0, str(sub_path))

EXPLAINER_ROOT = REPO_ROOT / "submodules" / "explainer"
if EXPLAINER_ROOT.is_dir() and str(EXPLAINER_ROOT) not in sys.path:
    sys.path.insert(0, str(EXPLAINER_ROOT))


Using helpers from: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/notebooks/src


In [2]:
# Core imports and notebook knobs
import torch
import pandas as pd

from time_to_explain.data.legacy.tg_dataset import load_tg_dataset, load_explain_idx
from submodules.explainer.tgnnexplainer.tgnnexplainer.xgraph.dataset.utils_dataset import construct_tgat_neighbor_finder
from submodules.models.tgat.module import TGAN
from submodules.models.tgn.model.tgn import TGN
from submodules.explainer.tgnnexplainer.tgnnexplainer.xgraph.models.ext.tgn.utils.data_processing import compute_time_statistics

from time_to_explain.adapters.subgraphx_tg_adapter import SubgraphXTGAdapter, SubgraphXTGAdapterConfig
from time_to_explain.extractors.tg_event_candidates_extractor import TGEventCandidatesExtractor
from time_to_explain.core.types import ExplanationContext

from IPython.display import display

from time_to_explain.utils.visualization import (
    plot_bipartite_graph,
    plot_explain_timeline,
)

DATASET_NAME = "wikipedia"     # e.g. "wikipedia", "reddit", "simulate_v1"
MODEL_NAME = "tgn"             # "tgn" or "tgat"
EVENT_INDEX = None              # set an integer (1-based) to override the default pick
CANDIDATE_LIMIT = 50            # number of candidate events passed to SubgraphX-TG
GRAPH_MAX_USERS = 40
GRAPH_MAX_ITEMS = 40
TIMELINE_WINDOW = 200           # events shown before/after the anchor in the timeline
MAX_COALITION_SIZE = 6            # upper bound on explanation size returned by SubgraphX

DEVICE = pick_device("auto")
print(f"Using device: {DEVICE}")

EXPLAIN_IDX_CSV = RESOURCES_DIR / "explainer" / "explain_index" / f"{DATASET_NAME}.csv"
CKPT_PATH = RESOURCES_DIR / "models" / DATASET_NAME / "checkpoints" / f"{MODEL_NAME.lower()}_{DATASET_NAME}_best.pth"

if not EXPLAIN_IDX_CSV.exists():
    raise FileNotFoundError(f"Explain-index file not found: {EXPLAIN_IDX_CSV}")
if not CKPT_PATH.exists():
    raise FileNotFoundError(f"Checkpoint not found: {CKPT_PATH}")


Using REPO_ROOT / ROOT_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
Using REPO_ROOT / ROOT_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
Using REPO_ROOT / ROOT_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
Using device: mps


In [3]:
# Load dataset and backbone model
print(f"Loading temporal dataset: {DATASET_NAME}")
events, edge_feats, node_feats = load_tg_dataset(DATASET_NAME)
print(f"Loaded {len(events):,} interactions")

model_name = MODEL_NAME.lower()
if model_name == "tgat":
    ngh_finder = construct_tgat_neighbor_finder(events)
    model = TGAN(
        ngh_finder,
        node_feats,
        edge_feats,
        device=DEVICE,
        attn_mode="prod",
        use_time="time",
        agg_method="attn",
        num_layers=2,
        n_head=4,
        null_idx=0,
        num_neighbors=20,
        drop_out=0.1,
    )
elif model_name == "tgn":
    m_src, s_src, m_dst, s_dst = compute_time_statistics(events.u.values, events.i.values, events.ts.values)
    ngh_finder = construct_tgat_neighbor_finder(events)
    model = TGN(
        ngh_finder,
        node_feats,
        edge_feats,
        device=DEVICE,
        n_layers=2,
        n_heads=2,
        dropout=0.1,
        use_memory=False,
        forbidden_memory_update=False,
        memory_update_at_start=True,
        message_dimension=100,
        memory_dimension=500,
        embedding_module_type="graph_attention",
        message_function="mlp",
        mean_time_shift_src=m_src,
        std_time_shift_src=s_src,
        mean_time_shift_dst=m_dst,
        std_time_shift_dst=s_dst,
        n_neighbors=None,
        aggregator_type="last",
        memory_updater_type="gru",
        use_destination_embedding_in_message=False,
        use_source_embedding_in_message=False,
        dyrep=False,
    )
else:
    raise NotImplementedError(f"Unsupported MODEL_NAME={MODEL_NAME}")

state_dict = torch.load(CKPT_PATH, map_location="cpu")
_ = model.load_state_dict(state_dict, strict=False)
model = model.to(DEVICE).eval()
model.num_layers = getattr(model, "num_layers", 2) or 2
model.num_neighbors = getattr(model, "num_neighbors", 20) or 20
print("Backbone ready on", DEVICE)


Loading temporal dataset: wikipedia

#Dataset: wikipedia, #Users: 8227, #Items: 1000, #Interactions: 157474, #Timestamps: 152757
#node feats shape: (9228, 172), #edge feats shape: (157475, 172)
Loaded 157,474 interactions

Backbone ready on mps


  state_dict = torch.load(CKPT_PATH, map_location="cpu")


In [4]:
# Build extractor and explainer
extractor = TGEventCandidatesExtractor(
    model=model,
    events=events,
    threshold_num=CANDIDATE_LIMIT,
    keep_order="last-N-then-sort",
)

adapter_cfg = SubgraphXTGAdapterConfig(
    model_name=model_name,
    dataset_name=DATASET_NAME,
    explanation_level="event",
    results_dir=str(RESOURCES_DIR / "results"),
    debug_mode=False,
    threshold_num=CANDIDATE_LIMIT,
    save_results=False,
    load_results=False,
    rollout=30,
    min_atoms=MAX_COALITION_SIZE,
    c_puct=10.0,
    cache=True,
    device=DEVICE,
)

explainer = SubgraphXTGAdapter(adapter_cfg)
dataset_pack = {"events": events, "dataset_name": DATASET_NAME}
explainer.prepare(model=model, dataset=dataset_pack)
print("Explainer prepared with SubgraphX-TG")


Explainer prepared with SubgraphX-TG


In [5]:
# Select a single event and generate an explanation
all_event_indices = load_explain_idx(EXPLAIN_IDX_CSV)
if EVENT_INDEX is None:
    target_event_idx = int(all_event_indices[0])
else:
    target_event_idx = int(EVENT_INDEX)

if not (1 <= target_event_idx <= len(events)):
    raise ValueError(f"EVENT_INDEX {target_event_idx} must be within [1, {len(events)}]")

anchor = {"target_kind": "edge", "event_idx": target_event_idx}
subgraph = extractor.extract(dataset_pack, anchor, k_hop=model.num_layers, num_neighbors=model.num_neighbors)

context = ExplanationContext(
    run_id=f"{MODEL_NAME}_{DATASET_NAME}_single",
    target_kind="edge",
    target={"event_idx": target_event_idx},
    k_hop=model.num_layers,
    num_neighbors=model.num_neighbors,
    subgraph=subgraph,
)

result = explainer.explain(context)
coalition = result.extras.get("coalition_eidx", [])
print(f"Explained event #{target_event_idx} -> coalition size {len(coalition)}")


100 events to explain

explain 0-th: 110314

The nodes in graph is 146


mcts simulating: 100%|██████████| 30/30 [00:13<00:00,  2.29it/s, states=838]

mcts recorder saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/candidate_scores/tgn_wikipedia_110314_mcts_recorder_pg_false_th50.csv
Explained event #110314 -> coalition size 6





In [6]:
# Prepare dataframes for visualisation
anchor_row = events.iloc[target_event_idx - 1].copy()
anchor_row.name = target_event_idx
anchor_user = int(anchor_row["u"])
anchor_item = int(anchor_row["i"])
anchor_edge = (anchor_user, anchor_item)
print("Event to be explained:")
display(anchor_row.to_frame().T)

# Ensure the global snapshot includes the anchor interaction
max_global_edges = GRAPH_MAX_USERS * GRAPH_MAX_ITEMS
head_df = events.head(max_global_edges)
if target_event_idx - 1 not in head_df.index:
    global_view_df = pd.concat([head_df, anchor_row.to_frame().T], ignore_index=False).sort_index()
else:
    global_view_df = head_df.copy()

global_view_df = global_view_df.drop_duplicates().copy()

# Candidate and explanation subsets
candidate_eidx = result.extras.get("candidate_eidx", []) or []
candidate_mask = [idx for idx in candidate_eidx if 1 <= idx <= len(events)]
if candidate_mask:
    candidate_df = events.iloc[[idx - 1 for idx in candidate_mask]].copy()
else:
    candidate_df = pd.DataFrame(columns=events.columns)

if not candidate_df.empty:
    global_view_df = pd.concat([global_view_df, candidate_df], ignore_index=False).drop_duplicates().copy()

coalition = [int(idx) for idx in coalition if 1 <= int(idx) <= len(events)]
if coalition:
    explanation_df = events.iloc[[idx - 1 for idx in coalition]].copy()
    global_view_df = pd.concat([global_view_df, explanation_df], ignore_index=False).drop_duplicates().copy()
else:
    explanation_df = pd.DataFrame(columns=events.columns)

coalition_edges = list({(int(row.u), int(row.i)) for _, row in explanation_df.iterrows()})

print(f"Candidate events considered ({len(candidate_df)})")
display(candidate_df.head())
print(f"Explanation edges returned ({len(explanation_df)})")
display(explanation_df.head())


Event to be explained:


Unnamed: 0,u,i,ts,label,idx,e_idx
110314,1601.0,8266.0,1864050.0,0.0,110314.0,110314.0


Candidate events considered (50)


Unnamed: 0,u,i,ts,label,idx,e_idx
89018,40,8339,1504134.0,0,89019,89019
89029,40,8339,1504254.0,0,89030,89030
89678,1601,8266,1515315.0,0,89679,89679
93887,1601,8266,1579303.0,0,93888,93888
94341,40,8441,1587604.0,0,94342,94342


Explanation edges returned (6)


Unnamed: 0,u,i,ts,label,idx,e_idx
94456,40,8441,1590442.0,0,94457,94457
94891,40,8458,1599510.0,0,94892,94892
99414,40,8487,1673638.0,0,99415,99415
99571,40,8441,1676327.0,0,99572,99572
103992,1601,8266,1750986.0,0,103993,103993


In [7]:
# Visualise dataset, context, and explanation
edges_global = [anchor_edge]
print(f"Dataset overview (first {len(global_view_df)} interactions; explained event marked)")
plot_bipartite_graph(
    global_view_df,
    max_users=GRAPH_MAX_USERS,
    max_items=GRAPH_MAX_ITEMS,
    highlight_users=[anchor_user],
    highlight_items=[anchor_item],
    highlight_edges=edges_global,
)

if not candidate_df.empty:
    print(f"Candidate subgraph supplied to SubgraphX (|C|={len(candidate_df)})")
    plot_bipartite_graph(
        candidate_df,
        max_users=GRAPH_MAX_USERS,
        max_items=GRAPH_MAX_ITEMS,
        highlight_users=[anchor_user],
        highlight_items=[anchor_item],
        highlight_edges=[anchor_edge],
    )
else:
    print("No candidate edges available to plot.")

if not explanation_df.empty:
    print(f"SubgraphX explanation coalition (|S|={len(explanation_df)}) with explained event marked")
    edges_to_explain = list(dict.fromkeys(coalition_edges or [anchor_edge]))
    plot_bipartite_graph(
        explanation_df,
        max_users=GRAPH_MAX_USERS,
        max_items=GRAPH_MAX_ITEMS,
        highlight_users=[anchor_user],
        highlight_items=[anchor_item],
        highlight_edges=edges_to_explain,
    )
else:
    print("Explainer returned an empty coalition.")

highlight_indices = list(dict.fromkeys([target_event_idx, *coalition]))
print("Temporal context around the explained event")
plot_explain_timeline(
    events,
    event_indices=highlight_indices,
    window=TIMELINE_WINDOW,
    max_base_points=20_000,
)


Dataset overview (first 1651 interactions; explained event marked)


Candidate subgraph supplied to SubgraphX (|C|=50)


SubgraphX explanation coalition (|S|=6) with explained event marked


Temporal context around the explained event
