In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from scipy.stats import norm
import matplotlib.pyplot as plt
import json
import torch.optim as optim
import copy
from torch.nn.functional import binary_cross_entropy
from sklearn.metrics import precision_recall_fscore_support
from torch.optim import lr_scheduler
from scipy import sparse
import networkx as nx
from pathlib import Path

SEED = 16

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

set_seed(SEED)

In [2]:
def import_tensor(filename, dtype, is_sparse=False):

    csr = sparse.load_npz(filename)

    if is_sparse:
        coo = csr.tocoo()

        # Build a PyTorch sparse tensor
        indices = torch.vstack([
            torch.from_numpy(coo.row.astype(np.int64)),
            torch.from_numpy(coo.col.astype(np.int64))
        ])
        values = torch.from_numpy(coo.data)

        sparse_tensor = torch.sparse_coo_tensor(indices, values, coo.shape, dtype=dtype).coalesce()
        # A_dense  = A_sparse.to_dense()

        return sparse_tensor

    else:
        dense = torch.from_numpy(csr.toarray())

        return dense

def read_data(folder_paths, has_label=True, is_sparse=False):
    graph_data = []
    for folder_path in folder_paths:
        X = import_tensor(folder_path/"X.npz", torch.long, is_sparse)
        A = import_tensor(folder_path/"A.npz", torch.long, is_sparse)
        E = import_tensor(folder_path/"E.npz", torch.float32, is_sparse)
        edge_index = np.load(folder_path/"edge_index.npy")
        if has_label:
            y = import_tensor(folder_path/"labels.npz", torch.long, is_sparse)
        else:
            y = None
        
        graph_data.append((X,A,E,edge_index,y))
    
    return graph_data

In [4]:
src = Path("../../data/swde_HTMLgraphs/movie/movie")
batchFiles = list(src.rglob("[0-9][0-9][0-9][0-9]"))
data = read_data(batchFiles[0:400], True, True)
data[0]

(tensor(indices=tensor([[  0,   0,   1,  ..., 619, 620, 620],
                        [ 37,  95,   9,  ...,  95,   0,  95]]),
        values=tensor([1, 1, 1,  ..., 1, 1, 1]),
        size=(621, 96), nnz=1242, layout=torch.sparse_coo),
 tensor(indices=tensor([[  0,   1,   1,  ..., 618, 619, 620],
                        [  1,   0,   2,  ..., 585, 586, 587]]),
        values=tensor([1, 1, 1,  ..., 1, 1, 1]),
        size=(621, 621), nnz=2758, layout=torch.sparse_coo),
 tensor(indices=tensor([[   0,    0,    0,  ..., 2757, 2757, 2757],
                        [  37,   95,  105,  ...,  192,  195,  196]]),
        values=tensor([ 1.0000,  1.0000,  1.0000,  ...,  1.0000, 12.0312,
                        1.0000]),
        size=(2758, 197), nnz=18566, layout=torch.sparse_coo),
 array([[  0,   1],
        [  1,   0],
        [  1,   2],
        ...,
        [618, 585],
        [619, 586],
        [620, 587]], shape=(2758, 2)),
 tensor(indices=tensor([[257, 257, 257, 257, 257, 257, 257, 257, 257

In [37]:
# Helper function to normalise the A matrix
def symmetric_normalize(A_tilde):
    """
    Performs symmetric normalization of A_tilde (Adj. matrix with self loops):
      A_norm = D^{-1/2} * A_tilde * D^{-1/2}
    Where D_{ii} = sum of row i in A_tilde.

    A_tilde (N, N): Adj. matrix with self loops
    Returns:
      A_norm : (N, N)
    """

    eps = 1e-5
    d = A_tilde.sum(dim=1) + eps
    D_inv = torch.diag(torch.pow(d, -0.5))
    return (D_inv @ A_tilde @ D_inv).to(torch.float32)

In [None]:
class AttentionGCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim, use_nonlinearity=True):
        super(AttentionGCNLayer, self).__init__()
        self.use_nonlinearity = use_nonlinearity

        # Learnable parameters
        self.B_k = nn.Parameter(torch.randn(output_dim))
        self.omega_k = nn.Parameter(torch.randn(input_dim, output_dim))
        self.phi_k = nn.Linear(output_dim, output_dim, bias=False)
        
    def forward(self, H_k, A):
        num_nodes = H_k.shape[0]
        
        # Step 1: Compute H'_k
        B_k_expanded = self.B_k.unsqueeze(0).expand(num_nodes, -1)  # Expand to (num_nodes, output_dim)
        H_prime = B_k_expanded + H_k @ self.omega_k  # (num_nodes, output_dim)
        
        # Step 2: Compute attention-based similarity matrix S
        H_transformed = self.phi_k(H_prime)  # (num_nodes, output_dim)
        S = torch.matmul(H_transformed, H_transformed.T)  # (num_nodes, num_nodes)
        S = torch.sigmoid(S)  # Normalize  # Normalize with a non-linearity
        
        # Step 3: Soft-masked aggregation
        A_hat = A + torch.eye(num_nodes, device=A.device)  # Add self-loops
        S_masked = S.masked_fill(A_hat == 0, float('-inf'))  # Apply masking correctly
        S_normalized = F.softmax(S_masked, dim=0)  # Apply softmax column-wise
        H_k_next = torch.matmul(S_normalized, H_prime)
        
        return F.relu(H_k_next) if self.use_nonlinearity else H_k_next
    
# class FCLayer(nn.Module):
#     def __init__(self, input_dim, output_dim, use_nonlinearity=True):
# #        super().__init__(input_dim, output_dim)
#         super(FCLayer, self).__init__()
#         self.use_nonlinearity = use_nonlinearity

#         self.b = nn.Parameter(torch.randn(output_dim))
#         self.W = nn.Parameter(torch.randn(input_dim, output_dim))

#     def forward(self, X):
#         y = F.linear(X, self.W, self.b)
#         if self.use_nl:
#             y = F.leaky_relu(y)
#         return y

class GraphAttentionNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GraphAttentionNetwork, self).__init__()
        self.gcn1 = AttentionGCNLayer(input_dim, hidden_dim*2, False)
        self.gcn2 = AttentionGCNLayer(hidden_dim*2, hidden_dim, False)
        self.linear = nn.Linear(hidden_dim, 1)
    
    def forward(self, A, X, **kwargs):
        H1 = self.gcn1(X, A)
        H2 = self.gcn2(H1, A)
        edge_classification = self.linear(H2)
        output = torch.sigmoid(edge_classification)
        
        if torch.isnan(output).any():
            output = torch.where(torch.isnan(output), torch.zeros_like(output), output)
        
        if kwargs.get("return_embeddings", None):
            return output, H1, H2
        else:
            return output

In [39]:
import copy
import torch
import torch.nn.functional as F
from torch import optim
from torch.optim import lr_scheduler
from torch.nn.functional import binary_cross_entropy_with_logits as BCEwLogits

IGNORE_LABEL = -1      # change if you use another sentinel for “no label”

# ---------- 1. One training epoch -------------------------------------------
def train_epoch(
    model,
    dataloader,               # iterable that yields (X, A, E, edge_index, y)
    optimizer,
    criterion=BCEwLogits,
    device="cpu"
):
    model.train()
    total_loss, total_edges = 0.0, 0

    for X, A, E, edge_index, y in dataloader:
        # Move to device ------------------------------------------------------
        X          = X.to(device)
        A          = A.to(device)
        E          = E.to(device)
        edge_index = edge_index.to(device)
        y          = y.to(device)

        mask = (y != IGNORE_LABEL)          # only supervise labelled edges
        if mask.sum() == 0:                 # nothing to learn in this sample
            continue

        optimizer.zero_grad()

        logits = model(X, A, E, edge_index) # (N_edges,)

        loss = criterion(logits[mask], y[mask].float())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss   += loss.item() * mask.sum().item()
        total_edges  += mask.sum().item()

    return total_loss / max(total_edges, 1)   # average over labelled edges


# ---------- 2. Full training loop -------------------------------------------
def train_model(
    model,
    train_loader,                 # edge‑level dataloader
    val_loader,                   # edge‑level dataloader
    num_epochs       = 100,
    lr               = 1e-3,
    validate_every   = 10,
    patience         = 10,
    device           = "cpu"
):
    """Train `model` to predict whether an edge exists."""
    model.to(device)

    optimizer  = optim.AdamW(model.parameters(), lr=lr)
    scheduler  = lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", patience=patience, factor=0.5, verbose=False
    )

    best_val_f1, best_state = 0.0, None
    loss_history = []

    for epoch in range(1, num_epochs + 1):
        loss = train_epoch(model, train_loader, optimizer, device=device)
        loss_history.append(loss)

        # ---- validation -----------------------------------------------------
        if epoch % validate_every == 0 or epoch == num_epochs:
            val_prec, val_rec, val_f1 = evaluate_edge_model(
                model, val_loader, device=device
            )
            scheduler.step(val_f1)

            current_lr = optimizer.param_groups[0]["lr"]
            print(
                f"Epoch {epoch:03d}/{num_epochs}  "
                f"loss={loss:.4f}  "
                f"P={val_prec:.3f}  R={val_rec:.3f}  F1={val_f1:.3f}  "
                f"lr={current_lr:.2e}"
            )

            # keep the best‑F1 checkpoint
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                best_state  = copy.deepcopy(model.state_dict())

            # early stop if LR too small
            if current_lr < 1e-5:
                print("LR below 1e‑5 → stopping.")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    return loss_history, best_state


# ---------- 3. Helper: evaluation (precision / recall / F1) -----------------
@torch.no_grad()
def evaluate_edge_model(model, dataloader, device="cpu", thr=0.5):
    model.eval()
    tp = fp = fn = 0

    for X, A, E, edge_index, y in dataloader:
        X, A, E = X.to(device), A.to(device), E.to(device)
        edge_index, y = edge_index.to(device), y.to(device)

        mask = (y != IGNORE_LABEL)
        if mask.sum() == 0:
            continue

        logits = model(X, A, E, edge_index)
        probs  = torch.sigmoid(logits)

        pred = (probs >= thr).long()
        tp  += ((pred == 1) & (y == 1) & mask).sum().item()
        fp  += ((pred == 1) & (y == 0) & mask).sum().item()
        fn  += ((pred == 0) & (y == 1) & mask).sum().item()

    precision = tp / (tp + fp + 1e-9)
    recall    = tp / (tp + fn + 1e-9)
    f1        = 2 * precision * recall / (precision + recall + 1e-9)
    return precision, recall, f1
