# Run TempME for Wikipedia + TGN

This notebook runs the TempME submodule (`learn_base.py` + `temp_exp_main.py`) for `data=wikipedia` and `base_type=tgn`, then parses and saves run metrics.


In [1]:
from __future__ import annotations

import os
import sys
from pathlib import Path


def _bootstrap_repo_root(start: Path | None = None) -> Path:
    here = (start or Path.cwd()).resolve()
    for candidate in (here, *here.parents):
        if (candidate / "time_to_explain").is_dir() and (candidate / "notebooks").is_dir():
            return candidate
    raise RuntimeError(f"Could not locate repository root from {here}")


PROJECT_ROOT = _bootstrap_repo_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

TEMP_ME_ROOT_CANDIDATES = [
    PROJECT_ROOT / "submodules" / "explainer" / "TempME",
    PROJECT_ROOT / "submodules" / "explainer" / "tempme",
]
TEMP_ME_ROOT = next((p for p in TEMP_ME_ROOT_CANDIDATES if p.exists()), None)
if TEMP_ME_ROOT is None:
    raise FileNotFoundError("Could not find TempME submodule under submodules/explainer")

DATASET_NAME = "wikipedia"
BASE_TYPE = "tgn"

# Quick mode is practical on CPU. Set False for full official epochs on GPU.
QUICK_RUN = True
FORCE_RERUN_BASE = False
FORCE_RERUN_EXPLAINER = False

graphs_python = Path("/Users/juliawenkmann/miniconda3/envs/graphs/bin/python")
PYTHON_BIN = str(graphs_python if graphs_python.exists() else Path(sys.executable))

if QUICK_RUN:
    BASE_OVERRIDES = {"gpu": -1, "n_epoch": 1, "bs": 1024}
    EXPLAINER_OVERRIDES = {"gpu": -1, "n_epoch": 1, "bs": 512, "test_bs": 512, "verbose": 1}
else:
    BASE_OVERRIDES = {"gpu": 0}
    EXPLAINER_OVERRIDES = {"gpu": 0, "verbose": 1}

print("PROJECT_ROOT:", PROJECT_ROOT)
print("TEMP_ME_ROOT:", TEMP_ME_ROOT)
print("PYTHON_BIN:", PYTHON_BIN)
print("QUICK_RUN:", QUICK_RUN)
print("BASE_OVERRIDES:", BASE_OVERRIDES)
print("EXPLAINER_OVERRIDES:", EXPLAINER_OVERRIDES)


PROJECT_ROOT: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
TEMP_ME_ROOT: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/explainer/TempME
PYTHON_BIN: /Users/juliawenkmann/miniconda3/envs/graphs/bin/python
QUICK_RUN: True
BASE_OVERRIDES: {'gpu': -1, 'n_epoch': 1, 'bs': 1024}
EXPLAINER_OVERRIDES: {'gpu': -1, 'n_epoch': 1, 'bs': 512, 'test_bs': 512, 'verbose': 1}


In [2]:
from shutil import copy2

RESOURCES_PROCESSED = PROJECT_ROOT / "resources" / "datasets" / "processed"
TEMP_ME_PROCESSED = TEMP_ME_ROOT / "processed"
TEMP_ME_PROCESSED.mkdir(parents=True, exist_ok=True)

for name in (
    f"ml_{DATASET_NAME}.csv",
    f"ml_{DATASET_NAME}.npy",
    f"ml_{DATASET_NAME}_node.npy",
):
    src = RESOURCES_PROCESSED / name
    dst = TEMP_ME_PROCESSED / name
    if src.exists() and not dst.exists():
        copy2(src, dst)
        print(f"Copied {src} -> {dst}")

required_pack = [
    TEMP_ME_PROCESSED / f"{DATASET_NAME}_train_cat.h5",
    TEMP_ME_PROCESSED / f"{DATASET_NAME}_test_cat.h5",
    TEMP_ME_PROCESSED / f"{DATASET_NAME}_train_edge.npy",
    TEMP_ME_PROCESSED / f"{DATASET_NAME}_test_edge.npy",
]
if not all(p.exists() for p in required_pack):
    from time_to_explain.data.tempme_preprocess import TempMEPreprocessConfig, prepare_tempme_dataset

    cfg = TempMEPreprocessConfig(
        dataset_name=DATASET_NAME,
        processed_dir=RESOURCES_PROCESSED,
        output_dir=TEMP_ME_PROCESSED,
        overwrite=False,
        validate_existing=False,
    )
    out = prepare_tempme_dataset(cfg)
    print("Prepared TempME packs:", out)
else:
    print("TempME preprocessed packs already present.")

PARAMS_ROOT = TEMP_ME_ROOT / "params"
BASE_CKPT = PARAMS_ROOT / "tgnn" / f"{BASE_TYPE}_{DATASET_NAME}.pt"
EXPL_CKPT = PARAMS_ROOT / "explainer" / BASE_TYPE / f"{DATASET_NAME}.pt"

print("Base checkpoint:", BASE_CKPT)
print("Explainer checkpoint:", EXPL_CKPT)


TempME preprocessed packs already present.
Base checkpoint: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/explainer/TempME/params/tgnn/tgn_wikipedia.pt
Explainer checkpoint: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/explainer/TempME/params/explainer/tgn/wikipedia.pt


In [3]:
import json
import re
import shlex
import subprocess
from datetime import datetime

import pandas as pd


def _cli_args(args: dict[str, object]) -> list[str]:
    out: list[str] = []
    for key, value in args.items():
        out.extend([f"--{key}", str(value)])
    return out


def run_and_capture(cmd: list[str], cwd: Path) -> tuple[int, list[str]]:
    env = os.environ.copy()
    pythonpath = env.get("PYTHONPATH", "")
    env["PYTHONPATH"] = f"{PROJECT_ROOT}{os.pathsep}{pythonpath}" if pythonpath else str(PROJECT_ROOT)

    print("$", shlex.join(cmd))
    proc = subprocess.Popen(
        cmd,
        cwd=str(cwd),
        env=env,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
    )

    lines: list[str] = []
    assert proc.stdout is not None
    for raw in proc.stdout:
        print(raw, end="")
        lines.append(raw.rstrip("\n"))

    rc = proc.wait()
    return rc, lines


def parse_learn_base_metrics(lines: list[str]) -> dict[str, float]:
    out: dict[str, float] = {}
    p_acc = re.compile(r"train acc:\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?),\s*test acc:\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)")
    p_ap = re.compile(r"train ap:\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?),\s*test ap:\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)")
    p_auc = re.compile(r"train auc:\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?),\s*test auc:\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)")

    for line in lines:
        m = p_acc.search(line)
        if m:
            out["train_acc"] = float(m.group(1))
            out["test_acc"] = float(m.group(2))
        m = p_ap.search(line)
        if m:
            out["train_ap"] = float(m.group(1))
            out["test_ap"] = float(m.group(2))
        m = p_auc.search(line)
        if m:
            out["train_auc"] = float(m.group(1))
            out["test_auc"] = float(m.group(2))

    return out


def parse_testing_epoch_metrics(lines: list[str]) -> dict[str, float]:
    testing_line = None
    for line in reversed(lines):
        if line.strip().startswith("Testing Epoch:"):
            testing_line = line
            break

    if testing_line is None:
        return {}

    out: dict[str, float] = {}
    for part in testing_line.split("|"):
        if ":" not in part:
            continue
        key, value = part.split(":", 1)
        key = key.strip().lower().replace(" ", "_")
        value = value.strip()
        try:
            out[key] = float(value)
        except ValueError:
            continue

    return out


def save_metrics_and_logs(metrics_rows: list[dict], logs: dict[str, list[str]]) -> tuple[Path, Path]:
    out_dir = PROJECT_ROOT / "notebooks" / "runs" / "tempme_wikipedia_tgn"
    out_dir.mkdir(parents=True, exist_ok=True)
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")

    metrics_df = pd.DataFrame(metrics_rows)
    metrics_path = out_dir / f"metrics_{ts}.csv"
    metrics_df.to_csv(metrics_path, index=False)

    logs_path = out_dir / f"logs_{ts}.json"
    logs_path.write_text(json.dumps(logs, indent=2), encoding="utf-8")

    return metrics_path, logs_path


In [None]:
base_logs: list[str] = []
base_metrics: dict[str, float] = {}

base_cmd = [
    PYTHON_BIN,
    str(TEMP_ME_ROOT / "learn_base.py"),
    "--base_type", BASE_TYPE,
    "--data", DATASET_NAME,
    *_cli_args(BASE_OVERRIDES),
]

if BASE_CKPT.exists() and not FORCE_RERUN_BASE:
    print(f"Skipping learn_base.py (checkpoint exists): {BASE_CKPT}")
else:
    rc, base_logs = run_and_capture(base_cmd, TEMP_ME_ROOT)
    if rc != 0:
        raise RuntimeError(f"learn_base.py failed with exit code {rc}")

base_metrics = parse_learn_base_metrics(base_logs)
base_metrics


$ /Users/juliawenkmann/miniconda3/envs/graphs/bin/python /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/explainer/TempME/learn_base.py --base_type tgn --data wikipedia --gpu -1 --n_epoch 1 --bs 1024
dataset:wikipedia, base_type model:tgn
num of training instances: 79376
num of batches per epoch: 78
start 0 epoch

  0%|          | 0/78 [00:00<?, ?it/s]
  1%|▏         | 1/78 [00:22<28:35, 22.28s/it]
  3%|▎         | 2/78 [00:42<26:32, 20.95s/it]
  4%|▍         | 3/78 [00:58<23:45, 19.01s/it]
  5%|▌         | 4/78 [01:16<22:28, 18.23s/it]
  6%|▋         | 5/78 [01:49<28:45, 23.63s/it]
  8%|▊         | 6/78 [02:38<38:40, 32.22s/it]
  9%|▉         | 7/78 [03:21<42:35, 36.00s/it]
 10%|█         | 8/78 [04:14<48:10, 41.30s/it]
 12%|█▏        | 9/78 [04:58<48:24, 42.10s/it]


In [None]:
expl_logs: list[str] = []
expl_metrics: dict[str, float] = {}

exp_cmd = [
    PYTHON_BIN,
    str(TEMP_ME_ROOT / "temp_exp_main.py"),
    "--base_type", BASE_TYPE,
    "--data", DATASET_NAME,
    *_cli_args(EXPLAINER_OVERRIDES),
]

if EXPL_CKPT.exists() and not FORCE_RERUN_EXPLAINER:
    print(f"Skipping temp_exp_main.py (checkpoint exists): {EXPL_CKPT}")
else:
    rc, expl_logs = run_and_capture(exp_cmd, TEMP_ME_ROOT)
    if rc != 0:
        raise RuntimeError(f"temp_exp_main.py failed with exit code {rc}")

expl_metrics = parse_testing_epoch_metrics(expl_logs)
expl_metrics


In [None]:
rows: list[dict[str, object]] = []
if base_metrics:
    rows.append({"stage": "learn_base", **base_metrics})
if expl_metrics:
    rows.append({"stage": "temp_exp_main", **expl_metrics})

metrics_df = pd.DataFrame(rows)
display(metrics_df)

logs_payload = {
    "learn_base": base_logs,
    "temp_exp_main": expl_logs,
}
metrics_path, logs_path = save_metrics_and_logs(rows, logs_payload)
print("Saved metrics to:", metrics_path)
print("Saved logs to:", logs_path)
