In [None]:
import sys
sys.path.append('./..')  

In [None]:
from torch_geometric.data import Data
import torch

def map_nodes_to_ids(g):
    """
    Maps tuple-based nodes in a NetworkX graph to integer node IDs starting from 1.

    Returns:
        mapped_edges (list of [int, int]): edges with integer node IDs starting at 1
        node_to_id (dict): original node -> integer ID
        num_nodes (int): number of unique nodes
    """
    node_list = list(g.nodes)
    node_to_id = {node: idx + 1 for idx, node in enumerate(node_list)}
    
    mapped_edges = []
    for u, v in g.edges():
        uid = node_to_id[u]
        vid = node_to_id[v]
        mapped_edges.append([uid, vid])
    
    return mapped_edges, node_to_id, len(node_list)

def convert_nx_to_pyg(graphs_nx):
    graphs_pyg = []
    
    for g in graphs_nx:
        node_features = [
            [data['pitch'], data['timestamp']]
            for _, data in g.nodes(data=True)
        ]
        
        x = torch.zeros((len(node_features) + 1, 2))  
        x[1:] = torch.tensor(node_features, dtype=torch.float)  
        
        edge_index = torch.tensor(list(g.edges())).long().t() + 1
        edge_attr = torch.ones(edge_index.shape[1], 1)

        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            y=torch.tensor([g.graph.get('label', -1)])
        )

        if 'pi' in g.graph:
            data.pi = torch.tensor(g.graph['pi'], dtype=torch.float)

        graphs_pyg.append(data)
    
    return graphs_pyg

In [None]:
import pickle
import os

filepath = "../collection/graphs/networkx/concrete/graphs_concrete.pkl"


with open(filepath, 'rb') as f:
    graphs_concrete = pickle.load(f)
print(f"Loaded {len(graphs_concrete)} graphs from {filepath}")


filepath= '../collection/graphs/networkx/rocky/graphs_rocky.pkl'
filename = "graphs_rocky.pkl"

with open(filepath, 'rb') as f:
    graphs_rocky = pickle.load(f)
print(f"Loaded {len(graphs_rocky)} graphs from {filepath}")

## TDA code

In [None]:
from gudhi.representations import PersistenceImage
import numpy as np
import gudhi as gd
import networkx as nx
import torch


def compute_average_filtration(G):
    """
    Compute average filtration values for edges of a temporal graph using node timestamps.

    Input:
    - G: networkx.Graph
         Each NODE has a 'timestamp' attribute.
    
    Output:
    - favg: dict mapping edge tuples (u,v) to average filtration value (float)
    """

    def normalize_edge(e):
        return tuple(sorted(e))

    Se = {normalize_edge(e): 0 for e in G.edges()}
    visited = {v: False for v in G.nodes()}
    favg = {}

    temporal_degree = {v: G.degree(v) for v in G.nodes()}

    for v in G.nodes():
        Ev = list(G.edges(v))
        stack = Ev.copy()

        while stack:
            e = stack.pop()
            e_norm = normalize_edge(e)
            u1, u2 = e_norm
            t = min(G.nodes[u1]['timestamp'], G.nodes[u2]['timestamp'])

            for e_prime in stack:
                e_prime_norm = normalize_edge(e_prime)
                v1, v2 = e_prime_norm
                t_prime = min(G.nodes[v1]['timestamp'], G.nodes[v2]['timestamp'])

                delta = abs(t - t_prime)

                Se[e_norm] += delta
                Se[e_prime_norm] += delta

            u = e[1] if e[0] == v else e[0]
            if visited[u]:
                favg[e_norm] = Se[e_norm] / (temporal_degree[v] + temporal_degree[u])

        visited[v] = True

    for e in G.edges():
        e_norm = normalize_edge(e)
        if e_norm not in favg:
            u, v = e_norm
            favg[e_norm] = Se[e_norm] / (temporal_degree[u] + temporal_degree[v])

    return favg



def build_simplex_tree_from_graph(G, favg, max_dim=3):
    """
    Build a GUDHI simplex tree from a graph G using average filtration values for edges.

    Parameters:
    - G: networkx.Graph
    - favg: dict mapping edges (u,v) to filtration values (floats)
    - max_dim: int, max simplex dimension (3 for up to tetrahedra)
    
    Returns:
    - st: gudhi.SimplexTree object with filtration
    """
    st = gd.SimplexTree()
    node_to_id = {node: i for i, node in enumerate(G.nodes())}
    
    for node, idx in node_to_id.items():
        incident_edges = list(G.edges(node))
        if incident_edges:
            filts = []
            for u,v in incident_edges:
                edge_norm = (min(u,v), max(u,v))
                filt_val = favg.get(edge_norm, float('inf'))
                filts.append(filt_val)
            vertex_filt = min(filts) if filts else 0.0
            if vertex_filt == float('inf'):
                vertex_filt = 0.0
        else:
            vertex_filt = 0.0  
            
        st.insert([idx], filtration=vertex_filt)
    
    for u, v in G.edges():
        edge_norm = (min(u, v), max(u, v))
        idx_u, idx_v = node_to_id[u], node_to_id[v]
        filt = favg.get(edge_norm, 0.0)
        st.insert([idx_u, idx_v], filtration=filt)
    
    if max_dim >= 2:
        for clique in nx.enumerate_all_cliques(G):
            if len(clique) >= 3 and len(clique) <= max_dim + 1:
                simplex = [node_to_id[n] for n in clique]
                max_filt = 0.0
                for i in range(len(simplex)):
                    for j in range(i + 1, len(simplex)):
                        edge_nodes = (clique[i], clique[j])
                        edge_norm = (min(edge_nodes), max(edge_nodes))
                        edge_filt = favg.get(edge_norm, 0.0)
                        if edge_filt > max_filt:
                            max_filt = edge_filt
                st.insert(simplex, filtration=max_filt)
    
    st.initialize_filtration()
    return st

def compute_persistence_images_for_graphs(
    graphs,
    max_simplex_dim=3,
    pi_resolution=[20, 20],
    pi_bandwidth=0.1,
    pi_im_range=[0, 1, 0, 1]
):
    """
    Returns:
    - diagrams: list of persistence diagrams (numpy arrays of [birth, persistence])
    - vectors: list of persistence image flattened vectors
    - pi_transformer: fitted PersistenceImage
    """
    diagrams = []

    for G in graphs:
        favg = compute_average_filtration(G)
        st = build_simplex_tree_from_graph(G, favg, max_dim=max_simplex_dim)
        diag = st.persistence()

        points = []
        for dim, (birth, death) in diag:
            if death != float('inf'):
                points.append([birth, death - birth])

        diag_array = np.array(points).reshape(-1, 2)  # ensures shape (0, 2) if empty
        diagrams.append(diag_array)

    # Filter out diagrams that are empty or not 2D
    valid_diagrams = []
    valid_indices = []
    for i, dgm in enumerate(diagrams):
        if dgm.ndim == 2 and dgm.shape[1] == 2 and len(dgm) > 0:
            valid_diagrams.append(dgm)
            valid_indices.append(i)

    pi = PersistenceImage(
        bandwidth=pi_bandwidth,
        weight=lambda x: x[1],
        resolution=pi_resolution,
        im_range=pi_im_range
    )

    # Handle case with no valid diagrams
    if valid_diagrams:
        pi.fit(valid_diagrams)
        vec_dim = np.prod(pi_resolution)
        vectors = []
        for i, dgm in enumerate(diagrams):
            if i in valid_indices:
                vec = pi.transform([dgm])[0]
            else:
                vec = np.zeros(vec_dim)
            vectors.append(vec)
    else:
        vec_dim = np.prod(pi_resolution)
        vectors = [np.zeros(vec_dim) for _ in diagrams]

    return diagrams, vectors, pi


def add_graph_labels(graphs, label):
    """
    Adds a graph-level label as data.y for each graph in the list.

    Args:
        graphs (list of PyG Data): list of graphs
        label (int): graph-level label

    Returns:
        list of PyG Data with data.y set to label
    """
    for g in graphs:
        g.y = torch.tensor([label], dtype=torch.long)  
    return graphs


def remap_node_indices(data_list, preserve_original_ids=False):
    new_data_list = []

    for data in data_list:
        edge_index = data.edge_index
        all_nodes = torch.unique(edge_index)

        id_map = {old.item(): new for new, old in enumerate(all_nodes)}

        new_edge_index = torch.stack([
            torch.tensor([id_map[i.item()] for i in edge_index[0]]),
            torch.tensor([id_map[i.item()] for i in edge_index[1]])
        ], dim=0)

        if hasattr(data, 'x') and data.x is not None:
            x_dim = data.x.size(1)
            max_index = int(edge_index.max().item()) + 1
            if data.x.size(0) < max_index:
                padding = torch.zeros((max_index - data.x.size(0), x_dim), dtype=data.x.dtype)
                data.x = torch.cat([data.x, padding], dim=0)

            x = torch.stack([data.x[old] for old in all_nodes])
        else:
            x = None

        new_data = Data(x=x, edge_index=new_edge_index, y=data.y)

        if hasattr(data, 'pi'):
            new_data.pi = data.pi

        if preserve_original_ids:
            new_data.original_ids = all_nodes

        new_data_list.append(new_data)

    return new_data_list

In [None]:
diagrams_concrete, vectors_concrete, pi_concrete = compute_persistence_images_for_graphs(graphs_concrete)
for G, vec in zip(graphs_concrete, vectors_concrete):
    G.graph['pi'] = vec

diagrams_rocky, vectors_rocky, pi_rocky = compute_persistence_images_for_graphs(graphs_rocky)
for G, vec in zip(graphs_rocky, vectors_rocky):
    G.graph['pi'] = vec

## prepare data

In [None]:
graphs_concrete_pyg = convert_nx_to_pyg(graphs_concrete)
graphs_rocky_pyg = convert_nx_to_pyg(graphs_rocky)


graphs_concrete_pyg = remap_node_indices(graphs_concrete_pyg)
graphs_rocky_pyg = remap_node_indices(graphs_rocky_pyg)


graphs_concrete_pyg = add_graph_labels(graphs_concrete_pyg, label=0)
graphs_rocky_pyg = add_graph_labels(graphs_rocky_pyg, label=1)

def combine_graphs(graphs_label_dict):
    all_graphs = []
    for graphs in graphs_label_dict.values():
        all_graphs.extend(graphs)
    return all_graphs

graphs_dict_with_labels = {
    0: graphs_concrete_pyg,  # label 0
    1: graphs_rocky_pyg      # label 1
}

all_graphs = combine_graphs(graphs_dict_with_labels)

In [None]:
from sklearn.model_selection import train_test_split

def split_graphs(graphs, test_size=0.2, val_size=0.25, random_state=42):
    """
    Split PyG graph objects into train, val, and test sets using stratified labels inside the graphs.

    Args:
        graphs (list): List of PyG Data objects. Each must have a `y` attribute (scalar label).
        test_size (float): Proportion of dataset to include in test split.
        val_size (float): Proportion of trainval set to include in validation split.
        random_state (int): Seed for reproducibility.

    Returns:
        dict: {"train": [...], "val": [...], "test": [...]}
    """
    labels = [int(g.y.item()) for g in graphs]  
    graphs_trainval, graphs_test = train_test_split(
        graphs, test_size=test_size, stratify=labels, random_state=random_state
    )

    labels_trainval = [int(g.y.item()) for g in graphs_trainval]
    graphs_train, graphs_val = train_test_split(
        graphs_trainval, test_size=val_size, stratify=labels_trainval, random_state=random_state
    )

    return {
        "train": graphs_train,
        "val": graphs_val,
        "test": graphs_test
    }

In [None]:
splits = split_graphs(all_graphs, test_size=0.2, val_size=0.25, random_state=42)

train_graphs = splits["train"]
val_graphs = splits["val"]
test_graphs = splits["test"]

print(f"Train: {len(train_graphs)}")
print(f"Validation: {len(val_graphs)}")
print(f"Test: {len(test_graphs)}")

sample = train_graphs[0]
print(f"Node dim: {sample.x.shape}")  
print(type(all_graphs[0]))
print(all_graphs[0])

In [None]:
from torch_geometric.data import DataLoader

train_loader = DataLoader(train_graphs, batch_size=4, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=4)
test_loader = DataLoader(test_graphs, batch_size=4)

## TGAT

In [None]:
import numpy as np
import torch
import logging
import torch.nn as nn
from torch_geometric.nn import global_mean_pool
import torch
from gudhi.representations import PersistenceImage


class StandardizeDimensions(nn.Module):
    """Ensures inputs are always [batch, seq_len, features] or [batch, features]"""
    def __init__(self):
        super().__init__()
    
    def forward(self, x, min_dims=2):
        if isinstance(x, (list, tuple)):
            return [self.forward(tensor, min_dims) for tensor in x]
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x)
        while x.dim() < min_dims:
            x = x.unsqueeze(0)
        if x.dim() == 2:
            return x.unsqueeze(1) if min_dims >= 3 else x
        elif x.dim() == 3:  
            return x
        else:
            raise ValueError(f"Unexpected shape {x.shape}. Max 3D tensors supported")


class MergeLayer(nn.Module):
    def __init__(self, dim1, dim2, dim3, dim4):
        super().__init__()
        self.fc1 = nn.Linear(dim1 + dim2, dim3)
        self.fc2 = nn.Linear(dim3, dim4)
        self.act = nn.ReLU()
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        
    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)
        h = self.act(self.fc1(x))
        return self.fc2(h)


class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):
        if q.dim() == 2:
            q = q.unsqueeze(1)
        if k.dim() == 2:
            k = k.unsqueeze(1)
        if v.dim() == 2:
            v = v.unsqueeze(1)
            
        attn = torch.bmm(q, k.transpose(1, 2)) / self.temperature
        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(1)
            attn = attn.masked_fill(mask, -1e10)
        attn = self.dropout(self.softmax(attn))
        output = torch.bmm(attn, v)
        return output, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        self.std_dims = StandardizeDimensions()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        
        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), attn_dropout=dropout)
        self.layer_norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        # Standardize input dims to at least 3D: [batch, seq_len, features]
        q, k, v = self.std_dims([q, k, v], min_dims=3)
        mask = self.std_dims(mask, min_dims=3) if mask is not None else None

        residual = q
        batch_size, len_q, _ = q.size()
        len_k = k.size(1)

        # Linear projections and reshape for multi-head attention
        q = self.w_qs(q).view(batch_size, len_q, self.n_head, self.d_k)
        k = self.w_ks(k).view(batch_size, len_k, self.n_head, self.d_k)
        v = self.w_vs(v).view(batch_size, len_k, self.n_head, self.d_v)

        # Permute and reshape for batch matrix multiplication
        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, self.d_k)
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, self.d_k)
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_k, self.d_v)

        if mask is not None:
            mask = mask.repeat(self.n_head, 1, 1)

        output, attn = self.attention(q, k, v, mask=mask)

        # Reshape back to [batch, seq_len, n_head * d_v]
        output = output.view(self.n_head, batch_size, len_q, self.d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(batch_size, len_q, -1)

        output = self.dropout(self.fc(output))
        return self.layer_norm(output + residual), attn


class TimeEncode(nn.Module):
    def __init__(self, expand_dim, factor=5):
        super().__init__()
        self.factor = factor
        self.basis_freq = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, expand_dim))).float())
        self.phase = nn.Parameter(torch.zeros(expand_dim).float())

    def forward(self, ts):
        if ts.dim() == 1:
            ts = ts.unsqueeze(-1)
        ts = ts.view(ts.size(0), ts.size(1), 1)
        map_ts = ts * self.basis_freq.view(1, 1, -1) + self.phase.view(1, 1, -1)
        return torch.cos(map_ts)


class PosEncode(nn.Module):
    def __init__(self, expand_dim, seq_len):
        super().__init__()
        self.pos_embeddings = nn.Embedding(seq_len, expand_dim)
        
    def forward(self, ts):
        if ts.dim() == 1:
            ts = ts.unsqueeze(-1)
        order = ts.argsort(dim=1)
        return self.pos_embeddings(order)


class LSTMPool(nn.Module):
    def __init__(self, feat_dim, edge_dim, time_dim):
        super().__init__()
        self.feat_dim = feat_dim
        self.time_dim = time_dim
        self.edge_dim = edge_dim
        self.att_dim = feat_dim + edge_dim + time_dim
        self.act = nn.ReLU()
        self.lstm = nn.LSTM(input_size=self.att_dim, hidden_size=feat_dim, num_layers=1, batch_first=True)
        self.merger = MergeLayer(feat_dim, feat_dim, feat_dim, feat_dim)

    def forward(self, src, src_t, seq, seq_t, seq_e, mask):
        if seq.dim() == 2:
            seq = seq.unsqueeze(1)
        if seq_t.dim() == 2:
            seq_t = seq_t.unsqueeze(1)
        seq_x = torch.cat([seq, seq_t], dim=2)
        _, (hn, _) = self.lstm(seq_x)
        return self.merger(hn[-1], src), None


class MeanPool(nn.Module):
    def __init__(self, feat_dim, edge_dim):
        super().__init__()
        self.edge_dim = edge_dim
        self.feat_dim = feat_dim
        self.act = nn.ReLU()
        self.merger = MergeLayer(edge_dim + feat_dim, feat_dim, feat_dim, feat_dim)
        
    def forward(self, src, src_t, seq, seq_t, seq_e, mask):
        if seq.dim() == 2:
            seq = seq.unsqueeze(1)
        if seq_e.dim() == 2:
            seq_e = seq_e.unsqueeze(1)
        seq_x = torch.cat([seq, seq_e], dim=2)
        return self.merger(seq_x.mean(dim=1), src), None


class AttnModel(nn.Module):
    def __init__(self, feat_dim, edge_dim, time_dim, n_head=2, drop_out=0.1):
        super().__init__()
        self.feat_dim = feat_dim
        self.time_dim = time_dim
        self.edge_in_dim = feat_dim + edge_dim + time_dim
        self.model_dim = self.edge_in_dim
        self.std_dims = StandardizeDimensions()
        self.merger = MergeLayer(self.model_dim, feat_dim, feat_dim, feat_dim)
        assert self.model_dim % n_head == 0

        self.multi_head_target = MultiHeadAttention(
            n_head=n_head,
            d_model=self.model_dim,
            d_k=self.model_dim // n_head,
            d_v=self.model_dim // n_head,
            dropout=drop_out
        )
        self.dropout = nn.Dropout(drop_out)

    def forward(self, src, src_t, seq, seq_t, seq_e, mask):
        src = self.std_dims(src, min_dims=3)
        src_t = self.std_dims(src_t, min_dims=3)
        seq = self.std_dims(seq, min_dims=3)
        seq_t = self.std_dims(seq_t, min_dims=3)
        seq_e = self.std_dims(seq_e, min_dims=3)

        seq_len = seq.size(1)
        assert seq_t.size(1) == seq_len
        assert seq_e.size(1) == seq_len

        q = torch.cat([src, src_t], dim=-1)
        k = torch.cat([seq, seq_t, seq_e], dim=-1)

        output, attn = self.multi_head_target(q=q, k=k, v=k, mask=mask)
        return self.merger(output.squeeze(1), src.squeeze(1)), attn


class TGAN(torch.nn.Module):
    def __init__(self, ngh_finder,n_feat, e_feat, use_time='time', agg_method='attn',
                 node_dim=None, time_dim=None, num_layers=3, n_head=1,
                 null_idx=0, num_heads=1, drop_out=0.1, seq_len=None):
        super(TGAN, self).__init__()
        
        self.raw_feat_dim = n_feat.shape[1]
        self.n_feat = torch.nn.Parameter(torch.tensor(n_feat, dtype=torch.float32))
        
        
        self.node_raw_embed = torch.nn.Embedding.from_pretrained(
        self.n_feat, padding_idx=0, freeze=True)
        self.feat_dim = max(64, self.raw_feat_dim)
        
        if self.raw_feat_dim != self.feat_dim:
            self.proj = nn.Sequential(
                nn.Linear(self.raw_feat_dim, self.feat_dim),
                nn.ReLU()
            )
        else:
            self.proj = torch.nn.Identity()
        
        self.num_layers = num_layers
        self.ngh_finder = ngh_finder
        self.null_idx = null_idx
        self.logger = logging.getLogger(__name__)
        self.n_feat_dim = self.feat_dim
        self.e_feat_dim = 0 if e_feat is None else e_feat.shape[1]
        self.model_dim = self.feat_dim + 1
        self.use_time = use_time
        self.merge_layer = MergeLayer(self.feat_dim, self.feat_dim, self.feat_dim, self.feat_dim)

        if agg_method == 'attn':
            self.attn_model_list = nn.ModuleList([
                AttnModel(self.feat_dim, 0, self.feat_dim, n_head=n_head, drop_out=drop_out)
                for _ in range(num_layers)
            ])
        elif agg_method == 'lstm':
            self.attn_model_list = nn.ModuleList([
                LSTMPool(self.feat_dim, self.feat_dim, self.feat_dim)
                for _ in range(num_layers)
            ])
        elif agg_method == 'mean':
            self.attn_model_list = nn.ModuleList([
                MeanPool(self.feat_dim, self.feat_dim)
                for _ in range(num_layers)
            ])

        if use_time == 'time':
            self.time_encoder = TimeEncode(expand_dim=self.feat_dim)
        elif use_time == 'pos':
            self.time_encoder = PosEncode(expand_dim=self.feat_dim, seq_len=seq_len)

        self.pi = PersistenceImage(bandwidth=0.1, weight=lambda x: x[1], im_range=[0, 1, 0, 1],
                                   resolution=[10, 10])

        self.classifier = nn.Sequential(
            nn.Linear(self.feat_dim + 100, self.feat_dim // 2),
            nn.ReLU(),
            nn.Dropout(drop_out),
            nn.Linear(self.feat_dim // 2, 1)
        )

    def forward(self, *args):
        return self.forward_graph_classification(*args)

    def forward_graph_classification(self, batch):
        device = self.n_feat.device
        src_idx_l = torch.arange(batch.num_nodes, device=device, dtype=torch.long)
        cut_time_l = torch.zeros(batch.num_nodes, device=device, dtype=torch.float32)
        
        node_embeddings = self.proj(self.tem_conv(src_idx_l, cut_time_l, self.num_layers))
        graph_embeddings = global_mean_pool(node_embeddings, batch.batch)
        
        # Use existing PI vectors stored in batch.pi, assumed shape [num_graphs, 400]
        pi_tensor = batch.pi.to(device).float()
        
        combined = torch.cat([graph_embeddings, pi_tensor], dim=1)
        return self.classifier(combined)


    def tem_conv(self, src_idx_l, cut_time_l, curr_layers, num_neighbors=20):
        if cut_time_l.dim() == 1:
            cut_time_l = cut_time_l.unsqueeze(-1)

        raw_embed = self.node_raw_embed(src_idx_l)
        src_node_feat = self.proj(raw_embed)
        if curr_layers == 0:
            return src_node_feat

        src_node_conv_feat = self.tem_conv(
            src_idx_l, 
            cut_time_l, 
            curr_layers - 1,
            num_neighbors
        )

        src_ngh_node_batch, _ = self.ngh_finder.get_temporal_neighbor(
            src_idx_l, 
            cut_time_l.squeeze(-1),  
            num_neighbors=num_neighbors
        )

        src_ngh_node_batch_th = src_ngh_node_batch.long().flatten()

        src_ngh_feat = self.tem_conv(
            src_ngh_node_batch.flatten(),
            cut_time_l.repeat_interleave(num_neighbors, dim=0),
            curr_layers - 1,
            num_neighbors
        ).view(src_idx_l.size(0), num_neighbors, -1)

        src_ngh_t_embed = self.time_encoder(cut_time_l - cut_time_l.mean())

        if src_ngh_t_embed.dim() == 2:
            src_ngh_t_embed = src_ngh_t_embed.unsqueeze(1)
        src_ngh_t_embed = src_ngh_t_embed.expand(-1, num_neighbors, -1)

        mask = src_ngh_node_batch == 0
        
        local, _ = self.attn_model_list[curr_layers - 1](
            src_node_conv_feat.unsqueeze(1),
            self.time_encoder(torch.zeros_like(cut_time_l)),
            src_ngh_feat,
            src_ngh_t_embed,
            torch.zeros_like(src_ngh_feat[..., :0]),
            mask.unsqueeze(1)
        )
        
        return local.squeeze(1)


class NeighborFinder:
    def __init__(self, adj_list, uniform=False):
        self.node_idx_l, self.node_ts_l, self.edge_idx_l, self.off_set_l = self.init_off_set(adj_list)
        self.uniform = uniform
        self.pi = PersistenceImage(bandwidth=0.1, weight=lambda x: x[1], im_range=[0, 1, 0, 1],
                                   resolution=[10, 10]) 

    def init_off_set(self, adj_list):
        n_idx_l, n_ts_l, e_idx_l = [], [], []
        off_set_l = [0]
        for i in range(len(adj_list)):
            curr = sorted(adj_list[i], key=lambda x: x[1])
            n_idx_l.extend([x[0] for x in curr])
            e_idx_l.extend([x[1] for x in curr])
            n_ts_l.extend([x[2] for x in curr])
            off_set_l.append(len(n_idx_l))
        return torch.tensor(n_idx_l), torch.tensor(n_ts_l), torch.tensor(e_idx_l), torch.tensor(off_set_l)

    def find_before(self, src_idx, cut_time):
        start = self.off_set_l[src_idx]
        end = self.off_set_l[src_idx + 1]
        neighbors_ts = self.node_ts_l[start:end]
        mask = neighbors_ts < cut_time
        return self.node_idx_l[start:end][mask], self.edge_idx_l[start:end][mask], neighbors_ts[mask]

    def get_temporal_neighbor(self, src_idx_l, cut_time_l, num_neighbors=20):
        if src_idx_l.is_cuda or cut_time_l.is_cuda:
            device = src_idx_l.device
        else:
            device = torch.device('cpu')

        out_ngh_node = torch.zeros((len(src_idx_l), num_neighbors), dtype=torch.long, device=device)
        out_ngh_t = torch.zeros((len(src_idx_l), num_neighbors), dtype=torch.float, device=device)

        for i in range(len(src_idx_l)):
            src_idx = src_idx_l[i].item()
            cut_time = cut_time_l[i].item()
            ngh_idx, _, ngh_ts = self.find_before(src_idx, cut_time)
            if len(ngh_idx) > 0:
                sampled = ngh_idx[:num_neighbors]
                sampled_ts = ngh_ts[:num_neighbors]
                out_ngh_node[i, -len(sampled):] = sampled
                out_ngh_t[i, -len(sampled_ts):] = sampled_ts
        return out_ngh_node, out_ngh_t


class GraphClassifier(torch.nn.Module):
    def __init__(self, dim, drop=0.3):
        super().__init__()
        self.fc_1 = torch.nn.Linear(dim, 80)
        self.fc_2 = torch.nn.Linear(80, 10)
        self.fc_3 = torch.nn.Linear(10, 1)
        self.act = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=drop)

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        x = self.act(self.fc_1(x))
        x = self.dropout(x)
        x = self.act(self.fc_2(x))
        x = self.dropout(x)
        return self.fc_3(x).flatten()


def build_neighbor_finder(dataset):
    src_all, dst_all, eidx_all, ts_all = [], [], [], []
    offset = 0
    node_offset = 0

    for data in dataset:
        n_nodes = data.num_nodes
        edges = data.edge_index.cpu().numpy()
        node_timestamps = data.x[:, 1].cpu().numpy()
        edge_timestamps = np.minimum(node_timestamps[edges[0]], node_timestamps[edges[1]])
        edge_indices = np.arange(offset, offset + edges.shape[1])
        src_all.append(edges[0] + node_offset)
        dst_all.append(edges[1] + node_offset)
        eidx_all.append(edge_indices)
        ts_all.append(edge_timestamps)
        offset += edges.shape[1]
        node_offset += n_nodes

    src_all = np.concatenate(src_all)
    dst_all = np.concatenate(dst_all)
    eidx_all = np.concatenate(eidx_all)
    ts_all = np.concatenate(ts_all)

    max_node = max(np.max(src_all), np.max(dst_all)) if len(src_all) > 0 else 0
    adjacency_list = [[] for _ in range(max_node + 1)]

    for s, d, eidx, ts in zip(src_all, dst_all, eidx_all, ts_all):
        adjacency_list[s].append((d, eidx, ts))
        adjacency_list[d].append((s, eidx, ts))

    nf = NeighborFinder(adjacency_list, uniform=False)
    nf.pi = PersistenceImage(bandwidth=0.1, weight=lambda x: x[1],
                             im_range=[0, 1, 0, 1], resolution=[10, 10])  
    return nf

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torch
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
from torch_geometric.nn import global_mean_pool



def eval_epoch_metrics(tgan, classifier, loader, device, n_layer):
    tgan.eval()
    classifier.eval()
    y_true, y_pred_logits = [], []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            src_idx_l = torch.arange(batch.num_nodes, device=device).long()
            cut_time_l = torch.zeros(batch.num_nodes, device=device).float()
            node_emb = tgan.tem_conv(src_idx_l, cut_time_l, curr_layers=n_layer)
            graph_emb = global_mean_pool(node_emb, batch.batch)

            pi_vecs = torch.stack([g.pi for g in batch.to_data_list()]).to(device)
            combined_emb = torch.cat([graph_emb, pi_vecs], dim=1)

            logits = classifier(combined_emb)
            y_true.append(batch.y.cpu().numpy())
            y_pred_logits.append(logits.cpu().numpy())

    y_true = np.concatenate(y_true)
    y_pred_logits = np.concatenate(y_pred_logits)
    y_pred = (y_pred_logits > 0).astype(int)

    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    auc = roc_auc_score(y_true, y_pred_logits)

    return acc, f1, auc

In [None]:
import torch
import numpy as np
import random
import os

def run_training_multiple_runs_tgat(train_loader, val_loader, test_loader,
                                    n_epoch=100, n_layer=2, lr=3e-4,
                                    drop_out=0.1,
                                    num_runs=5, seed_base=42, patience=20,
                                    checkpoint_dir=None, min_delta=0.001):
    

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if checkpoint_dir is None:
        checkpoint_dir = "tgat_checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    all_test_metrics = []
    all_val_metrics = []

    # Precompute neighbor finder & node features
    train_ngh_finder = build_neighbor_finder(train_loader.dataset)
    node_features = torch.cat([data.x for data in train_loader.dataset], dim=0)
    node_features = node_features.detach().clone().float()  # safe for Parameter

    for run in range(num_runs):
        seed = seed_base + run
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    
        tgat = TGAN(
            ngh_finder=train_ngh_finder,  
            n_feat=node_features.numpy(),
            e_feat=None,
            use_time='time',
            agg_method='attn',
            num_layers=2,
            n_head=4,
            drop_out=0.1
        ).to(device)

        

        classifier = GraphClassifier(dim=tgat.feat_dim + 400, drop=drop_out).to(device)
        optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)
        criterion = torch.nn.BCEWithLogitsLoss()

        print(f"\n--- Run {run+1}/{num_runs} (Seed: {seed}) ---")

  
        dummy_src = torch.arange(10, device=device)
        dummy_time = torch.zeros(10, device=device)
        with torch.no_grad():
            _ = tgat.tem_conv(dummy_src, dummy_time, curr_layers=n_layer)

        # --- Initialize best metrics ---
        best_val_auc = 0
        best_test_metrics = {'acc': 0, 'f1': 0, 'auc': 0}
        epochs_no_improve = 0

        # --- Check for existing checkpoint to resume ---
        checkpoint_path = os.path.join(checkpoint_dir, f'run_{run}_best.pth')
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location=device)
            best_val_auc = checkpoint.get('best_val_auc', 0)
            best_test_metrics = checkpoint.get('best_test_metrics', {'acc':0,'f1':0,'auc':0})
            print(f"Resuming run {run} from checkpoint: best_val_auc={best_val_auc:.4f}")

        for epoch in range(n_epoch):
            # Training
            total_loss = 0
            tgat.eval()
            classifier.train()
            for batch in train_loader:
                batch = batch.to(device)
                optimizer.zero_grad()

                src_idx = torch.arange(batch.num_nodes, device=device)
                cut_time = torch.zeros(batch.num_nodes, device=device)
                node_emb = tgat.tem_conv(src_idx, cut_time, curr_layers=n_layer)
                graph_emb = global_mean_pool(node_emb, batch.batch)
                pi_vecs = torch.stack([g.pi for g in batch.to_data_list()]).to(device)
                combined_emb = torch.cat([graph_emb, pi_vecs], dim=1)

                logits = classifier(combined_emb)
                loss = criterion(logits, batch.y.float())
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * getattr(batch, "num_graphs", batch.y.size(0))

            train_loss = total_loss / len(train_loader.dataset)

            # Validation & Test
            val_acc, val_f1, val_auc = eval_epoch_metrics(tgat, classifier, val_loader, device, n_layer)
            test_acc, test_f1, test_auc = eval_epoch_metrics(tgat, classifier, test_loader, device, n_layer)
            print(f"Epoch {epoch}: Train Loss: {train_loss:.4f} | Val AUC: {val_auc:.4f} | Test AUC: {test_auc:.4f}")

            # --- Save every epoch (for crash recovery) ---
            torch.save({
              'tgat_state_dict': tgat.state_dict(),
              'classifier_state_dict': classifier.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'best_val_auc': best_val_auc,
              'best_test_metrics': best_test_metrics,
              'epoch': epoch
          }, checkpoint_path)

            # --- Save best model ---
            if val_auc > best_val_auc + min_delta:
                best_val_auc = val_auc
                best_test_metrics = {'acc': test_acc, 'f1': test_f1, 'auc': test_auc}
                epochs_no_improve = 0
                torch.save({
                    'classifier_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_val_auc': best_val_auc,
                    'best_test_metrics': best_test_metrics,
                    'epoch': epoch
                }, checkpoint_path)
            else:
                epochs_no_improve += 1

            # Early stopping
            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} for run {run+1}")
                break

        # Load best model for final evaluation
        checkpoint = torch.load(checkpoint_path, map_location=device)
        best_val_auc = checkpoint['best_val_auc']
        best_test_metrics = checkpoint['best_test_metrics']

        all_val_metrics.append(best_val_auc)
        all_test_metrics.append(best_test_metrics)

        del tgat, classifier, optimizer
        torch.cuda.empty_cache()

    # Aggregate metrics
    test_accs = np.array([m['acc'] for m in all_test_metrics])
    test_f1s = np.array([m['f1'] for m in all_test_metrics])
    test_aucs = np.array([m['auc'] for m in all_test_metrics])

    results = {
        'mean_test_acc': float(test_accs.mean()),
        'std_test_acc': float(test_accs.std()),
        'mean_test_f1': float(test_f1s.mean()),
        'std_test_f1': float(test_f1s.std()),
        'mean_test_auc': float(test_aucs.mean()),
        'std_test_auc': float(test_aucs.std())
    }

    print(f"\n{'='*60}")
    print(f"FINAL SUMMARY OVER {num_runs} RUNS")
    print(f"{'='*60}")
    print(f"Accuracy: {results['mean_test_acc']*100:.2f}% ± {results['std_test_acc']*100:.2f}%")
    print(f"F1 Score: {results['mean_test_f1']*100:.2f}% ± {results['std_test_f1']*100:.2f}%")
    print(f"AUC:      {results['mean_test_auc']*100:.2f}% ± {results['std_test_auc']*100:.2f}%")

    return results



In [None]:
# Run with all features
results = run_training_multiple_runs_tgat(
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    n_epoch=100,
    n_layer=2,
    lr=1e-3,
    drop_out=0.1,
    num_runs=10,
    seed_base=42,
    patience=20,
    min_delta=0.001,
    checkpoint_dir=None
)

print(f"Final Results: {results['mean_test_auc']:.4f} ± {results['std_test_auc']:.4f}")


## Ablation

In [None]:
import pickle
import os

filepath = "../collection/graphs/networkx/concrete/graphs_concrete.pkl"


with open(filepath, 'rb') as f:
    graphs_concrete_ablation = pickle.load(f)
print(f"Loaded {len(graphs_concrete_ablation)} graphs from {filepath}")


filepath = '../collection/graphs/networkx/rocky/graphs_rocky.pkl'


with open(filepath, 'rb') as f:
    graphs_rocky_ablation = pickle.load(f)
print(f"Loaded {len(graphs_rocky_ablation)} graphs from {filepath}")

graphs_concrete_pyg_ablation = convert_nx_to_pyg(graphs_concrete_ablation)
graphs_rocky_pyg_ablation = convert_nx_to_pyg(graphs_rocky_ablation)

graphs_concrete_pyg_ablation = remap_node_indices(graphs_concrete_pyg_ablation)
graphs_rocky_pyg_ablation = remap_node_indices(graphs_rocky_pyg_ablation)

graphs_concrete_pyg_ablation = add_graph_labels(graphs_concrete_pyg_ablation, label=0)
graphs_rocky_pyg_ablation = add_graph_labels(graphs_rocky_pyg_ablation, label=1)

def combine_graphs(graphs_label_dict):
    all_graphs = []
    for graphs in graphs_label_dict.values():
        all_graphs.extend(graphs)
    return all_graphs

graphs_dict_with_labels_ablation = {
    0: graphs_concrete_pyg_ablation,  # label 0
    1: graphs_rocky_pyg_ablation      # label 1
}

all_graphs_ablation = combine_graphs(graphs_dict_with_labels_ablation)

In [None]:
from sklearn.model_selection import train_test_split

splits_ablation = split_graphs(all_graphs_ablation, test_size=0.2, val_size=0.25, random_state=42)

train_graphs_ablation = splits_ablation["train"]
val_graphs_ablation = splits_ablation["val"]
test_graphs_ablation = splits_ablation["test"]

print(f"Train: {len(train_graphs_ablation)}")
print(f"Validation: {len(val_graphs_ablation)}")
print(f"Test: {len(test_graphs_ablation)}")

sample_ablation = train_graphs_ablation[0]
print(f"Node dim: {sample_ablation.x.shape}")  
print(type(all_graphs_ablation[0]))
print(all_graphs_ablation[0])

In [None]:
from torch_geometric.data import DataLoader

train_loader_ablation = DataLoader(train_graphs_ablation, batch_size=4, shuffle=True)
val_loader_ablation = DataLoader(val_graphs_ablation, batch_size=4)
test_loader_ablation = DataLoader(test_graphs_ablation, batch_size=4)

In [None]:
graph_ablation = train_loader_ablation.dataset[0]
print("Standard attributes via keys():", graph_ablation.keys())

## TGAT without PI

In [None]:
import numpy as np
import torch
import logging
import torch
import torch.nn as nn
from torch_geometric.nn import global_mean_pool
from sklearn.metrics import roc_auc_score


class StandardizeDimensions(nn.Module):
    """Ensures inputs are always [batch, seq_len, features] or [batch, features]"""
    def __init__(self):
        super().__init__()
    
    def forward(self, x, min_dims=2):
        if isinstance(x, (list, tuple)):
            return [self.forward(tensor, min_dims) for tensor in x]
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x)
        while x.dim() < min_dims:
            x = x.unsqueeze(0)
        if x.dim() == 2:
            return x.unsqueeze(1) if min_dims >= 3 else x
        elif x.dim() == 3:  
            return x
        else:
            raise ValueError(f"Unexpected shape {x.shape}. Max 3D tensors supported")


class MergeLayer(nn.Module):
    def __init__(self, dim1, dim2, dim3, dim4):
        super().__init__()
        self.fc1 = nn.Linear(dim1 + dim2, dim3)
        self.fc2 = nn.Linear(dim3, dim4)
        self.act = nn.ReLU()
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        
    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)
        h = self.act(self.fc1(x))
        return self.fc2(h)


class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):
        if q.dim() == 2:
            q = q.unsqueeze(1)
        if k.dim() == 2:
            k = k.unsqueeze(1)
        if v.dim() == 2:
            v = v.unsqueeze(1)
            
        attn = torch.bmm(q, k.transpose(1, 2)) / self.temperature
        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(1)
            attn = attn.masked_fill(mask, -1e10)
        attn = self.dropout(self.softmax(attn))
        output = torch.bmm(attn, v)
        return output, attn


class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        self.std_dims = StandardizeDimensions()
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        
        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), attn_dropout=dropout)
        self.layer_norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        # Standardize input dims to at least 3D: [batch, seq_len, features]
        q, k, v = self.std_dims([q, k, v], min_dims=3)
        mask = self.std_dims(mask, min_dims=3) if mask is not None else None

        residual = q
        batch_size, len_q, _ = q.size()
        len_k = k.size(1)

        # Linear projections and reshape for multi-head attention
        q = self.w_qs(q).view(batch_size, len_q, self.n_head, self.d_k)
        k = self.w_ks(k).view(batch_size, len_k, self.n_head, self.d_k)
        v = self.w_vs(v).view(batch_size, len_k, self.n_head, self.d_v)

        # Permute and reshape for batch matrix multiplication
        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, self.d_k)
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, self.d_k)
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_k, self.d_v)

        if mask is not None:
            mask = mask.repeat(self.n_head, 1, 1)

        output, attn = self.attention(q, k, v, mask=mask)

        # Reshape back to [batch, seq_len, n_head * d_v]
        output = output.view(self.n_head, batch_size, len_q, self.d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(batch_size, len_q, -1)

        output = self.dropout(self.fc(output))
        return self.layer_norm(output + residual), attn


class TimeEncode(nn.Module):
    def __init__(self, expand_dim, factor=5):
        super().__init__()
        self.factor = factor
        self.basis_freq = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, expand_dim))).float())
        self.phase = nn.Parameter(torch.zeros(expand_dim).float())

    def forward(self, ts):
        if ts.dim() == 1:
            ts = ts.unsqueeze(-1)
        ts = ts.view(ts.size(0), ts.size(1), 1)
        map_ts = ts * self.basis_freq.view(1, 1, -1) + self.phase.view(1, 1, -1)
        return torch.cos(map_ts)


class PosEncode(nn.Module):
    def __init__(self, expand_dim, seq_len):
        super().__init__()
        self.pos_embeddings = nn.Embedding(seq_len, expand_dim)
        
    def forward(self, ts):
        if ts.dim() == 1:
            ts = ts.unsqueeze(-1)
        order = ts.argsort(dim=1)
        return self.pos_embeddings(order)


class LSTMPool(nn.Module):
    def __init__(self, feat_dim, edge_dim, time_dim):
        super().__init__()
        self.feat_dim = feat_dim
        self.time_dim = time_dim
        self.edge_dim = edge_dim
        self.att_dim = feat_dim + edge_dim + time_dim
        self.act = nn.ReLU()
        self.lstm = nn.LSTM(input_size=self.att_dim, hidden_size=feat_dim, num_layers=1, batch_first=True)
        self.merger = MergeLayer(feat_dim, feat_dim, feat_dim, feat_dim)

    def forward(self, src, src_t, seq, seq_t, seq_e, mask):
        if seq.dim() == 2:
            seq = seq.unsqueeze(1)
        if seq_t.dim() == 2:
            seq_t = seq_t.unsqueeze(1)
        seq_x = torch.cat([seq, seq_t], dim=2)
        _, (hn, _) = self.lstm(seq_x)
        return self.merger(hn[-1], src), None


class MeanPool(nn.Module):
    def __init__(self, feat_dim, edge_dim):
        super().__init__()
        self.edge_dim = edge_dim
        self.feat_dim = feat_dim
        self.act = nn.ReLU()
        self.merger = MergeLayer(edge_dim + feat_dim, feat_dim, feat_dim, feat_dim)
        
    def forward(self, src, src_t, seq, seq_t, seq_e, mask):
        if seq.dim() == 2:
            seq = seq.unsqueeze(1)
        if seq_e.dim() == 2:
            seq_e = seq_e.unsqueeze(1)
        seq_x = torch.cat([seq, seq_e], dim=2)
        return self.merger(seq_x.mean(dim=1), src), None


class AttnModel(nn.Module):
    def __init__(self, feat_dim, edge_dim, time_dim, n_head=2, drop_out=0.1):
        super().__init__()
        self.feat_dim = feat_dim
        self.time_dim = time_dim
        self.edge_in_dim = feat_dim + edge_dim + time_dim
        self.model_dim = self.edge_in_dim
        self.std_dims = StandardizeDimensions()
        self.merger = MergeLayer(self.model_dim, feat_dim, feat_dim, feat_dim)
        assert self.model_dim % n_head == 0

        self.multi_head_target = MultiHeadAttention(
            n_head=n_head,
            d_model=self.model_dim,
            d_k=self.model_dim // n_head,
            d_v=self.model_dim // n_head,
            dropout=drop_out
        )
        self.dropout = nn.Dropout(drop_out)

    def forward(self, src, src_t, seq, seq_t, seq_e, mask):
        src = self.std_dims(src, min_dims=3)
        src_t = self.std_dims(src_t, min_dims=3)
        seq = self.std_dims(seq, min_dims=3)
        seq_t = self.std_dims(seq_t, min_dims=3)
        seq_e = self.std_dims(seq_e, min_dims=3)

        seq_len = seq.size(1)
        assert seq_t.size(1) == seq_len
        assert seq_e.size(1) == seq_len

        q = torch.cat([src, src_t], dim=-1)
        k = torch.cat([seq, seq_t, seq_e], dim=-1)

        output, attn = self.multi_head_target(q=q, k=k, v=k, mask=mask)
        return self.merger(output.squeeze(1), src.squeeze(1)), attn


class TGAN(torch.nn.Module):
    def __init__(self, ngh_finder, n_feat, e_feat, use_time='time', agg_method='attn',
                 node_dim=None, time_dim=None, num_layers=3, n_head=1,
                 null_idx=0, num_heads=1, drop_out=0.1, seq_len=None):
        super(TGAN, self).__init__()
        
        self.raw_feat_dim = n_feat.shape[1]
        self.n_feat = torch.nn.Parameter(torch.tensor(n_feat, dtype=torch.float32))
        
        self.node_raw_embed = torch.nn.Embedding.from_pretrained(
            self.n_feat, padding_idx=0, freeze=True)
        self.feat_dim = max(64, self.raw_feat_dim)
        
        if self.raw_feat_dim != self.feat_dim:
            self.proj = nn.Sequential(
                nn.Linear(self.raw_feat_dim, self.feat_dim),
                nn.ReLU()
            )
        else:
            self.proj = torch.nn.Identity()
        
        self.num_layers = num_layers
        self.ngh_finder = ngh_finder
        self.null_idx = null_idx
        self.logger = logging.getLogger(__name__)
        self.n_feat_dim = self.feat_dim
        self.e_feat_dim = 0 if e_feat is None else e_feat.shape[1]
        self.model_dim = self.feat_dim + 1
        self.use_time = use_time
        self.merge_layer = MergeLayer(self.feat_dim, self.feat_dim, self.feat_dim, self.feat_dim)

        if agg_method == 'attn':
            self.attn_model_list = nn.ModuleList([
                AttnModel(self.feat_dim, 0, self.feat_dim, n_head=n_head, drop_out=drop_out)
                for _ in range(num_layers)
            ])
        elif agg_method == 'lstm':
            self.attn_model_list = nn.ModuleList([
                LSTMPool(self.feat_dim, self.feat_dim, self.feat_dim)
                for _ in range(num_layers)
            ])
        elif agg_method == 'mean':
            self.attn_model_list = nn.ModuleList([
                MeanPool(self.feat_dim, self.feat_dim)
                for _ in range(num_layers)
            ])

        if use_time == 'time':
            self.time_encoder = TimeEncode(expand_dim=self.feat_dim)
        elif use_time == 'pos':
            self.time_encoder = PosEncode(expand_dim=self.feat_dim, seq_len=seq_len)

        self.classifier = nn.Sequential(
            nn.Linear(self.feat_dim, self.feat_dim // 2),
            nn.ReLU(),
            nn.Dropout(drop_out),
            nn.Linear(self.feat_dim // 2, 1)
        )

    def forward(self, *args):
        return self.forward_graph_classification(*args)

    def forward_graph_classification(self, batch):
        device = self.n_feat.device
        src_idx_l = torch.arange(batch.num_nodes, device=device, dtype=torch.long)
        cut_time_l = torch.zeros(batch.num_nodes, device=device, dtype=torch.float32)
        
        node_embeddings = self.proj(self.tem_conv(src_idx_l, cut_time_l, self.num_layers))
        graph_embeddings = global_mean_pool(node_embeddings, batch.batch)
        
        # --- ABLATION: REPLACE real PI with random noise ---
        # random noise with the same shape as the real PI features
        batch_size = graph_embeddings.size(0)
        random_features = torch.randn(batch_size, 400).to(device) # <-- 400 features of noise
        combined = torch.cat([graph_embeddings, random_features], dim=1)

        
        return self.classifier(combined) # Classifier must expect (feat_dim + 400) inputs

    def tem_conv(self, src_idx_l, cut_time_l, curr_layers, num_neighbors=20):
        if cut_time_l.dim() == 1:
            cut_time_l = cut_time_l.unsqueeze(-1)

        raw_embed = self.node_raw_embed(src_idx_l)
        src_node_feat = self.proj(raw_embed)
        if curr_layers == 0:
            return src_node_feat

        src_node_conv_feat = self.tem_conv(
            src_idx_l, 
            cut_time_l, 
            curr_layers - 1,
            num_neighbors
        )

        src_ngh_node_batch, _ = self.ngh_finder.get_temporal_neighbor(
            src_idx_l, 
            cut_time_l.squeeze(-1),  
            num_neighbors=num_neighbors
        )

        src_ngh_node_batch_th = src_ngh_node_batch.long().flatten()

        src_ngh_feat = self.tem_conv(
            src_ngh_node_batch.flatten(),
            cut_time_l.repeat_interleave(num_neighbors, dim=0),
            curr_layers - 1,
            num_neighbors
        ).view(src_idx_l.size(0), num_neighbors, -1)

        src_ngh_t_embed = self.time_encoder(cut_time_l - cut_time_l.mean())

        if src_ngh_t_embed.dim() == 2:
            src_ngh_t_embed = src_ngh_t_embed.unsqueeze(1)
        src_ngh_t_embed = src_ngh_t_embed.expand(-1, num_neighbors, -1)

        mask = src_ngh_node_batch == 0
        
        local, _ = self.attn_model_list[curr_layers - 1](
            src_node_conv_feat.unsqueeze(1),
            self.time_encoder(torch.zeros_like(cut_time_l)),
            src_ngh_feat,
            src_ngh_t_embed,
            torch.zeros_like(src_ngh_feat[..., :0]),
            mask.unsqueeze(1)
        )
        
        return local.squeeze(1)
    
class NeighborFinder:
    def __init__(self, adj_list, uniform=False):
        self.node_idx_l, self.node_ts_l, self.edge_idx_l, self.off_set_l = self.init_off_set(adj_list)
        self.uniform = uniform

    def init_off_set(self, adj_list):
        n_idx_l, n_ts_l, e_idx_l = [], [], []
        off_set_l = [0]
        for i in range(len(adj_list)):
            curr = sorted(adj_list[i], key=lambda x: x[1])
            n_idx_l.extend([x[0] for x in curr])
            e_idx_l.extend([x[1] for x in curr])
            n_ts_l.extend([x[2] for x in curr])
            off_set_l.append(len(n_idx_l))
        return torch.tensor(n_idx_l), torch.tensor(n_ts_l), torch.tensor(e_idx_l), torch.tensor(off_set_l)

    def find_before(self, src_idx, cut_time):
        start = self.off_set_l[src_idx]
        end = self.off_set_l[src_idx + 1]
        neighbors_ts = self.node_ts_l[start:end]
        mask = neighbors_ts < cut_time
        return self.node_idx_l[start:end][mask], self.edge_idx_l[start:end][mask], neighbors_ts[mask]

    def get_temporal_neighbor(self, src_idx_l, cut_time_l, num_neighbors=20):
        if src_idx_l.is_cuda or cut_time_l.is_cuda:
            device = src_idx_l.device
        else:
            device = torch.device('cpu')

        out_ngh_node = torch.zeros((len(src_idx_l), num_neighbors), dtype=torch.long, device=device)
        out_ngh_t = torch.zeros((len(src_idx_l), num_neighbors), dtype=torch.float, device=device)

        for i in range(len(src_idx_l)):
            src_idx = src_idx_l[i].item()
            cut_time = cut_time_l[i].item()
            ngh_idx, _, ngh_ts = self.find_before(src_idx, cut_time)
            if len(ngh_idx) > 0:
                sampled = ngh_idx[:num_neighbors]
                sampled_ts = ngh_ts[:num_neighbors]
                out_ngh_node[i, -len(sampled):] = sampled
                out_ngh_t[i, -len(sampled_ts):] = sampled_ts
        return out_ngh_node, out_ngh_t


class GraphClassifier(torch.nn.Module):
    def __init__(self, dim, drop=0.3):
        super().__init__()
        self.fc_1 = torch.nn.Linear(dim, 80)
        self.fc_2 = torch.nn.Linear(80, 10)
        self.fc_3 = torch.nn.Linear(10, 1)
        self.act = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=drop)

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        x = self.act(self.fc_1(x))
        x = self.dropout(x)
        x = self.act(self.fc_2(x))
        x = self.dropout(x)
        return self.fc_3(x).flatten()


def build_neighbor_finder(dataset):
    src_all, dst_all, eidx_all, ts_all = [], [], [], []
    offset = 0
    node_offset = 0

    for data in dataset:
        n_nodes = data.num_nodes
        edges = data.edge_index.cpu().numpy()
        node_timestamps = data.x[:, 1].cpu().numpy()
        edge_timestamps = np.minimum(node_timestamps[edges[0]], node_timestamps[edges[1]])
        edge_indices = np.arange(offset, offset + edges.shape[1])
        src_all.append(edges[0] + node_offset)
        dst_all.append(edges[1] + node_offset)
        eidx_all.append(edge_indices)
        ts_all.append(edge_timestamps)
        offset += edges.shape[1]
        node_offset += n_nodes

    src_all = np.concatenate(src_all)
    dst_all = np.concatenate(dst_all)
    eidx_all = np.concatenate(eidx_all)
    ts_all = np.concatenate(ts_all)

    max_node = max(np.max(src_all), np.max(dst_all)) if len(src_all) > 0 else 0
    adjacency_list = [[] for _ in range(max_node + 1)]

    for s, d, eidx, ts in zip(src_all, dst_all, eidx_all, ts_all):
        adjacency_list[s].append((d, eidx, ts))
        adjacency_list[d].append((s, eidx, ts))

    nf = NeighborFinder(adjacency_list, uniform=False)
    return nf

In [None]:
from tqdm import trange, tqdm
import torch
import numpy as np
import random
import os
from datetime import datetime


def eval_epoch_metrics(tgan, classifier, loader, device, n_layer):
    tgan.eval()
    classifier.eval()
    y_true, y_pred_logits = [], []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            src_idx_l = torch.arange(batch.num_nodes, device=device).long()
            cut_time_l = torch.zeros(batch.num_nodes, device=device).float()
            node_emb = tgan.tem_conv(src_idx_l, cut_time_l, curr_layers=n_layer)
            graph_emb = global_mean_pool(node_emb, batch.batch)

            # Modified: Remove PI vector processing and concatenation
            logits = classifier(graph_emb)
            y_true.append(batch.y.cpu().numpy())
            y_pred_logits.append(logits.cpu().numpy())

    y_true = np.concatenate(y_true)
    y_pred_logits = np.concatenate(y_pred_logits)
    y_pred = (y_pred_logits > 0).astype(int)

    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    auc = roc_auc_score(y_true, y_pred_logits)

    return acc, f1, auc

def run_training_multiple_runs_tgat_ablation(train_loader, val_loader, test_loader,
                                    n_epoch=100, n_layer=2, lr=3e-4,
                                    drop_out=0.1,
                                    num_runs=5, seed_base=42, patience=20,
                                    checkpoint_dir=None, min_delta=0.001):
    

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if checkpoint_dir is None:
        checkpoint_dir = "tgat_ablation"
    os.makedirs(checkpoint_dir, exist_ok=True)

    all_test_metrics = []
    all_val_metrics = []

    # Precompute neighbor finder & node features
    train_ngh_finder = build_neighbor_finder(train_loader.dataset)
    node_features = torch.cat([data.x for data in train_loader.dataset], dim=0)
    node_features = node_features.detach().clone().float()  # safe for Parameter

    for run in range(num_runs):
        seed = seed_base + run
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    
        tgat = TGAN(
            ngh_finder=train_ngh_finder,  
            n_feat=node_features.numpy(),
            e_feat=None,
            use_time='time',
            agg_method='attn',
            num_layers=2,
            n_head=4,
            drop_out=0.1
        ).to(device)

        # Modified: Remove +400 from dim since we're not using PI anymore
        classifier = GraphClassifier(dim=tgat.feat_dim, drop=drop_out).to(device)
        optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)
        criterion = torch.nn.BCEWithLogitsLoss()

        print(f"\n--- Run {run+1}/{num_runs} (Seed: {seed}) ---")

        # Prewarm GPU
        dummy_src = torch.arange(10, device=device)
        dummy_time = torch.zeros(10, device=device)
        with torch.no_grad():
            _ = tgat.tem_conv(dummy_src, dummy_time, curr_layers=n_layer)

        # --- Initialize best metrics ---
        best_val_auc = 0
        best_test_metrics = {'acc': 0, 'f1': 0, 'auc': 0}
        epochs_no_improve = 0

        # --- Check for existing checkpoint to resume ---
        checkpoint_path = os.path.join(checkpoint_dir, f'run_{run}_best.pth')
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path, map_location=device)
            best_val_auc = checkpoint.get('best_val_auc', 0)
            best_test_metrics = checkpoint.get('best_test_metrics', {'acc':0,'f1':0,'auc':0})
            print(f"Resuming run {run} from checkpoint: best_val_auc={best_val_auc:.4f}")

        for epoch in range(n_epoch):
            # Training
            total_loss = 0
            tgat.eval()
            classifier.train()
            for batch in train_loader:
                batch = batch.to(device)
                optimizer.zero_grad()

                src_idx = torch.arange(batch.num_nodes, device=device)
                cut_time = torch.zeros(batch.num_nodes, device=device)
                node_emb = tgat.tem_conv(src_idx, cut_time, curr_layers=n_layer)
                graph_emb = global_mean_pool(node_emb, batch.batch)
                
                # Modified: Remove PI vector concatenation
                logits = classifier(graph_emb)
                loss = criterion(logits, batch.y.float())
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * getattr(batch, "num_graphs", batch.y.size(0))

            train_loss = total_loss / len(train_loader.dataset)

            # Validation & Test
            val_acc, val_f1, val_auc = eval_epoch_metrics(tgat, classifier, val_loader, device, n_layer)
            test_acc, test_f1, test_auc = eval_epoch_metrics(tgat, classifier, test_loader, device, n_layer)
            print(f"Epoch {epoch}: Train Loss: {train_loss:.4f} | Val AUC: {val_auc:.4f} | Test AUC: {test_auc:.4f}")

            # --- Save every epoch (for crash recovery) ---
            torch.save({
              'tgat_state_dict': tgat.state_dict(),
              'classifier_state_dict': classifier.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'best_val_auc': best_val_auc,
              'best_test_metrics': best_test_metrics,
              'epoch': epoch
          }, checkpoint_path)

            # --- Save best model ---
            if val_auc > best_val_auc + min_delta:
                best_val_auc = val_auc
                best_test_metrics = {'acc': test_acc, 'f1': test_f1, 'auc': test_auc}
                epochs_no_improve = 0
                torch.save({
                    'classifier_state_dict': classifier.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_val_auc': best_val_auc,
                    'best_test_metrics': best_test_metrics,
                    'epoch': epoch
                }, checkpoint_path)
            else:
                epochs_no_improve += 1

            # Early stopping
            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch} for run {run+1}")
                break

        # Load best model for final evaluation
        checkpoint = torch.load(checkpoint_path, map_location=device)
        best_val_auc = checkpoint['best_val_auc']
        best_test_metrics = checkpoint['best_test_metrics']

        all_val_metrics.append(best_val_auc)
        all_test_metrics.append(best_test_metrics)

        del tgat, classifier, optimizer
        torch.cuda.empty_cache()

    # Aggregate metrics
    test_accs = np.array([m['acc'] for m in all_test_metrics])
    test_f1s = np.array([m['f1'] for m in all_test_metrics])
    test_aucs = np.array([m['auc'] for m in all_test_metrics])

    results = {
        'mean_test_acc': float(test_accs.mean()),
        'std_test_acc': float(test_accs.std()),
        'mean_test_f1': float(test_f1s.mean()),
        'std_test_f1': float(test_f1s.std()),
        'mean_test_auc': float(test_aucs.mean()),
        'std_test_auc': float(test_aucs.std())
    }

    print(f"\n{'='*60}")
    print(f"FINAL SUMMARY OVER {num_runs} RUNS")
    print(f"{'='*60}")
    print(f"Accuracy: {results['mean_test_acc']*100:.2f}% ± {results['std_test_acc']*100:.2f}%")
    print(f"F1 Score: {results['mean_test_f1']*100:.2f}% ± {results['std_test_f1']*100:.2f}%")
    print(f"AUC:      {results['mean_test_auc']*100:.2f}% ± {results['std_test_auc']*100:.2f}%")

    return results

In [None]:

results = run_training_multiple_runs_tgat_ablation(
    train_loader=train_loader_ablation,
    val_loader=val_loader_ablation,
    test_loader=test_loader_ablation,
    n_epoch=100,
    n_layer=2,
    lr=1e-3,
    drop_out=0.1,
    num_runs=10,
    seed_base=42,
    patience=15,
    min_delta=0.001,
    checkpoint_dir=None
)

print(f"Final Results: {results['mean_test_auc']:.4f} ± {results['std_test_auc']:.4f}")