In [82]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [102]:
import torch
import pandas as pd
from PopSynthesis.Methods.GNN_activity.model import convert_to_temporal_data, GraphAttentionEmbedding, LinkPredictor
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import IdentityMessage, LastAggregator, LastNeighborLoader
from torch_geometric.loader import TemporalDataLoader

In [106]:
# Load the starting and expected graphs
data = torch.load("data/train_graph.pt")

In [107]:
# Convert PyG HeteroData to TemporalData
train_data = convert_to_temporal_data(data)

In [108]:
# === Model Configuration ===
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

memory_dim = time_dim = embedding_dim = 100

memory = TGNMemory(
    num_nodes=train_data.num_nodes,
    raw_msg_dim=train_data.msg.size(-1),
    memory_dim=memory_dim,
    time_dim=time_dim,
    message_module=IdentityMessage(train_data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    msg_dim=train_data.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

# === Optimizer & Loss ===
optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters()) | set(link_pred.parameters()), lr=0.0001)
criterion_edge = torch.nn.BCEWithLogitsLoss()
criterion_joint = torch.nn.BCEWithLogitsLoss()

In [109]:
# === DataLoader ===
train_loader = TemporalDataLoader(
    train_data,
    batch_size=200,
    neg_sampling_ratio=1.0,
)

neighbor_loader = LastNeighborLoader(train_data.num_nodes, size=10, device=device)

# === Training State Variables ===
assoc = torch.empty(train_data.num_nodes, dtype=torch.long, device=device)

In [116]:
def train():
    memory.train()
    gnn.train()
    link_pred.train()

    memory.reset_state()  # Fresh memory
    neighbor_loader.reset_state()  # Empty graph at start

    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(device)

        # Neighborhood sampling
        n_id, edge_index, e_id = neighbor_loader(batch.n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Update memory
        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, train_data.t[e_id].to(device), train_data.msg[e_id].to(device))

        # Positive & Negative samples
        pos_edge_pred, pos_joint_pred = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
        neg_edge_pred, neg_joint_pred = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])

        # Labels
        edge_labels = torch.cat([torch.ones_like(pos_edge_pred), torch.zeros_like(neg_edge_pred)])
        edge_preds = torch.cat([pos_edge_pred, neg_edge_pred])

        # Extract `joint_activity` from graph data
        joint_labels = batch.msg[:, 0].view(-1, 1) # Joint is first
        # Loss Calculation
        loss_edge = criterion_edge(edge_preds, edge_labels)
        loss_joint = criterion_joint(pos_joint_pred, joint_labels)  # Apply only on positive edges

        loss = loss_edge + loss_joint  # Total loss

        # Memory Update
        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

        # Backpropagation
        loss.backward()
        optimizer.step()
        memory.detach()
        total_loss += float(loss) * batch.num_events

    return total_loss / train_data.num_events


In [118]:
# === Run Training ===
for epoch in range(1, 101):
    loss = train()
    print(f"Epoch: {epoch:02d}, Loss: {loss:.4f}")

Epoch: 01, Loss: 1.1940
Epoch: 02, Loss: 1.1913
Epoch: 03, Loss: 1.1885
Epoch: 04, Loss: 1.1858
Epoch: 05, Loss: 1.1830
Epoch: 06, Loss: 1.1801
Epoch: 07, Loss: 1.1773
Epoch: 08, Loss: 1.1744
Epoch: 09, Loss: 1.1716
Epoch: 10, Loss: 1.1688
Epoch: 11, Loss: 1.1659
Epoch: 12, Loss: 1.1631
Epoch: 13, Loss: 1.1602
Epoch: 14, Loss: 1.1573
Epoch: 15, Loss: 1.1544
Epoch: 16, Loss: 1.1515
Epoch: 17, Loss: 1.1487
Epoch: 18, Loss: 1.1458
Epoch: 19, Loss: 1.1429
Epoch: 20, Loss: 1.1399
Epoch: 21, Loss: 1.1370
Epoch: 22, Loss: 1.1341
Epoch: 23, Loss: 1.1311
Epoch: 24, Loss: 1.1281
Epoch: 25, Loss: 1.1251
Epoch: 26, Loss: 1.1220
Epoch: 27, Loss: 1.1190
Epoch: 28, Loss: 1.1158
Epoch: 29, Loss: 1.1127
Epoch: 30, Loss: 1.1095
Epoch: 31, Loss: 1.1063
Epoch: 32, Loss: 1.1031
Epoch: 33, Loss: 1.0998
Epoch: 34, Loss: 1.0965
Epoch: 35, Loss: 1.0932
Epoch: 36, Loss: 1.0899
Epoch: 37, Loss: 1.0865
Epoch: 38, Loss: 1.0831
Epoch: 39, Loss: 1.0797
Epoch: 40, Loss: 1.0762
Epoch: 41, Loss: 1.0727
Epoch: 42, Loss: