# Train the PGExplainer model
T-GNNExplainer relies on a pretrained navigator that is realized as dynamic adaptation of PGExplainer. 
Thus, this navigator component has to be trained prior to evaluating T-GNNExplainer. 

In [2]:
# Find and add `notebooks/src` to sys.path, no matter where the notebook lives.
from pathlib import Path
import sys, importlib
import os
import subprocess

def _add_notebooks_src_to_path():
    here = Path.cwd().resolve()
    for p in [here, *here.parents]:
        candidate = p / "notebooks" / "src"
        if candidate.is_dir():
            if str(candidate) not in sys.path:
                sys.path.insert(0, str(candidate))
            return candidate
    raise FileNotFoundError("Could not find 'notebooks/src' from current working directory.")

print("Using helpers from:", _add_notebooks_src_to_path())

from constants import (
    REPO_ROOT, PKG_DIR, RESOURCES_DIR, PROCESSED_DATA_DIR, MODELS_ROOT, TGN_SUBMODULE_ROOT, ensure_repo_importable, get_last_checkpoint
)
ensure_repo_importable()
from device import pick_device

for p in (str(TGN_SUBMODULE_ROOT), str(REPO_ROOT), str(PKG_DIR)):
    if p not in sys.path:
        sys.path.insert(0, p)

# 2) If your notebook already imported `utils`, remove it to avoid collision
if "utils" in sys.modules:
    del sys.modules["utils"]

importlib.invalidate_caches()

# 4) (Optional) sanity check that TGN's local packages resolve
import importlib.util as iu
print("utils.utils   ->", iu.find_spec("utils.utils"))
print("modules.memory->", iu.find_spec("modules.memory"))

# 5) Now this import should work without the previous error
from time_to_explain.explainer.train_pgexplainer import (
    train_pgexplainer
)

print("REPO_ROOT        :", REPO_ROOT)
print("PKG_DIR          :", PKG_DIR)
print("RESOURCES_DIR    :", RESOURCES_DIR)
print("PROCESSED_DATA_DIR:", PROCESSED_DATA_DIR)
print("MODELS_ROOT      :", MODELS_ROOT)

Using helpers from: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/notebooks/src
utils.utils   -> ModuleSpec(name='utils.utils', loader=<_frozen_importlib_external.SourceFileLoader object at 0x34377d7d0>, origin='/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/utils/utils.py')
modules.memory-> ModuleSpec(name='modules.memory', loader=<_frozen_importlib_external.SourceFileLoader object at 0x34377e3d0>, origin='/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/modules/memory.py')


ImportError: cannot import name 'create_tgnn_wrapper' from 'time_to_explain.models.wrapper' (/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/time_to_explain/models/wrapper.py)

## Setting

Replace ``MODEL-TYPE`` with the type of the model you want to evaluate, e.g., 'TGAT' or 'TGN'.

Replace ``DATASET-NAME`` with the name of the dataset on which you want to train the PGExplainer model, e.g., 'uci', 
'wikipedia', etc.

Only provide the ``--bipartite`` flag if the underlying dataset is a bipartite graph (Wikipedia/UCI-Forums), else
omit the ``--bipartite`` flag from the command.

In [None]:
MODEL_TYPE = "TGAT"
DATASET_NAME = "wikipedia"
BIPARTITE = True

DIRECTED = False
EPOCHS = 30
LEARNING_RATE = 0.0001
BATCH_SIZE = 16

MODEL_PATH = MODELS_ROOT / DATASET_NAME
CHECKPOINT_PATH = MODEL_PATH / 'checkpoints/'
if not os.path.exists(CHECKPOINT_PATH):
    os.mkdir(CHECKPOINT_PATH)
LAST_CHECKPOINT = get_last_checkpoint(CHECKPOINT_PATH,MODEL_TYPE, DATASET_NAME)    
DEVICE = pick_device("auto")
print(DEVICE)

mps


## Train PGExplainer

In [None]:
def train_pgexplainer(
    *,
    dataset_name: str,                 # e.g., "wikipedia"
    model_type: str,                   # "TGN" or "TGAT" (or "TTGN" if you kept that label)
    epochs: int = 100,
    learning_rate: float = 1e-4,
    batch_size: int = 16,
    directed: bool = False,
    bipartite: bool = False,
    device: str = "auto",              # "auto" | "cpu" | "cuda" | "mps"
    cuda: bool = True,                 # legacy flag still supported by your device selector
    candidates_size: int = 30,         # kept for compatibility (some wrappers read it)
    # Base model checkpoint to explain:
    tgn_checkpoint: str | Path | None = None,
    # Where to place the PGExplainer outputs:
    out_dir: str | Path | None = None,  # default: MODELS_ROOT/<dataset>/pg_explainer
    # Roots (normally from constants.py)
    models_root: str | Path = None,
    processed_root: str | Path = None,
):
    """
    Train PGExplainer in-notebook with variables only.

    Returns:
      explainer, wrapper, out_dir (Path)
    """
    dataset_dir = Path(processed_root) / dataset_name
    if not dataset_dir.exists():
        raise FileNotFoundError(f"Processed dataset not found: {dataset_dir}")

    out_dir = Path(out_dir) if out_dir else Path(models_root) / dataset_name / "pg_explainer"
    out_dir.mkdir(parents=True, exist_ok=True)

    # Resolve base model checkpoint (file)
    resume_file = get_last_checkpoint(
        models_root=models_root,
        dataset=dataset_name,
        model_type=model_type,
        tgn_checkpoint=tgn_checkpoint,
    )
    if resume_file is None:
        raise FileNotFoundError(
            "Could not resolve a base TGN/TGAT checkpoint to explain.\n"
            f"Searched under: {models_root}/{dataset_name}/checkpoints and "
            f"{models_root}/{dataset_name}/{model_type}-{dataset_name}.pth\n"
            "Pass `tgn_checkpoint` explicitly if needed."
        )

    # Build the wrapper (loads weights from resume_file)
    wrapper = create_tgnn_wrapper(
        model_type=model_type,
        dataset_dir=str(dataset_dir),
        directed=directed,
        bipartite=bipartite,
        device=device,
        cuda=cuda,
        update_memory_at_start=False,
        checkpoint_path=resume_file,   # <-- FILE to load
    )

    # Optional compatibility knob (some code reads this)
    try:
        setattr(wrapper, "explanation_candidates_size", candidates_size)
    except Exception:
        pass

    # Make sure eval neighbor finder exists
    _ensure_full_neighbor_finder(wrapper)

    # Build PGExplainer
    embedding = StaticEmbedding(wrapper.dataset, wrapper)
    explainer = TPGExplainer(wrapper, embedding, device=wrapper.device)

    # Train PGExplainer
    model_name = getattr(wrapper, "name", getattr(wrapper, "model_name", model_type))
    print(f"Training PGExplainer for base model '{model_name}' on dataset '{dataset_name}'")
    print("dataset dir   :", dataset_dir)
    print("output dir    :", out_dir)
    print("base checkpoint:", resume_file)
    print("device        :", wrapper.device)
    print("epochs        :", epochs, "| lr:", learning_rate, "| batch:", batch_size)

    explainer.train(
        epochs=epochs,
        learning_rate=learning_rate,
        batch_size=batch_size,
        model_name=model_name,
        save_directory=str(out_dir),
    )

    return explainer, wrapper, out_dir


In [None]:
explainer, base_wrapper, pg_dir = train_pgexplainer(
    dataset_name=DATASET_NAME,
    model_type=MODEL_TYPE,
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    directed=DIRECTED,
    bipartite=BIPARTITE,             # set if your dataset is bipartite
    device=DEVICE,              # "cuda"/"cpu"/"mps" also fine               
    # tgn_checkpoint="/path/to/TGAT-wikipedia-19.pth",  # optional explicit file
    # out_dir="/custom/output/path",                    # optional custom folder
)