In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
import random
from scipy.stats import norm
import matplotlib.pyplot as plt
import json
import torch.optim as optim
import copy
from torch.nn.functional import binary_cross_entropy
from sklearn.metrics import precision_recall_fscore_support
from torch.optim import lr_scheduler
from scipy import sparse
import networkx as nx
from pathlib import Path
from torch.utils.data import Dataset, DataLoader

SEED = 16

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

set_seed(SEED)

In [18]:
class LazyGraphDataset(Dataset):
    """
    Each graph lives in its own .npz / .pt / whatever on disk.
    __getitem__ loads it just-in-time.
    """

    def __init__(self, folderpaths):
        """
        meta_csv: a CSV (or list of dicts) with columns:
            path_X, path_A_index, path_A_feat, path_L_index, path_L_feat, path_y
        Only these tiny strings stay in RAM.
        """
        self.folderpaths = list(folderpaths)

    def _import_tensor(self, filename: str, dtype: torch.dtype, is_sparse: bool = False):
        """
        Load a .npz CSR matrix and return either
        • a torch.sparse_csr_tensor              (if is_sparse=True)
        • a torch.Tensor (dense)                 (otherwise)
        """
        csr = sparse.load_npz(filename).tocsr()

        if is_sparse:
            crow = torch.from_numpy(csr.indptr.astype(np.int64))
            col  = torch.from_numpy(csr.indices.astype(np.int64))
            val  = torch.from_numpy(csr.data).to(dtype)
            return torch.sparse_csr_tensor(crow, col, val,size=csr.shape, dtype=dtype, requires_grad=False)
        # — otherwise densify —
        return torch.from_numpy(csr.toarray()).to(dtype)

    def __getitem__(self, idx):
        folder_path = self.folderpaths[idx]

        X = self._import_tensor((folder_path/"X.npz"), torch.long, is_sparse=False)
        #A = self._import_tensor(folder_path/"A.npz", torch.long, True)
        Aef = self._import_tensor((folder_path/"E.npz"), torch.float32, is_sparse=True)
        Aei = np.load((folder_path/"edge_index.npy"))
        Lef = self._import_tensor((folder_path/"labels.npz"), torch.float32, is_sparse=True)
        Lei = np.load((folder_path/"label_index.npy"))
        y = np.load((folder_path/"label_value.npy"))

        return X, Aei, Aef, Lei, Lef, y

    def __len__(self):
        return len(self.folderpaths)

def graph_collate(batch):
    # batch is a list of tuples
    xs, aei, aef, lei, lef, ys = zip(*batch)   # tuples of length B

    return (list(xs),                          # list of sparse X
            list(aei),                         # list of edge_index
            list(aef),                         # list of sparse A_edge_feat
            list(lei),
            list(lef),
            list(ys))                          # dense y can still be list/stack

def make_loader(dataset, batch_size=1, shuffle=False):
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=graph_collate, pin_memory=True)

In [32]:
src = Path("../../data/swde_HTMLgraphs/movie/movie")
batchFiles = list(src.rglob("[0-9][0-9][0-9][0-9]"))
dataset = LazyGraphDataset(batchFiles)
dataloader = make_loader(dataset, batch_size=2, shuffle=False)

In [None]:
for Xs, Aeis, Aefs, Leis, Lefs, ys in dataloader:
    print(Xs[0].shape, Aeis[0].shape, Aefs[0].shape, Leis[0].shape, Lefs[0].shape, ys[0].shape)
    break

torch.Size([621, 96]) [[  0   1]
 [  1   0]
 [  1   2]
 ...
 [618 585]
 [619 586]
 [620 587]] tensor(crow_indices=tensor([    0,     9,    18,  ..., 18552, 18559, 18566]),
       col_indices=tensor([ 37,  95, 105,  ..., 192, 195, 196]),
       values=tensor([ 1.0000,  1.0000,  1.0000,  ...,  1.0000, 12.0312,
                       1.0000]), size=(2758, 197), nnz=18566,
       layout=torch.sparse_csr) [[546 509]
 [509 546]
 [516 548]
 [548 516]
 [516 549]
 [549 516]
 [521 470]
 [470 521]
 [538 616]
 [616 538]
 [522 403]
 [403 522]
 [518 588]
 [588 518]
 [523 473]
 [473 523]
 [528 555]
 [555 528]
 [528 556]
 [556 528]
 [528 557]
 [557 528]
 [528 558]
 [558 528]
 [524 589]
 [589 524]
 [530 593]
 [593 530]
 [530 594]
 [594 530]
 [530 595]
 [595 530]
 [530 596]
 [596 530]
 [530 597]
 [597 530]
 [530 598]
 [598 530]
 [530 599]
 [599 530]
 [530 600]
 [600 530]
 [530 601]
 [601 530]
 [530 602]
 [602 530]
 [530 603]
 [603 530]
 [519 398]
 [398 519]
 [542 584]
 [584 542]
 [536 615]
 [615 536]
 [

In [37]:
# Helper function to normalise the A matrix
def symmetric_normalize(A_tilde):
    """
    Performs symmetric normalization of A_tilde (Adj. matrix with self loops):
      A_norm = D^{-1/2} * A_tilde * D^{-1/2}
    Where D_{ii} = sum of row i in A_tilde.

    A_tilde (N, N): Adj. matrix with self loops
    Returns:
      A_norm : (N, N)
    """

    eps = 1e-5
    d = A_tilde.sum(dim=1) + eps
    D_inv = torch.diag(torch.pow(d, -0.5))
    return (D_inv @ A_tilde @ D_inv).to(torch.float32)

In [None]:
# To advance the model, use the methods in https://arxiv.org/pdf/2311.02921

class GraphAttentionNetwork(nn.Module):
    """
    HTML‑graph model

        X  ─╮
            │  GAT( 96 → 64 )
            │  ReLU
            │  GAT( 64 → 32 )
            │  ReLU
            └─ Edge‑feature constructor
                      [h_i ‖ h_j ‖ φ(e_ij)] ─► MLP(69 → 1)

    Parameters
    ----------
    in_dim          : node‑feature size   (= 96)
    edge_in_dim     : raw edge‑feature size (= 197)
    edge_emb_dim    : Edge-feature MLP output dims
    """
    def __init__(self,
                 in_dim: int        = 96,
                 edge_in_dim: int   = 197,
                 edge_emb_dim: int  = 8,
                 hidden1: int       = 64,
                 hidden2: int       = 32,
                 heads:  int        = 1):
        super().__init__()

        # ── Node-level encoder (edge-aware) ────────────────────────────
        self.gat1 = GATv2Conv(in_dim,
                              hidden1,
                              heads=heads,
                              concat=True,
                              edge_dim=edge_emb_dim,
                              fill_value="zeros")

        self.gat2 = GATv2Conv(hidden1 * heads,
                              hidden2,
                              heads=1,
                              concat=True,
                              edge_dim=edge_emb_dim,
                              fill_value="zeros")

        # ── Edge feature projector ────────────── (It is not a linear layer as it does works on a sparse matrix)
        self.W_edge = nn.Parameter(torch.empty(edge_in_dim, edge_emb_dim))
        nn.init.xavier_uniform_(self.W_edge)

        # ── Edge-level MLP decoder (unchanged) ────────────────────────
        self.edge_mlp = nn.Sequential(
            nn.Linear(hidden2 * 2 + edge_emb_dim, hidden2),
            nn.ReLU(),
            nn.Linear(hidden2, 1)
        )

    # ---------------------------------------------------------------------

    def forward(
        self,
        x_dense: torch.Tensor,        # (N_nodes, 96)          sparse
        A_edge_index: torch.Tensor,   # (2, nnz_A)             COO  (from A)
        A_edge_attr: torch.Tensor,    # (nnz_A, 197)           dense / sparse.mm
        E_edge_index: torch.Tensor,   # (2, N_E)               candidates
        E_edge_attr: torch.Tensor     # (N_E, 197)             sparse features
    ):
        # 1) node features
        #x = x_sparse.to_dense()
        A_edge_emb = torch.sparse.mm(A_edge_attr, self.W_edge)     # (nnz_A , 8)

        # 2) edge-aware GATv2 layers
        h = F.relu(self.gat1(x_dense, A_edge_index, A_edge_emb))
        h = F.relu(self.gat2(h, A_edge_index, A_edge_emb))   # (N_nodes , 32)

        # 3) candidate-edge projection  φ(E) = E @ W_edge
        E_edge_emb = torch.sparse.mm(E_edge_attr, self.W_edge)     # (N_E , 8)

        # 4) gather node embeddings and classify
        src, dst = E_edge_index
        z = torch.cat([h[src], h[dst], E_edge_emb], dim=1)      # (N_E , 72)
        return self.edge_mlp(z).squeeze(-1)                   # (N_E ,)

In [39]:
import copy
import torch
import torch.nn.functional as F
from torch import optim
from torch.optim import lr_scheduler
from torch.nn.functional import binary_cross_entropy_with_logits as BCEwLogits

IGNORE_LABEL = -1      # change if you use another sentinel for “no label”

# ---------- 1. One training epoch -------------------------------------------
def train_epoch(
    model,
    dataloader,               # iterable that yields (X, A, E, edge_index, y)
    optimizer,
    criterion=BCEwLogits,
    device="cpu"
):
    model.train()
    total_loss, total_edges = 0.0, 0

    for X, A, E, edge_index, y in dataloader:
        # Move to device ------------------------------------------------------
        X          = X.to(device)
        A          = A.to(device)
        E          = E.to(device)
        edge_index = edge_index.to(device)
        y          = y.to(device)

        mask = (y != IGNORE_LABEL)          # only supervise labelled edges
        if mask.sum() == 0:                 # nothing to learn in this sample
            continue

        optimizer.zero_grad()

        logits = model(X, A, E, edge_index) # (N_edges,)

        loss = criterion(logits[mask], y[mask].float())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss   += loss.item() * mask.sum().item()
        total_edges  += mask.sum().item()

    return total_loss / max(total_edges, 1)   # average over labelled edges


# ---------- 2. Full training loop -------------------------------------------
def train_model(
    model,
    train_loader,                 # edge‑level dataloader
    val_loader,                   # edge‑level dataloader
    num_epochs       = 100,
    lr               = 1e-3,
    validate_every   = 10,
    patience         = 10,
    device           = "cpu"
):
    """Train `model` to predict whether an edge exists."""
    model.to(device)

    optimizer  = optim.AdamW(model.parameters(), lr=lr)
    scheduler  = lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", patience=patience, factor=0.5, verbose=False
    )

    best_val_f1, best_state = 0.0, None
    loss_history = []

    for epoch in range(1, num_epochs + 1):
        loss = train_epoch(model, train_loader, optimizer, device=device)
        loss_history.append(loss)

        # ---- validation -----------------------------------------------------
        if epoch % validate_every == 0 or epoch == num_epochs:
            val_prec, val_rec, val_f1 = evaluate_edge_model(
                model, val_loader, device=device
            )
            scheduler.step(val_f1)

            current_lr = optimizer.param_groups[0]["lr"]
            print(
                f"Epoch {epoch:03d}/{num_epochs}  "
                f"loss={loss:.4f}  "
                f"P={val_prec:.3f}  R={val_rec:.3f}  F1={val_f1:.3f}  "
                f"lr={current_lr:.2e}"
            )

            # keep the best‑F1 checkpoint
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                best_state  = copy.deepcopy(model.state_dict())

            # early stop if LR too small
            if current_lr < 1e-5:
                print("LR below 1e‑5 → stopping.")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    return loss_history, best_state


# ---------- 3. Helper: evaluation (precision / recall / F1) -----------------
@torch.no_grad()
def evaluate_edge_model(model, dataloader, device="cpu", thr=0.5):
    model.eval()
    tp = fp = fn = 0

    for X, A, E, edge_index, y in dataloader:
        X, A, E = X.to(device), A.to(device), E.to(device)
        edge_index, y = edge_index.to(device), y.to(device)

        mask = (y != IGNORE_LABEL)
        if mask.sum() == 0:
            continue

        logits = model(X, A, E, edge_index)
        probs  = torch.sigmoid(logits)

        pred = (probs >= thr).long()
        tp  += ((pred == 1) & (y == 1) & mask).sum().item()
        fp  += ((pred == 1) & (y == 0) & mask).sum().item()
        fn  += ((pred == 0) & (y == 1) & mask).sum().item()

    precision = tp / (tp + fp + 1e-9)
    recall    = tp / (tp + fn + 1e-9)
    f1        = 2 * precision * recall / (precision + recall + 1e-9)
    return precision, recall, f1
