# Training mdoels

In [1]:
# Find and add `notebooks/src` to sys.path, no matter where the notebook lives.
from pathlib import Path
import sys, importlib
import os
import subprocess

def _add_notebooks_src_to_path():
    here = Path.cwd().resolve()
    for p in [here, *here.parents]:
        candidate = p / "notebooks" / "src"
        if candidate.is_dir():
            if str(candidate) not in sys.path:
                sys.path.insert(0, str(candidate))
            return candidate
    raise FileNotFoundError("Could not find 'notebooks/src' from current working directory.")

print("Using helpers from:", _add_notebooks_src_to_path())

from constants import (
    REPO_ROOT, PKG_DIR, RESOURCES_DIR, PROCESSED_DATA_DIR, MODELS_ROOT, TGN_SUBMODULE_ROOT, ensure_repo_importable, get_last_checkpoint
)
ensure_repo_importable()
from device import pick_device

for p in (str(TGN_SUBMODULE_ROOT), str(REPO_ROOT), str(PKG_DIR)):
    if p not in sys.path:
        sys.path.insert(0, p)

# 2) If your notebook already imported `utils`, remove it to avoid collision
if "utils" in sys.modules:
    del sys.modules["utils"]

importlib.invalidate_caches()

# 4) (Optional) sanity check that TGN's local packages resolve
import importlib.util as iu
print("utils.utils   ->", iu.find_spec("utils.utils"))
print("modules.memory->", iu.find_spec("modules.memory"))

# 5) Now this import should work without the previous error
from time_to_explain.models.wrapper import (
    create_dataset, create_tgn_wrapper, create_wrapper, create_tgat_wrapper
)

print("REPO_ROOT        :", REPO_ROOT)
print("PKG_DIR          :", PKG_DIR)
print("RESOURCES_DIR    :", RESOURCES_DIR)
print("PROCESSED_DATA_DIR:", PROCESSED_DATA_DIR)
print("MODELS_ROOT      :", MODELS_ROOT)

Using helpers from: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/notebooks/src
utils.utils   -> ModuleSpec(name='utils.utils', loader=<_frozen_importlib_external.SourceFileLoader object at 0x105446d50>, origin='/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/utils/utils.py')
modules.memory-> ModuleSpec(name='modules.memory', loader=<_frozen_importlib_external.SourceFileLoader object at 0x10545f110>, origin='/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/modules/memory.py')
REPO_ROOT        : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
PKG_DIR          : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/time_to_explain
RESOURCES_DIR    : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources
PROCESSED_DATA_DIR: /Users/juliawenkmann/Documents/CodingProjects/master_thesis

## Setting

In [2]:
MODEL_TYPE = "TGAT"
DATASET_NAME = "wikipedia"
BIPARTITE = True
DIRECTED = False
EPOCHS = 30

MODEL_PATH = MODELS_ROOT / DATASET_NAME
CHECKPOINT_PATH = MODEL_PATH / 'checkpoints/'
if not os.path.exists(CHECKPOINT_PATH):
    os.mkdir(CHECKPOINT_PATH)
LAST_CHECKPOINT = get_last_checkpoint(CHECKPOINT_PATH,MODEL_TYPE, DATASET_NAME)    
DEVICE = pick_device("auto")
print(DEVICE)

mps


In [3]:
def train_tgnn_variables(
    *,
    dataset_name: str,              
    model_type: str = "TGN",        # "TGN" or "TGAT"
    epochs: int = 30,
    directed: bool = False,
    bipartite: bool = False,
    device: str = "auto",           # "auto" | "cpu" | "cuda" | "mps"
    cuda: bool = True,              # legacy flag supported by your device selector
    update_memory_at_start: bool = False,
    model_dir = MODEL_PATH,
    checkpoint_path: str | None = None,  # resume/init checkpoint (optional)
    last_checkpoint: str | None = None, 
    out_root: str | Path | None = None,  # defaults to MODELS_ROOT
):
    """
    Train T-GNN/TGAT using only Python variables (no argparse).

    Saves:
      <out_root>/<dataset_name>/checkpoints/
      <out_root>/<dataset_name>/results.pkl

    Returns:
      wrapper (so you can evaluate/inspect in the notebook)
      model_dir (Path to the run's artifacts)
      results_path (Path to results.pkl)
    """
    # ---- Resolve paths
    dataset_dir = Path(PROCESSED_DATA_DIR) / dataset_name
    if not dataset_dir.exists():
        raise FileNotFoundError(
            f"Processed dataset directory not found: {dataset_dir}\n"
            f"Set PROCESSED_DATA_DIR in constants.py or fix dataset_name."
        )

    # ---- Build wrapper (variable-based API, no parser)
    wrapper = create_wrapper(
        model_type=model_type,
        dataset_dir=str(dataset_dir),
        directed=directed,
        bipartite=bipartite,
        device=device,
        update_memory_at_start=update_memory_at_start,
        checkpoint_path=last_checkpoint,
    )

    # ---- Train
    results_path = model_dir / "results.pkl"

    print(f"Training {model_type} on '{dataset_name}'"
          f"{' (bipartite)' if bipartite else ''}"
          f"{' (directed)' if directed else ''} …")
    print("repo root     :", REPO_ROOT)
    print("dataset dir   :", dataset_dir)
    print("model out dir :", model_dir)
    print("epochs        :", epochs)
    print("device        :", device, "(cuda flag:", cuda, ")")
    if checkpoint_path:
        print("resume from   :", checkpoint_path)

    wrapper.train_model(
        epochs,
        checkpoint_path=str(checkpoint_path),
        results_path=str(results_path),
    )

    return wrapper, model_dir, results_path


In [4]:
wrapper, model_dir, results_path = train_tgnn_variables(
    dataset_name=DATASET_NAME,
    model_type=MODEL_TYPE,
    epochs=EPOCHS,
    directed=DIRECTED,
    bipartite=BIPARTITE,
    device=DEVICE,   # or "cuda"/"cpu"/"mps"
    model_dir=MODEL_PATH,
    checkpoint_path=CHECKPOINT_PATH,  # optional resume
    last_checkpoint=LAST_CHECKPOINT,
    #out_root="/custom/output/root",      # defaults to MODELS_ROOT
)    

Creating TGATWrapper...
Training TGAT on 'wikipedia' (bipartite) …
repo root     : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
dataset dir   : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/datasets/processed/wikipedia
model out dir : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/wikipedia
epochs        : 30
device        : mps (cuda flag: True )
resume from   : /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/resources/models/wikipedia/checkpoints
The dataset has 157474 interactions, involving 9227 different nodes
The training dataset has 79750 interactions, involving 5992 different nodes
The validation dataset has 23621 interactions, involving 3256 different nodes
The test dataset has 23621 interactions, involving 3564 different nodes
The new node validation dataset has 11738 interactions, involving 2115 different nodes
The new node test dat

INFO:TGNNWrapper:num of training instances: 79750
INFO:TGNNWrapper:num of batches per epoch: 2493
INFO:TGNNWrapper:start 0 epoch


Epoch 0:   0%|          | 0/2493 [00:00<?, ?it/s]

AttributeError: 'TGATWrapper' object has no attribute 'full_ngh_finder'