In [None]:
"""
gnn_experiment_refactored.py
—————————————————————————————————
End-to-end GNN experimentation wrapper for **binary rare-event
classification** on temporal node snapshots.

Key additions vs. original version
• Per-node positive-class weights built into BCE loss
• Edge-dropout (train-time only)
• Configurable LR scheduler with LR print at each epoch
"""

# ----------------------------------------------------------------------
# 0. Imports
# ----------------------------------------------------------------------
import json
import pathlib
from dataclasses import dataclass, field, asdict
from typing import Dict, Any, Optional, Union, List, Tuple

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    GCNConv,
    GATv2Conv,
    GCN,
    GraphSAGE,
    GAT,
)


# ----------------------------------------------------------------------
# 1. Configuration object
# ----------------------------------------------------------------------
@dataclass
class GNNConfig:
    # --- model ---------------------------------------------------------
    task: str = "node_clf"            # {"node_clf", "edge_clf"}
    model_name: str = "gcn"           # {"gcn","graphsage","gat","gatv2"}
    num_layers: int = 2
    hidden_dim: int = 128
    heads: int = 8                    # for attention models
    norm: str = "batch"               # {"batch","layer",None}

    # --- data & label --------------------------------------------------
    target_col: str = "target"

    # --- split ---------------------------------------------------------
    split_mode: str = "date"          # {"date","ratio"}
    cutoff_date: Optional[str] = None
    val_ratio: float = 0.10
    test_ratio: float = 0.10
    shuffle_in_split: bool = False

    # --- optimisation --------------------------------------------------
    batch_size: int = 32
    lr: float = 1e-3
    epochs: int = 100
    patience: int = 10
    grad_clip: float = 1.0

    optimiser: str = "adam"           # {"adam","adamw","sgd"}
    optimiser_kwargs: Dict[str, Any] = field(default_factory=dict)

    # --- LR scheduler --------------------------------------------------
    lr_scheduler: Optional[str] = None         # {"step","plateau","cosine","onecycle"}
    lr_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict)
    print_lr_each_epoch: bool = True

    # --- loss / weights -----------------------------------------------
    loss_fn: str = "bce"              # {"bce","cross_entropy"}
    class_weights: Optional[List[float]] = None   # global BCE/CE weights
    node_pos_weights: Optional[List[float]] = None  # per-node list (len == #nodes)

    # --- regularisation -----------------------------------------------
    in_dropout: float = 0.0           # feature-level dropout
    layer_dropout: float = 0.5        # model internal dropout
    edge_dropout: float = 0.0         # probability of dropping each edge at train-time

    # --- misc ----------------------------------------------------------
    device: str = "cuda"
    run_name: str = "default_run"
    seed: int = 42

    # ------------------------------------------------------------------
    @staticmethod
    def load(cfg: Union["GNNConfig", str, Dict[str, Any]]) -> "GNNConfig":
        if isinstance(cfg, GNNConfig):
            return cfg
        if isinstance(cfg, dict):
            return GNNConfig(**cfg)
        if isinstance(cfg, (str, pathlib.Path)):
            path = pathlib.Path(cfg)
            with open(path, "r", encoding="utf-8") as f:
                if path.suffix == ".json":
                    data = json.load(f)
                else:
                    raise ValueError("Unsupported file type for config path")
            return GNNConfig(**data)
        raise TypeError(f"Unsupported cfg type: {type(cfg)}")

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


# ----------------------------------------------------------------------
# 2. Utilities
# ----------------------------------------------------------------------
def has_nan(t: torch.Tensor) -> bool:
    return torch.isnan(t).any() or torch.isinf(t).any()


# ----------------------------------------------------------------------
# 3. Mini model zoo
# ----------------------------------------------------------------------
class GNNNodeSimple(torch.nn.Module):
    def __init__(self, d_in: int, hidden: int, d_out: int, p_drop: float):
        super().__init__()
        self.conv1 = GCNConv(d_in, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.dropout = torch.nn.Dropout(p_drop)
        self.fc = torch.nn.Linear(hidden, d_out)

    def forward(self, data: Data):
        x, ei = data.x, data.edge_index
        x = torch.relu(self.conv1(x, ei))
        x = torch.relu(self.conv2(x, ei))
        x = self.dropout(x)
        return self.fc(x)


class GNNGCN(torch.nn.Module):
    def __init__(self, d_in: int, hidden: int, d_out: int, layers: int, p_drop: float):
        super().__init__()
        self.body = GCN(d_in, hidden, layers, norm="batch", act="relu")
        self.dropout = torch.nn.Dropout(p_drop)
        self.fc = torch.nn.Linear(hidden, d_out)

    def forward(self, data: Data):
        x = self.body(data.x, data.edge_index)
        x = self.dropout(x)
        return self.fc(x)


class GNNSage(torch.nn.Module):
    def __init__(self, d_in: int, hidden: int, d_out: int, layers: int, p_drop: float):
        super().__init__()
        self.body = GraphSAGE(d_in, hidden, layers, norm="batch", act="relu")
        self.dropout = torch.nn.Dropout(p_drop)
        self.fc = torch.nn.Linear(hidden, d_out)

    def forward(self, data: Data):
        x = self.body(data.x, data.edge_index)
        x = self.dropout(x)
        return self.fc(x)


class GNNGAT(torch.nn.Module):
    def __init__(self, d_in: int, hidden: int, d_out: int, layers: int, heads: int, p_drop: float):
        super().__init__()
        self.body = GAT(d_in, hidden, layers, heads=heads, norm="batch", act="relu")
        self.dropout = torch.nn.Dropout(p_drop)
        self.fc = torch.nn.Linear(hidden, d_out)

    def forward(self, data: Data):
        x = self.body(data.x, data.edge_index)
        x = self.dropout(x)
        return self.fc(x)


class GNNGAT2(torch.nn.Module):
    def __init__(
        self,
        d_in: int,
        hidden: int,
        d_out: int,
        layers: int,
        heads: int,
        edge_dim: int,
        p_drop: float,
    ):
        super().__init__()
        self.conv1 = GATv2Conv(d_in, hidden, heads=heads, concat=True, edge_dim=edge_dim)
        self.conv2 = GATv2Conv(hidden * heads, hidden, heads=1, concat=False, edge_dim=edge_dim)
        self.dropout = torch.nn.Dropout(p_drop)
        self.fc = torch.nn.Linear(hidden, d_out)

    def forward(self, data: Data):
        x = torch.relu(self.conv1(data.x, data.edge_index, data.edge_attr))
        x = torch.relu(self.conv2(x, data.edge_index, data.edge_attr))
        x = self.dropout(x)
        return self.fc(x)


# ----------------------------------------------------------------------
# 4. Experiment wrapper
# ----------------------------------------------------------------------
class GNNExperiment:
    """
    Raw node/edge time-series → PyG snapshots → DataLoaders → Model.
    """

    # ------------------------------------------------------------ constructor
    def __init__(
        self,
        node_frames: Dict[str, pd.DataFrame],
        edge_frames: Dict[str, pd.DataFrame],
        graph: nx.DiGraph,
        cfg: Union[GNNConfig, str, Dict[str, Any]] = GNNConfig(),
    ):
        # raw inputs
        self.node_frames = node_frames
        self.edge_frames = edge_frames
        self.graph = graph

        # configuration
        self.cfg = GNNConfig.load(cfg)

        # runtime placeholders
        self.reg_order: List[str] = []
        self.edge_order: List[Tuple[int, int]] = []
        self.snapshots: Optional[List[Data]] = None

        self.train_dl = self.val_dl = self.test_dl = None

        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.loss_fn = None
        self.metric_fn = None

        self.history: Dict[str, list] = {"train_loss": [], "val_loss": [], "val_metric": []}
        self.best_val_loss = float("inf")
        self.best_ckpt = None
        self.patience_counter = 0

        # per-node weight lookup
        self._node_weight_lookup: Optional[torch.Tensor] = None

        # device
        self.device = torch.device(
            "cuda" if (self.cfg.device == "cuda" and torch.cuda.is_available()) else "cpu"
        )

    # ---------------------------------------------------------------- helpers
    @staticmethod
    def _edge_key(src: str, dst: str) -> str:
        return f"{src}-{dst}".lower()

    @staticmethod
    def _ensure_same_index(idxs: List[pd.DatetimeIndex]) -> List[pd.Timestamp]:
        common = sorted(set.intersection(*(set(ix) for ix in idxs)))
        if not common:
            raise ValueError("No common timestamps across supplied DataFrames")
        return common

    # ------------------------------------------------------- 1. snapshots
    def prepare_snapshots(self) -> "GNNExperiment":
        # node & edge order
        self.reg_order = sorted(self.graph.nodes)
        node_pos = {n: i for i, n in enumerate(self.reg_order)}
        self.edge_order = [(node_pos[s], node_pos[d]) for (s, d) in self.graph.edges]

        # node-weight lookup
        if self.cfg.node_pos_weights is not None:
            if len(self.cfg.node_pos_weights) != len(self.reg_order):
                raise ValueError("len(node_pos_weights) must equal number of nodes")
            self._node_weight_lookup = torch.tensor(self.cfg.node_pos_weights, dtype=torch.float32)
        else:
            self._node_weight_lookup = None

        # timestamps intersection
        ts_common = self._ensure_same_index(
            [df.index for df in self.node_frames.values()]
            + [df.index for df in self.edge_frames.values()]
        )

        snapshots: List[Data] = []
        for ts in ts_common:
            # node features + label
            feats, labels = [], []
            for region in self.reg_order:
                row = self.node_frames[region].loc[ts]
                labels.append(row[self.cfg.target_col])
                feats.append(row.drop(self.cfg.target_col).to_numpy(dtype=np.float32))
            x = torch.tensor(np.vstack(feats), dtype=torch.float32)
            y = torch.tensor(labels, dtype=torch.float32)

            # edge features
            edge_rows = []
            for s_idx, d_idx in self.edge_order:
                s, d = self.reg_order[s_idx], self.reg_order[d_idx]
                edge_rows.append(
                    self.edge_frames[self._edge_key(s, d)].loc[ts].to_numpy(dtype=np.float32)
                )
            edge_attr = torch.tensor(np.vstack(edge_rows), dtype=torch.float32)
            edge_index = torch.tensor(np.array(self.edge_order).T, dtype=torch.long)

            node_weight = (
                self._node_weight_lookup.clone()
                if self._node_weight_lookup is not None
                else torch.ones(len(self.reg_order))
            )

            snapshots.append(
                Data(
                    x=x,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    node_weight=node_weight,
                    snap_time=torch.tensor([pd.Timestamp(ts).value]),
                )
            )

        self.snapshots = snapshots
        return self

    # ------------------------------------------------------- 2. loaders
    def build_loaders(self) -> "GNNExperiment":
        if self.snapshots is None:
            raise RuntimeError("Call prepare_snapshots() before build_loaders()")

        snaps_sorted = sorted(self.snapshots, key=lambda g: g.snap_time.item())

        if self.cfg.split_mode == "date":
            if self.cfg.cutoff_date is None:
                raise ValueError("cutoff_date must be set for date split")
            cutoff_int = pd.Timestamp(self.cfg.cutoff_date).value
            train_set = [g for g in snaps_sorted if g.snap_time.item() <= cutoff_int]
            holdout = [g for g in snaps_sorted if g.snap_time.item() > cutoff_int]
            if not holdout:
                raise ValueError("No snapshots after cutoff_date for val/test split")
            mid = len(holdout) // 2
            val_set, test_set = holdout[:mid], holdout[mid:]
        else:
            n_total = len(snaps_sorted)
            n_test = int(n_total * self.cfg.test_ratio)
            n_val = int(n_total * self.cfg.val_ratio)
            n_train = n_total - n_val - n_test
            train_set = snaps_sorted[:n_train]
            val_set = snaps_sorted[n_train : n_train + n_val]
            test_set = snaps_sorted[n_train + n_val :]

        dl_kw = dict(batch_size=self.cfg.batch_size, shuffle=self.cfg.shuffle_in_split)
        self.train_dl = DataLoader(train_set, **dl_kw)
        self.val_dl = DataLoader(val_set, batch_size=self.cfg.batch_size, shuffle=False)
        self.test_dl = DataLoader(test_set, batch_size=self.cfg.batch_size, shuffle=False)
        return self

    # ------------------------------------------------------- 3. model
    def init_model(self) -> "GNNExperiment":
        if self.snapshots is None:
            raise RuntimeError("Call prepare_snapshots() first.")

        d_in = self.snapshots[0].x.size(1)
        d_out = 1  # binary logit
        edge_dim = self.snapshots[0].edge_attr.size(1)
        p_drop = self.cfg.layer_dropout

        name = self.cfg.model_name.lower()
        if name in {"simple", "baseline"}:
            model = GNNNodeSimple(d_in, self.cfg.hidden_dim, d_out, p_drop)
        elif name == "gcn":
            model = GNNGCN(d_in, self.cfg.hidden_dim, d_out, self.cfg.num_layers, p_drop)
        elif name in {"graphsage", "sage"}:
            model = GNNSage(d_in, self.cfg.hidden_dim, d_out, self.cfg.num_layers, p_drop)
        elif name == "gat":
            model = GNNGAT(d_in, self.cfg.hidden_dim, d_out, self.cfg.num_layers, self.cfg.heads, p_drop)
        elif name in {"gatv2", "gat2"}:
            model = GNNGAT2(
                d_in,
                self.cfg.hidden_dim,
                d_out,
                self.cfg.num_layers,
                self.cfg.heads,
                edge_dim,
                p_drop,
            )
        else:
            raise ValueError(f"Unknown model_name '{self.cfg.model_name}'")

        self.model = model.to(self.device)
        return self

    # ------------------------------------------------------- 4. compile
    def compile(self) -> "GNNExperiment":
        if self.model is None:
            raise RuntimeError("Call init_model() before compile()")

        # ---- loss ------------------------------------------------------
        if self.cfg.loss_fn.lower() in {"bce", "binary"}:
            self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="none")
            self.metric_fn = torch.nn.BCEWithLogitsLoss(reduction="none")
        elif self.cfg.loss_fn.lower() == "cross_entropy":
            self.loss_fn = torch.nn.CrossEntropyLoss()
            self.metric_fn = self.loss_fn
        else:
            raise ValueError(f"Unsupported loss_fn '{self.cfg.loss_fn}'")

        # ---- optimiser -------------------------------------------------
        opt_kw = dict(lr=self.cfg.lr, **self.cfg.optimiser_kwargs)
        opt_name = self.cfg.optimiser.lower()
        if opt_name == "adam":
            self.optimizer = torch.optim.Adam(self.model.parameters(), **opt_kw)
        elif opt_name == "adamw":
            self.optimizer = torch.optim.AdamW(self.model.parameters(), **opt_kw)
        elif opt_name == "sgd":
            self.optimizer = torch.optim.SGD(self.model.parameters(), **opt_kw)
        else:
            raise ValueError(f"Unknown optimiser '{self.cfg.optimiser}'")

        # ---- scheduler -------------------------------------------------
        self.scheduler = None
        if self.cfg.lr_scheduler is not None:
            sname = self.cfg.lr_scheduler.lower()
            kw = self.cfg.lr_scheduler_kwargs
            if sname == "step":
                self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, **kw)
            elif sname == "cosine":
                self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, **kw)
            elif sname == "plateau":
                self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, **kw)
            elif sname == "onecycle":
                self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, **kw)
            else:
                raise ValueError(f"Unknown lr_scheduler '{self.cfg.lr_scheduler}'")

        # track last LR for pretty print
        self._last_lr = self.optimizer.param_groups[0]["lr"]
        return self

    # ------------------------------------------------ edge-dropout helper
    def _apply_edge_dropout(self, data: Data) -> Data:
        p = getattr(self.cfg, "edge_dropout", 0.0)
        if p <= 0.0 or not self.model.training:
            return data
        E = data.edge_index.size(1)
        keep_mask = torch.rand(E, device=data.edge_index.device) >= p
        if keep_mask.all():
            return data
        data.edge_index = data.edge_index[:, keep_mask]
        if data.edge_attr is not None:
            data.edge_attr = data.edge_attr[keep_mask]
        return data

    # ------------------------------------------------ train-epoch helper
    def _train_epoch(self, debug: bool = False) -> float:
        self.model.train()
        total = 0.0
        for step, data in enumerate(self.train_dl):
            data = data.to(self.device)
            data = self._apply_edge_dropout(data)

            data.x = data.x.float()
            if data.edge_attr is not None:
                data.edge_attr = data.edge_attr.float()

            self.optimizer.zero_grad()
            logits = self.model(data).squeeze(-1)
            loss_vec = self.loss_fn(logits, data.y)
            loss_vec = loss_vec * data.node_weight.to(loss_vec.device)
            loss = loss_vec.mean()

            if debug and (has_nan(loss) or has_nan(logits)):
                raise RuntimeError(f"NaN detected at step {step}")

            loss.backward()
            if self.cfg.grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
            self.optimizer.step()

            total += loss.item()
        return total / len(self.train_dl)

    # ------------------------------------------------ eval helper
    @torch.no_grad()
    def _eval_loader(self, loader: DataLoader) -> Tuple[float, float]:
        self.model.eval()
        total_loss = total_metric = 0.0
        for data in loader:
            data = data.to(self.device)
            data.x = data.x.float()
            if data.edge_attr is not None:
                data.edge_attr = data.edge_attr.float()

            logits = self.model(data).squeeze(-1)
            loss_vec = self.loss_fn(logits, data.y)
            metric_vec = self.metric_fn(logits, data.y)

            w = data.node_weight.to(loss_vec.device)
            total_loss += (loss_vec * w).mean().item()
            total_metric += (metric_vec * w).mean().item()
        n = len(loader)
        return total_loss / n, total_metric / n

    # ------------------------------------------------ main training loop
    def train(self, debug: bool = False):
        if any(v is None for v in (self.model, self.optimizer, self.loss_fn)):
            raise RuntimeError("Call init_model() and compile() before train()")

        for epoch in range(1, self.cfg.epochs + 1):
            tr_loss = self._train_epoch(debug)
            val_loss, val_metric = self._eval_loader(self.val_dl)

            # scheduler
            if self.scheduler is not None:
                if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_loss)
                else:
                    self.scheduler.step()

            # LR print
            if self.cfg.print_lr_each_epoch:
                curr_lr = self.optimizer.param_groups[0]["lr"]
                lr_msg = f"LR changed → {curr_lr:.2e}" if curr_lr != self._last_lr else f"LR={curr_lr:.2e}"
                self._last_lr = curr_lr
            else:
                lr_msg = ""

            # store history
            self.history["train_loss"].append(tr_loss)
            self.history["val_loss"].append(val_loss)
            self.history["val_metric"].append(val_metric)

            print(
                f"[Epoch {epoch:03d}/{self.cfg.epochs}] "
                f"train={tr_loss:.4f}  val={val_loss:.4f}  "
                f"metric={val_metric:.4f}  {lr_msg}"
            )

            # early stopping
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.patience_counter = 0
                self.best_ckpt = {k: v.cpu() for k, v in self.model.state_dict().items()}
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.cfg.patience:
                    print(f"Early stopping at epoch {epoch}")
                    break
        return self

    # ------------------------------------------------ evaluate
    def evaluate(self, split: str = "test") -> Dict[str, float]:
        loader = {"train": self.train_dl, "val": self.val_dl, "test": self.test_dl}.get(split)
        if loader is None:
            raise ValueError("split must be 'train', 'val' or 'test'")
        loss, metric = self._eval_loader(loader)
        return {"loss": loss, "metric": metric}

    # ------------------------------------------------ predict
    @torch.no_grad()
    def predict(
        self,
        node_frames_new: Dict[str, pd.DataFrame],
        edge_frames_new: Dict[str, pd.DataFrame],
        timestamps: Optional[List[pd.Timestamp]] = None,
        return_df: bool = True,
    ):
        if self.model is None:
            raise RuntimeError("Train or load a model before calling predict()")

        # decide timestamps
        if timestamps is None:
            idxs = (
                [df.index for df in node_frames_new.values()]
                + [df.index for df in edge_frames_new.values()]
            )
            timestamps = sorted(set.intersection(*(set(ix) for ix in idxs)))
        else:
            timestamps = [pd.Timestamp(ts) for ts in timestamps]

        snaps = []
        for ts in timestamps:
            feats = []
            for region in self.reg_order:
                feats.append(
                    node_frames_new[region]
                    .loc[ts]
                    .drop(self.cfg.target_col)
                    .to_numpy(dtype=np.float32)
                )
            x = torch.tensor(np.vstack(feats), dtype=torch.float32)

            edge_rows = []
            for s_idx, d_idx in self.edge_order:
                s, d = self.reg_order[s_idx], self.reg_order[d_idx]
                edge_rows.append(
                    edge_frames_new[f"{s}-{d}".lower()].loc[ts].to_numpy(dtype=np.float32)
                )
            edge_attr = torch.tensor(np.vstack(edge_rows), dtype=torch.float32)
            edge_index = torch.tensor(np.array(self.edge_order).T, dtype=torch.long)

            snaps.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr))

        loader = DataLoader(snaps, batch_size=self.cfg.batch_size, shuffle=False)
        self.model.eval()
        preds = []
        for data in loader:
            data = data.to(self.device)
            data.x = data.x.float()
            if data.edge_attr is not None:
                data.edge_attr = data.edge_attr.float()
            logits = self.model(data).squeeze(-1).cpu()
            preds.append(logits)

        y_hat = torch.cat(preds, dim=0)           # (len(ts)*N)
        y_hat = y_hat.reshape(len(timestamps), len(self.reg_order))

        if return_df:
            return pd.DataFrame(
                y_hat.numpy(),
                index=pd.to_datetime(timestamps),
                columns=self.reg_order,
            )
        return y_hat

    # ------------------------------------------------ plot history
    def plot_history(self, metric: str = "loss"):
        if not self.history["train_loss"]:
            raise RuntimeError("Nothing in history – did you call train()?")

        if metric == "loss":
            plt.figure(figsize=(6, 4))
            plt.plot(self.history["train_loss"], label="Train")
            plt.plot(self.history["val_loss"], label="Val")
            plt.xlabel("Epoch")
            plt.ylabel("Loss")
            plt.title("Training / Validation Loss")
            plt.legend()
            plt.grid(True)
            plt.show()

        elif metric in {"metric", "mae", "accuracy"}:
            if not self.history["val_metric"]:
                raise ValueError("Metric history empty; choose 'loss' instead.")
            plt.figure(figsize=(6, 4))
            plt.plot(self.history["val_metric"], label=f"Val {metric}")
            plt.xlabel("Epoch")
            plt.ylabel(metric.upper())
            plt.title(f"Validation {metric.upper()} Curve")
            plt.legend()
            plt.grid(True)
            plt.show()
        else:
            raise ValueError("metric must be 'loss' or 'metric'")
