# Training Launcher (TGN / TGAT)

This block adds a simple switch to run **TGN** or **TGAT** training from the notebook.

- **TGN**: calls your script  
  `/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/train_self_supervised.py`  
  with the exact args you gave: `--use_memory --prefix tgn-attn --n_runs 10`.

- **TGAT**: included as a template. Adjust the `TGAT_SCRIPT` path and arguments to match your repo (see the code cell for details).

> Tip: set `DRY_RUN=True` first to confirm the command, then set it to `False` to actually launch training.

In [1]:
from pathlib import Path
import os, sys, subprocess, shlex
from datetime import datetime

# === User configuration ===
MODEL_TYPE = "TGN"   # "TGN" or "TGAT"
DATA_TYPE = "wikipedia"
DRY_RUN = False      # True -> print command only, False -> actually run it
PYTHON_BIN = "python"  # or an absolute path to your env's python

# Your repository root (as on your machine)
PROJECT_ROOT = Path("/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain").resolve()

# Optional: restrict GPUs (set None to leave unchanged)
CUDA_VISIBLE_DEVICES = None  # e.g., "0" or "0,1"

# ---- TGN config (exactly as requested) ----
TGN_SCRIPT = PROJECT_ROOT / "submodules" / "models" / "tgn" / "train_self_supervised.py"
TGN_ARGS = ["--data", DATA_TYPE, "--use_memory", "--prefix", "tgn-attn", "--n_runs", "10"]

# ---- TGAT config (TEMPLATE — update to your actual script/args) ----
# Common TGAT repos expose a training entry point like train_supervised.py or main.py.
# Adjust TGAT_SCRIPT and TGAT_ARGS to match your setup.
TGAT_SCRIPT = PROJECT_ROOT / "submodules" / "models" / "tgat" / "train_supervised.py"  # <-- change if different
TGAT_ARGS = ["--prefix", "tgat-attn", "--n_runs", "10"]  # <-- add/remove flags as your script expects

# === Derived / utility ===
def build_cmd(python_bin, script_path, extra_args):
    if not script_path.exists():
        raise FileNotFoundError(f"Training script not found: {script_path}")
    return [python_bin, str(script_path), *extra_args]

def run_cmd(cmd, env=None):
    print("$", " ".join(shlex.quote(c) for c in cmd))
    if DRY_RUN:
        print("[DRY_RUN] Skipping execution.")
        return 0
    proc = subprocess.run(cmd, env=env, check=False)
    if proc.returncode != 0:
        print(f"[ERROR] process exited with code {proc.returncode}")
    return proc.returncode

def prepare_env():
    env = os.environ.copy()
    if CUDA_VISIBLE_DEVICES is not None:
        env["CUDA_VISIBLE_DEVICES"] = str(CUDA_VISIBLE_DEVICES)
        print("Set CUDA_VISIBLE_DEVICES =", env["CUDA_VISIBLE_DEVICES"])
    # Ensure the repo root is on PYTHONPATH for intra-repo imports
    env["PYTHONPATH"] = str(PROJECT_ROOT) + os.pathsep + env.get("PYTHONPATH", "")
    return env

print("Configuration loaded.")
print("PROJECT_ROOT:", PROJECT_ROOT)
print("MODEL_TYPE:", MODEL_TYPE)
print("DATA_TYPE:", DATA_TYPE)
print("TGN script:", TGN_SCRIPT)
print("TGAT script:", TGAT_SCRIPT)

Configuration loaded.
PROJECT_ROOT: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain
MODEL_TYPE: TGN
DATA_TYPE: wikipedia
TGN script: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/train_self_supervised.py
TGAT script: /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgat/train_supervised.py


In [2]:
# --- Launch training based on MODEL_TYPE ---
env = prepare_env()

if MODEL_TYPE.upper() == "TGN":
    cmd = build_cmd(PYTHON_BIN, TGN_SCRIPT, TGN_ARGS)
    code = run_cmd(cmd, env=env)
    if code == 0:
        print("[TGN] Training completed (or started successfully).")
    else:
        print("[TGN] Training failed — see output above.")

elif MODEL_TYPE.upper() == "TGAT":
    try:
        cmd = build_cmd(PYTHON_BIN, TGAT_SCRIPT, TGAT_ARGS)
    except FileNotFoundError as e:
        raise FileNotFoundError(
            f"{e}\n\n"
            "TGAT template provided — please set TGAT_SCRIPT to your actual training entry point "
            "and adjust TGAT_ARGS to match your script."
        )
    code = run_cmd(cmd, env=env)
    if code == 0:
        print("[TGAT] Training completed (or started successfully).")
    else:
        print("[TGAT] Training failed — see output above.")

else:
    raise ValueError("MODEL_TYPE must be 'TGN' or 'TGAT'")

$ python /Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/train_self_supervised.py --data wikipedia --use_memory --prefix tgn-attn --n_runs 10


INFO:root:Namespace(data='wikipedia', bs=200, prefix='tgn-attn', n_degree=10, n_head=2, n_epoch=50, n_layer=1, lr=0.0001, patience=5, n_runs=10, drop_out=0.1, gpu=0, node_dim=100, time_dim=100, backprop_every=1, use_memory=True, embedding_module='graph_attention', message_function='identity', memory_updater='gru', aggregator='last', memory_update_at_end=False, message_dim=100, memory_dim=172, different_new_nodes=False, uniform=False, randomize_features=False, use_destination_embedding_in_message=False, use_source_embedding_in_message=False, dyrep=False)


The dataset has 157474 interactions, involving 9227 different nodes
The training dataset has 79202 interactions, involving 5904 different nodes
The validation dataset has 23621 interactions, involving 3256 different nodes
The test dataset has 23621 interactions, involving 3564 different nodes
The new node validation dataset has 11742 interactions, involving 2134 different nodes
The new node test dataset has 11765 interactions, involving 2482 different nodes
922 nodes were used for the inductive testing, i.e. are never seen during training


INFO:root:num of training instances: 79202
INFO:root:num of batches per epoch: 397
INFO:root:start 0 epoch
INFO:root:epoch: 0 took 128.07s
INFO:root:Epoch mean loss: 0.891759562402288
INFO:root:val auc: 0.8724800994207206, new node val auc: 0.8589361676204278
INFO:root:val ap: 0.8652861989403586, new node val ap: 0.8524048537663556
Traceback (most recent call last):
  File "/Users/juliawenkmann/Documents/CodingProjects/master_thesis/time_to_explain/submodules/models/tgn/train_self_supervised.py", line 302, in <module>
    torch.save(tgn.state_dict(), get_checkpoint_path(epoch))
  File "/Users/juliawenkmann/miniconda3/envs/graphs/lib/python3.11/site-packages/torch/serialization.py", line 849, in save
    with _open_zipfile_writer(f) as opened_zipfile:
         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/juliawenkmann/miniconda3/envs/graphs/lib/python3.11/site-packages/torch/serialization.py", line 716, in _open_zipfile_writer
    return container(name_or_buffer)
           ^^^^^^^^^^^^^^^^^^

[ERROR] process exited with code 1
[TGN] Training failed — see output above.
