### Edge classification 

In [None]:
# gnn_edge_clf.py
# --------------------------------------------------------------
# Dependencies: torch, torch_geometric, pandas, numpy, networkx
# --------------------------------------------------------------
import json, pathlib, math, random
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional, List, Tuple

import numpy as np
import pandas as pd
import networkx as nx
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCN, GraphSAGE, GAT, GATv2Conv

# ---------------------------------------------------------------------
# 1. Configuration
# ---------------------------------------------------------------------
@dataclass
class GNNConfig:
    model_name: str = "gcn"         # {"gcn","graphsage","gat"}
    hidden_dim: int = 128
    num_layers: int = 2
    heads: int = 8                  # for GAT*
    layer_dropout: float = 0.3

    # optimisation
    batch_size: int = 32
    lr: float = 1e-3
    epochs: int = 100
    patience: int = 15
    grad_clip: float = 1.0

    # split
    split_mode: str = "date"        # {"date","ratio"}
    cutoff_date: Optional[str] = None
    val_ratio: float = 0.10
    test_ratio: float = 0.10
    shuffle_in_split: bool = False

    # misc
    device: str = "cuda"
    seed: int = 42

    # edge imbalance override
    edge_pos_weights: Optional[List[float]] = None

    @staticmethod
    def load(x):                     # allow json / dict / object
        if isinstance(x, GNNConfig):
            return x
        if isinstance(x, dict):
            return GNNConfig(**x)
        p = pathlib.Path(x)
        with open(p) as f:
            return GNNConfig(**json.load(f))

# ---------------------------------------------------------------------
# 2. Mini Encoder Zoo  (node → z)
# ---------------------------------------------------------------------
class EncoderGCN(nn.Module):
    def __init__(self, d_in, h, L, p):
        super().__init__()
        self.body = GCN(d_in, h, L, norm="batch", act="relu")
        self.drop = nn.Dropout(p)
    def forward(self, x, ei, ea):
        return self.drop(self.body(x, ei))

class EncoderSAGE(nn.Module):
    def __init__(self, d_in, h, L, p):
        super().__init__()
        self.body = GraphSAGE(d_in, h, L, norm="batch", act="relu")
        self.drop = nn.Dropout(p)
    def forward(self, x, ei, ea):
        return self.drop(self.body(x, ei))

class EncoderGAT(nn.Module):
    def __init__(self, d_in, h, L, heads, p):
        super().__init__()
        self.body = GAT(d_in, h, L, heads=heads, norm="batch", act="relu")
        self.drop = nn.Dropout(p)
    def forward(self, x, ei, ea):
        return self.drop(self.body(x, ei))

ENCODER_FACTORY = {
    "gcn":   EncoderGCN,
    "graphsage": EncoderSAGE,
    "sage":  EncoderSAGE,
    "gat":   EncoderGAT,
}

# ---------------------------------------------------------------------
# 3. Edge decoder  (z_u , z_v) → logit
# ---------------------------------------------------------------------
class DotDecoder(nn.Module):
    def forward(self, z_src, z_dst):
        return (z_src * z_dst).sum(-1, keepdim=True)  # (E,1)

# ---------------------------------------------------------------------
# 4. Experiment wrapper
# ---------------------------------------------------------------------
class EdgeExperiment:
    def __init__(
        self,
        node_frames: Dict[str, pd.DataFrame],
        edge_frames: Dict[str, pd.DataFrame],
        graph: nx.DiGraph,
        cfg: GNNConfig | Dict[str, Any] | str = GNNConfig(),
    ):
        self.cfg = GNNConfig.load(cfg)
        torch.manual_seed(self.cfg.seed); np.random.seed(self.cfg.seed); random.seed(self.cfg.seed)

        self.node_frames = node_frames
        self.edge_frames = edge_frames
        self.graph = graph

        self.device = torch.device(
            "cuda" if self.cfg.device=="cuda" and torch.cuda.is_available() else "cpu"
        )

        # placeholders
        self.edge_order: List[Tuple[int,int]] = []
        self.snapshots: List[Data] = []

    # -------------------------------------------- snapshot builder
    def prepare_snapshots(self):
        nodes = sorted(self.graph.nodes)
        n2idx = {n:i for i,n in enumerate(nodes)}
        self.edge_order = [(n2idx[u], n2idx[v]) for u,v in self.graph.edges]
        E = len(self.edge_order)

        # intersect timestamps
        idxs = [df.index for df in self.node_frames.values()] \
             + [df.index for df in self.edge_frames.values()]
        ts_common = sorted(set.intersection(*map(set, idxs)))

        snaps: List[Data] = []
        for ts in ts_common:
            # node feats
            x = torch.tensor(
                np.vstack([ self.node_frames[n].loc[ts].to_numpy(dtype=np.float32)
                            for n in nodes ]),
                dtype=torch.float32)
            # edge feats & labels
            edge_feat_rows, edge_labels = [], []
            for (s_idx,d_idx) in self.edge_order:
                s,d = nodes[s_idx], nodes[d_idx]
                row = self.edge_frames[f"{s}-{d}".lower()].loc[ts]
                edge_labels.append(row["target"])
                edge_feat_rows.append(row.drop("target").to_numpy(dtype=np.float32))
            edge_attr = torch.tensor(np.vstack(edge_feat_rows), dtype=torch.float32)
            edge_label = torch.tensor(edge_labels, dtype=torch.float32)
            edge_index = torch.tensor(np.array(self.edge_order).T, dtype=torch.long)

            snaps.append(Data(
                x=x, edge_index=edge_index, edge_attr=edge_attr,
                edge_label=edge_label,
                snap_time=torch.tensor([pd.Timestamp(ts).value]),
            ))
        self.snapshots = snaps
        return self

    # -------------------------------------------- train/val/test split
    def build_loaders(self):
        snaps_sorted = sorted(self.snapshots, key=lambda g:g.snap_time.item())
        if self.cfg.split_mode=="date":
            cut = pd.Timestamp(self.cfg.cutoff_date).value
            train = [g for g in snaps_sorted if g.snap_time.item()<=cut]
            hold  = [g for g in snaps_sorted if g.snap_time.item()>cut]
            mid = len(hold)//2
            val, test = hold[:mid], hold[mid:]
        else:
            n=len(snaps_sorted)
            nv = int(n*self.cfg.val_ratio); nt=int(n*self.cfg.test_ratio)
            train, val, test = snaps_sorted[:n-nv-nt], snaps_sorted[n-nv-nt:n-nt], snaps_sorted[n-nt:]
        self._compute_edge_weights(train)                     # <- imbalance weights

        bs = self.cfg.batch_size
        self.train_dl = DataLoader(train, batch_size=bs, shuffle=self.cfg.shuffle_in_split)
        self.val_dl   = DataLoader(val,   batch_size=bs, shuffle=False)
        self.test_dl  = DataLoader(test,  batch_size=bs, shuffle=False)
        return self

    # -------------------------------------------- imbalance weights
    def _compute_edge_weights(self, train_snaps: List[Data]):
        E = len(self.edge_order)
        pos = torch.zeros(E); neg = torch.zeros(E)
        for g in train_snaps:
            pos += g.edge_label
            neg += 1 - g.edge_label
        w = neg / (pos + 1e-6)
        if self.cfg.edge_pos_weights is not None:
            w = torch.tensor(self.cfg.edge_pos_weights, dtype=torch.float32)
        # attach to every snapshot (clone so grads aren’t tracked)
        for g in self.snapshots:
            g.edge_weight = w.clone()
        self.edge_weight = w             # store for metrics

    # -------------------------------------------- model init
    def init_model(self):
        d_in  = self.snapshots[0].x.size(1)
        d_e   = self.snapshots[0].edge_attr.size(1)
        enc_cls = ENCODER_FACTORY[self.cfg.model_name.lower()]
        enc = enc_cls(d_in, self.cfg.hidden_dim, self.cfg.num_layers,
                      self.cfg.layer_dropout, **({"heads":self.cfg.heads} if "gat" in self.cfg.model_name else {}))
        dec = DotDecoder()
        self.model = nn.ModuleDict({"enc":enc, "dec":dec}).to(self.device)
        return self

    # -------------------------------------------- optimiser / loss
    def compile(self):
        self.loss_fn = nn.BCEWithLogitsLoss(reduction="none")
        self.opt = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)

    # -------------------------------------------- train helpers
    def _forward_edges(self, data:Data):
        z = self.model["enc"](data.x, data.edge_index, data.edge_attr)
        src,dst = data.edge_index        # edges are aligned with labels
        logits = self.model["dec"](z[src], z[dst]).squeeze(-1)
        return logits

    def _run_epoch(self, loader, train:bool):
        if train: self.model.train();   torch.set_grad_enabled(True)
        else:     self.model.eval();    torch.set_grad_enabled(False)

        tot_loss = 0.0
        for data in loader:
            data = data.to(self.device)
            logits = self._forward_edges(data)
            loss_vec = self.loss_fn(logits, data.edge_label)
            loss = (loss_vec * data.edge_weight.to(loss_vec)).mean()

            if train:
                self.opt.zero_grad(); loss.backward()
                if self.cfg.grad_clip: nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
                self.opt.step()
            tot_loss += loss.item()
        return tot_loss/len(loader)

    # -------------------------------------------- fit
    def train(self):
        best,val_best=1e9,0; patience=0
        for epoch in range(1, self.cfg.epochs+1):
            tr = self._run_epoch(self.train_dl, True)
            vl = self._run_epoch(self.val_dl,   False)
            print(f"[{epoch:03d}] train {tr:.4f} | val {vl:.4f}")
            if vl<best: best=vl; patience=0; self.best_state= {k:v.cpu() for k,v in self.model.state_dict().items()}
            else:       patience+=1
            if patience>=self.cfg.patience:
                print("Early stop."); break
        self.model.load_state_dict(self.best_state)

    # -------------------------------------------- predict on list of snapshots (loader)
    @torch.no_grad()
    def predict_loader(self, loader:DataLoader):
        self.model.eval()
        all_logits=[]
        for data in loader:
            data=data.to(self.device)
            logits=self._forward_edges(data)
            all_logits.append(logits.cpu())
        return torch.cat(all_logits)    # concatenated over batches/snapshots

# ---------------------------------------------------------------------
# Usage sketch (not executed here):
# ---------------------------------------------------------------------
#   exp = (EdgeExperiment(node_frames, edge_frames, graph)
#            .prepare_snapshots()
#            .build_loaders()
#            .init_model()
#            .compile())
#   exp.train()
#   test_logits = exp.predict_loader(exp.test_dl)
