In [25]:
import io, tarfile, math
from pathlib import Path
from typing import Any, Dict, Iterator

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader, Subset
from scipy import sparse
from torch_geometric.data import Data
from torch_geometric.nn import TransformerConv
from sklearn.metrics import average_precision_score

datafile = "/vol/bitbucket/mjh24/IAEA-thesis/data/swde_HTMLgraphsNEWFEATURES.tar"

In [26]:
class TarGraphDataset(Dataset):
    """
    Each graph is stored under its own sub-directory *inside* one .tar:

        graphs.tar
        ├── 0001/X.npz
        ├── 0001/E.npz
        ├── 0001/edge_index.npy
        ├── 0001/labels.npz
        ├── 0001/label_index.npy
        ├── 0001/label_value.npy
        ├── 0002/…
        └── …

    The tar is opened once; __getitem__ streams the six files for graph *idx*
    straight into memory, converts them to native PyTorch tensors and returns.
    """

    def __init__(self, tar_path: str | Path):
        self.tar = tarfile.open(tar_path, mode="r:*")      # gzip/none/…
        self.index: dict[str, dict[str, tarfile.TarInfo]] = {}
        self.sublen = {}

        # Build a small lookup table in RAM  {gid: {filename: tarinfo}}
        for member in self.tar.getmembers():
            if not member.isfile():
                continue

            p     = Path(member.name)
            gid   = str(p.parent)   # '0007'
            fname = p.name          # 'X.npz'

            # keep only folders that really are 4-digit graph IDs
            if gid[-4:].isdigit():
                self.index.setdefault(gid, {})[fname] = member

        self.gids = sorted(self.index)

        # Remove those with no labels
        for gid, files in self.index.items():
            if not files.get("labels.npz"):
                self.gids.remove(gid)

        # Count
        name, counts = np.unique([Path(gid).parent.name for gid in self.gids], return_counts=True)

        # Get cumsum
        running = 0
        for lbl, cnt in zip(name, counts):
            self.sublen[lbl] = (running, running + cnt)
            running += cnt

    # ------------- helpers --------------------------------------------------
    @staticmethod
    def _npz_to_csr(buf: bytes, dtype=torch.float32):
        csr = sparse.load_npz(io.BytesIO(buf)).tocsr()
        crow = torch.from_numpy(csr.indptr.astype(np.int64))
        col  = torch.from_numpy(csr.indices.astype(np.int64))
        val  = torch.from_numpy(csr.data).to(dtype)
        return torch.sparse_csr_tensor(
            crow, col, val, size=csr.shape, dtype=dtype, requires_grad=False
        )

    @staticmethod
    def _npy_to_tensor(buf: bytes, dtype):
        arr = np.load(io.BytesIO(buf), allow_pickle=False)
        return torch.from_numpy(arr).to(dtype)

    def get_sublen(self, name):
        return self.sublen[name]

    # ------------- Dataset API ---------------------------------------------
    def __len__(self):
        return len(self.gids)

    def __getitem__(self, idx):
        gid   = self.gids[idx]
        files = self.index[gid]

        get = lambda name: self.tar.extractfile(files[name]).read()
        
        fileinfo = gid

        X   = self._npz_to_csr(get("X.npz"),       dtype=torch.float32)
        Aef = self._npz_to_csr(get("E.npz"),       dtype=torch.float32)
        Lef = self._npz_to_csr(get("labels.npz"),  dtype=torch.float32)

        Aei = self._npy_to_tensor(get("edge_index.npy"),  dtype=torch.int64)
        Lei = self._npy_to_tensor(get("label_index.npy"), dtype=torch.int64)
        y   = self._npy_to_tensor(get("label_value.npy"), dtype=torch.int64)

        return fileinfo, X, Aei.t().contiguous(), Aef, Lei.t().contiguous(), Lef, y


def concat_csr(blocks):
    """
    Vertically stack CSR matrices that all share the same n_cols.
    Keeps sparsity and returns a single torch.sparse_csr_tensor.
    """
    crow_bufs, col_bufs, val_bufs = [], [], []
    nnz_so_far, n_rows, n_cols = 0, 0, blocks[0].size(1)

    for k, csr in enumerate(blocks):
        crow = csr.crow_indices().clone()          # (n_rows_k + 1,)

        # 1) shift by *cumulative* nnz so far
        crow += nnz_so_far

        # 2) drop the leading 0 for every block after the first
        if k > 0:
            crow = crow[1:]

        crow_bufs.append(crow)
        col_bufs.append(csr.col_indices())
        val_bufs.append(csr.values())

        nnz_so_far += csr.values().numel()
        n_rows     += csr.size(0)

    crow_cat = torch.cat(crow_bufs)
    col_cat  = torch.cat(col_bufs)
    val_cat  = torch.cat(val_bufs)

    return torch.sparse_csr_tensor(
        crow_cat, col_cat, val_cat,
        size=(n_rows, n_cols),
        dtype=val_cat.dtype,
        device=val_cat.device,
        requires_grad=False
    )


def sparse_graph_collate(batch):
    # unpack each graph
    filenames, xs, aei, aef, lei, lef, ys = zip(*batch)

    # node-count prefix sum for shifting
    node_offsets = torch.cumsum(
        torch.tensor([0] + [x.size(0) for x in xs[:-1]]), 0)

    # ----- merge node features (CSR) -----------------------------
    X_batch = concat_csr(xs)

    # ----- merge structural edges --------------------------------
    Aei_shifted = []
    for off, ei in zip(node_offsets, aei):
        Aei_shifted.append(ei + off)   # shift both rows
    Aei_batch = torch.cat(Aei_shifted, dim=1)   # (2 , E_tot)

    Aef_batch = concat_csr(aef)

    # ----- merge label edges -------------------------------------
    Lei_shifted = []
    for off, ei in zip(node_offsets, lei):
        Lei_shifted.append(ei + off)
    Lei_batch = torch.cat(Lei_shifted, dim=1)

    Lef_batch = concat_csr(lef)
    y_batch   = torch.cat(ys)

    return filenames, X_batch, Aei_batch, Aef_batch, Lei_batch, Lef_batch, y_batch


def debug_collate(batch):
    _, xs, aei, aef, lei, lef, ys = zip(*batch)
    print("--- one mini-batch ---")
    for i, X in enumerate(xs):
        print(f"graph {i}:  nodes={X.size(0):4d}   "
              f"struct-edges={aei[i].shape[1]:4d}   "
              f"label-edges={lei[i].shape[1]:3d}")
    # then call the real collate to keep training code unchanged
    return sparse_graph_collate(batch)

# ───────────────────────────────────────────────────────── loader utilities
def identity_collate(batch):
    """batch == list of length 1 → return that single sample untouched."""
    return batch[0]

def make_loader(ds, batch_size=1, shuffle=False):
    return DataLoader(ds,
                      batch_size=batch_size,
                      shuffle=shuffle,
                      collate_fn=sparse_graph_collate,
                      num_workers=0,
                      pin_memory=True)    # fast GPU transfer

In [27]:
def build_e2n_minibatch_from_collate(
    X_dense: torch.Tensor,           # (N, Dx) float32
    Aei: torch.Tensor,               # (2, EA) int64
    Aef_dense: torch.Tensor,         # (EA, De) float32
    Lei_chunk: torch.Tensor,         # (2, B) int64
    Lef_chunk: torch.Tensor,         # (B, De) float32
    y_chunk: torch.Tensor            # (B,) int64/float
) -> Data:
    """
    Returns a homogeneous PyG Data where:
      - DOM nodes are rows [0..N-1]
      - Each selected candidate edge becomes a dummy node appended after DOM nodes
      - Structural edges are preserved
      - We add edges t<->u and t<->v for each dummy node t
      - Loss is computed only on dummy nodes (we return y_dummy)
    """
    N, Dx = X_dense.size(0), X_dense.size(1)
    De    = Lef_chunk.size(1)
    B     = Lef_chunk.size(0)
    if B == 0:
        return None

    # Placeholder x to satisfy PyG; model projects per-type
    x_all = torch.zeros((N + B, max(Dx, De)), dtype=X_dense.dtype, device=X_dense.device)
    x_all[:N, :Dx] = X_dense
    x_all[N:, :De] = Lef_chunk

    # Structural edges (+ attrs)
    A_ei = Aei.to(torch.long)
    A_ea = Aef_dense

    # u/v endpoints for the selected candidate edges
    u = Lei_chunk[0].to(device=device, dtype=torch.long)
    v = Lei_chunk[1].to(device=device, dtype=torch.long)

    # Dummy node indices
    t = torch.arange(N, N + B, dtype=torch.long, device=X_dense.device)

    # Connect dummy to endpoints (both directions)
    e1 = torch.stack([t, u], dim=0)
    e2 = torch.stack([u, t], dim=0)
    e3 = torch.stack([t, v], dim=0)
    e4 = torch.stack([v, t], dim=0)
    E_ei = torch.cat([e1, e2, e3, e4], dim=1)    # (2, 4B)

    # Edge attrs for these links: repeat the candidate feature 4x
    E_ea = Lef_chunk.repeat_interleave(4, dim=0)   # (4B, De)

    # Concatenate structural + dummy edges
    edge_index_all = torch.cat([A_ei, E_ei], dim=1)                 # (2, EA + 4B)
    edge_attr_all  = torch.cat([A_ea, E_ea], dim=0)                 # (EA + 4B, De)

    # Mask & labels for dummy nodes
    mask_dummy = torch.zeros(N + B, dtype=torch.bool, device=X_dense.device)
    mask_dummy[N:] = True
    y_dummy = y_chunk.to(device=X_dense.device, dtype=torch.float32)

    return Data(
        x=x_all,
        edge_index=edge_index_all,
        edge_attr=edge_attr_all,
        mask_dummy=mask_dummy,
        y_dummy=y_dummy,
        N_dom=N, B_dummy=B, Dx=Dx, De=De
    )

In [28]:
class Edge2NodeTransformer(nn.Module):
    def __init__(self,
                 node_in_dim: int,
                 edge_in_dim: int,
                 edge_emb_dim: int = 64,    # project De → edge_emb_dim
                 hidden1: int = 128,
                 hidden2: int = 64,
                 heads: int = 4,
                 dropout: float = 0.2):
        super().__init__()
        self.node_in  = nn.Linear(node_in_dim,  hidden1 * heads)
        self.dummy_in = nn.Linear(edge_in_dim,  hidden1 * heads)
        self.edge_proj = nn.Linear(edge_in_dim, edge_emb_dim, bias=False)

        self.tr1 = TransformerConv(
            in_channels=hidden1 * heads,
            out_channels=hidden1,
            heads=heads,
            edge_dim=edge_emb_dim,
            dropout=dropout,
            beta=True
        )
        self.ln1 = nn.LayerNorm(hidden1 * heads)

        self.tr2 = TransformerConv(
            in_channels=hidden1 * heads,
            out_channels=hidden2,
            heads=1,
            edge_dim=edge_emb_dim,
            dropout=dropout,
            beta=True
        )
        self.ln2 = nn.LayerNorm(hidden2)

        self.cls = nn.Linear(hidden2, 1)

        # init β-gate ~0.5 so messages aren't suppressed at start
        for tr in (self.tr1, self.tr2):
            if getattr(tr, "lin_beta", None) is not None:
                nn.init.zeros_(tr.lin_beta.weight)
                if tr.lin_beta.bias is not None:
                    nn.init.zeros_(tr.lin_beta.bias)

    def forward(self, data: Data):
        x_all = data.x                     # (N+B, max(Dx,De))
        N, B, Dx, De = data.N_dom, data.B_dummy, data.Dx, data.De

        # Type-specific input adapters
        h = torch.zeros((N + B, self.node_in.out_features),
                        dtype=x_all.dtype, device=x_all.device)
        h[:N] = self.node_in(x_all[:N, :Dx])   # DOM
        h[N:] = self.dummy_in(x_all[N:, :De])  # dummy (edge-as-node)

        # Project edge attrs once
        eattr = self.edge_proj(data.edge_attr)

        # Two TransformerConv layers
        h = F.relu(self.tr1(h, data.edge_index, eattr)); h = self.ln1(h)
        h = F.relu(self.tr2(h, data.edge_index, eattr)); h = self.ln2(h)

        logits_all = self.cls(h).squeeze(-1)
        return logits_all[N:]   # only dummy nodes (candidate-edge nodes)

In [29]:
def build_e2n_minibatch_from_collate(
    X_dense: torch.Tensor,           # (N, Dx) float32 on device
    Aei: torch.Tensor,               # (2, EA) int64 on device
    Aef_dense: torch.Tensor,         # (EA, De) float32 on device
    Lei_chunk: torch.Tensor,         # (2, B) int64 (CPU or device)
    Lef_chunk: torch.Tensor,         # (B, De) float32 on device
    y_chunk: torch.Tensor            # (B,) int/float (CPU or device)
) -> Data:
    N, Dx = X_dense.size(0), X_dense.size(1)
    De    = Lef_chunk.size(1)
    B     = Lef_chunk.size(0)
    if B == 0:
        return None

    device = X_dense.device

    # Placeholder x; model will project per type
    x_all = torch.zeros((N + B, max(Dx, De)), dtype=X_dense.dtype, device=device)
    x_all[:N, :Dx] = X_dense
    x_all[N:, :De] = Lef_chunk

    # u/v endpoints → move to device
    u = Lei_chunk[0].to(device=device, dtype=torch.long)
    v = Lei_chunk[1].to(device=device, dtype=torch.long)

    # Dummy node indices
    t = torch.arange(N, N + B, dtype=torch.long, device=device)

    # Connect dummy ↔ endpoints (both directions)
    e1 = torch.stack([t, u], dim=0)
    e2 = torch.stack([u, t], dim=0)
    e3 = torch.stack([t, v], dim=0)
    e4 = torch.stack([v, t], dim=0)
    E_ei = torch.cat([e1, e2, e3, e4], dim=1)          # (2, 4B)

    # Edge attrs for these links: repeat candidate feature 4×
    E_ea = Lef_chunk.repeat_interleave(4, dim=0)       # (4B, De)

    # Concatenate structural + dummy edges
    edge_index_all = torch.cat([Aei, E_ei], dim=1)     # (2, EA+4B)
    edge_attr_all  = torch.cat([Aef_dense, E_ea], dim=0)

    # Dummy mask & labels
    mask_dummy = torch.zeros(N + B, dtype=torch.bool, device=device)
    mask_dummy[N:] = True
    y_dummy = y_chunk.to(device=device, dtype=torch.float32)

    return Data(
        x=x_all,
        edge_index=edge_index_all,
        edge_attr=edge_attr_all,
        mask_dummy=mask_dummy,
        y_dummy=y_dummy,
        N_dom=N, B_dummy=B, Dx=Dx, De=De
    )

def iterate_e2n_minibatches_from_loader_batch(batch, device, edge_bs=20000):
    # Your collate output:
    # (filenames, X_csr, Aei, Aef_csr, Lei, Lef_csr, y)
    _, X_csr, Aei, Aef_csr, Lei, Lef_csr, y = batch

    # Densify once per collated batch
    X_dense   = X_csr.to_dense().to(device, non_blocking=True)     # (N, Dx)
    Aei       = Aei.to(device, non_blocking=True)                  # (2, EA)
    Aef_dense = Aef_csr.to_dense().to(device, non_blocking=True)   # (EA, De)
    Lef_dense = Lef_csr.to_dense()                                 # (M, De) CPU ok
    # Lei, y stay on CPU; we’ll move slices as needed

    M = Lei.size(1)
    for s in range(0, M, edge_bs):
        e = min(s + edge_bs, M)
        Lei_chunk = Lei[:, s:e]                                    # CPU ok
        Lef_chunk = Lef_dense[s:e, :].to(device, non_blocking=True)
        y_chunk   = y[s:e]                                          # CPU ok
        data = build_e2n_minibatch_from_collate(
            X_dense, Aei, Aef_dense, Lei_chunk, Lef_chunk, y_chunk
        )
        if data is None:
            continue
        yield data  # already on device

def train_one_epoch(loader, model, optimizer, device,
                    edge_bs=20000, clip_norm=1.0, amp=True):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=amp)
    bce = nn.BCEWithLogitsLoss()

    total_loss, total_B = 0.0, 0
    count=0
    for batch in loader:
        count+=1
        bs = edge_bs
        while True:
            try:
                for data in iterate_e2n_minibatches_from_loader_batch(batch, device, edge_bs=bs):
                    optimizer.zero_grad(set_to_none=True)
                    with torch.cuda.amp.autocast(enabled=amp):
                        logits = model(data)
                        loss = bce(logits, data.y_dummy)
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    if clip_norm:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
                    scaler.step(optimizer); scaler.update()
                    total_loss += float(loss.detach()) * int(data.y_dummy.numel())
                    total_B    += int(data.y_dummy.numel())
                    if count % 500 == 0:
                        print(f"Epoch {count}/{len(loader)} loss: {loss}")
                break
            except torch.cuda.OutOfMemoryError:
                torch.cuda.empty_cache()
                bs = max(bs // 2, 2048)
                if bs <= 2048:
                    raise
    return total_loss / max(total_B, 1)

@torch.no_grad()
def eval_epoch(loader, model, device, edge_bs=20000):
    model.eval()
    bce = nn.BCEWithLogitsLoss(reduction='sum')

    total_loss, total_B = 0.0, 0
    all_p, all_y = [], []

    for batch in loader:
        for data in iterate_e2n_minibatches_from_loader_batch(batch, device, edge_bs=edge_bs):
            logits = model(data)
            loss = bce(logits, data.y_dummy)
            total_loss += float(loss.cpu())
            total_B    += int(data.y_dummy.numel())
            all_p.append(torch.sigmoid(logits).cpu())
            all_y.append(data.y_dummy.cpu())

    if total_B == 0:
        return {"loss": 0.0, "aupr": 0.0}
    p = torch.cat(all_p).numpy()
    y = torch.cat(all_y).numpy()
    from sklearn.metrics import average_precision_score
    aupr = float(average_precision_score(y, p))
    return {"loss": total_loss / total_B, "aupr": aupr}


In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DATAFILE_TRAIN = "/vol/bitbucket/mjh24/IAEA-thesis/data/swde_HTMLgraphsNEWFEATURES.tar"
dataset = TarGraphDataset(DATAFILE_TRAIN)
N = len(dataset)

matchcollege_start, matchcollege_end = dataset.get_sublen('university-matchcollege(2000)')
allmovie_start, allmovie_end = dataset.get_sublen('movie-allmovie(2000)')
imdb_start, imdb_end = dataset.get_sublen('movie-imdb(2000)')
usatoday_start, usatoday_end = dataset.get_sublen('nbaplayer-usatoday(436)')
yahoo_start, yahoo_end = dataset.get_sublen('nbaplayer-yahoo(438)')
matchcollege_idx = list(range(matchcollege_start, matchcollege_end))
allmovie_idx = list(range(allmovie_start, allmovie_end))
imdb_idx = list(range(imdb_start, imdb_end))
usatoday_idx=list(range(usatoday_start, usatoday_end))
yahoo_idx=list(range(yahoo_start, yahoo_end))

val_idx = list(set(allmovie_idx))#list(set(matchcollege_idx[-10:])) + list(set(allmovie_idx[-10:]))#
train_idx = list(set(range(N)) - set(val_idx) - set(usatoday_idx) - set(yahoo_idx))#list(set(matchcollege_idx + allmovie_idx) - set(val_idx))#
train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, val_idx)

train_loader = make_loader(train_ds, batch_size=32, shuffle=True)
val_loader   = make_loader(val_ds,   batch_size=32, shuffle=False)

model = Edge2NodeTransformer(
    node_in_dim=114, edge_in_dim=200,
    edge_emb_dim=64, hidden1=128, hidden2=64, heads=4, dropout=0.2
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(1, 6):
    tl = train_one_epoch(train_loader, model, optimizer, device,
                            edge_bs=20000, clip_norm=1.0, amp=True)
    metrics = eval_epoch(val_loader, model, device, edge_bs=20000)
    print(f"epoch {epoch:02d} | train {tl:.4f} | val {metrics['loss']:.4f} | AUPR {metrics['aupr']:.4f}")

  scaler = torch.cuda.amp.GradScaler(enabled=amp)
  with torch.cuda.amp.autocast(enabled=amp):


Epoch 500/774 loss: 0.05423086881637573
epoch 01 | train 0.1460 | val 2.0835 | AUPR 0.6435


  scaler = torch.cuda.amp.GradScaler(enabled=amp)
  with torch.cuda.amp.autocast(enabled=amp):


Epoch 500/774 loss: 0.025969387963414192
epoch 02 | train 0.0344 | val 2.3053 | AUPR 0.6624


  scaler = torch.cuda.amp.GradScaler(enabled=amp)
  with torch.cuda.amp.autocast(enabled=amp):


Epoch 500/774 loss: 0.014316810294985771
epoch 03 | train 0.0238 | val 2.1365 | AUPR 0.7099


  scaler = torch.cuda.amp.GradScaler(enabled=amp)
  with torch.cuda.amp.autocast(enabled=amp):


Epoch 500/774 loss: 0.023809099569916725
epoch 04 | train 0.0214 | val 2.6550 | AUPR 0.6658


  scaler = torch.cuda.amp.GradScaler(enabled=amp)
  with torch.cuda.amp.autocast(enabled=amp):


Epoch 500/774 loss: 0.017748890444636345
epoch 05 | train 0.0191 | val 2.2922 | AUPR 0.7027


In [31]:
torch.save(model.state_dict(), "firstAttempt.pt")