# Train and evaluate Explainers

With all these prerequisites out of the way you can now run the experiments themselves. The experiments are run for each
explanation method (T-GNNExplainer, GreDyCF, CoDy), for each dataset, for each correct/incorrect setting 
(correct predictions only/incorrect predictions only), and for each selection policy (random, temporal, spatio-temporal, 
local-gradient) separately. For convenience, all selection strategies can be automatically evaluated in parallel from a 
single script. An additional feature of the evaluation is that it can be interrupted by Keyboard Interruption or by the
maximum processing time. When the evaluation is interrupted before it is finished, the intermediary results are saved. 
The evaluation automatically resumes from intermediary results.

In [1]:
# Find and add `notebooks/src` to sys.path, no matter where the notebook lives.
from pathlib import Path
import sys, importlib
import os
import subprocess

def _add_notebooks_src_to_path():
    here = Path.cwd().resolve()
    for p in [here, *here.parents]:
        candidate = p / "notebooks" / "src"
        if candidate.is_dir():
            if str(candidate) not in sys.path:
                sys.path.insert(0, str(candidate))
            return candidate
    raise FileNotFoundError("Could not find 'notebooks/src' from current working directory.")

print("Using helpers from:", _add_notebooks_src_to_path())

from constants import (
    REPO_ROOT, PKG_DIR, RESOURCES_DIR, PROCESSED_DATA_DIR, MODELS_ROOT, TGN_SUBMODULE_ROOT, ensure_repo_importable, get_last_checkpoint
)
ensure_repo_importable()
from device import pick_device

for p in (str(TGN_SUBMODULE_ROOT), str(REPO_ROOT), str(PKG_DIR)):
    if p not in sys.path:
        sys.path.insert(0, p)

# 2) If your notebook already imported `utils`, remove it to avoid collision
if "utils" in sys.modules:
    del sys.modules["utils"]

importlib.invalidate_caches()

# 4) (Optional) sanity check that TGN's local packages resolve
import importlib.util as iu
print("utils.utils   ->", iu.find_spec("utils.utils"))
print("modules.memory->", iu.find_spec("modules.memory"))

# 5) Now this import should work without the previous error
from time_to_explain.adapters import*

print("REPO_ROOT        :", REPO_ROOT)
print("PKG_DIR          :", PKG_DIR)
print("RESOURCES_DIR    :", RESOURCES_DIR)
print("PROCESSED_DATA_DIR:", PROCESSED_DATA_DIR)
print("MODELS_ROOT      :", MODELS_ROOT)

Using helpers from: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/notebooks/src
utils.utils   -> ModuleSpec(name='utils.utils', loader=<_frozen_importlib_external.SourceFileLoader object at 0x307d6fa90>, origin='/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/time_to_explain/utils/utils.py')
modules.memory-> ModuleSpec(name='modules.memory', loader=<_frozen_importlib_external.SourceFileLoader object at 0x10699f190>, origin='/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/modules/memory.py')
REPO_ROOT        : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
PKG_DIR          : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/time_to_explain
RESOURCES_DIR    : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources
PROCESSED_DATA_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_

## Setting:

Replace ``MODEL-TYPE`` with the type of the model you want to evaluate, e.g., 'TGAT' or 'TGN'.

Replace ``DATASET-NAME`` with the name of the dataset on which you want to train the PGExplainer model, e.g., 'uci', 
'wikipedia', etc.

Replace ``EXPLAINER-NAME`` with the explainer you want to evaluate. Options are ``tgnnexplainer``, ``greedy``, ``cody``.

Replace ``SELECTION-NAME`` with the selection policy that you want to evaluate. The options are ``random``, 
``temporal``, ``spatio-temporal``, ``local-gradient``, and ``all``. Use the ``all`` option to efficiently evaluate the
different selection strategies with caching between selection strategies.
**Do not provide a `SELECTION-NAME`` argument when evaluating T-GNNExplainer**

Replace ``TIME-LIMIT`` with an integer number that sets a limit on the maximum time that the evaluation runs before 
concluding in minutes. The evaluation can be resumed from that state at a later time.

Only set ``bipartite = True``  if the underlying dataset is a bipartite graph (Wikipedia/UCI-Forums).

As an example, to run the evaluation of CoDy for all selection strategies, with a time limit of 240 minutes and the
bipartite wikipedia dataset, the following command is used:


In [2]:
MODEL_TYPE = "TGAT"
DATASET_NAME = "wikipedia"
EXPLAINER = "cody"
SELECTION_NAME = "all"
TIME_LIMIT = 240
BIPARTITE = True

DIRECTED = False
EPOCHS = 10

MODEL_PATH = MODELS_ROOT / DATASET_NAME
CHECKPOINT_PATH = MODEL_PATH / 'checkpoints/'
if not os.path.exists(CHECKPOINT_PATH):
    os.mkdir(CHECKPOINT_PATH)
LAST_CHECKPOINT = get_last_checkpoint(CHECKPOINT_PATH,MODEL_TYPE, DATASET_NAME)    
DEVICE = pick_device("auto")
print(DEVICE)

mps


## Load Dataset and Model

In [3]:
# --- core imports from your project ---
import torch
from time_to_explain.data.legacy.tg_dataset import load_tg_dataset, load_explain_idx
from submodules.explainer.tgnnexplainer.tgnnexplainer.xgraph.dataset.utils_dataset import construct_tgat_neighbor_finder
from submodules.models.tgat.module import TGAN
from submodules.models.tgn.model.tgn import TGN
from submodules.explainer.tgnnexplainer.tgnnexplainer.xgraph.models.ext.tgn.utils.data_processing import compute_time_statistics

# --- unified framework imports ---
from time_to_explain.core.runner import EvaluationRunner, EvalConfig
from time_to_explain.adapters.subgraphx_tg_adapter import (
    SubgraphXTGAdapter, SubgraphXTGAdapterConfig
)
from time_to_explain.adapters.tg_model_adapter import TemporalGNNModelAdapter
from time_to_explain.extractors.tg_event_candidates_extractor import TGEventCandidatesExtractor
import time_to_explain.metrics.sparsity
import time_to_explain.metrics.fidelity
# --- utilities ---
import os, pandas as pd

# ---- your notebook knobs ----
dataset_name = "wikipedia"          # e.g., "wikipedia", "reddit", ...
model_name   = "tgn"                # "tgn" or "tgat"
use_gpu      = torch.cuda.is_available()
device_id    = 0

# event list file (1-based indices)
explain_idx_csv = str(
    RESOURCES_DIR / "explainer"  / "explain_index" / f"{dataset_name}.csv"
)

# backbone checkpoint path (same pattern as your Hydra pipeline)
ckpt_path = str(
    RESOURCES_DIR / "models" /  f"{dataset_name}" / "checkpoints" / f"{model_name}_{dataset_name}_best.pth"
)
print(ckpt_path)

Using REPO_ROOT / ROOT_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
Using REPO_ROOT / ROOT_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
Using REPO_ROOT / ROOT_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/wikipedia/checkpoints/tgn_wikipedia_best.pth


In [4]:
# Load data
events, edge_feats, node_feats = load_tg_dataset(dataset_name)

# Build the backbone model
if model_name == "tgat":
    ngh_finder = construct_tgat_neighbor_finder(events)
    backbone = TGAN(
        ngh_finder,
        node_feats,
        edge_feats,
        device=DEVICE,
        attn_mode="prod",
        use_time="time",
        agg_method="attn",
        num_layers=2,
        n_head=4,
        null_idx=0,
        num_neighbors=20,
        drop_out=0.1,
    )
elif model_name == "tgn":
    m_src, s_src, m_dst, s_dst = compute_time_statistics(events.u.values, events.i.values, events.ts.values)
    ngh_finder = construct_tgat_neighbor_finder(events)   # your utils often reuse the same NF
    backbone = TGN(
        ngh_finder,
        node_feats,
        edge_feats,
        device=DEVICE,
        n_layers=2,
        n_heads=2,
        dropout=0.1,
        use_memory=False,
        forbidden_memory_update=False,
        memory_update_at_start=True,
        message_dimension=100,
        memory_dimension=500,
        embedding_module_type="graph_attention",
        message_function="mlp",
        mean_time_shift_src=0,
        std_time_shift_src=1,
        mean_time_shift_dst=0,
        std_time_shift_dst=1,
        n_neighbors=None,
        aggregator_type="last",
        memory_updater_type="gru",
        use_destination_embedding_in_message=False,
        use_source_embedding_in_message=False,
        dyrep=False,
    )
else:
    raise NotImplementedError(model_name)

# Load backbone weights
state_dict = torch.load(ckpt_path, map_location="cpu")
_ = backbone.load_state_dict(state_dict, strict=False)
_ = backbone.to(DEVICE).eval()
print("Backbone ready on", DEVICE)
# Wrap backbone with ModelProtocol adapter (adds predict_proba & masking)
model = TemporalGNNModelAdapter(backbone, events, device=DEVICE)



#Dataset: wikipedia, #Users: 8227, #Items: 1000, #Interactions: 157474, #Timestamps: 152757
#node feats shape: (9228, 172), #edge feats shape: (157475, 172)

Backbone ready on mps


  state_dict = torch.load(ckpt_path, map_location="cpu")


## Build Extractor and Explainer

In [5]:
# Extractor creates a stable candidate edge order for each event (needed for fair metrics)
extractor = TGEventCandidatesExtractor(
    model=model,
    events=events,
    threshold_num=50,               # same as your SubgraphX config
    keep_order="last-N-then-sort",  # matches SubgraphX-TG's pattern
)

# Adapter wraps your existing SubgraphX‑TG with the same knobs you use in Hydra
adapter_cfg = SubgraphXTGAdapterConfig(
    model_name=model_name,
    dataset_name=dataset_name,
    explanation_level="event",
    results_dir=str(RESOURCES_DIR / "results"),
    debug_mode=False,
    threshold_num=50,
    save_results=True,
    mcts_saved_dir=str(RESOURCES_DIR / "results" / "tgnnexplainer_subgraphx" / "mcts_saved_dir" ),
    load_results=False,
    rollout=30,
    min_atoms=2,
    c_puct=10.0,
    use_navigator=False,            # set True + navigator_type ("pg"|"mlp"|"dot") if you use a navigator
    # navigator_type="pg",
    # navigator_params={"train_epochs": 10, "lr": 1e-3, "batch_size": 64, "explainer_ckpt_dir": str(ROOT_DIR / "xgraph" / "explainer_ckpts")},
    cache=True,
)

explainer = SubgraphXTGAdapter(adapter_cfg)


## Run Evaluation

In [6]:
# Load 1-based event indices you want to explain (same file the Hydra pipeline uses)
target_event_idxs = load_explain_idx(explain_idx_csv, start=0)

# Build anchors (we keep "target_kind" for compatibility; the crucial piece is event_idx)
N = 5  # take a small batch first
anchors = [{"target_kind": "edge", "event_idx": int(e)} for e in target_event_idxs[:N]]

cfg = EvalConfig(
    out_dir="runs",
    metrics={
        "sparsity": {
            "eps": 1e-6,
            "components": ["edges", "nodes"]
        },
        # fidelity-minus at multiple K (drop top-k edges)
        "fidelity_minus": {
            "topk": [6, 12, 18],
            "result_as_logit": True,     # set False if your model returns probabilities
            "normalize": "minmax",       # for ranking importances
            "by": "value"                # or "abs" if you prefer |w|
        },
        # (optional) sufficiency variant
        "fidelity_keep": {"topk": [6, 12, 18], "result_as_logit": True}
    },
    seed=42
)

# Runner config — 'sparsity' works out-of-the-box; (optional) add 'fidelity' once your model adapter supports masks
runner = EvaluationRunner(
    model=model,
    dataset={"events": events, "dataset_name": dataset_name},
    extractor=extractor,
    explainers=[explainer],
    config=cfg
)

100 events to explain


#### Run

In [7]:

def _ensure_int_attr(obj, primary, alt_names, fallback):
    val = getattr(obj, primary, None)
    if val is None:
        for n in alt_names:
            v2 = getattr(obj, n, None)
            if v2 is not None:
                val = v2
                break
    if val is None:
        val = fallback
    try:
        setattr(obj, primary, int(val))
    except Exception:
        setattr(obj, primary, fallback)

# For layers: prefer .num_layers then .n_layers, else fallback k_hop (=2 default below)
_ensure_int_attr(model, "num_layers", ["n_layers"], 2)

# For neighbors: prefer .num_neighbors then .n_neighbors, else fallback (e.g., 20)
_ensure_int_attr(model, "num_neighbors", ["n_neighbors"], 20)

# Also make the runner pass explicit values (so extractor can use arguments if needed)
out = runner.run(
    anchors,
    k_hop=getattr(model, "num_layers", 2) or 2,
    num_neighbors=getattr(model, "num_neighbors", 20) or 20,
    run_id=f"{model_name}_{dataset_name}_tgnn_explainer_subgraphx"
)
out


explain 0-th: 110314

The nodes in graph is 146


mcts simulating: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s, states=838]


mcts recorder saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/candidate_scores/tgn_wikipedia_110314_mcts_recorder_pg_false_th50.csv
results saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/tgnnexplainer_subgraphx/mcts_saved_dir/tgn_wikipedia_110314_mcts_node_info_pg_false_th50.pt

explain 1-th: 110832

The nodes in graph is 488


mcts simulating: 100%|██████████| 30/30 [00:17<00:00,  1.70it/s, states=838]


mcts recorder saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/candidate_scores/tgn_wikipedia_110832_mcts_recorder_pg_false_th50.csv
results saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/tgnnexplainer_subgraphx/mcts_saved_dir/tgn_wikipedia_110832_mcts_node_info_pg_false_th50.pt

explain 2-th: 111397

The nodes in graph is 447


mcts simulating: 100%|██████████| 30/30 [00:17<00:00,  1.73it/s, states=838]


mcts recorder saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/candidate_scores/tgn_wikipedia_111397_mcts_recorder_pg_false_th50.csv
results saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/tgnnexplainer_subgraphx/mcts_saved_dir/tgn_wikipedia_111397_mcts_node_info_pg_false_th50.pt

explain 3-th: 111915

The nodes in graph is 732


mcts simulating: 100%|██████████| 30/30 [00:13<00:00,  2.20it/s, states=705]


mcts recorder saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/candidate_scores/tgn_wikipedia_111915_mcts_recorder_pg_false_th50.csv
results saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/tgnnexplainer_subgraphx/mcts_saved_dir/tgn_wikipedia_111915_mcts_node_info_pg_false_th50.pt

explain 4-th: 112473

The nodes in graph is 2


mcts simulating: 100%|██████████| 30/30 [00:14<00:00,  2.11it/s, states=705]

mcts recorder saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/candidate_scores/tgn_wikipedia_112473_mcts_recorder_pg_false_th50.csv
results saved at /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/results/tgnnexplainer_subgraphx/mcts_saved_dir/tgn_wikipedia_112473_mcts_node_info_pg_false_th50.pt





{'out_dir': 'runs/tgn_wikipedia_tgnn_explainer_subgraphx',
 'jsonl': 'runs/tgn_wikipedia_tgnn_explainer_subgraphx/results.jsonl',
 'csv': 'runs/tgn_wikipedia_tgnn_explainer_subgraphx/metrics.csv'}

## Inspect Results

In [8]:
# Metrics (CSV): one row per (anchor, explainer)
metrics_df = pd.read_csv(out["csv"])
metrics_df.head()

Unnamed: 0,run_id,anchor_idx,explainer,elapsed_sec,sparsity.edges.n,sparsity.edges.l0,sparsity.edges.zero_frac,sparsity.edges.density,sparsity.edges.gini,sparsity.edges.entropy,...,fidelity_drop.@12,fidelity_drop.prediction_drop.@18,fidelity_drop.@18,fidelity_keep.prediction_full,fidelity_keep.prediction_keep.@6,fidelity_keep.@6,fidelity_keep.prediction_keep.@12,fidelity_keep.@12,fidelity_keep.prediction_keep.@18,fidelity_keep.@18
0,tgn_wikipedia_tgnn_explainer_subgraphx,0,subgraphx_tg_tgn,22.35813,50,48,0.96,0.04,0.94,0.693147,...,0.011043,0.986843,0.003759,0.983084,0.346144,0.63694,0.813248,0.169836,0.80367,0.179414
1,tgn_wikipedia_tgnn_explainer_subgraphx,1,subgraphx_tg_tgn,17.662668,50,48,0.96,0.04,0.94,0.693147,...,0.00399,0.997401,0.003992,0.993409,0.371453,0.621956,0.226732,0.766677,0.195011,0.798398
2,tgn_wikipedia_tgnn_explainer_subgraphx,2,subgraphx_tg_tgn,17.354573,50,48,0.96,0.04,0.94,0.693147,...,0.00103,0.9962,0.000158,0.996358,0.890276,0.106083,0.786143,0.210216,0.631153,0.365205
3,tgn_wikipedia_tgnn_explainer_subgraphx,3,subgraphx_tg_tgn,13.690124,40,38,0.95,0.05,0.925,0.693147,...,0.00123,0.991354,0.007672,0.999026,0.996413,0.002613,0.948619,0.050407,0.989406,0.00962
4,tgn_wikipedia_tgnn_explainer_subgraphx,4,subgraphx_tg_tgn,14.275819,40,38,0.95,0.05,0.925,0.693147,...,0.001276,0.995579,0.001555,0.997134,0.99514,0.001994,0.992922,0.004212,0.997057,7.7e-05


In [9]:
# Basic sanity checks to make sure metric outputs look reasonable
import numpy as np
import re

if metrics_df.empty:
    raise RuntimeError("metrics_df is empty; nothing to evaluate.")

fidelity_cols = [c for c in metrics_df.columns if c.startswith("fidelity_")]
sparsity_cols = [c for c in metrics_df.columns if "sparsity" in c]

summary = {}

# Fidelity checks
fid_arrays = []
for col in fidelity_cols:
    vals = metrics_df[col].to_numpy(dtype=float)
    finite = vals[np.isfinite(vals)]
    if finite.size == 0:
        raise RuntimeError(f"{col} has no finite values (all NaN/inf).")
    fid_arrays.append(finite)
    frac_nan = 1.0 - (finite.size / max(1, vals.size))
    if frac_nan > 0:
        print(f"⚠️  {col}: {frac_nan:.1%} of values were NaN and were dropped from sanity checks.")
    neg_ratio = float((finite < 0).mean())
    if neg_ratio > 0.5:
        print(f"⚠️  {col}: more than half of the finite values are negative ({neg_ratio:.1%}).")

if fid_arrays:
    all_fid = np.concatenate(fid_arrays)
    summary["fidelity"] = {
        "mean": float(all_fid.mean()),
        "std": float(all_fid.std()),
        "min": float(all_fid.min()),
        "max": float(all_fid.max()),
    }
else:
    print("No fidelity_* columns present; skipping fidelity sanity checks.")

# Sparsity checks (should live in [0, 1])
ratio_suffixes = (".zero_frac", ".density")
count_suffixes = (".n", ".l0")

for col in sparsity_cols:
    vals = metrics_df[col].to_numpy(dtype=float)
    finite_mask = np.isfinite(vals)
    finite = vals[finite_mask]
    if finite.size == 0:
        print(f"⚠️  {col}: all values are NaN/inf; skipping.")
        continue
    suffix = next((s for s in ratio_suffixes + count_suffixes if col.endswith(s)), None)
    if suffix in count_suffixes:
        if (finite < -1e-6).any():
            raise RuntimeError(f"{col} contains negative counts: min={finite.min():.3f}")
        summary.setdefault("sparsity_counts", {})[col] = float(finite.mean())
        continue
    if suffix in ratio_suffixes:
        if (finite < -1e-6).any() or (finite > 1 + 1e-6).any():
            raise RuntimeError(
                f"{col} contains values outside [0, 1]: min={finite.min():.3f}, max={finite.max():.3f}"
            )
        summary.setdefault("sparsity", {})[col] = float(finite.mean())
        continue
    summary.setdefault("sparsity_misc", {})[col] = float(finite.mean())
    missing_frac = 1.0 - (finite.size / max(1, vals.size))
    if missing_frac > 0:
        print(f"⚠️  {col}: skipped {missing_frac:.1%} missing values.")

# Monotonicity checks for fidelity@k (per row)
trend_expectations = {
    "fidelity_minus": ("non_decreasing", "Dropping more edges should not reduce the drop in score."),
    "fidelity_keep": ("non_increasing", "Keeping more edges should not increase the drop in score."),
}
pattern = re.compile(r"^(fidelity_(?:minus|keep))\.@(\d+)")
violations = []

for prefix, (expect, desc) in trend_expectations.items():
    prefix_cols = [c for c in fidelity_cols if c.startswith(prefix + ".@")]
    if not prefix_cols:
        continue
    # sort columns by the numeric k
    col_info = []
    for col in prefix_cols:
        m = pattern.match(col)
        if m:
            col_info.append((int(m.group(2)), col))
    col_info.sort()
    if not col_info:
        continue

    for row_idx in metrics_df.index:
        values = []
        for k, col in col_info:
            val = metrics_df.at[row_idx, col]
            if val is None or (isinstance(val, float) and not np.isfinite(val)):
                continue
            val_f = float(val)
            if not np.isfinite(val_f):
                continue
            values.append((k, val_f))
        if len(values) < 2:
            continue
        values.sort()
        ks = [k for k, _ in values]
        seq = np.array([v for _, v in values], dtype=float)
        diffs = np.diff(seq)
        tol = 1e-6
        if expect == "non_decreasing":
            if np.any(diffs < -tol):
                violations.append((prefix, row_idx, ks, seq.tolist(), desc))
        else:  # non_increasing
            if np.any(diffs > tol):
                violations.append((prefix, row_idx, ks, seq.tolist(), desc))

if violations:
    lines = ["Fidelity monotonicity check failed for the following rows:"]
    for prefix, row_idx, ks, seq, desc in violations[:5]:
        anchor = metrics_df.at[row_idx, "anchor_idx"] if "anchor_idx" in metrics_df.columns else row_idx
        expl = metrics_df.at[row_idx, "explainer"] if "explainer" in metrics_df.columns else "?"
        lines.append(
            f"  [{prefix}] row={row_idx} (anchor={anchor}, explainer={expl}): ks={ks} values={seq} :: {desc}"
        )
    raise RuntimeError("\n".join(lines))

print("Metric sanity summary:")
for key, val in summary.items():
    print(f"  {key}: {val}")

RuntimeError: Fidelity monotonicity check failed for the following rows:
  [fidelity_keep] row=0 (anchor=0, explainer=subgraphx_tg_tgn): ks=[6, 12, 18] values=[0.6369396716478828, 0.1698360988847915, 0.1794140770786234] :: Keeping more edges should not increase the drop in score.
  [fidelity_keep] row=1 (anchor=1, explainer=subgraphx_tg_tgn): ks=[6, 12, 18] values=[0.6219558113796256, 0.7666771465682881, 0.7983978901267827] :: Keeping more edges should not increase the drop in score.
  [fidelity_keep] row=2 (anchor=2, explainer=subgraphx_tg_tgn): ks=[6, 12, 18] values=[0.1060826576091397, 0.2102158365633641, 0.3652049407212189] :: Keeping more edges should not increase the drop in score.
  [fidelity_keep] row=3 (anchor=3, explainer=subgraphx_tg_tgn): ks=[6, 12, 18] values=[0.0026126441836936, 0.0504068653314142, 0.0096203419281067] :: Keeping more edges should not increase the drop in score.
  [fidelity_keep] row=4 (anchor=4, explainer=subgraphx_tg_tgn): ks=[6, 12, 18] values=[0.001994034958752, 0.0042117738600718, 7.688482328316315e-05] :: Keeping more edges should not increase the drop in score.

In [None]:
# Full JSONL (per-anchor payloads, including coalition and candidate list in "extras")
import json
with open(out["jsonl"]) as f:
    lines = [json.loads(line) for line in f]
print("Num explanations:", len(lines))
lines[0].keys(), lines[0]["result"].keys()

Num explanations: 5


(dict_keys(['context', 'result', 'metrics']),
 dict_keys(['explainer', 'elapsed_sec', 'importance_edges', 'importance_nodes', 'importance_time', 'extras']))

In [None]:
first = lines[0]
first["result"]["extras"]["event_idx"], first["result"]["extras"]["coalition_eidx"][:10], first["result"]["extras"]["candidate_eidx"][:10]

(110314,
 [94457, 109720],
 [89019, 89030, 89679, 93888, 94342, 94345, 94351, 94353, 94386, 94450])