# Prepare datasets

Set `DATASET_NAME` and run the pipeline. Outputs are written under `resources/datasets`. Synthetic configs live in `configs/datasets` (e.g. `triadic_closure.json`).
Enable `WRITE_FILES` to save figures (PDF recommended for paper-ready outputs).


In [1]:
from pathlib import Path
import sys
import json
from typing import Optional

def _find_repo_root(start: Optional[Path] = None) -> Path:
    start = (start or Path.cwd()).resolve()
    for p in [start, *start.parents]:
        if (p / "time_to_explain").is_dir():
            return p
    raise FileNotFoundError("Could not find repo root containing 'time_to_explain'.")

REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from time_to_explain.utils.device import pick_device
CONFIG_PATH = REPO_ROOT / "configs" / "notebooks" / "global.json"
NOTEBOOK_CFG = json.loads(CONFIG_PATH.read_text(encoding="utf-8")) if CONFIG_PATH.exists() else {}
SEED = int(NOTEBOOK_CFG.get("seed", 42))
DEVICE = pick_device(NOTEBOOK_CFG.get("device", "auto"))
print(f"Notebook config: seed={SEED}, device={DEVICE}")

print("Repo root:", REPO_ROOT)


Notebook config: seed=42, device=mps
Repo root: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain


In [2]:
from pathlib import Path
import pandas as pd
# Examples: "nicolaus", "triadic_closure", "stick_figure", "erdos_small", "hawkes_small", "wikipedia", "reddit"
DATASET_NAME = "wikipedia"
RECIPE = None  # Optional: override the registry recipe for custom names
FROM_CACHE = False
OVERWRITE = True

ENSURE_REAL = True  # Downloads/processing for wikipedia/reddit/simulate_v1/v2
FORCE_DOWNLOAD = False
INDEX_SIZE = 500

SHOW = True
WRITE_FILES = True
SAVE_FORMAT = "pdf"
SAVE_DIR = Path("viz_out") / DATASET_NAME if WRITE_FILES else None


In [3]:
from time_to_explain.data.tgnn_prepare import prepare_tgnn_dataset
import json
import numpy as np
from time_to_explain.data.io import load_processed_dataset
from time_to_explain.data.synthetic_recipes.stick_figure import (
    write_stick_figure_explain_index,
    write_sticky_hips_explain_index,
)
from time_to_explain.visualization import visualize_dataset
import shutil

paths, summary = prepare_tgnn_dataset(
    DATASET_NAME,
    recipe=RECIPE,
    from_cache=FROM_CACHE,
    overwrite=OVERWRITE,
    ensure_real=ENSURE_REAL,
    force_download=FORCE_DOWNLOAD,
    index_size=INDEX_SIZE,
    verbose=True,
)

if DATASET_NAME in {"stick_figure", "sticky_hips"}:
    bundle = load_processed_dataset(paths.ml_csv)
    cfg = (bundle.get("metadata") or {}).get("config") or {}
    if not cfg:
        raise ValueError("Missing stick-figure config in metadata.")
    gt_raw_path = paths.processed_dir / f"{DATASET_NAME}_gt_raw.json"
    gt_path = paths.processed_dir / f"{DATASET_NAME}_gt.json"
    if DATASET_NAME == "stick_figure":
        write_stick_figure_explain_index(
            paths.ml_csv,
            paths.explain_idx,
            config=cfg,
            last_k_frames=1,
            test_split=0.85,
            overwrite=OVERWRITE,
            gt_raw_out=gt_raw_path,
            gt_out=gt_path,
        )
    else:
        write_sticky_hips_explain_index(
            paths.ml_csv,
            paths.explain_idx,
            config=cfg,
            last_k_frames=1,
            test_split=0.85,
            overwrite=OVERWRITE,
            gt_raw_out=gt_raw_path,
            gt_out=gt_path,
        )
    legacy_dir = paths.root_dir / "resources" / "explainer" / "explain_index"
    legacy_dir.mkdir(parents=True, exist_ok=True)
    shutil.copy2(paths.explain_idx, legacy_dir / f"{DATASET_NAME}.csv")



if DATASET_NAME in {"triadic_closure", "tri_closure"}:
    tri_df = pd.read_csv(paths.ml_csv)
    ts = tri_df["ts"].to_numpy()
    t_train = np.quantile(ts, 0.70)
    t_val = np.quantile(ts, 0.85)

    gt_path = paths.processed_dir / f"{DATASET_NAME}_gt.json"
    if gt_path.exists():
        gt_data = json.loads(gt_path.read_text(encoding="utf-8"))
        explain_idx_all = np.array(sorted(int(k) for k in gt_data.keys()), dtype=int)
    else:
        bundle = load_processed_dataset(paths.ml_csv)
        gt = (bundle.get("metadata") or {}).get("ground_truth") or {}
        raw_targets = [int(t) for t in gt.get("targets", [])]
        if not raw_targets:
            raise ValueError("No ground_truth targets found in metadata.")
        explain_idx_all = np.array(
            sorted(int(tri_df.iloc[t]["idx"]) for t in raw_targets),
            dtype=int,
        )

    idx_to_ts = dict(zip(tri_df["idx"].astype(int).values, tri_df["ts"].values))
    explain_idx_test = [i for i in explain_idx_all if idx_to_ts.get(i, -np.inf) >= t_val]

    out_path = paths.explain_dir / f"{DATASET_NAME}_explain_indexes_test.json"
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(json.dumps(explain_idx_test, indent=2), encoding="utf-8")
    print("test explain targets:", len(explain_idx_test))
    print("saved:", out_path)
explain_indices = summary.get("explain_idx_preview")
if paths.explain_idx.exists():
    try:
        explain_df = pd.read_csv(paths.explain_idx)
        preview = explain_df.get("event_idx")
        if preview is not None:
            explain_indices = preview.head(5).astype(int).tolist()
    except Exception:
        pass

if explain_indices:
    print(f"Explain preview: {explain_indices}")

visuals = visualize_dataset(
    paths.ml_csv,
    explain_indices=explain_indices,
    show=SHOW,
    save_dir=SAVE_DIR,
    export_format=SAVE_FORMAT,
)

summary




Dataset: wikipedia
Layout : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/datasets/processed
TGNN setup: download + process + index
â†“ Downloading http://snap.stanford.edu/jodie/wikipedia.csv -> /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/datasets/raw/wikipedia.csv


ValueError: wikipedia: raw CSV missing columns ['u', 'i', 'ts', 'label']. Columns=['user_id', 'item_id', 'timestamp', 'state_label', 'comma_separated_list_of_features']

In [None]:
# Render a few stick-figure clips from the real processed data.
if DATASET_NAME in {"stick_figure", "sticky_hips"}:
    from time_to_explain.visualization import animate_stick_figure
    from time_to_explain.data.io import load_processed_dataset

    bundle = load_processed_dataset(paths.ml_csv if "paths" in globals() else DATASET_NAME)
    clip_ids = [0, 1, 2]
    for clip_id in clip_ids:
        try:
            animate_stick_figure(bundle, clip_id=clip_id, show=SHOW)
        except ValueError as exc:
            print(f"Skip clip {clip_id}: {exc}")
else:
    print("Set DATASET_NAME = 'stick_figure' or 'sticky_hips' to render stick-figure animations.")



In [None]:
import random

from time_to_explain.data.io import load_processed_dataset

bundle = load_processed_dataset(paths.ml_csv if "paths" in globals() else DATASET_NAME)
df = bundle["interactions"]
meta = bundle.get("metadata") or {}

gt = meta.get("ground_truth") or {}
targets = [int(t) for t in gt.get("targets", [])]
rationales = gt.get("rationales") or {}

if not targets or not rationales:
    raise ValueError("No ground_truth targets/rationales found in metadata.")

target = int(random.choice(targets))
support = rationales.get(str(target), rationales.get(target))
if not support or len(support) < 2:
    raise ValueError(f"Not enough support edges for target {target}.")

s1, s2 = (int(support[0]), int(support[1]))

row_t = df.iloc[target]
row_s1 = df.iloc[s1]
row_s2 = df.iloc[s2]

u, v, ts = int(row_t["u"]), int(row_t["i"]), float(row_t["ts"])
u1, w1, ts1 = int(row_s1["u"]), int(row_s1["i"]), float(row_s1["ts"])
w2, v2, ts2 = int(row_s2["u"]), int(row_s2["i"]), float(row_s2["ts"])

print("TARGET:", target, (u, v), ts)
print("SUP1:  ", s1, (u1, w1), ts1)
print("SUP2:  ", s2, (w2, v2), ts2)
print(
    "Checks:",
    "u matches:", u1 == u,
    "v matches:", v2 == v,
    "w matches:", w1 == w2,
    "time order:", ts1 < ts2 < ts,
)


TARGET: 104075 (3409, 3410) 104075.0
SUP1:   104008 (3405, 3407) 104008.0
SUP2:   104009 (3407, 3409) 104009.0
Checks: u matches: False v matches: False w matches: True time order: True
