#### Config.py file 

In [None]:
"""
Centralised run-configuration object for the GNN project.
(Backwards-compatible; adds optional edge-classification settings.)
"""

import json
import pathlib
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, Optional, Union, List

try:
    import yaml  # type: ignore
except ImportError:  # pragma: no cover
    yaml = None  # YAML support is optional


# ---------------------------------------------------------------------------
# Dataclass
# ---------------------------------------------------------------------------

@dataclass
class GNNConfig:
    # -------- task / model --------------------------------------------------
    task: str = "node_clf"          # {"node_clf","node_reg","edge_clf"}  (new: "edge_clf")
    model_name: str = "gatv2"       # {"gcn","graphsage","gat","gatv2", ...}
    num_layers: int = 2
    hidden_dim: int = 64
    heads: int = 4
    dropout: float = 0.5            # layer dropout
    norm: str = "batch"             # {"batch","layer", None}

    # -------- data ----------------------------------------------------------
    target_col: str = "target"      # node target (unchanged)
    # (optional) multi-target for nodes:
    target_cols: Optional[List[str]] = None

    split_mode: str = "date"        # {"date","ratio"}
    cutoff_date: Optional[str] = None
    val_ratio: float = 0.10
    test_ratio: float = 0.10
    shuffle_in_split: bool = False

    # -------- optimisation --------------------------------------------------
    batch_size: int = 32
    lr: float = 1e-3
    epochs: int = 100
    patience: int = 10
    grad_clip: Optional[float] = 1.0

    optimiser: str = "adam"         # {"adam","adamw","sgd"}
    optimiser_kwargs: Dict[str, Any] = field(default_factory=dict)

    # -------- scheduler -----------------------------------------------------
    lr_scheduler: Optional[str] = None      # {"step","plateau","cosine","onecycle"}
    lr_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict)
    print_lr_each_epoch: bool = True

    # -------- loss / metric -------------------------------------------------
    loss_fn: str = "bce"            # {"bce","focal_bce","cross_entropy","mse","mae","huber"}
    class_weights: Optional[Union[float, List[float], str]] = None  # node pos_weight; allow "auto"
    node_pos_weights: Optional[List[float]] = None                  # per-node weighting (node tasks)
    node_reduction: str = "mean"
    metric: str = "acc"

    # -------- regularisation -----------------------------------------------
    in_dropout: float = 0.0
    edge_dropout: float = 0.0

    # -------- misc ----------------------------------------------------------
    device: str = "cuda"
    run_name: str = "default_run"
    seed: int = 42

    # =======================================================================
    # Edge-classification additions (all optional; keep default node behavior)
    # =======================================================================

    # When task == "edge_clf" (or decoder is set), the trainer will use
    # an encoder-only graph NN + an edge decoder to classify edges.
    decoder: Optional[str] = None   # {"dot","concat_mlp","hadamard_mlp","bilinear"}; None -> node tasks
    decoder_kwargs: Dict[str, Any] = field(default_factory=dict)

    # Edge targets
    edge_target_col: str = "edge_target"        # or:
    edge_target_cols: Optional[List[str]] = None  # multi-target edges

    # Edge class imbalance handling
    edge_class_weights: Optional[Union[float, List[float], str]] = None
    # same semantics as class_weights: None | scalar | list | "auto"
    weight_smooth: float = 1.0  # optional smoothing exponent for per-edge weights

    # -----------------------------------------------------------------------
    # utils
    # -----------------------------------------------------------------------
    def is_edge_task(self) -> bool:
        """Single place to decide if we run the edge pipeline."""
        if isinstance(self.task, str) and self.task.lower().startswith("edge"):
            return True
        return self.decoder is not None  # explicit override

    @staticmethod
    def load(cfg: Union["GNNConfig", str, Dict[str, Any], pathlib.Path]) -> "GNNConfig":
        """
        Flexible loader:

        * existing `GNNConfig` -> returned untouched
        * `dict` -> validated & returned
        * `str` or `pathlib.Path` file path (YAML / JSON) -> loaded & validated
        """
        if isinstance(cfg, GNNConfig):
            return cfg
        if isinstance(cfg, dict):
            return GNNConfig(**cfg)

        # path-like
        path = pathlib.Path(cfg)
        if not path.exists():
            raise FileNotFoundError(path)

        with open(path, "r", encoding="utf-8") as fh:
            if path.suffix.lower() in {".yaml", ".yml"}:
                if yaml is None:  # pragma: no cover
                    raise RuntimeError("PyYAML not installed: `pip install pyyaml`")
                data = yaml.safe_load(fh)
            elif path.suffix.lower() == ".json":
                data = json.load(fh)
            else:
                raise ValueError(f"Unsupported config file type: {path.suffix}")

        return GNNConfig(**data)

    # ----------------------------- to-* helpers -----------------------------
    def to_dict(self) -> Dict[str, Any]:
        """Return a deep dict (handy for logging)."""
        return asdict(self)

    def to_yaml(self) -> str:
        """Return a YAML string representation (requires PyYAML)."""
        if yaml is None:  # pragma: no cover
            raise RuntimeError("PyYAML not installed: `pip install pyyaml`")
        return yaml.safe_dump(self.to_dict(), sort_keys=False)


# ---------------------------------------------------------------------------
# Convenience free-function loader
# ---------------------------------------------------------------------------
def load_config(cfg_like: Union[str, pathlib.Path, Dict[str, Any], GNNConfig]) -> GNNConfig:
    """
    Wrapper so callers can simply write ::

        cfg = load_config("configs/my_run.yaml")

    instead of :func:`GNNConfig.load`.
    """
    return GNNConfig.load(cfg_like)


#### Data loading 

In [None]:
"""
Utilities that convert raw Pandas DataFrames -> PyTorch Geometric
snapshots -> DataLoaders.

Node tasks:
  1) make_snapshots        : raw frames  -> List[Data] (node labels in .y)
  2) split_snapshots       : date/ratio split
  3) build_dataloaders     : returns (train_dl, val_dl, test_dl)

Edge tasks (cfg.is_edge_task() == True):
  1) make_edge_snapshots   : raw frames  -> List[Data] (edge labels in .edge_label)
  2) split_snapshots       : same
  3) build_dataloaders     : same API
"""

from typing import Dict, List, Tuple, Optional, Sequence
import pathlib

import numpy as np
import pandas as pd
import torch
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from jp_da_imb.gnn.config import GNNConfig


# --------------------------------------------------------------------------- helpers

def _edge_key(src: str, dst: str) -> str:
    """Normalise edge dictionary keys (case-insensitive, arrow style)."""
    return f"{src}->{dst}".lower()


def _common_timestamps(idxs: Sequence[pd.DatetimeIndex]) -> List[pd.Timestamp]:
    """Intersection of all DatetimeIndexes - raise if empty."""
    common = sorted(set.intersection(*(set(ix) for ix in idxs)))
    if not common:
        raise ValueError("No common timestamps across supplied DataFrames")
    return [pd.Timestamp(ts) for ts in common]


# --------------------------------------------------------------------------- NODE snapshots (existing)

def make_snapshots(
    *,
    node_frames: Dict[str, pd.DataFrame],
    edge_frames: Dict[str, pd.DataFrame],
    graph: nx.DiGraph,
    cfg: GNNConfig,
) -> List[Data]:
    """
    Convert raw frames into an ordered list of `torch_geometric.data.Data`
    snapshots for NODE tasks. Node ordering = sorted(graph.nodes).
    """
    # ----- node & edge order ------------------------------------------------
    reg_order: List[str] = sorted(graph.nodes)
    node_pos: Dict[str, int] = {r: i for i, r in enumerate(reg_order)}
    edge_order: List[Tuple[int, int]] = [
        (node_pos[s], node_pos[d]) for (s, d) in graph.edges
    ]

    # ----- optional node-level weights -------------------------------------
    if cfg.node_pos_weights is not None:
        if len(cfg.node_pos_weights) != len(reg_order):
            raise ValueError(
                "len(node_pos_weights) must equal number of nodes "
                f"({len(reg_order)})"
            )
        node_weight_lookup = torch.tensor(
            cfg.node_pos_weights, dtype=torch.float32
        )
    else:
        node_weight_lookup = None

    # ----- decide which columns are targets --------------------------------
    if getattr(cfg, "target_cols", None) is not None:
        target_cols: List[str] = list(cfg.target_cols)  # type: ignore
    else:
        target_cols: List[str] = [cfg.target_col]

    # ----- timestamps intersection -----------------------------------------
    ts_common = _common_timestamps(
        [df.index for df in node_frames.values()]
        + [df.index for df in edge_frames.values()]
    )

    snapshots: List[Data] = []
    for ts in ts_common:
        # ---------- build node matrix & labels ------------------------------
        feats: List[np.ndarray] = []
        labels: List[np.ndarray] = []
        for region in reg_order:
            row = node_frames[region].loc[ts]
            labels.append(
                row[target_cols].to_numpy(dtype=np.float32, copy=False)
            )
            feats.append(
                row.drop(labels=target_cols).to_numpy(
                    dtype=np.float32, copy=False
                )
            )

        x = torch.tensor(np.vstack(feats), dtype=torch.float32)
        y = torch.tensor(np.vstack(labels), dtype=torch.float32)

        # ---------- build edge features ------------------------------------
        edge_rows: List[np.ndarray] = []
        for s_idx, d_idx in edge_order:
            s, d = reg_order[s_idx], reg_order[d_idx]
            edge_rows.append(
                edge_frames[_edge_key(s, d)]
                .loc[ts]
                .to_numpy(dtype=np.float32, copy=False)
            )

        edge_attr = torch.tensor(np.vstack(edge_rows), dtype=torch.float32)
        edge_index = torch.tensor(
            np.array(edge_order).T, dtype=torch.long
        )

        # ---------- package snapshot ---------------------------------------
        node_weight = (
            node_weight_lookup.clone()
            if node_weight_lookup is not None
            else None
        )

        snapshots.append(
            Data(
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
                node_weight=node_weight,
                snap_time=torch.tensor([ts.value]),
            )
        )

    # sort by timestamp just in case
    snapshots.sort(key=lambda g: g.snap_time.item())
    return snapshots


# --------------------------------------------------------------------------- EDGE snapshots (new)

def _resolve_edge_target_cols(df_any: pd.DataFrame, cfg: GNNConfig) -> List[str]:
    """
    Decide which columns in edge frames are the target(s).
    Priority:
      1) cfg.edge_target_cols (if provided and present)
      2) cfg.edge_target_col (default "edge_target") if present
      3) fallback "target" if present
    """
    cols = list(df_any.columns)

    # explicit multi-target
    if getattr(cfg, "edge_target_cols", None):
        missing = [c for c in cfg.edge_target_cols if c not in cols]
        if missing:
            raise ValueError(f"Edge target cols missing in edge frame: {missing}")
        return list(cfg.edge_target_cols)  # type: ignore

    # single-column configured
    if getattr(cfg, "edge_target_col", None) and cfg.edge_target_col in cols:
        return [cfg.edge_target_col]  # type: ignore

    # fallback common name
    if "target" in cols:
        return ["target"]

    raise ValueError(
        "Could not resolve edge target column(s). "
        "Set cfg.edge_target_col / edge_target_cols or include 'target'."
    )


def make_edge_snapshots(
    *,
    node_frames: Dict[str, pd.DataFrame],
    edge_frames: Dict[str, pd.DataFrame],
    graph: nx.DiGraph,
    cfg: GNNConfig,
) -> List[Data]:
    """
    Build snapshots for EDGE tasks.
      - node features in .x
      - edge_index (2,E) from graph edges
      - edge_attr (E, F_edge) if any feature columns remain after dropping targets
      - edge_label (E, T_edge) from configured target column(s)
      - snap_time (int64 ns)
    """
    # fixed node order & edge order
    reg_order: List[str] = sorted(graph.nodes)
    node_pos: Dict[str, int] = {r: i for i, r in enumerate(reg_order)}
    edge_order: List[Tuple[int, int]] = [
        (node_pos[s], node_pos[d]) for (s, d) in graph.edges
    ]

    # timestamps intersection across all node & edge frames
    ts_common = _common_timestamps(
        [df.index for df in node_frames.values()]
        + [df.index for df in edge_frames.values()]
    )

    # infer target/feature columns once from any edge frame
    first_edge_key = next(iter(edge_frames))
    target_cols = _resolve_edge_target_cols(edge_frames[first_edge_key], cfg)
    feat_cols = [c for c in edge_frames[first_edge_key].columns if c not in target_cols]

    snapshots: List[Data] = []
    for ts in ts_common:
        # node features
        x_rows: List[np.ndarray] = []
        for region in reg_order:
            row = node_frames[region].loc[ts]
            # keep ALL node features (edge task has no node targets)
            x_rows.append(row.to_numpy(dtype=np.float32, copy=False))
        x = torch.tensor(np.vstack(x_rows), dtype=torch.float32)

        # edge attributes and labels in same E order
        e_feat_rows: List[np.ndarray] = []
        e_label_rows: List[np.ndarray] = []
        for s_idx, d_idx in edge_order:
            s, d = reg_order[s_idx], reg_order[d_idx]
            row = edge_frames[_edge_key(s, d)].loc[ts]

            # labels
            lab = row[target_cols].to_numpy(dtype=np.float32, copy=False)
            e_label_rows.append(lab if lab.ndim > 0 else np.array([lab], dtype=np.float32))

            # features (may be empty)
            if feat_cols:
                e_feat_rows.append(
                    row[feat_cols].to_numpy(dtype=np.float32, copy=False)
                )

        edge_index = torch.tensor(np.array(edge_order).T, dtype=torch.long)
        edge_label = torch.tensor(np.vstack(e_label_rows), dtype=torch.float32)
        data_kwargs = dict(
            x=x,
            edge_index=edge_index,
            edge_label=edge_label,
            snap_time=torch.tensor([ts.value]),
        )
        if feat_cols:
            data_kwargs["edge_attr"] = torch.tensor(np.vstack(e_feat_rows), dtype=torch.float32)

        snapshots.append(Data(**data_kwargs))

    snapshots.sort(key=lambda g: g.snap_time.item())
    return snapshots


# --------------------------------------------------------------------------- split (shared)

def split_snapshots(
    snapshots: List[Data],
    cfg: GNNConfig,
) -> Tuple[List[Data], List[Data], List[Data]]:
    """
    Train/val/test split by **date** or **ratio** (as defined in cfg).
    """
    if cfg.split_mode == "date":
        if cfg.cutoff_date is None:
            raise ValueError("cutoff_date must be set when split_mode == 'date'")
        cutoff_int = pd.Timestamp(cfg.cutoff_date).value
        train_set = [g for g in snapshots if g.snap_time.item() <= cutoff_int]
        holdout = [g for g in snapshots if g.snap_time.item() > cutoff_int]
        if not holdout:
            raise ValueError("No snapshots after cutoff_date for val/test split")
        mid = len(holdout) // 2
        val_set, test_set = holdout[:mid], holdout[mid:]
    else:  # ratio
        n_total = len(snapshots)
        n_test = int(n_total * cfg.test_ratio)
        n_val = int(n_total * cfg.val_ratio)
        n_train = n_total - n_val - n_test
        train_set = snapshots[:n_train]
        val_set = snapshots[n_train : n_train + n_val]
        test_set = snapshots[n_train + n_val :]

    return train_set, val_set, test_set


# --------------------------------------------------------------------------- dataloaders (dispatch)

def build_dataloaders(
    *,
    node_frames: Dict[str, pd.DataFrame],
    edge_frames: Dict[str, pd.DataFrame],
    graph: nx.DiGraph,
    c

#### model 

In [None]:
"""
Core model zoo — every network lives in this single file. Register new
builders with the decorator from ``gnn.models.registry``.

Backwards-compatible:
- Node tasks: builders return logits by default (unchanged).
- Edge tasks: pass as_encoder=True to get node embeddings (no final head),
  then pair with a decoder from the small zoo below.
"""

from __future__ import annotations
from typing import Optional

import torch
import torch.nn as nn
from torch_geometric.nn import (
    GCN,
    GraphSAGE,
    GAT,
    GATv2Conv,
    GCNConv,
)

from jp_da_imb.gnn.models.registry import register


# ======================================================================
# Encoders / Predictors
# ======================================================================

# ------------------------------------------------------------------ GCN

@register("gcn")
def build_gcn(
    *,
    d_in: int,
    d_out: int = 1,              # kept for signature compatibility
    hidden_dim: int,
    num_layers: int,
    dropout: float,
    norm: str = "batch",
    as_encoder: bool = False,     # NEW: when True, return embeddings
    **_,
) -> nn.Module:
    if as_encoder:
        # produce hidden_dim features (no final prediction head)
        return GCN(
            in_channels=d_in,
            hidden_channels=hidden_dim,
            num_layers=num_layers,
            out_channels=hidden_dim,
            dropout=dropout,
            norm=norm,
            act="relu",
            **_,
        )
    else:
        # original predictor path (unchanged)
        return GCN(
            in_channels=d_in,
            hidden_channels=hidden_dim,
            num_layers=num_layers,
            out_channels=d_out,
            dropout=dropout,
            norm=norm,
            act="relu",
            **_,
        )


# ------------------------------------------------------------------ GraphSAGE

class _GraphSageData(nn.Module):
    """Adapter so we can call GraphSAGE (model) on a `Data` object"""

    def __init__(
        self,
        *,
        d_in: int,
        d_out: int,
        hidden_dim: int,
        num_layers: int,
        dropout: float,
        norm: str,
        as_encoder: bool = False,
        **_,
    ):
        super().__init__()
        out_channels = hidden_dim if as_encoder else d_out
        self.body = GraphSAGE(
            in_channels=d_in,
            hidden_channels=hidden_dim,
            num_layers=num_layers,
            out_channels=out_channels,
            dropout=dropout,
            norm=norm,
            act="relu",
        )

    def forward(self, data):
        return self.body(data.x, data.edge_index)


@register("graphsage")
@register("sage")  # alias
def build_sage(**kw) -> nn.Module:
    return _GraphSageData(**kw)


# ------------------------------------------------------------------ GAT

@register("gat")
def build_gat(
    *,
    d_in: int,
    d_out: int = 1,
    hidden_dim: int,
    num_layers: int,
    heads: int,
    dropout: float,
    norm: str = "batch",
    as_encoder: bool = False,
    **_,
) -> nn.Module:
    out_channels = hidden_dim if as_encoder else d_out
    return GAT(
        in_channels=d_in,
        hidden_channels=hidden_dim,
        num_layers=num_layers,
        out_channels=out_channels,
        heads=heads,
        dropout=dropout,
        norm=norm,
        act="relu",
        **_,
    )


# ------------------------------------------------------------------ GATv2 (custom stack)

class _GATv2Net(nn.Module):
    """
    A flexible GATv2 implemented with `torch_geometric.nn.GATv2Conv`.

    Parameters
    ----------
    d_in       : input feature dimension
    d_out      : output dimension (e.g. 1 for binary logit) [ignored if as_encoder=True]
    hidden_dim : hidden feature dimension inside each attention head
    num_layers : number of GATv2Conv layers
    heads      : number of attention heads per layer (kept constant)
    dropout    : dropout after each layer
    edge_dim   : optional edge-feature dimension (None -> no edge_attr)
    norm       : "batch" | "layer" | None  – applied after each layer
    as_encoder : if True, forward() returns node embeddings of size hidden_dim*heads
    """

    def __init__(
        self,
        *,
        d_in: int,
        d_out: int,
        hidden_dim: int,
        num_layers: int,
        heads: int,
        dropout: float,
        edge_dim: Optional[int] = None,
        norm: str = "batch",
        as_encoder: bool = False,
    ):
        super().__init__()

        self.as_encoder = as_encoder
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        in_ch = d_in
        for _ in range(num_layers):
            conv = GATv2Conv(
                in_channels=in_ch,
                out_channels=hidden_dim,
                heads=heads,
                concat=True,          # keep heads concatenated
                edge_dim=edge_dim,
                dropout=dropout,
            )
            self.convs.append(conv)

            if norm == "batch":
                self.norms.append(nn.BatchNorm1d(hidden_dim * heads))
            elif norm == "layer":
                self.norms.append(nn.LayerNorm(hidden_dim * heads))
            else:
                self.norms.append(None)

            in_ch = hidden_dim * heads  # next layer's input size

        self.dropout = nn.Dropout(p=dropout)
        # final head only used when NOT an encoder
        self.head = nn.Linear(in_ch, d_out)

    def forward(self, data):
        x, e1 = data.x, data.edge_index
        ea = getattr(data, "edge_attr", None)

        for conv, norm in zip(self.convs, self.norms):
            x = conv(x=x, edge_index=e1, edge_attr=ea)
            if norm is not None:
                x = norm(x)
            x = torch.relu(x)
            x = self.dropout(x)

        if self.as_encoder:
            return x  # embeddings [N, hidden*heads]
        return self.head(x)


@register("gatv2")
def build_gatv2(
    *,
    d_in: int,
    d_out: int = 1,
    hidden_dim: int,
    num_layers: int,
    heads: int,
    dropout: float,
    edge_dim: Optional[int] = None,
    norm: str = "batch",
    as_encoder: bool = False,
    **_,
) -> nn.Module:
    # Registry wrapper – instantiated by `build_model(...)`.
    return _GATv2Net(
        d_in=d_in,
        d_out=d_out,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        heads=heads,
        dropout=dropout,
        edge_dim=edge_dim,
        norm=norm,
        as_encoder=as_encoder,
    )


# ------------------------------------------------------------------ Simple 2-layer baseline

class _GNNNodeSimple(nn.Module):
    """Two GCNConv layers + linear head (predictor)."""

    def __init__(
        self,
        *,
        d_in: int,
        d_out: int,
        hidden_dim: int,
        dropout: float,
        as_encoder: bool = False,
        **_,
    ):
        super().__init__()
        self.as_encoder = as_encoder
        self.conv1 = GCNConv(in_channels=d_in, out_channels=hidden_dim)
        self.conv2 = GCNConv(in_channels=hidden_dim, out_channels=hidden_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.head = nn.Linear(in_features=hidden_dim, out_features=d_out)

    def forward(self, data):
        x, e1 = data.x, data.edge_index
        x = torch.relu(self.conv1(x=x, edge_index=e1))
        x = torch.relu(self.conv2(x=x, edge_index=e1))
        x = self.dropout(x)
        if self.as_encoder:
            return x
        return self.head(x)


@register("simple")
@register("baseline")
def build_simple(
    *,
    d_in: int,
    d_out: int = 1,
    hidden_dim: int,
    dropout: float,
    as_encoder: bool = False,
    **_,
) -> nn.Module:
    return _GNNNodeSimple(
        d_in=d_in,
        d_out=d_out,
        hidden_dim=hidden_dim,
        dropout=dropout,
        as_encoder=as_encoder,
    )


# ======================================================================
# Edge decoders (z_u, z_v) -> logit
# ======================================================================

class DotDecoder(nn.Module):
    """z_u ⊙ z_v"""
    def forward(self, z_src: torch.Tensor, z_dst: torch.Tensor) -> torch.Tensor:
        return (z_src * z_dst).sum(-1, keepdim=True)


class ConcatMLPDecoder(nn.Module):
    """MLP on [z_u || z_v]"""
    def __init__(self, d_in: int, hidden: int = 128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2 * d_in, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )

    def forward(self, z_src: torch.Tensor, z_dst: torch.Tensor) -> torch.Tensor:
        return self.mlp(torch.cat([z_src, z_dst], dim=-1))


class HadamardMLPDecoder(nn.Module):
    """MLP on (z_u ⊙ z_v)"""
    def __init__(self, d_in: int, hidden: int = 128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(d_in, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )

    def forward(self, z_src: torch.Tensor, z_dst: torch.Tensor) -> torch.Tensor:
        return self.mlp(z_src * z_dst)


class BilinearDecoder(nn.Module):
    """z_uᵀ W z_v"""
    def __init__(self, d_in: int):
        super().__init__()
        self.W = nn.Parameter(torch.empty(d_in, d_in))
        nn.init.xavier_uniform_(self.W)

    def forward(self, z_src: torch.Tensor, z_dst: torch.Tensor) -> torch.Tensor:
        return (z_src @ self.W * z_dst).sum(-1, keepdim=True)


# Small decoder registry + builder (kept local for simplicity)
_DECODERS = {
    "dot": lambda d_in, **kw: DotDecoder(),
    "concat_mlp": lambda d_in, **kw: ConcatMLPDecoder(d_in, hidden=kw.get("mlp_hidden", 128)),
    "hadamard_mlp": lambda d_in, **kw: HadamardMLPDecoder(d_in, hidden=kw.get("mlp_hidden", 128)),
    "bilinear": lambda d_in, **kw: BilinearDecoder(d_in),
}


def build_decoder(name: str, d_in: int, **kwargs) -> nn.Module:
    """
    Create an edge decoder by name. Example:
        dec = build_decoder(cfg.decoder, d_in=latent_dim, mlp_hidden=128)
    """
    key = name.lower()
    try:
        factory = _DECODERS[key]
    except KeyError as exc:
        raise ValueError(f"Unknown decoder '{name}'. Available: {list(_DECODERS)}") from exc
    return factory(d_in, **kwargs)

enc = build_model(
    name=cfg.model_name,
    d_in=d_in, hidden_dim=cfg.hidden_dim, num_layers=cfg.num_layers,
    heads=cfg.heads, dropout=cfg.dropout, norm=cfg.norm,
    as_encoder=True,              # <- new switch
    edge_dim=edge_dim,            # if you have edge_attr and use GATv2
)
latent_dim = cfg.hidden_dim * (cfg.heads if cfg.model_name.lower().startswith("gat") else 1)
dec = build_decoder(cfg.decoder, d_in=latent_dim, **(cfg.decoder_kwargs or {})

#### loss 

In [None]:
"""
Utility functions that scan a DataLoader once to derive class imbalance
statistics (for binary / multi-label classification).

- compute_class_weights       : for NODE labels (batch.y)
- compute_edge_class_weights  : for EDGE labels (batch.edge_label)
"""

from __future__ import annotations
from typing import Iterable
import torch
from torch_geometric.loader import DataLoader


def _reduce_pos_neg(y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    y: (..., T) with values in {0,1}
    Returns (pos_counts[T], neg_counts[T])
    """
    y = y.reshape(-1, y.size(-1)).float()
    pos = y.sum(dim=0)
    neg = (1.0 - y).sum(dim=0)
    return pos, neg


def compute_class_weights(
    train_loader: DataLoader,
    eps: float = 1e-6,
) -> float | torch.Tensor:
    """
    Return a `pos_weight` (scalar or per-target tensor) suitable for
    torch.nn.functional.binary_cross_entropy_with_logits on NODE tasks.
    """
    n_pos = n_neg = None
    for batch in train_loader:
        y = batch.y
        pos, neg = _reduce_pos_neg(y)
        n_pos = pos if n_pos is None else n_pos + pos
        n_neg = neg if n_neg is None else n_neg + neg

    if n_pos is None:
        raise ValueError("Training loader yielded no batches")

    pos_weight = (n_neg + eps) / (n_pos + eps)
    return pos_weight.item() if pos_weight.numel() == 1 else pos_weight


def compute_edge_class_weights(
    train_loader: DataLoader,
    eps: float = 1e-6,
) -> float | torch.Tensor:
    """
    Return a `pos_weight` (scalar or per-target tensor) for EDGE tasks,
    scanning batch.edge_label.
    """
    n_pos = n_neg = None
    for batch in train_loader:
        y = batch.edge_label
        # ensure shape (..., T)
        if y.dim() == 1:
            y = y.unsqueeze(-1)
        pos, neg = _reduce_pos_neg(y)
        n_pos = pos if n_pos is None else n_pos + pos
        n_neg = neg if n_neg is None else n_neg + neg

    if n_pos is None:
        raise ValueError("Training loader yielded no batches")

    pos_weight = (n_neg + eps) / (n_pos + eps)
    return pos_weight.item() if pos_weight.numel() == 1 else pos_weight


In [None]:
"""
Core loss zoo + registry.

Each build_* returns a callable:

    loss_fn(pred, target, *, sample_weight=None, node_weight=None)

- Accepts either `sample_weight` (new, generic) or `node_weight` (back-compat alias)
- Supports class-imbalance `pos_weight` for BCE-style losses
"""

from __future__ import annotations
from typing import Any, Callable, Dict, Optional, Union, Sequence

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader

from jp_da_imb.gnn.loss.class_weights import (
    compute_class_weights,
    compute_edge_class_weights,
)

# --------------------------------------------------------------------------- registry
LOSSES: Dict[str, Callable[..., "LossCallable"]] = {}
LossCallable = Callable[..., torch.Tensor]


def register(name: str) -> Callable[[Callable[..., LossCallable]], Callable[..., LossCallable]]:
    key = name.lower()

    def _decorator(fn: Callable[..., LossCallable]) -> Callable[..., LossCallable]:
        if key in LOSSES:
            raise RuntimeError(f"Loss {name} already registered")
        LOSSES[key] = fn
        return fn

    return _decorator


PosWeightLike = Union[float, torch.Tensor, Sequence[float]]


def _auto_pos_weight(train_loader: DataLoader) -> float | torch.Tensor:
    """
    Detect whether the loader yields NODE batches (.y) or EDGE batches (.edge_label)
    and compute an appropriate pos_weight tensor.
    """
    first_batch = None
    for b in train_loader:
        first_batch = b
        break
    if first_batch is None:
        raise ValueError("Training loader yielded no batches")

    if hasattr(first_batch, "edge_label"):
        return compute_edge_class_weights(train_loader)
    elif hasattr(first_batch, "y"):
        return compute_class_weights(train_loader)
    else:
        raise ValueError("Could not find 'y' or 'edge_label' on training batches")


def build_loss(
    *,
    name: str,
    class_weights: Optional[Union[PosWeightLike, str]] = None,
    train_loader: Optional[DataLoader] = None,
    node_reduction: str = "mean",
    **loss_kw: Any,
) -> LossCallable:
    """
    Parameters
    ----------
    name : registered loss id -> "bce", "focal_bce", "mse", ...
    class_weights : None | scalar | list/tensor | "auto"
        If "auto", a single pass over train_loader is made to compute pos_weight.
    train_loader : required when class_weights == "auto"
    node_reduction : "mean" | "sum" | "none" (applied after sample weighting)
    **loss_kw : forwarded to the specific builder (e.g. gamma=2.0 for focal BCE)

    Returns
    -------
    loss_fn(pred, target, *, sample_weight=None, node_weight=None) -> scalar Tensor
    """
    # resolve pos_weight (BCE-only)
    if isinstance(class_weights, str):
        if class_weights != "auto":
            raise ValueError("class_weights string must be 'auto' or numeric/list")
        if train_loader is None:
            raise RuntimeError("'class_weights: auto' requires the training DataLoader")
        pos_weight = _auto_pos_weight(train_loader)
    else:
        pos_weight = class_weights  # could be None | scalar | list/tensor

    # attach resolved values to kwargs expected by individual builders
    loss_kw = {
        **loss_kw,
        "pos_weight": pos_weight,
        "reduction": node_reduction,
    }
    if name.lower() not in {"bce", "focal_bce"}:
        loss_kw.pop("pos_weight", None)

    # get builder
    try:
        builder = LOSSES[name.lower()]
    except KeyError as exc:
        raise ValueError(f"Unknown loss '{name}'. Available: {list(LOSSES)}") from exc

    return builder(**loss_kw)


# --------------------------------------------------------------------------- helpers

def _apply_weights(
    loss_vec: torch.Tensor,
    weight: Optional[torch.Tensor],
    reduction: str,
) -> torch.Tensor:
    """
    loss_vec: (..., N) or (..., E) or (..., N, T) — elementwise
    weight  : broadcastable to loss_vec
    """
    if weight is not None:
        # broadcast to loss_vec shape
        while weight.dim() < loss_vec.dim():
            weight = weight.unsqueeze(0)
        loss_vec = loss_vec * weight

    if reduction == "mean":
        return loss_vec.mean()
    if reduction == "sum":
        return loss_vec.sum()
    return loss_vec  # "none"


def _coalesce_weight(node_weight: Optional[torch.Tensor], sample_weight: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
    """Prefer sample_weight; fall back to node_weight for back-compat."""
    return sample_weight if sample_weight is not None else node_weight


# --------------------------------------------------------------------------- BCE

@register("bce")
def build_bce(
    *,
    pos_weight: Optional[float | torch.Tensor] = None,
    reduction: str = "mean",
) -> LossCallable:
    """
    Binary / multi-label BCE with optional global *pos_weight*.
    Accepts sample_weight (alias: node_weight) for element-wise weighting.
    """
    def _loss(
        pred: torch.Tensor,
        target: torch.Tensor,
        *,
        sample_weight: Optional[torch.Tensor] = None,
        node_weight: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        w = _coalesce_weight(node_weight, sample_weight)
        loss_vec = F.binary_cross_entropy_with_logits(
            pred,
            target,
            pos_weight=(torch.as_tensor(pos_weight, device=pred.device) if pos_weight is not None else None),
            reduction="none",
        )
        return _apply_weights(loss_vec, w, reduction)

    return _loss


# --------------------------------------------------------------------------- focal BCE

@register("focal_bce")
def build_focal_bce(
    *,
    gamma: float = 2.0,
    alpha: float | None = None,  # class balance factor in [0,1]
    reduction: str = "mean",
) -> LossCallable:
    """
    Focal BCE for imbalanced binary classification.
    """
    def _loss(
        pred: torch.Tensor,
        target: torch.Tensor,
        *,
        sample_weight: Optional[torch.Tensor] = None,
        node_weight: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        w = _coalesce_weight(node_weight, sample_weight)

        # BCE without reduction
        bce = F.binary_cross_entropy_with_logits(pred, target, reduction="none")

        # p_t = exp(-bce) is numerically stable proxy for sigmoid-anchored prob term
        p_t = torch.exp(-bce)
        focal_term = (1 - p_t) ** gamma

        loss_vec = focal_term * bce
        if alpha is not None:
            alpha_t = torch.where(target == 1, torch.as_tensor(alpha, device=pred.device), 1 - torch.as_tensor(alpha, device=pred.device))
            loss_vec = alpha_t * loss_vec

        return _apply_weights(loss_vec, w, reduction)

    return _loss


# --------------------------------------------------------------------------- MSE

@register("mse")
def build_mse(
    *,
    reduction: str = "mean",
) -> LossCallable:
    def _loss(
        pred: torch.Tensor,
        target: torch.Tensor,
        *,
        sample_weight: Optional[torch.Tensor] = None,
        node_weight: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        w = _coalesce_weight(node_weight, sample_weight)
        loss_vec = F.mse_loss(pred, target, reduction="none")
        return _apply_weights(loss_vec, w, reduction)

    return _loss


# --------------------------------------------------------------------------- MAE (L1)

@register("mae")
def build_mae(
    *,
    reduction: str = "mean",
) -> LossCallable:
    def _loss(
        pred: torch.Tensor,
        target: torch.Tensor,
        *,
        sample_weight: Optional[torch.Tensor] = None,
        node_weight: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        w = _coalesce_weight(node_weight, sample_weight)
        loss_vec = F.l1_loss(pred, target, reduction="none")
        return _apply_weights(loss_vec, w, reduction)

    return _loss


# --------------------------------------------------------------------------- Huber

@register("huber")
def build_huber(
    *,
    delta: float = 1.0,
    reduction: str = "mean",
) -> LossCallable:
    def _loss(
        pred: torch.Tensor,
        target: torch.Tensor,
        *,
        sample_weight: Optional[torch.Tensor] = None,
        node_weight: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        w = _coalesce_weight(node_weight, sample_weight)
        loss_vec = F.huber_loss(pred, target, delta=delta, reduction="none")
        return _apply_weights(loss_vec, w, reduction)

    return _loss


In [None]:
"""
Simple metric registry.

All classification metrics expect LOGITS (will apply sigmoid internally).
Regression metrics expect raw predictions.
"""

from __future__ import annotations
from typing import Callable, Dict

import torch
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, precision_score, recall_score

METRICS: Dict[str, Callable[[torch.Tensor, torch.Tensor], float]] = {}


def register(name: str):
    def _decorator(fn):
        METRICS[name.lower()] = fn
        return fn
    return _decorator


# ---------------------- classification

@register("auc")
def auroc(logits: torch.Tensor, target: torch.Tensor) -> float:
    y_true = target.detach().cpu().numpy().ravel()
    y_score = logits.detach().sigmoid().cpu().numpy().ravel()
    return roc_auc_score(y_true, y_score)


@register("pr_auc")
def pr_auc(logits: torch.Tensor, target: torch.Tensor) -> float:
    """Average precision (area under Precision-Recall curve)."""
    y_true = target.detach().cpu().numpy().ravel()
    y_score = logits.detach().sigmoid().cpu().numpy().ravel()
    return average_precision_score(y_true, y_score)


@register("acc")
def accuracy(logits: torch.Tensor, target: torch.Tensor) -> float:
    pred = (logits.sigmoid() > 0.5).float()
    return (pred == target).float().mean().item()


@register("f1")
def f1(logits: torch.Tensor, target: torch.Tensor) -> float:
    y_true = target.detach().cpu().numpy().ravel()
    y_pred = (logits.detach().sigmoid().cpu().numpy().ravel() > 0.5).astype(int)
    return f1_score(y_true, y_pred, zero_division=0)


@register("precision")
def precision(logits: torch.Tensor, target: torch.Tensor) -> float:
    y_true = target.detach().cpu().numpy().ravel()
    y_pred = (logits.detach().sigmoid().cpu().numpy().ravel() > 0.5).astype(int)
    return precision_score(y_true, y_pred, zero_division=0)


@register("recall")
def recall(logits: torch.Tensor, target: torch.Tensor) -> float:
    y_true = target.detach().cpu().numpy().ravel()
    y_pred = (logits.detach().sigmoid().cpu().numpy().ravel() > 0.5).astype(int)
    return recall_score(y_true, y_pred, zero_division=0)


# ---------------------- regression

@register("mae")
def mae(pred: torch.Tensor, target: torch.Tensor) -> float:
    return torch.nn.functional.l1_loss(pred, target, reduction="mean").item()


@register("r2")
def r2(pred: torch.Tensor, target: torch.Tensor) -> float:
    ss_res = torch.sum((target - pred) ** 2)
    ss_tot = torch.sum((target - target.mean()) ** 2)
    return (1 - ss_res / ss_tot).item()


#### Trainer

In [None]:
"""
High-level training/evaluation loop with early stopping, schedulers,
TensorBoard logging, and checkpointing.

Backwards-compatible:
- Node tasks: behavior unchanged.
- Edge tasks: set cfg.task="edge_clf" (or set cfg.decoder) to enable
  encoder+decoder pipeline and train on batch.edge_label.
"""

from __future__ import annotations

from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch

from jp_da_imb.gnn.config import GNNConfig
from jp_da_imb.gnn.models import build_model, build_decoder
from jp_da_imb.gnn.loss.loss_functions import build_loss
from jp_da_imb.gnn.loss.metrics import METRICS


# --------------------------------------------------------------------------- small utilities

def _device_from_cfg(cfg: GNNConfig) -> torch.device:
    if cfg.device == "cuda" and torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def _default_log_dir(run_name: str) -> str:
    return str(Path("runs") / run_name)


def _save_checkpoint(model: nn.Module, path: Path, meta: Dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    torch.save({"state_dict": model.state_dict(), "meta": meta}, path)


# --------------------------------------------------------------------------- core

class Trainer:
    """
    Parameters
    ----------
    cfg      : GNNConfig
    train_dl : DataLoader
    val_dl   : DataLoader
    test_dl  : DataLoader
    log_dir  : optional path for TensorBoard logs (default: ./runs/<run_name>)
    ckpt_dir : optional path to save checkpoints (default: ./checkpoints)
    """

    # ----------------------------- constructor
    def __init__(
        self,
        cfg: GNNConfig,
        train_dl: DataLoader,
        val_dl: DataLoader,
        test_dl: DataLoader,
        *,
        log_dir: Optional[str | Path] = None,
        ckpt_dir: Optional[str | Path] = None,
    ):
        self.cfg = cfg
        self.train_dl, self.val_dl, self.test_dl = train_dl, val_dl, test_dl
        self.device = _device_from_cfg(cfg)
        self.is_edge = cfg.is_edge_task()

        # --- sample batch for shape inference --------------------------------
        sample = next(iter(train_dl))
        d_in = sample.x.size(-1)
        edge_dim = (
            sample.edge_attr.size(-1)
            if getattr(sample, "edge_attr", None) is not None
            else None
        )

        # --- build model(s) ---------------------------------------------------
        if self.is_edge:
            # Encoder returns node embeddings; decoder maps (z_src,z_dst)->logit(s)
            self.encoder = build_model(
                name=cfg.model_name,
                d_in=d_in,
                hidden_dim=cfg.hidden_dim,
                num_layers=cfg.num_layers,
                heads=cfg.heads,
                dropout=cfg.dropout,
                norm=cfg.norm,
                edge_dim=edge_dim,  # used by gatv2
                as_encoder=True,    # <- key switch
            ).to(self.device)

            # Infer latent dim robustly by a dry forward on the sample batch
            with torch.no_grad():
                z_probe = self.encoder(sample.to(self.device))
            latent_dim = z_probe.size(-1)

            if not cfg.decoder:
                raise ValueError("Edge task requires cfg.decoder to be set (dot/concat_mlp/hadamard_mlp/bilinear).")

            self.decoder = build_decoder(
                cfg.decoder, d_in=latent_dim, **(cfg.decoder_kwargs or {})
            ).to(self.device)

            # simple module container for optimizer
            self.model = nn.ModuleDict({"enc": self.encoder, "dec": self.decoder}).to(self.device)

            # edge targets dimension
            t_edge = sample.edge_label.size(-1) if sample.edge_label.dim() > 1 else 1
            self._edge_out_dim = t_edge  # currently decoders output 1; loss will broadcast if t_edge==1
            if t_edge != 1:
                # You can expand decoders to output multi-target; for now we enforce binary/single-target edges.
                raise NotImplementedError(
                    "Current decoder outputs 1 logit per edge. "
                    "Support for multi-target edge prediction (T>1) not implemented."
                )

        else:
            # Node predictor (unchanged)
            d_out = sample.y.size(-1)
            self.model = build_model(
                name=cfg.model_name,
                d_in=d_in,
                d_out=d_out,
                hidden_dim=cfg.hidden_dim,
                num_layers=cfg.num_layers,
                heads=cfg.heads,
                dropout=cfg.dropout,
                norm=cfg.norm,
                edge_dim=edge_dim,  # used by gatv2
            ).to(self.device)

        # --- build loss callable ---------------------------------------------
        # Use edge_class_weights for edge tasks, else class_weights.
        cw = cfg.edge_class_weights if self.is_edge else cfg.class_weights
        self.loss_fn = build_loss(
            name=cfg.loss_fn,
            class_weights=cw,
            train_loader=train_dl if (isinstance(cw, str) and cw == "auto") else None,
            node_reduction=self.cfg.node_reduction,
        )

        # --- build the metric -------------------------------------------------
        self.metric = METRICS.get(cfg.metric.lower())
        if self.metric is None:
            raise ValueError(f"Unknown metric: {cfg.metric}. Options: {list(METRICS)}")

        # --- optimiser --------------------------------------------------------
        opt_kw = dict(lr=cfg.lr, **cfg.optimiser_kwargs)
        if cfg.optimiser == "adam":
            self.optimizer = torch.optim.Adam(self.model.parameters(), **opt_kw)
        elif cfg.optimiser == "adamw":
            self.optimizer = torch.optim.AdamW(self.model.parameters(), **opt_kw)
        elif cfg.optimiser == "sgd":
            self.optimizer = torch.optim.SGD(self.model.parameters(), **opt_kw)
        else:
            raise ValueError(f"Unknown optimiser '{cfg.optimiser}'")

        # --- LR scheduler -----------------------------------------------------
        self.scheduler = None
        if cfg.lr_scheduler:
            sname = cfg.lr_scheduler.lower()
            kw = cfg.lr_scheduler_kwargs or {}
            if sname == "step":
                self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, **kw)
            elif sname == "plateau":
                self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, **kw)
            elif sname == "cosine":
                self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, **kw)
            elif sname == "onecycle":
                self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, **kw)
            else:
                raise ValueError(f"Unknown lr_scheduler '{cfg.lr_scheduler}'")

        # --- logging ----------------------------------------------------------
        tb_dir = _default_log_dir(cfg.run_name) if log_dir is None else str(log_dir)
        self.writer = SummaryWriter(log_dir=tb_dir)
        print(f"TensorBoard logs -> {tb_dir}")

        # --- checkpointing ----------------------------------------------------
        self.ckpt_dir = Path(ckpt_dir or "./checkpoints")
        self.ckpt_dir.mkdir(parents=True, exist_ok=True)
        self.best_val_loss = float("inf")
        self.patience_counter = 0

        # history
        self.history: Dict[str, List[float]] = dict(train=[], val=[], val_metric=[])

    # ----------------------------- helpers

    def _forward_logits(self, batch) -> torch.Tensor:
        """
        Returns logits for the current task.
        - Node: model(batch) -> [N, T]
        - Edge: decoder(z[src], z[dst]) -> [E] or [E, T]
        """
        if self.is_edge:
            z = self.encoder(batch)
            src, dst = batch.edge_index
            logits = self.decoder(z[src], z[dst]).squeeze(-1)  # [E] or [E, 1] -> [E]
            return logits
        else:
            return self.model(batch)

    def _forward_loss_and_metric(self, batch) -> Tuple[torch.Tensor, float, torch.Tensor]:
        """
        Returns (loss_scalar, metric_value, logits_tensor)
        """
        batch = batch.to(self.device)
        logits = self._forward_logits(batch)

        # targets + optional weights
        if self.is_edge:
            target = batch.edge_label
            if target.dim() == 1:
                target = target.unsqueeze(-1)  # [E, 1] to match BCE broadcasting
            sample_weight = getattr(batch, "edge_weight", None)
            loss = self.loss_fn(
                pred=logits,
                target=target.squeeze(-1) if logits.dim() == 1 else target,
                sample_weight=sample_weight,
            )
            metric_val = self.metric(logits, target)
        else:
            target = batch.y
            node_weight = getattr(batch, "node_weight", None)
            loss = self.loss_fn(pred=self._maybe_squeeze(logits, target), target=target, node_weight=node_weight)
            metric_val = self.metric(logits, target)

        return loss, float(metric_val), logits

    @staticmethod
    def _maybe_squeeze(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        # convenience: if pred is [N] and target [N,1], squeeze target shape
        if pred.dim() == 1 and target.dim() == 2 and target.size(-1) == 1:
            return pred.unsqueeze(-1)
        return pred

    # ----------------------------- epoch funcs

    def _train_epoch(self) -> float:
        self.model.train()
        total = 0.0
        for batch in self.train_dl:
            self.optimizer.zero_grad()
            loss, _, _ = self._forward_loss_and_metric(batch)
            loss.backward()
            if self.cfg.grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
            self.optimizer.step()
            total += loss.item()
        return total / max(1, len(self.train_dl))

    @torch.no_grad()
    def _eval(self, loader: DataLoader) -> tuple[float, float]:
        self.model.eval()
        tot_loss, tot_metric, n_batches = 0.0, 0.0, 0
        for batch in loader:
            loss, metric_val, _ = self._forward_loss_and_metric(batch)
            tot_loss += loss.item()
            tot_metric += metric_val
            n_batches += 1
        return tot_loss / max(1, n_batches), tot_metric / max(1, n_batches)

    # ----------------------------- main loop

    def fit(self) -> None:
        for epoch in range(1, self.cfg.epochs + 1):
            train_loss = self._train_epoch()
            val_loss, val_metric = self._eval(self.val_dl)

            # scheduler step
            if self.scheduler is not None:
                if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_loss)
                else:
                    self.scheduler.step()

            # tensorboard
            if self.writer:
                self.writer.add_scalar("loss/train", train_loss, epoch)
                self.writer.add_scalar("loss/val", val_loss, epoch)
                self.writer.add_scalar(f"{self.cfg.metric}/val", val_metric, epoch)

            # console
            lr = self.optimizer.param_groups[0]["lr"]
            if self.cfg.print_lr_each_epoch:
                print(
                    f"[{epoch:03d}/{self.cfg.epochs}] "
                    f"train={train_loss:.4f}  val={val_loss:.4f}  lr={lr:.2e}  "
                    f"{self.cfg.metric}(val)={val_metric:.4f}"
                )
            else:
                print(
                    f"[{epoch:03d}/{self.cfg.epochs}] train={train_loss:.4f}  val={val_loss:.4f}  {self.cfg.metric}(val)={val_metric:.4f}"
                )

            self.history["train"].append(train_loss)
            self.history["val"].append(val_loss)
            self.history.setdefault("val_metric", []).append(val_metric)

            # early stopping + checkpoint
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.patience_counter = 0
                _save_checkpoint(
                    self.model,
                    path=self.ckpt_dir / "best.pt",
                    meta=dict(cfg=self.cfg.to_dict(), epoch=epoch, val_loss=val_loss),
                )
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.cfg.patience:
                    print(f"Early stopping (no val improvement in {self.cfg.patience} epochs).")
                    break

        if self.writer:
            self.writer.flush()

    # ----------------------------- test / inference

    @torch.no_grad()
    def evaluate(
        self,
        split: str = "test",
        *,
        return_df: bool = False,
    ) -> Dict[str, float] | Tuple[Dict[str, float], pd.DataFrame]:
        """
        Compute loss + metric on *split*; optionally also return a tidy DataFrame.

        Node DF columns:  snap_time, node, pred[0..], target[0..]
        Edge DF columns:  snap_time, src, dst,  pred[0..], target[0..]
        """
        loader = {"train": self.train_dl, "val": self.val_dl, "test": self.test_dl}[split]

        self.model.eval()
        total_loss = total_metric = 0.0
        records: List[Dict[str, Any]] = []

        for batch in loader:
            loss, metric_val, logits = self._forward_loss_and_metric(batch)
            total_loss += loss.item()
            total_metric += metric_val

            if return_df:
                if self.is_edge:
                    self._append_edge_records(records, batch, logits)
                else:
                    self._append_node_records(records, batch, logits)

        stats = {
            "loss": total_loss / max(1, len(loader)),
            f"{self.cfg.metric}": total_metric / max(1, len(loader)),
        }
        if not return_df:
            return stats

        df = pd.DataFrame.from_records(records)
        return stats, df

    @torch.no_grad()
    def predict(self, loader: Optional[DataLoader] = None) -> List[torch.Tensor]:
        """
        Forward-only predictions (list of tensors per batch).
        """
        loader = loader or self.test_dl
        self.model.eval()
        out: List[torch.Tensor] = []
        for batch in loader:
            batch = batch.to(self.device)
            logits = self._forward_logits(batch)
            out.append(logits.detach().cpu())
        return out

    # ----------------------------- dataframe helpers

    @staticmethod
    def _ensure_batch(batch) -> Batch:
        # PyG sometimes hands a Data if batch_size==1; normalize to Batch API
        return batch if isinstance(batch, Batch) else Batch.from_data_list([batch])

    def _append_node_records(self, records: List[Dict[str, Any]], batch, logits: torch.Tensor) -> None:
        batch = self._ensure_batch(batch)
        # logits: [sum_N, T], target: [sum_N, T]
        pred = logits.detach().cpu()
        tgt = batch.y.detach().cpu()
        tss = batch.snap_time.detach().cpu().numpy()  # [num_graphs]

        # infer shapes
        n_tar = pred.size(-1)
        # number of nodes per graph can be derived from ptr diffs
        ptr = batch.ptr.detach().cpu().numpy()  # len = num_graphs+1
        for g in range(len(tss)):
            start, end = ptr[g], ptr[g + 1]
            ts = pd.Timestamp(int(batch.snap_time[g].item()))
            for node in range(end - start):
                rec: Dict[str, Any] = {"snap_time": ts, "node": node}
                for t in range(n_tar):
                    rec["pred" + (str(t) if n_tar > 1 else "")] = float(pred[start + node, t].item())
                    rec["target" + (str(t) if n_tar > 1 else "")] = float(tgt[start + node, t].item())
                records.append(rec)

    def _append_edge_records(self, records: List[Dict[str, Any]], batch, logits: torch.Tensor) -> None:
        batch = self._ensure_batch(batch)
        # logits: [sum_E] or [sum_E, T], edge_label [sum_E] or [sum_E, T]
        pred = logits.detach().cpu()
        tgt = batch.edge_label.detach().cpu()
        if tgt.dim() == 1:
            tgt = tgt.unsqueeze(-1)
        if pred.dim() == 1:
            pred = pred.unsqueeze(-1)
        n_tar = pred.size(-1)

        # map edges to graphs via source node's graph id
        src, dst = batch.edge_index
        gids = batch.batch[src].detach().cpu().numpy()  # graph id per edge
        ptr = batch.ptr.detach().cpu().numpy()          # node-offset per graph
        tss = batch.snap_time.detach().cpu().numpy()    # [num_graphs]

        for i in range(pred.size(0)):  # over edges
            g = int(gids[i])
            node_off = int(ptr[g])
            ts = pd.Timestamp(int(tss[g]))
            src_local = int(src[i].item() - node_off)
            dst_local = int(dst[i].item() - node_off)

            rec: Dict[str, Any] = {"snap_time": ts, "src": src_local, "dst": dst_local}
            for t in range(n_tar):
                rec["pred" + (str(t) if n_tar > 1 else "")] = float(pred[i, t].item())
                rec["target" + (str(t) if n_tar > 1 else "")] = float(tgt[i, t].item())
            records.append(rec)


In [None]:
Notes / rationale

Edge vs node: a single self.is_edge flag (from cfg.is_edge_task()) drives the branch everywhere.

Encoders/decoders: for edges we instantiate the node encoder (as_encoder=True), infer latent_dim by a dry forward, then build the decoder with build_decoder.

Losses: build_loss now accepts "auto" weights for both node and edge. We feed edge_class_weights in edge mode (or fall back to class_weights if you prefer).

Sample weights: wired via sample_weight= (or node_weight for node tasks). If you later add edge_weight tensors to each Data, they’ll be picked up automatically.

Schedulers: supports step, plateau, cosine, onecycle (unchanged). plateau steps on val_loss.

Evaluate (tidy DF):

Node: outputs snap_time, node, pred*, target*.

Edge: outputs snap_time, src, dst, pred*, target*. We robustly map edges to their graph via batch.batch[src], and convert global node ids to per-graph indices using batch.ptr.

Multi-target edges: currently raises if T_edge > 1 (decoders return 1 logit). If you want multi-target, we can extend decoders to output out_dim=T_edge and remove that check.

If you want me to also add optional per-edge smoothing weights (your separate code’s weight_smooth idea), we can add a small helper to compute and attach edge_weight tensors to each snapshot post-split—say in build_dataloaders or a compute_edge_weights(train_set) function—and the trainer will consume them automatically.