In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from scipy.stats import norm
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.nn.functional import binary_cross_entropy
from sklearn.metrics import precision_recall_fscore_support
from torch.optim import lr_scheduler
from scipy import sparse
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import io, tarfile, os
import copy
from torch.nn.functional import binary_cross_entropy_with_logits as BCEwLogits
from GATModel import GraphAttentionNetwork
if torch.cuda.is_available():
    torch.cuda.current_device()
%env CUDA_LAUNCH_BLOCKING=1

datafile = "../../data/swde_HTMLgraphsNEWFEATURES.tar"

plt.ion()

SEED = 0

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

set_seed(SEED)

  from .autonotebook import tqdm as notebook_tqdm


env: CUDA_LAUNCH_BLOCKING=1


***BELOW***
If data-loading < 5-10 % of total epoch time with num_workers=0, stick with the simple path.
Otherwise, parallel loading with share-friendly torch_sparse.SparseTensor
almost always pays off.

In [2]:
# ───────────────────────────────────────────────────────── Tar-reader dataset
class TarGraphDataset(Dataset):
    """
    Each graph is stored under its own sub-directory *inside* one .tar:

        graphs.tar
        ├── 0001/X.npz
        ├── 0001/E.npz
        ├── 0001/edge_index.npy
        ├── 0001/labels.npz
        ├── 0001/label_index.npy
        ├── 0001/label_value.npy
        ├── 0002/…
        └── …

    The tar is opened once; __getitem__ streams the six files for graph *idx*
    straight into memory, converts them to native PyTorch tensors and returns.
    """

    def __init__(self, tar_path: str | Path):
        self.tar = tarfile.open(tar_path, mode="r:*")      # gzip/none/…
        self.index: dict[str, dict[str, tarfile.TarInfo]] = {}
        self.sublen = {}

        # Build a small lookup table in RAM  {gid: {filename: tarinfo}}
        for member in self.tar.getmembers():
            if not member.isfile():
                continue

            p     = Path(member.name)
            gid   = str(p.parent)   # '0007'
            fname = p.name          # 'X.npz'

            # keep only folders that really are 4-digit graph IDs
            if gid[-4:].isdigit():
                self.index.setdefault(gid, {})[fname] = member

        self.gids = sorted(self.index)

        # Remove thos with no labels
        for gid, files in self.index.items():
            if not files.get("labels.npz"):
                self.gids.remove(gid)

        # Count
        name, counts = np.unique([Path(gid).parent.name for gid in self.gids], return_counts=True)

        # Get cumsum
        running = 0
        for lbl, cnt in zip(name, counts):
            self.sublen[lbl] = (running, running + cnt)
            running += cnt

    # ------------- helpers --------------------------------------------------
    @staticmethod
    def _npz_to_csr(buf: bytes, dtype=torch.float32):
        csr = sparse.load_npz(io.BytesIO(buf)).tocsr()
        crow = torch.from_numpy(csr.indptr.astype(np.int64))
        col  = torch.from_numpy(csr.indices.astype(np.int64))
        val  = torch.from_numpy(csr.data).to(dtype)
        return torch.sparse_csr_tensor(
            crow, col, val, size=csr.shape, dtype=dtype, requires_grad=False
        )

    @staticmethod
    def _npy_to_tensor(buf: bytes, dtype):
        arr = np.load(io.BytesIO(buf), allow_pickle=False)
        return torch.from_numpy(arr).to(dtype)

    def get_sublen(self, name):
        return self.sublen[name]

    # ------------- Dataset API ---------------------------------------------
    def __len__(self):
        return len(self.gids)

    def __getitem__(self, idx):
        gid   = self.gids[idx]
        files = self.index[gid]

        get = lambda name: self.tar.extractfile(files[name]).read()
        
        fileinfo = gid

        X   = self._npz_to_csr(get("X.npz"),       dtype=torch.float32)
        Aef = self._npz_to_csr(get("E.npz"),       dtype=torch.float32)
        Lef = self._npz_to_csr(get("labels.npz"),  dtype=torch.float32)

        Aei = self._npy_to_tensor(get("edge_index.npy"),  dtype=torch.int64)
        Lei = self._npy_to_tensor(get("label_index.npy"), dtype=torch.int64)
        y   = self._npy_to_tensor(get("label_value.npy"), dtype=torch.int64)

        return fileinfo, X, Aei.t().contiguous(), Aef, Lei.t().contiguous(), Lef, y


In [3]:
def concat_csr(blocks):
    """
    Vertically stack CSR matrices that all share the same n_cols.
    Keeps sparsity and returns a single torch.sparse_csr_tensor.
    """
    crow_bufs, col_bufs, val_bufs = [], [], []
    nnz_so_far, n_rows, n_cols = 0, 0, blocks[0].size(1)

    for k, csr in enumerate(blocks):
        crow = csr.crow_indices().clone()          # (n_rows_k + 1,)

        # 1) shift by *cumulative* nnz so far
        crow += nnz_so_far

        # 2) drop the leading 0 for every block after the first
        if k > 0:
            crow = crow[1:]

        crow_bufs.append(crow)
        col_bufs.append(csr.col_indices())
        val_bufs.append(csr.values())

        nnz_so_far += csr.values().numel()
        n_rows     += csr.size(0)

    crow_cat = torch.cat(crow_bufs)
    col_cat  = torch.cat(col_bufs)
    val_cat  = torch.cat(val_bufs)

    return torch.sparse_csr_tensor(
        crow_cat, col_cat, val_cat,
        size=(n_rows, n_cols),
        dtype=val_cat.dtype,
        device=val_cat.device,
        requires_grad=False
    )


def sparse_graph_collate(batch):
    # unpack each graph
    filenames, xs, aei, aef, lei, lef, ys = zip(*batch)

    # node-count prefix sum for shifting
    node_counts = [x.size(0) for x in xs]
    node_offsets = torch.cumsum(
        torch.tensor([0] + [x.size(0) for x in xs[:-1]]), 0)

    # ----- merge node features (CSR) -----------------------------
    X_batch = concat_csr(xs)

    # ----- merge structural edges --------------------------------
    Aei_shifted = []
    for off, ei in zip(node_offsets, aei):
        Aei_shifted.append(ei + off)   # shift both rows
    Aei_batch = torch.cat(Aei_shifted, dim=1)   # (2 , E_tot)

    Aef_batch = concat_csr(aef)

    # ----- merge label edges -------------------------------------
    Lei_shifted = []
    for off, ei in zip(node_offsets, lei):
        Lei_shifted.append(ei + off)
    Lei_batch = torch.cat(Lei_shifted, dim=1)

    Lef_batch = concat_csr(lef)
    y_batch   = torch.cat(ys)

    batch_vecs = []
    for gid, count in enumerate(node_counts):
        batch_vecs.append(torch.full((count,), gid, dtype=torch.long))
    batch = torch.cat(batch_vecs, dim=0)

    return filenames, X_batch, Aei_batch, Aef_batch, Lei_batch, Lef_batch, y_batch, batch

def debug_collate(batch):
    _, xs, aei, aef, lei, lef, ys = zip(*batch)
    print("--- one mini-batch ---")
    for i, X in enumerate(xs):
        print(f"graph {i}:  nodes={X.size(0):4d}   "
              f"struct-edges={aei[i].shape[1]:4d}   "
              f"label-edges={lei[i].shape[1]:3d}")
    # then call the real collate to keep training code unchanged
    return sparse_graph_collate(batch)

# ───────────────────────────────────────────────────────── loader utilities
def identity_collate(batch):
    """batch == list of length 1 → return that single sample untouched."""
    return batch[0]

def make_loader(ds, batch_size=1, shuffle=False):
    return DataLoader(ds,
                      batch_size=batch_size,
                      shuffle=shuffle,
                      collate_fn=sparse_graph_collate,
                      num_workers=0,
                      pin_memory=True)    # fast GPU transfer

In [4]:
# dataset   = TarGraphDataset("../../data/swde_HTMLgraphs.tar")
# loader    = make_loader(dataset, batch_size=8, shuffle=False)

# next(iter(loader))

# count = 0
# for fileinfo, X, Aei, Aef, Lei, Lef, y, batch in loader:
#     print(fileinfo)
#     count +=1
# print(count)

In [5]:
# This is a lazy loader for individual files

# class LazyGraphDataset(Dataset):
#     """
#     Each graph lives in its own .npz / .pt / whatever on disk.
#     __getitem__ loads it just-in-time.
#     """

#     def __init__(self, folderpaths):
#         """
#         meta_csv: a CSV (or list of dicts) with columns:
#             path_X, path_A_index, path_A_feat, path_L_index, path_L_feat, path_y
#         Only these tiny strings stay in RAM.
#         """
#         self.folderpaths = list(folderpaths)

#     def _import_tensor(self, filename: str, dtype: torch.dtype, is_sparse: bool = False):
#         """
#         Load a .npz CSR matrix and return either
#         • a torch.sparse_csr_tensor              (if is_sparse=True)
#         • a torch.Tensor (dense)                 (otherwise)
#         """
#         csr = sparse.load_npz(filename).tocsr()

#         if is_sparse:
#             crow = torch.from_numpy(csr.indptr.astype(np.int64))
#             col  = torch.from_numpy(csr.indices.astype(np.int64))
#             val  = torch.from_numpy(csr.data).to(dtype)
#             return torch.sparse_csr_tensor(crow, col, val,size=csr.shape, dtype=dtype, requires_grad=False)
#         # — otherwise densify —
#         return torch.from_numpy(csr.toarray()).to(dtype)

#     def __getitem__(self, idx):
#         folder_path = self.folderpaths[idx]

#         X = self._import_tensor((folder_path/"X.npz"), torch.float32, is_sparse=False)
#         #A = self._import_tensor(folder_path/"A.npz", torch.long, True)
#         Aef = self._import_tensor((folder_path/"E.npz"), torch.float32, is_sparse=True)
#         Aei = torch.from_numpy(np.load((folder_path/"edge_index.npy")))
#         Lef = self._import_tensor((folder_path/"labels.npz"), torch.float32, is_sparse=True)
#         Lei = torch.from_numpy(np.load((folder_path/"label_index.npy")))
#         y = torch.from_numpy(np.load((folder_path/"label_value.npy")))

#         return X, Aei, Aef, Lei, Lef, y

#     def __len__(self):
#         return len(self.folderpaths)

# def graph_collate(batch):
#     # batch is a list of tuples
#     xs, aei, aef, lei, lef, ys = zip(*batch)   # tuples of length B

#     return (list(xs),                          # list of sparse X
#             list(aei),                         # list of edge_index
#             list(aef),                         # list of sparse A_edge_feat
#             list(lei),
#             list(lef),
#             list(ys))                          # dense y can still be list/stack

# def make_loader(dataset, batch_size=1, shuffle=False):
#     return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=graph_collate, num_workers=0, pin_memory=True)

In [6]:
# def walk_limited(root: Path, max_depth: int, pat: str):
#     root_depth = len(root.parts)
#     for dirpath, dirnames, _ in os.walk(root):
#         depth = len(Path(dirpath).parts) - root_depth
#         if depth > max_depth:
#             # prune traversal
#             dirnames[:] = []
#             continue
#         for d in dirnames:
#             p = Path(dirpath, d)
#             if p.match(pat):
#                 yield p

# src = Path("/vol/bitbucket/mjh24/IAEA-thesis/data/swde_HTMLgraphs/movie/movie")
# batch_dirs = list(walk_limited(src, max_depth=2, pat='[0-9][0-9][0-9][0-9]'))
# print(src.exists())
# batchFiles = list(src.rglob("[0-9][0-9][0-9][0-9]"))
# print(len(batchFiles))
# dataset = LazyGraphDataset(batchFiles)
# dataloader = make_loader(dataset, batch_size=1, shuffle=False)

In [7]:
# for Xs, Aeis, Aefs, Leis, Lefs, ys in dataloader:
#     print(Xs[0].shape, Aeis[0].shape, Aefs[0].shape, Leis[0].shape, Lefs[0].shape, ys[0].shape)
#     break

In [8]:
# # Helper function to normalise the A matrix
# def symmetric_normalize(A_tilde):
#     """
#     Performs symmetric normalization of A_tilde (Adj. matrix with self loops):
#       A_norm = D^{-1/2} * A_tilde * D^{-1/2}
#     Where D_{ii} = sum of row i in A_tilde.

#     A_tilde (N, N): Adj. matrix with self loops
#     Returns:
#       A_norm : (N, N)
#     """

#     eps = 1e-5
#     d = A_tilde.sum(dim=1) + eps
#     D_inv = torch.diag(torch.pow(d, -0.5))
#     return (D_inv @ A_tilde @ D_inv).to(torch.float32)

In [9]:
def plot_metrics_live(
    train_vals,
    val_vals,
    p, r, f1,                 # new metric lists (same length as train/val)
    save_path,
    xlabel="Epoch",
    ylabel_left="Loss / Accuracy",
    ylabel_right="P · R · F1",
    title="Training progress",
    fig_ax=None
):
    """
    Live-updating dual-axis plot.
    Call once per epoch with the (growing) metric lists.

    Parameters
    ----------
    train_vals, val_vals : list[float]
        Main metric to compare (e.g. loss or accuracy).
    p, r, f1 : list[float]
        Precision, recall, f1 – plotted on a secondary y-axis.
    save_path : str or Path
        Where to write the PNG each time.
    fig_ax : tuple(fig, (ax_left, ax_right)) | None
        Pass back what you got from the previous call to avoid flicker.

    Returns
    -------
    (fig, (ax_left, ax_right))
        Feed this straight back into the next call.
    """
    # ---------- figure / axes boilerplate ----------
    if fig_ax is None:
        fig, ax_left = plt.subplots(figsize=(8, 5))
        ax_right = ax_left.twinx()
    else:
        fig, (ax_left, ax_right) = fig_ax

    # ---------- clear and redraw ----------
    ax_left.cla()
    ax_right.cla()

    epochs = range(1, len(train_vals) + 1)

    # left-axis curves
    ax_left.plot(epochs, train_vals, "-o", label="Train", markersize=4)
    ax_left.plot(epochs, val_vals,   "-s", label="Val",   markersize=4)
    ax_left.set_xlabel(xlabel)
    ax_left.set_ylabel(ylabel_left)
    ax_left.grid(True, axis="both")

    # right-axis curves
    ax_right.plot(epochs, p,  "--d", label="Precision", markersize=4)
    ax_right.plot(epochs, r,  "--^", label="Recall",    markersize=4)
    ax_right.plot(epochs, f1, "--*", label="F1",        markersize=4)
    ax_right.set_ylabel(ylabel_right)

    # one combined legend
    lines_l, labels_l = ax_left.get_legend_handles_labels()
    lines_r, labels_r = ax_right.get_legend_handles_labels()
    ax_left.legend(lines_l + lines_r, labels_l + labels_r, loc="upper center", ncol=5)

    ax_left.set_title(title)
    fig.tight_layout()
    fig.savefig(Path(save_path), dpi=150)

    return fig, (ax_left, ax_right)

In [10]:
# # To advance the model, use the methods in https://arxiv.org/pdf/2311.02921

# class GraphAttentionNetwork(nn.Module):
#     """
#     HTML‑graph model

#         X  ─╮
#             │  GAT( 96 → 64 )
#             │  ReLU
#             │  GAT( 64 → 32 )
#             │  ReLU
#             └─ Edge‑feature constructor
#                       [h_i ‖ h_j ‖ φ(e_ij)] ─► MLP(69 → 1)

#     Parameters
#     ----------
#     in_dim          : node‑feature size   (= 96)
#     edge_in_dim     : raw edge‑feature size (= 197)
#     edge_emb_dim    : Edge-feature MLP output dims
#     """
#     def __init__(self,
#                  in_dim: int        = 96,
#                  edge_in_dim: int   = 197,
#                  edge_emb_dim: int  = 8,
#                  hidden1: int       = 128,
#                  hidden2: int       = 64,
#                  hidden3: int       = 32,
#                  heads:  int        = 4):
#         super().__init__()

#         # ── Node-level encoder (edge-aware) ────────────────────────────
#         self.tr1 = TransformerConv(
#             in_channels      = in_dim,
#             out_channels     = hidden1,
#             heads            = heads,
#             edge_dim         = edge_emb_dim,
#             dropout          = 0.1,
#             beta             = False         # learnable α in α·x + (1-α)·attn
#         )
#         self.ln1 = nn.LayerNorm(hidden1 * heads)

#         self.tr2 = TransformerConv(
#             in_channels      = hidden1 * heads,
#             out_channels     = hidden2,
#             heads            = 1,
#             edge_dim         = edge_emb_dim,
#             dropout          = 0.1,
#             beta             = False
#         )
#         self.ln2 = nn.LayerNorm(hidden2)
#         self.tr3 = TransformerConv(
#             in_channels      = hidden2,
#             out_channels     = hidden3,
#             heads            = 1,
#             edge_dim         = edge_emb_dim,
#             dropout          = 0.1,
#             beta             = False
#         )

#         # ── Edge feature projector ────────────── (It is not an explicit linear layer as it works on a sparse matrix)
#         self.AW_edge = nn.Parameter(torch.empty(edge_in_dim, edge_emb_dim))
#         nn.init.xavier_uniform_(self.AW_edge)
#         # self.EW_edge = nn.Parameter(torch.empty(edge_in_dim, edge_emb_dim))
#         # nn.init.xavier_uniform_(self.EW_edge)

#         # ── Edge-level MLP decoder (unchanged) ────────────────────────
#         self.edge_mlp = nn.Sequential(
#             nn.Linear(hidden3 * 2 + edge_emb_dim, hidden3),
#             nn.ReLU(),
#             #nn.Dropout(0.2),
#             nn.Linear(hidden3, 1)
#         )

#         # init beta gate around 0.5 to avoid identity lock
#         for tr in (self.tr1, self.tr2):
#             if getattr(tr, "lin_beta", None) is not None:
#                 nn.init.zeros_(tr.lin_beta.weight)
#                 if tr.lin_beta.bias is not None:
#                     nn.init.zeros_(tr.lin_beta.bias)

#     # ---------------------------------------------------------------------

#     def forward(
#         self,
#         x_sparse: torch.Tensor,        # (N_nodes, 96)          sparse
#         A_edge_index: torch.Tensor,   # (2, nnz_A)             COO  (from A)
#         A_edge_attr: torch.Tensor,    # (nnz_A, 197)           dense / sparse.mm
#         E_edge_index: torch.Tensor,   # (2, N_E)               candidates
#         E_edge_attr: torch.Tensor,    # (N_E, 197)             sparse features
#         E_attr_dropout=0.0,            # Probability of dropping out a whole edge_attr when training
#         E_attr_include=True,
#         A_attr_include=True
#     ):
#         # 1) node features
#         x_dense = x_sparse.to_dense()
#         A_edge_emb = torch.sparse.mm(A_edge_attr, self.AW_edge)     # (nnz_A , 8)
#         #A_edge_emb = A_edge_attr.to_dense()
#         #E_edge_emb = E_edge_attr.to_dense()

#         if not A_attr_include:
#             A_edge_emb = torch.zeros_like(A_edge_emb)

#         # 2) edge-aware GATv2 layers
#         h = F.relu( self.tr1(x_dense, A_edge_index, A_edge_emb) )
#         #h = self.ln1(h)
#         #Try a linlayer here to condense heads
#         h = F.relu( self.tr2(h,        A_edge_index, A_edge_emb) )
#         #h = self.ln2(h)
#         h = F.relu( self.tr3(h,        A_edge_index, A_edge_emb) )

#         # 3) candidate-edge projection  φ(E) = E @ W_edge
#         E_edge_emb = torch.sparse.mm(E_edge_attr, self.AW_edge)     # (N_E , 8)
        
#         if self.training:
#             mask = torch.rand(E_edge_emb.size(0), 1,
#                             device=E_edge_emb.device) > E_attr_dropout   # (N_E,1)
#             E_edge_emb = E_edge_emb * mask

#         if not E_attr_include:
#             E_edge_emb = torch.zeros_like(E_edge_emb)

#         # 4) gather node embeddings and classify
#         src, dst = E_edge_index
#         z = torch.cat([h[src], h[dst], E_edge_emb], dim=1)      # (N_E , 72)
#         return self.edge_mlp(z).squeeze(-1)                   # (N_E ,) returns the logits

In [11]:
def focal_loss(logits, targets, alpha = 0.25, gamma = 2.0, reduction: str = "mean"):
    bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
    p_t = torch.exp(-bce)          # = σ(z) if y==1 else 1-σ(z)
    loss = (alpha * (1.0 - p_t).pow(gamma) * bce)
    return loss.mean() if reduction == "mean" else loss.sum()


# ---------- 2. Pair-wise AUC (logistic ranking) loss -------------------------
def pairwise_auc_loss(logits, targets, sample_k: int | None = None):
    """
    logits   : float tensor (B,)
    targets  : {0,1} tensor (B,)
    sample_k : optional int – #negatives to sample per positive.  If None,
               uses *all* positives × negatives (can be heavy for big batches).
    """
    pos_logits = logits[targets == 1]      # shape (P,)
    neg_logits = logits[targets == 0]      # shape (N,)

    if pos_logits.numel() == 0 or neg_logits.numel() == 0:
        # No valid pairs (edge cases in small batches) – return 0 so it
        # doesn't break the graph.
        return logits.new_tensor(0.0, requires_grad=True)

    # --- optional negative subsampling to save memory ---
    if sample_k is not None and neg_logits.numel() > sample_k:
        idx = torch.randperm(neg_logits.numel(), device=logits.device)[:sample_k]
        neg_logits = neg_logits[idx]

    # Broadcast positives against negatives: diff = s_pos - s_neg
    diff = pos_logits[:, None] - neg_logits[None, :]        # (P, N) or (P, k)
    loss = F.softplus(-diff)                                # log(1+e^(-diff))

    return loss.mean()


# ---------- 3. Combined wrapper ----------------------------------------------
class PairwiseAUCFocalLoss(nn.Module):
    """
    total_loss = pairwise_auc_loss + lambda_focal * focal_loss
    """
    def __init__(self,
                 gamma: float = 2.0,
                 alpha: float = 0.25,
                 lambda_focal: float = 0.5,
                 sample_k: int | None = None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.lambda_focal = lambda_focal
        self.sample_k = sample_k

    def forward(self, logits, targets):
        loss_rank = pairwise_auc_loss(
            logits, targets, sample_k=self.sample_k
        )
        loss_focal = focal_loss(
            logits, targets, alpha=self.alpha, gamma=self.gamma
        )
        return loss_rank * (1 - self.lambda_focal) + self.lambda_focal * loss_focal


In [None]:
CLIP_NORM = 2.0           # gradient clipping

# ---------- one epoch --------------------------------------------------------
def train_epoch(model, loader, optimizer,
                criterion, sched, epoch, totalEpoch, device="cpu", **kwargs):

    model.train()
    
    running_loss, running_edges = 0.0, 0
    count = 0
    l = len(loader)

    for f, X_sparse, Aei, Aef, Lei, Lef, y, batch in loader:
        count += 1
        X_sparse, Aei, Aef = X_sparse.to(device), Aei.to(device), Aef.to(device)
        Lei, Lef, y = Lei.to(device), Lef.to(device), y.to(device)
        batch = batch.to(device)

        optimizer.zero_grad()
        
        logits = model(X_sparse, batch, Aei, Aef, Lei, Lef, kwargs["p_Lef_drop"], kwargs["use_E_attr"], kwargs["use_A_attr"])          # (N_label,)
        loss   = criterion(logits, y.float())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
        optimizer.step()
        sched.step()

        running_loss  += loss.item() * y.numel()
        running_edges += y.numel()

        if count % 20 == 0:
            print(f"epoch {count}/{l} "
                    f"loss={loss:.4f}")

            # loss, p, r, f1 = eval_edge_model(model, loader, criterion, device=device, use_E_attr=kwargs["use_E_attr"], use_A_attr = kwargs["use_A_attr"])
            # print(f"\t\tloss={loss:.4f}  P={p:.3f} R={r:.3f} F1={f1:.3f}  E_features={kwargs["use_E_attr"]} A_features={kwargs["use_A_attr"]}")

    return running_loss / running_edges


# ---------- evaluation -------------------------------------------------------
@torch.no_grad()
def eval_edge_model(model, loader, criterion, device="cpu", thr_type="median", **kwargs):
    model.eval()
    TP = FP = FN = 0
    TP2 = FP2 = FN2 = 0
    running_loss, running_edges = 0.0, 0

    filenames = []
    for f, X_sparse, Aei, Aef, Lei, Lef, y, batch in loader:
        filenames += f
        X_sparse, Aei, Aef = X_sparse.to(device), Aei.to(device), Aef.to(device)
        Lei, Lef, y = Lei.to(device), Lef.to(device), y.to(device)
        batch = batch.to(device)

        #Complete Model
        logits = model(X_sparse, batch, Aei, Aef, Lei, Lef, 0, kwargs["use_E_attr"], kwargs["use_A_attr"])
        loss   = criterion(logits, y.float())
        running_loss  += loss.item() * y.numel()
        running_edges += y.numel()
        probs  = torch.sigmoid(logits)
        if thr_type=="median":
            thr = torch.median(probs)

        pred = (probs >= thr).long()
        TP  += ((pred == 1) & (y == 1)).sum().item()
        FP  += ((pred == 1) & (y == 0)).sum().item()
        FN  += ((pred == 0) & (y == 1)).sum().item()

    print(f"Validating {np.unique([filename[:-5] for filename in filenames])} website type")
    
    prec = TP / (TP + FP + 1e-9)
    rec  = TP / (TP + FN + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)
    return running_loss / running_edges, prec, rec, f1

In [19]:
def train_model(model,
                train_loader,
                val_loader,
                load_checkpoint,
                num_epochs     = 100,
                lr             = 1e-3,
                validate_every = 10,
                patience       = 10,
                device         = "cpu"):

    print(model)

    model_path = "./model_in_training.pt"
    if os.path.exists(model_path) and load_checkpoint:
        print("loading existing model...")
        model.load_state_dict(torch.load(model_path))

    model.to(device)
    
    opt   = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    sched = lr_scheduler.OneCycleLR(opt, max_lr=4e-4, epochs=num_epochs, steps_per_epoch=len(train_loader),
                   pct_start=0.1, anneal_strategy='cos', div_factor=25, final_div_factor=1e4, cycle_momentum=False)
                        #StepLR(opt, step_size=3, gamma=0.9)
    criterion = focal_loss
    # criterion = PairwiseAUCFocalLoss(
    #             gamma=2.0,
    #             alpha=0.25,
    #             lambda_focal=1,  # 0 ⇒ pure ranking loss; 1 ⇒ equal weight
    #             sample_k=128     # speeds up training; set None for exact loss
    #         )
    #criterion = nn.BCEWithLogitsLoss()

    best_f1, fig_ax, best_state = 0.0, None, None
    train_loss, val_loss, precision, recall, f1score = [], [], [], [], []

    for epoch in range(1, num_epochs + 1):
        p_Lef_drop = 0#.3 - 0.3 * (epoch-2)/(num_epochs-2 + 1e-9)        
        use_E_attr,  use_A_attr = (epoch>0), (epoch>0)

        loss = train_epoch(model, train_loader, opt, criterion, sched, epoch, num_epochs, device=device, use_E_attr=use_E_attr, use_A_attr = use_A_attr, p_Lef_drop = p_Lef_drop)
        train_loss.append(loss)

        if epoch % validate_every == 0 or epoch == num_epochs:
            loss, p, r, f1 = eval_edge_model(model, val_loader, criterion, device=device, use_E_attr=use_E_attr, use_A_attr = use_A_attr)
            val_loss.append(loss)

            lr_now = opt.param_groups[0]["lr"]
            print(f"Epoch {epoch:03d}/{num_epochs} "
                  f"loss={loss:.4f}  P={p:.3f} R={r:.3f} F1={f1:.3f}  lr={lr_now:.2e}  E_features={use_E_attr} A_features={use_A_attr}")
            precision.append(p)
            recall.append(r)
            f1score.append(f1)

            if f1 >= best_f1:
                best_f1, best_state = f1, copy.deepcopy(model.state_dict())

            # if lr_now < 1e-5:
            #     print("Stop: LR < 1e-5")
            #     break

            fig_ax = plot_metrics_live(
                train_loss,
                val_loss,
                precision,recall,f1score,
                "CurrentRun",
                xlabel="Epoch",
                ylabel_left="Loss",
                ylabel_right="P · R · F1",
                title="Model Performance",
                fig_ax=fig_ax
            )
            
            with torch.no_grad():
                torch.save(model.state_dict(), model_path)

    if best_state is not None:
        model.load_state_dict(best_state)

    return best_state, train_loss, val_loss, fig_ax


In [20]:
dataset = TarGraphDataset(datafile)
N = len(dataset)
# n_train = int(0.95 * N)
# n_val   = N - n_train
# train_ds, val_ds = random_split(dataset, [n_train, n_val])
matchcollege_start, matchcollege_end = dataset.get_sublen('university-matchcollege(2000)')
allmovie_start, allmovie_end = dataset.get_sublen('movie-allmovie(2000)')
imdb_start, imdb_end = dataset.get_sublen('movie-imdb(2000)')
usatoday_start, usatoday_end = dataset.get_sublen('nbaplayer-usatoday(436)')
yahoo_start, yahoo_end = dataset.get_sublen('nbaplayer-yahoo(438)')
matchcollege_idx = list(range(matchcollege_start, matchcollege_end))
allmovie_idx = list(range(allmovie_start, allmovie_end))
imdb_idx = list(range(imdb_start, imdb_end))
usatoday_idx=list(range(usatoday_start, usatoday_end))
yahoo_idx=list(range(yahoo_start, yahoo_end))

val_idx = list(set(allmovie_idx))#list(set(matchcollege_idx[-10:])) + list(set(allmovie_idx[-10:]))#
train_idx = list(set(range(N)) - set(val_idx) - set(usatoday_idx) - set(yahoo_idx))#list(set(matchcollege_idx + allmovie_idx) - set(val_idx))#
train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, val_idx)

train_loader = make_loader(train_ds, batch_size=8, shuffle=True)
val_loader = make_loader(val_ds, batch_size=8, shuffle=True)

model = GraphAttentionNetwork(in_dim = 114, pe_dim=12, edge_in_dim = 200, edge_emb_dim = 32, heads = 4)#16,32,4 was the winner

load_checkpoint = False
_, trainloss, valloss, fig_ax = train_model(model,
            train_loader,
            val_loader,
            load_checkpoint,
            num_epochs     = 12,
            lr             = 1e-3,
            validate_every = 1,
            patience       = 1,
            device         = "cuda")

GraphAttentionNetwork(
  (pe_lin): Linear(in_features=18, out_features=12, bias=True)
  (pe_norm): BatchNorm1d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (convs): ModuleList(
    (0-5): 6 x GPSConv(108, conv=GINEConv(nn=Sequential(
      (0): Linear(in_features=108, out_features=108, bias=True)
      (1): ReLU()
      (2): Linear(in_features=108, out_features=108, bias=True)
    )), heads=4, attn_type=multihead)
  )
  (edge_mlp): Sequential(
    (0): Linear(in_features=248, out_features=108, bias=True)
    (1): ReLU()
    (2): Linear(in_features=108, out_features=1, bias=True)
  )
)
epoch 1/3095 loss=0.0441


KeyboardInterrupt: 

In [None]:
#Save model
torch.save(model.state_dict(), "FULLTRAINEDALLDATAModelf1-74-learning-bettersched.pt")

In [None]:
# model_path = "./FULLTRAINEDALLDATAModelf1-74-learning.pt"
# if os.path.exists(model_path) and load_checkpoint:
#     print("loading existing model...")
#     model.load_state_dict(torch.load(model_path))



eval_edge_model(model, val_loader, focal_loss, device="cuda", use_E_attr=True, use_A_attr=True)
#b4 submitting to A100
#Experiemnt with the comparison loss
#Do self layers myself
#
#Graphs of also without edges

#32 32 layers
#Just train everything from the start

Validating ['swde_HTMLgraphs/movie/movie/movie-allmovie(2000)'] website type


(0.04156546794519418,
 0.7552735881680096,
 0.7551079869696979,
 0.7551907779904338)