In [31]:
from pathlib import Path
import sys

project_root = Path.cwd()
while project_root != project_root.parent and not (project_root / "code_lib").exists():
    project_root = project_root.parent

if not (project_root / "code_lib").exists():
    raise RuntimeError(f"Unable to locate 'code_lib' starting from {Path.cwd()}")

if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

print(f"Using project root: {project_root}")


Using project root: c:\Users\luket\Documents\Fork\graph_ml


In [76]:
from pathlib import Path
import sys
import optuna

import os, glob, re
import pandas as pd
import numpy as np

from pathlib import Path
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score


import torchmetrics.functional as tmf

from tqdm import tqdm
from pathlib import Path
from torch_geometric.data import Data
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn.models.tgn import LastNeighborLoader

from code_lib.temporal_node_classification_builder import (
    TemporalNodeClassificationBuilder, load_elliptic_data
)

from torch_geometric.nn.models.tgn import (
    TGNMemory,
    LastNeighborLoader,
    IdentityMessage,
    MeanAggregator,
)

from code_lib.temporal_edge_builder import TemporalEdgeBuilder

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)



cuda


In [57]:
from pathlib import Path
import sys
import torch

# 2. Library imports
from code_lib.temporal_node_classification_builder import (
    TemporalNodeClassificationBuilder,
    load_elliptic_data,
)
from code_lib.temporal_edge_builder import TemporalEdgeBuilder  # optional (see below)

# 3. Load nodes + raw edges
DATA_DIR = project_root / "elliptic_dataset"
nodes_df, raw_edges = load_elliptic_data(str(DATA_DIR), use_temporal_features=True)

# --- Option A: use raw edges (fastest, no extra weighting) --------------------
edges_df = raw_edges.copy()
use_edge_weights = False
edge_weight_col = None

# --- Option B (optional): build decayed edge weights for every timestep -------
# Uncomment if you really need weighted edges; this can take several minutes.
# edge_builder = TemporalEdgeBuilder(raw_edges, decay_lambda=0.1, verbose=True)
# edges_df = edge_builder.build_temporal_edge_sequence(
#     start_timestep=int(raw_edges["Time step"].min()),
#     end_timestep=int(raw_edges["Time step"].max()),
# )
# use_edge_weights = True
# edge_weight_col = "temporal_weight"

# 4. Instantiate the node builder with the full edge history
builder = TemporalNodeClassificationBuilder(
    nodes_df=nodes_df,
    edges_df=edges_df,
    include_class_as_feature=False,
    add_temporal_features=True,
    add_edge_weights=use_edge_weights,
    edge_weight_col=edge_weight_col,
    verbose=False,
)

# 5. Build snapshots + event stream
snapshots = builder.build_snapshot_sequence(return_node_metadata=True)
event_stream = builder.build_event_stream(dense=False, include_edge_attr=False)

print(f"Snapshots built: {len(snapshots)} graphs")
print(f"Event stream events total: {event_stream.src.numel()}")

# 6. Temporal splits that now span all timesteps
splits = builder.get_event_stream_split(
    train_timesteps=(5, 26),
    val_timesteps=(27, 31),
    test_timesteps=(32, 40),
    dense=False,
    include_edge_attr=False,
)

for name, data in splits.items():
    count = int(data.t.numel())
    min_t = int(data.t.min()) if count else "NA"
    max_t = int(data.t.max()) if count else "NA"
    print(f"{name.capitalize():5s} -> events: {count:7d}, t-range: [{min_t}, {max_t}]")


Snapshots built: 49 graphs
Event stream events total: 4170754
Train -> events: 1547601, t-range: [5, 26]
Val   -> events:  121047, t-range: [27, 31]
Test  -> events:  765003, t-range: [32, 40]


In [86]:
print(torch.unique(splits["train"].y))

tensor([0., 1., 2.])


In [59]:
def make_loader(data, batch_size, shuffle):
    return TemporalDataLoader(data, batch_size=batch_size, shuffle=shuffle)

In [61]:
class MLPMessage(nn.Module):
    def __init__(self, raw_msg_dim, hidden_dim, out_dim, dropout, memory_dim, time_dim):
        super().__init__()
        in_dim = 2 * memory_dim + raw_msg_dim + time_dim
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )
        self.out_channels = out_dim

    def forward(self, z_src, z_dst, raw_msg, t_enc):
        h = torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1)
        return self.net(h)


class LinkPredictor(nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout):
        super().__init__()
        self.lin1 = nn.Linear(in_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, z_src, z_dst):
        h = torch.cat([z_src, z_dst], dim=-1)
        h = self.dropout(F.relu(self.lin1(h)))
        return self.lin2(h).view(-1)

In [62]:
from dataclasses import dataclass

@dataclass
class TrialConfig:
    memory_dim: int
    time_dim: int
    msg_hidden: int
    msg_out: int
    decoder_hidden: int
    neighbor_size: int
    batch_size: int
    lr: float
    weight_decay: float
    dropout: float
    epochs: int

In [63]:
def build_modules(config: TrialConfig, raw_msg_dim: int, num_nodes: int):
    message_module = MLPMessage(
        raw_msg_dim=raw_msg_dim,
        hidden_dim=config.msg_hidden,
        out_dim=config.msg_out,
        dropout=config.dropout,
        memory_dim=config.memory_dim,
        time_dim=config.time_dim,
    ).to(DEVICE)


    memory = TGNMemory(
        num_nodes=num_nodes,
        raw_msg_dim=raw_msg_dim,  # event_stream.msg.size(-1)
        memory_dim=config.memory_dim,
        time_dim=config.time_dim,
        message_module=message_module,
        aggregator_module=MeanAggregator(),
    ).to(DEVICE)
    
    neighbor_loader = LastNeighborLoader(
        num_nodes=num_nodes,
        size=config.neighbor_size,
        device=DEVICE,
    )

    decoder = LinkPredictor(
        in_dim=config.memory_dim * 2,
        hidden_dim=config.decoder_hidden,
        dropout=config.dropout,
    ).to(DEVICE)

    params = list(memory.parameters()) + list(decoder.parameters())
    optimizer = torch.optim.Adam(params, lr=config.lr, weight_decay=config.weight_decay)

    return memory, neighbor_loader, decoder, optimizer

In [80]:
from sklearn.metrics import average_precision_score, f1_score

criterion = nn.BCEWithLogitsLoss()

def _gather_embeddings(memory, batch, assoc_buffer):
    n_id = torch.cat([batch.src, batch.dst]).unique()
    z_mem, _ = memory(n_id)  # shape [|n_id|, memory_dim]

    assoc_buffer[n_id] = torch.arange(n_id.size(0), device=assoc_buffer.device)
    src_idx = assoc_buffer[batch.src]
    dst_idx = assoc_buffer[batch.dst]
    assoc_buffer[n_id] = -1  # reset mapping for next batch

    z_src = z_mem[src_idx]
    z_dst = z_mem[dst_idx]
    return z_src, z_dst

def run_epoch(loader, memory, neighbor_loader, decoder, optimizer=None):
    is_train = optimizer is not None
    memory.train(is_train)
    decoder.train(is_train)

    memory.reset_state()
    neighbor_loader.reset_state()
    assoc_buffer = torch.full(
        (memory.num_nodes,), -1, device=DEVICE, dtype=torch.long
    )

    total_loss = 0.0
    total_events = 0

    for batch in loader:
        batch = batch.to(DEVICE)

        if is_train:
            optimizer.zero_grad()

        z_src, z_dst = _gather_embeddings(memory, batch, assoc_buffer)
        logits = decoder(z_src, z_dst)
        loss = criterion(logits, batch.y.float())

        if is_train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(memory.parameters(), 1.0)
            optimizer.step()
            memory.detach()

        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

        num_events = batch.src.size(0)
        total_loss += float(loss.item()) * num_events
        total_events += num_events

    return total_loss / max(total_events, 1)

@torch.no_grad()
def evaluate(loader, memory, neighbor_loader, decoder):
    memory.eval()
    decoder.eval()

    memory.reset_state()
    neighbor_loader.reset_state()
    assoc_buffer = torch.full(
        (memory.num_nodes,), -1, device=DEVICE, dtype=torch.long
    )

    preds, targets = [], []

    for batch in loader:
        batch = batch.to(DEVICE)

        z_src, z_dst = _gather_embeddings(memory, batch, assoc_buffer)
        logits = decoder(z_src, z_dst)
        preds.append(torch.sigmoid(logits).cpu())
        targets.append(batch.y.float().cpu())

        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)
        memory.detach()

    if not preds:
        return 0.0, 0.0

    y_score = torch.cat(preds)
    y_true = torch.cat(targets).long()

    ap = tmf.average_precision(y_score, y_true, task="binary")
    f1 = tmf.f1_score((y_score > 0.5).long(), y_true, task="binary")
    return ap, f1


In [81]:
raw_msg_dim = event_stream.msg.size(-1)
num_nodes = builder.nodes_df["address"].nunique()

def suggest_config(trial) -> TrialConfig:
    return TrialConfig(
        memory_dim=trial.suggest_categorical("memory_dim", [32, 64, 128]),
        time_dim=trial.suggest_categorical("time_dim", [4, 8, 16]),
        msg_hidden=trial.suggest_categorical("msg_hidden", [32, 64, 128]),
        msg_out=trial.suggest_categorical("msg_out", [16, 32, 64]),
        decoder_hidden=trial.suggest_categorical("decoder_hidden", [32, 64, 128]),
        neighbor_size=trial.suggest_categorical("neighbor_size", [10, 20, 50]),
        batch_size=trial.suggest_categorical("batch_size", [2048, 4096, 8192]),
        lr=trial.suggest_float("lr", 1e-4, 5e-3, log=True),
        weight_decay=trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True),
        dropout=trial.suggest_float("dropout", 0.0, 0.4),
        epochs=trial.suggest_int("epochs", 5, 15),
    )

train_loader_cache: dict[int, TemporalDataLoader] = {}
val_loader_cache: dict[int, TemporalDataLoader] = {}

def get_loaders(batch_size):
    if batch_size not in train_loader_cache:
        train_loader_cache[batch_size] = make_loader(splits["train"], batch_size, shuffle=True)
        val_loader_cache[batch_size] = make_loader(splits["val"], batch_size, shuffle=False)
    return train_loader_cache[batch_size], val_loader_cache[batch_size]

def objective(trial):
    config = suggest_config(trial)
    train_loader, val_loader = get_loaders(config.batch_size)

    memory, neighbor_loader, decoder, optimizer = build_modules(
        config, raw_msg_dim, num_nodes
    )

    best_ap = 0.0
    for epoch in range(config.epochs):
        run_epoch(train_loader, memory, neighbor_loader, decoder, optimizer)
        val_ap, _ = evaluate(val_loader, memory, neighbor_loader, decoder)
        trial.report(val_ap, step=epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()
        best_ap = max(best_ap, val_ap)

    return best_ap

In [82]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)  # adjust trial count as needed

print("Best trial:", study.best_trial.value)
print("Best params:", study.best_trial.params)

[I 2025-11-08 10:53:04,399] A new study created in memory with name: no-name-40bb3244-e757-47ca-afaa-82c073713ff4
[W 2025-11-08 10:55:22,716] Trial 0 failed with parameters: {'memory_dim': 32, 'time_dim': 8, 'msg_hidden': 32, 'msg_out': 16, 'decoder_hidden': 64, 'neighbor_size': 10, 'batch_size': 2048, 'lr': 0.001642281355380936, 'weight_decay': 0.00036853802741068496, 'dropout': 0.041180774022345594, 'epochs': 6} because of the following error: RuntimeError('Detected the following values in `target`: tensor([0, 1, 2]) but expected only the following values [0, 1].').
Traceback (most recent call last):
  File "c:\Users\luket\Documents\Fork\graph_ml\.venv\Lib\site-packages\optuna\study\_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "C:\Users\luket\AppData\Local\Temp\ipykernel_29292\253380575.py", line 39, in objective
    val_ap, _ = evaluate(val_loader, memory, neighbor_loader, decoder)
                ^^^^^^^^^^^^^^^^^

RuntimeError: Detected the following values in `target`: tensor([0, 1, 2]) but expected only the following values [0, 1].