In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
import random
from scipy.stats import norm
import matplotlib.pyplot as plt
import json
import torch.optim as optim
import copy
from torch.nn.functional import binary_cross_entropy
from sklearn.metrics import precision_recall_fscore_support
from torch.optim import lr_scheduler
from scipy import sparse
import networkx as nx
from pathlib import Path
from torch.utils.data import Dataset, DataLoader

SEED = 16

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

set_seed(SEED)

ModuleNotFoundError: No module named 'torch_sparse'

In [None]:
class LazyGraphDataset(Dataset):
    """
    Each graph lives in its own .npz / .pt / whatever on disk.
    __getitem__ loads it just-in-time.
    """

    def __init__(self, folderpaths):
        """
        meta_csv: a CSV (or list of dicts) with columns:
            path_X, path_A_index, path_A_feat, path_L_index, path_L_feat, path_y
        Only these tiny strings stay in RAM.
        """
        self.folderpaths = list(folderpaths)

    def _import_tensor(self, filename: str, dtype: torch.dtype, is_sparse: bool = False):
        """
        Load a .npz CSR matrix and return either
        • a torch.sparse_csr_tensor              (if is_sparse=True)
        • a torch.Tensor (dense)                 (otherwise)
        """
        csr = sparse.load_npz(filename).tocsr()

        if is_sparse:
            crow = torch.from_numpy(csr.indptr.astype(np.int64))
            col  = torch.from_numpy(csr.indices.astype(np.int64))
            val  = torch.from_numpy(csr.data).to(dtype)
            return torch.sparse_csr_tensor(crow, col, val,size=csr.shape, dtype=dtype, requires_grad=False)
        # — otherwise densify —
        return torch.from_numpy(csr.toarray()).to(dtype)

    def __getitem__(self, idx):
        folder_path = self.folderpaths[idx]

        X = self._import_tensor((folder_path/"X.npz"), torch.long, is_sparse=False)
        #A = self._import_tensor(folder_path/"A.npz", torch.long, True)
        Aef = self._import_tensor((folder_path/"E.npz"), torch.float32, is_sparse=True)
        Aei = np.load((folder_path/"edge_index.npy"))
        Lef = self._import_tensor((folder_path/"labels.npz"), torch.float32, is_sparse=True)
        Lei = np.load((folder_path/"label_index.npy"))
        y = np.load((folder_path/"label_value.npy"))

        return X, Aei, Aef, Lei, Lef, y

    def __len__(self):
        return len(self.folderpaths)

def graph_collate(batch):
    # batch is a list of tuples
    xs, aei, aef, lei, lef, ys = zip(*batch)   # tuples of length B

    return (list(xs),                          # list of sparse X
            list(aei),                         # list of edge_index
            list(aef),                         # list of sparse A_edge_feat
            list(lei),
            list(lef),
            list(ys))                          # dense y can still be list/stack

def make_loader(dataset, batch_size=1, shuffle=False):
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=graph_collate, num_workers=0, pin_memory=True)

In [11]:
src = Path("../../data/swde_HTMLgraphs/movie/movie")
batchFiles = list(src.rglob("[0-9][0-9][0-9][0-9]"))
dataset = LazyGraphDataset(batchFiles)
dataloader = make_loader(dataset, batch_size=1, shuffle=False)

In [None]:
for Xs, Aeis, Aefs, Leis, Lefs, ys in dataloader:
    print(Xs[0].shape, Aeis[0].shape, Aefs[0].shape, Leis[0].shape, Lefs[0].shape, ys[0].shape)
    break

In [37]:
# Helper function to normalise the A matrix
def symmetric_normalize(A_tilde):
    """
    Performs symmetric normalization of A_tilde (Adj. matrix with self loops):
      A_norm = D^{-1/2} * A_tilde * D^{-1/2}
    Where D_{ii} = sum of row i in A_tilde.

    A_tilde (N, N): Adj. matrix with self loops
    Returns:
      A_norm : (N, N)
    """

    eps = 1e-5
    d = A_tilde.sum(dim=1) + eps
    D_inv = torch.diag(torch.pow(d, -0.5))
    return (D_inv @ A_tilde @ D_inv).to(torch.float32)

In [None]:
# To advance the model, use the methods in https://arxiv.org/pdf/2311.02921

class GraphAttentionNetwork(nn.Module):
    """
    HTML‑graph model

        X  ─╮
            │  GAT( 96 → 64 )
            │  ReLU
            │  GAT( 64 → 32 )
            │  ReLU
            └─ Edge‑feature constructor
                      [h_i ‖ h_j ‖ φ(e_ij)] ─► MLP(69 → 1)

    Parameters
    ----------
    in_dim          : node‑feature size   (= 96)
    edge_in_dim     : raw edge‑feature size (= 197)
    edge_emb_dim    : Edge-feature MLP output dims
    """
    def __init__(self,
                 in_dim: int        = 96,
                 edge_in_dim: int   = 197,
                 edge_emb_dim: int  = 8,
                 hidden1: int       = 64,
                 hidden2: int       = 32,
                 heads:  int        = 1):
        super().__init__()

        # ── Node-level encoder (edge-aware) ────────────────────────────
        self.gat1 = GATv2Conv(in_dim,
                              hidden1,
                              heads=heads,
                              concat=True,
                              edge_dim=edge_emb_dim,
                              fill_value="zeros")

        self.gat2 = GATv2Conv(hidden1 * heads,
                              hidden2,
                              heads=1,
                              concat=True,
                              edge_dim=edge_emb_dim,
                              fill_value="zeros")

        # ── Edge feature projector ────────────── (It is not a linear layer as it does works on a sparse matrix)
        self.W_edge = nn.Parameter(torch.empty(edge_in_dim, edge_emb_dim))
        nn.init.xavier_uniform_(self.W_edge)

        # ── Edge-level MLP decoder (unchanged) ────────────────────────
        self.edge_mlp = nn.Sequential(
            nn.Linear(hidden2 * 2 + edge_emb_dim, hidden2),
            nn.ReLU(),
            nn.Linear(hidden2, 1)
        )

    # ---------------------------------------------------------------------

    def forward(
        self,
        x_dense: torch.Tensor,        # (N_nodes, 96)          sparse
        A_edge_index: torch.Tensor,   # (2, nnz_A)             COO  (from A)
        A_edge_attr: torch.Tensor,    # (nnz_A, 197)           dense / sparse.mm
        E_edge_index: torch.Tensor,   # (2, N_E)               candidates
        E_edge_attr: torch.Tensor     # (N_E, 197)             sparse features
    ):
        # 1) node features
        #x = x_sparse.to_dense()
        A_edge_emb = torch.sparse.mm(A_edge_attr, self.W_edge)     # (nnz_A , 8)

        # 2) edge-aware GATv2 layers
        h = F.relu(self.gat1(x_dense, A_edge_index, A_edge_emb))
        h = F.relu(self.gat2(h, A_edge_index, A_edge_emb))   # (N_nodes , 32)

        # 3) candidate-edge projection  φ(E) = E @ W_edge
        E_edge_emb = torch.sparse.mm(E_edge_attr, self.W_edge)     # (N_E , 8)

        # 4) gather node embeddings and classify
        src, dst = E_edge_index
        z = torch.cat([h[src], h[dst], E_edge_emb], dim=1)      # (N_E , 72)
        return self.edge_mlp(z).squeeze(-1)                   # (N_E ,) returns the logits

In [None]:
import copy, torch
import torch.nn.functional as F
from torch import optim
from torch.optim import lr_scheduler
from torch.nn.functional import binary_cross_entropy_with_logits as BCEwLogits

CLIP_NORM = 1.0           # gradient clipping


# ---------- one epoch --------------------------------------------------------
def train_epoch(model, loader, optimizer,
                criterion=BCEwLogits, device="cpu"):

    model.train()
    running_loss, running_edges = 0.0, 0

    for X, Aei, Aef, Lei, Lef, y in loader:
        X, Aei, Aef = X.to(device), Aei.to(device), Aef.to(device)
        Lei, Lef, y = Lei.to(device), Lef.to(device), y.to(device)

        optimizer.zero_grad()

        logits = model(X, Aei, Aef, Lei, Lef)          # (N_label,)
        loss   = criterion(logits, y.float())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
        optimizer.step()

        running_loss  += loss.item() * y.numel()
        running_edges += y.numel()

    return running_loss / running_edges


# ---------- evaluation -------------------------------------------------------
@torch.no_grad()
def eval_edge_model(model, loader, device="cpu", thr=0.5):
    model.eval()
    TP = FP = FN = 0

    for X, Aei, Aef, Lei, Lef, y in loader:
        X, Aei, Aef = X.to(device), Aei.to(device), Aef.to(device)
        Lei, Lef, y = Lei.to(device), Lef.to(device), y.to(device)

        logits = model(X, Aei, Aef, Lei, Lef)
        probs  = torch.sigmoid(logits)

        pred = (probs >= thr).long()
        TP  += ((pred == 1) & (y == 1)).sum().item()
        FP  += ((pred == 1) & (y == 0)).sum().item()
        FN  += ((pred == 0) & (y == 1)).sum().item()

    prec = TP / (TP + FP + 1e-9)
    rec  = TP / (TP + FN + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)
    return prec, rec, f1

In [None]:
def train_model(model,
                train_loader,
                val_loader,
                num_epochs     = 100,
                lr             = 1e-3,
                validate_every = 10,
                patience       = 10,
                device         = "cpu"):

    model.to(device)
    opt   = optim.AdamW(model.parameters(), lr=lr)
    sched = lr_scheduler.ReduceLROnPlateau(opt, mode="max",
                                          patience=patience, factor=0.5)

    best_f1, best_state, hist = 0.0, None, []

    for epoch in range(1, num_epochs + 1):
        loss = train_epoch(model, train_loader, opt, device=device)
        hist.append(loss)

        if epoch % validate_every == 0 or epoch == num_epochs:
            p, r, f1 = eval_edge_model(model, val_loader, device=device)
            sched.step(f1)

            lr_now = opt.param_groups[0]["lr"]
            print(f"Epoch {epoch:03d}/{num_epochs} "
                  f"loss={loss:.4f}  P={p:.3f} R={r:.3f} F1={f1:.3f}  lr={lr_now:.2e}")

            if f1 > best_f1:
                best_f1, best_state = f1, copy.deepcopy(model.state_dict())

            if lr_now < 1e-5:
                print("Stop: LR < 1e-5")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    return hist, best_state


In [None]:
model = GraphAttentionNetwork(in_dim = 96, edge_in_dim = 197, edge_emb_dim = 8, hidden1 = 64, hidden2 = 32, heads = 1)


train_model(model,
            train_loader,
            val_loader,
            num_epochs     = 100,
            lr             = 1e-3,
            validate_every = 10,
            patience       = 10,
            device         = "cpu")