In [None]:
!pip install torch_geometric torch_geometric_temporal

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
# METRLA dataset loader using torch_geometric_temporal
from torch_geometric_temporal.dataset import METRLADatasetLoader
from torch_geometric.data import TemporalData

def load_metrla_temporal_data():
    loader = METRLADatasetLoader()
    dataset = loader.get_dataset(num_timesteps_in=1, num_timesteps_out=1)
    snapshot = next(iter(dataset))
    edge_index = snapshot.edge_index.clone().detach().long()
    src = edge_index[0]
    dst = edge_index[1]
    t = torch.zeros(src.size(0))
    if hasattr(snapshot, 'edge_weight') and snapshot.edge_weight is not None:
        msg = snapshot.edge_weight.clone().detach().float().view(-1, 1)
    elif hasattr(snapshot, 'edge_attr') and snapshot.edge_attr is not None:
        msg = snapshot.edge_attr.clone().detach().float()
        if msg.dim() == 1:
            msg = msg.view(-1, 1)
    else:
        msg = torch.ones(src.size(0), 1, dtype=torch.float32)
    data = TemporalData(src=src, dst=dst, t=t, msg=msg)
    num_nodes = snapshot.x.shape[0]
    return data, num_nodes


In [None]:
# HTGN++: Hierarchical Temporal Graph Network with Bayesian Embeddings and Learnable Time Kernels
# Dataset: general TemporalData input (e.g., METRLA)

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch_geometric.data import TemporalData, DataLoader

# === Learnable Temporal Encoding ===
class TemporalEncoding(nn.Module):
    def __init__(self, num_kernels=8):
        super().__init__()
        self.freqs = nn.Parameter(torch.randn(num_kernels))
        self.weights = nn.Parameter(torch.randn(num_kernels))

    def forward(self, delta_t):
        delta_t = delta_t.unsqueeze(-1)
        return (self.weights * torch.sin(self.freqs * delta_t)).sum(-1)

# === HTGN++ Model with PyG Integration ===
class HTGN(nn.Module):
    def __init__(self, num_nodes, node_dim=32, msg_dim=64, embed_dim=16):
        super().__init__()
        self.memory_short = nn.Parameter(torch.zeros(num_nodes, node_dim))
        self.memory_long = nn.Parameter(torch.zeros(num_nodes, node_dim))

        self.temporal_enc = TemporalEncoding()
        self.msg_net = nn.Sequential(
            nn.Linear(4 + 1, msg_dim), nn.ReLU(), nn.Linear(msg_dim, node_dim)
        )
        self.update_gru = nn.GRUCell(node_dim, node_dim)

        self.mu_net = nn.Linear(node_dim * 2, embed_dim)
        self.logvar_net = nn.Linear(node_dim * 2, embed_dim)
        self.pred_net = nn.Sequential(nn.Linear(embed_dim * 2, 32), nn.ReLU(), nn.Linear(32, 1))

    def forward(self, data):
        losses, mus, logvars = [], [], []
        for i in range(data.t.size(0)):
            t = data.t[i].unsqueeze(0)
            src, tgt = data.src[i], data.dst[i]
            feat = data.msg[i]

            time_embed = self.temporal_enc(t)
            msg = self.msg_net(torch.cat([feat, time_embed]))
            with torch.no_grad():
              self.memory_short[src] = self.update_gru(msg.unsqueeze(0), self.memory_short[src].unsqueeze(0)).squeeze(0)
              self.memory_long[src] = 0.99 * self.memory_long[src] + 0.01 * self.memory_short[src]

            m_src = torch.cat([self.memory_short[src], self.memory_long[src]])
            mu = self.mu_net(m_src)
            logvar = self.logvar_net(m_src)
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z_src = mu + eps * std

            m_tgt = torch.cat([self.memory_short[tgt], self.memory_long[tgt]])
            mu_tgt = self.mu_net(m_tgt)
            logvar_tgt = self.logvar_net(m_tgt)
            std_tgt = torch.exp(0.5 * logvar_tgt)
            eps_tgt = torch.randn_like(std_tgt)
            z_tgt = mu_tgt + eps_tgt * std_tgt

            pred = self.pred_net(torch.cat([z_src, z_tgt]))
            label = torch.tensor([1.0])  # or real label if available

            loss_recon = F.binary_cross_entropy_with_logits(pred, label)
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss = loss_recon + 1e-3 * kl_loss

            losses.append(loss)
            mus.append(mu.detach().numpy())
            logvars.append(logvar.detach().numpy())

        return torch.stack(losses).mean(), np.array(mus), np.array(logvars)

# === Training Setup ===
def train_htgn(data, num_nodes):
    model = HTGN(num_nodes)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    losses = []
    for epoch in range(10):
        loss, mus, logvars = model(data)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

    # Visualization
    plt.plot(losses)
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

    tsne = TSNE(n_components=2)
    embeddings_2d = tsne.fit_transform(mus)
    plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.6)
    plt.title("t-SNE of Node Embeddings")
    plt.show()

    plt.hist(np.exp(logvars).flatten(), bins=30)
    plt.title("Bayesian Embedding Variance")
    plt.xlabel("Variance")
    plt.ylabel("Count")
    plt.show()




In [None]:
# TGN baseline model for temporal link prediction

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, f1_score, roc_curve
from torch_geometric.nn import TGNMemory
from torch_geometric.data import TemporalData
from torch_geometric.nn.models.tgn import LastAggregator
from torch_geometric.utils import scatter_argmax

# === Message Module Wrapper with correct TGNMemory input signature ===
class MessageNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim)
        )
        self.out_channels = output_dim

    def forward(self, src, dst, t, raw_msg):
        if raw_msg.size(0) == 0:
            return raw_msg.new_zeros((0, self.out_channels))

        # Sanitize time encoding shape
        if t.dim() == 1:
            t = t.unsqueeze(-1)
        elif t.dim() == 2 and t.size(1) != 1:
            t = t.mean(dim=1, keepdim=True)  # reduce if too wide

        # Truncate or expand to match raw_msg batch size
        t = t[:raw_msg.size(0)]

        if t.size(0) != raw_msg.size(0):
            pad_size = raw_msg.size(0) - t.size(0)
            t = torch.cat([t, t.new_zeros(pad_size, 1)], dim=0)

        x = torch.cat([raw_msg, t], dim=-1)
        return self.net(x)

# === TGN Model ===

class SafeLastAggregator(nn.Module):
    def forward(self, msg, index, t, dim_size):
        argmax = scatter_argmax(t, index, dim=0, dim_size=dim_size)
        out = msg.new_zeros((dim_size, msg.size(-1)))
        mask = argmax < msg.size(0)
        out[mask] = msg[argmax[mask]]
        return out

class TGNLinkPredictor(nn.Module):
    def __init__(self, num_nodes, msg_dim, emb_dim=32):
        super().__init__()
        self.memory = TGNMemory(
            num_nodes=num_nodes,
            raw_msg_dim=msg_dim,
            memory_dim=emb_dim,
            time_dim=1,
            message_module=MessageNet(msg_dim + 1, emb_dim),
            aggregator_module=SafeLastAggregator())

        self.edge_predictor = nn.Sequential(
            nn.Linear(2 * emb_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, 1)
        )

    def forward(self, src, dst):
        z_src = self.memory.memory[src]
        z_dst = self.memory.memory[dst]
        return self.edge_predictor(torch.cat([z_src, z_dst], dim=-1))

    def update_memory(self, src, dst, t, msg):
        self.memory.update_state(src, dst, t, msg)

# === Training and Evaluation ===
def train_tgn(data, num_nodes, epochs=10):
    model = TGNLinkPredictor(num_nodes=num_nodes, msg_dim=data.msg.size(-1))
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)

    y_true, y_score = [], []

    for epoch in range(epochs):
        model.memory.reset_state()
        loss_total = 0

        for i in range(data.t.size(0)):
            src, dst = data.src[i], data.dst[i]
            t = data.t[i]
            msg = data.msg[i].unsqueeze(0)

            with torch.no_grad():
                model.update_memory(src.unsqueeze(0), dst.unsqueeze(0), t.unsqueeze(0), msg)

            pred = model(src.unsqueeze(0), dst.unsqueeze(0))
            label = torch.tensor([1.0])
            loss = F.binary_cross_entropy_with_logits(pred.view(-1), label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_total += loss.item()
            y_true.append(label.item())
            y_score.append(torch.sigmoid(pred).item())

        print(f"Epoch {epoch+1}, Loss: {loss_total:.4f}")

    # Evaluation
    auc = roc_auc_score(y_true, y_score)
    y_pred_bin = [1 if p > 0.5 else 0 for p in y_score]
    f1 = f1_score(y_true, y_pred_bin)
    fpr, tpr, _ = roc_curve(y_true, y_score)

    print(f"TGN ROC-AUC: {auc:.4f}, F1-score: {f1:.4f}")

    return fpr, tpr, auc, f1

# Example:


In [None]:
# TGAT baseline for temporal graph data
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, f1_score, roc_curve
from torch_geometric.data import TemporalData

# === TGAT Time Encoding ===
class Time2Vec(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(1, dim)
        self.freqs = nn.Parameter(torch.randn(dim))

    def forward(self, t):
        # t: (batch, 1)
        return torch.cat([torch.sin(self.freqs * t), self.linear(t)], dim=-1)

# === TGAT Simplified Architecture ===
class TGATLinkPredictor(nn.Module):
    def __init__(self, num_nodes, feat_dim, time_dim=8, embed_dim=32):
        super().__init__()
        self.node_embed = nn.Embedding(num_nodes, embed_dim)
        self.time_enc = Time2Vec(time_dim)
        self.fc_msg = nn.Linear(feat_dim + time_dim * 2, embed_dim)

        self.pred_net = nn.Sequential(
            nn.Linear(embed_dim * 2, embed_dim), nn.ReLU(), nn.Linear(embed_dim, 1)
        )

    def forward(self, src, dst, t, msg):
        t_input = t.view(1, 1)
        t_feat = self.time_enc(t_input).squeeze(0)  # (dim,)
        feat_src = self.fc_msg(torch.cat([msg, t_feat], dim=-1)) + self.node_embed(src)
        feat_dst = self.fc_msg(torch.cat([msg, t_feat], dim=-1)) + self.node_embed(dst)
        return self.pred_net(torch.cat([feat_src, feat_dst], dim=-1))

# === Training and Evaluation ===
def train_tgat(data, num_nodes, epochs=10):
    model = TGATLinkPredictor(num_nodes=num_nodes, feat_dim=data.msg.size(-1))
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)

    y_true, y_score = [], []

    for epoch in range(epochs):
        total_loss = 0

        for i in range(data.t.size(0)):
            src, dst = data.src[i], data.dst[i]
            t = data.t[i].unsqueeze(0)
            msg = data.msg[i]

            pred = model(src, dst, t, msg)
            label = torch.tensor([1.0])
            loss = F.binary_cross_entropy_with_logits(pred.view(-1), label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            y_true.append(label.item())
            y_score.append(torch.sigmoid(pred).item())

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

    # Evaluation
    auc = roc_auc_score(y_true, y_score)
    y_pred_bin = [1 if p > 0.5 else 0 for p in y_score]
    f1 = f1_score(y_true, y_pred_bin)
    fpr, tpr, _ = roc_curve(y_true, y_score)

    print(f"TGAT ROC-AUC: {auc:.4f}, F1-score: {f1:.4f}")

    return fpr, tpr, auc, f1

# Example:


In [None]:
# Train models on METRLA dataset
def train_tgn_on_metrla(data, num_nodes, epochs=10):
    model = TGNLinkPredictor(num_nodes=num_nodes, msg_dim=data.msg.size(-1))
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
    y_true, y_score = [], []
    for epoch in range(epochs):
        model.memory.reset_state()
        loss_total = 0
        for i in range(data.t.size(0)):
            src, dst = data.src[i], data.dst[i]
            t = data.t[i]
            msg = data.msg[i].unsqueeze(0)
            with torch.no_grad():
                model.update_memory(src.unsqueeze(0), dst.unsqueeze(0), t.unsqueeze(0), msg)
            pred = model(src.unsqueeze(0), dst.unsqueeze(0))
            label = torch.tensor([1.0])
            loss = F.binary_cross_entropy_with_logits(pred.view(-1), label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_total += loss.item()
            y_true.append(label.item())
            y_score.append(torch.sigmoid(pred).item())
        print(f"Epoch {epoch+1}, Loss: {loss_total:.4f}")
    auc = roc_auc_score(y_true, y_score)
    y_pred_bin = [1 if p > 0.5 else 0 for p in y_score]
    f1 = f1_score(y_true, y_pred_bin)
    fpr, tpr, _ = roc_curve(y_true, y_score)
    print(f"TGN ROC-AUC: {auc:.4f}, F1-score: {f1:.4f}")
    return fpr, tpr, auc, f1

def train_tgat_on_metrla(data, num_nodes, epochs=10):
    model = TGATLinkPredictor(num_nodes=num_nodes, feat_dim=data.msg.size(-1))
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
    y_true, y_score = [], []
    for epoch in range(epochs):
        total_loss = 0
        for i in range(data.t.size(0)):
            src, dst = data.src[i], data.dst[i]
            t = data.t[i].unsqueeze(0)
            msg = data.msg[i]
            pred = model(src, dst, t, msg)
            label = torch.tensor([1.0])
            loss = F.binary_cross_entropy_with_logits(pred.view(-1), label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            y_true.append(label.item())
            y_score.append(torch.sigmoid(pred).item())
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
    auc = roc_auc_score(y_true, y_score)
    y_pred_bin = [1 if p > 0.5 else 0 for p in y_score]
    f1 = f1_score(y_true, y_pred_bin)
    fpr, tpr, _ = roc_curve(y_true, y_score)
    print(f"TGAT ROC-AUC: {auc:.4f}, F1-score: {f1:.4f}")
    return fpr, tpr, auc, f1

def train_htgn_on_metrla(data, num_nodes):
    return train_htgn(data, num_nodes)

def run_metrla_experiment():
    data, num_nodes = load_metrla_temporal_data()
    fpr_tgn, tpr_tgn, auc_tgn, _ = train_tgn_on_metrla(data, num_nodes)
    fpr_tgat, tpr_tgat, auc_tgat, _ = train_tgat_on_metrla(data, num_nodes)
    fpr_htgn, tpr_htgn, auc_htgn, _ = train_htgn_on_metrla(data, num_nodes)
    plt.figure(figsize=(8,6))
    plt.plot(fpr_tgn, tpr_tgn, label=f'TGN (AUC = {auc_tgn:.3f})')
    plt.plot(fpr_tgat, tpr_tgat, label=f'TGAT (AUC = {auc_tgat:.3f})')
    plt.plot(fpr_htgn, tpr_htgn, label=f'HTGN (AUC = {auc_htgn:.3f})')
    plt.plot([0,1],[0,1],'k--',lw=1)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve Comparison on METRLA')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

run_metrla_experiment()
