In [None]:
# ======================================================================
# gnn_experiment.py
# ======================================================================
from __future__ import annotations

import json
import pathlib
from dataclasses import dataclass, field, asdict
from typing import Dict, Any, Optional, Union, List

import networkx as nx
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

# ------------------------------------------------------------- optional
try:
    import yaml
except ImportError:   # YAML is optional; JSON works out-of-the-box
    yaml = None


# ----------------------------------------------------------------------
# 1. Configuration object
# ----------------------------------------------------------------------
@dataclass
class GNNConfig:
    # --- task / model --------------------------------------------------
    task: str            = "node_reg"               # {"node_reg", "node_clf", "edge_clf"}
    model_name: str      = "gcn"                    # {"gcn", "graphsage", "gat", "gatv2"}
    num_layers: int      = 2
    hidden_dim: int      = 128
    heads: int           = 8                        # for attention models
    norm: str            = "batch"                  # {"batch", "layer", None}

    # --- dataset & target ---------------------------------------------
    target_col: str      = "target"                 # node column to predict

    # --- data splitting -----------------------------------------------
    split_mode: str      = "date"                   # {"date", "ratio"}
    cutoff_date: Optional[str] = None               # required if split_mode == "date"
    val_ratio: float     = 0.10                     # used if split_mode == "ratio"
    test_ratio: float    = 0.10                     # used if split_mode == "ratio"
    shuffle_in_split: bool = False                  # shuffle batches inside DataLoader?

    # --- optimisation --------------------------------------------------
    batch_size: int      = 32
    lr: float            = 1e-3
    epochs: int          = 100
    patience: int        = 10
    grad_clip: float     = 1.0

    # --- loss / optimiser ---------------------------------------------
    loss_fn: str         = "mse"                    # {"mse", "mae", "bce", "cross_entropy"}
    class_weights: Optional[list] = None            # for BCE / CE
    optimiser: str       = "adam"                   # {"adam", "adamw", "sgd"}
    optimiser_kwargs: Dict[str, Any] = field(default_factory=dict)

    # --- misc ----------------------------------------------------------
    device: str          = "cuda"                   # fallback to cpu if unavailable
    run_name: str        = "default_run"
    seed: int            = 42

    # ------------------------------------------------------------------
    # helper constructor
    @staticmethod
    def load(cfg: Union["GNNConfig", str, Dict[str, Any]]) -> "GNNConfig":
        """
        Accept a GNNConfig, a dict, or a path to JSON/YAML and return a GNNConfig.
        """
        if isinstance(cfg, GNNConfig):
            return cfg
        if isinstance(cfg, dict):
            return GNNConfig(**cfg)
        if isinstance(cfg, (str, pathlib.Path)):
            path = pathlib.Path(cfg)
            with open(path, "r", encoding="utf-8") as f:
                if path.suffix == ".json":
                    data = json.load(f)
                elif path.suffix in {".yml", ".yaml"} and yaml is not None:
                    data = yaml.safe_load(f)
                else:
                    raise ValueError("Unsupported file type for config path")
            return GNNConfig(**data)
        raise TypeError(f"Unsupported cfg type: {type(cfg)}")

    # pretty-print helper
    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


# ----------------------------------------------------------------------
# 2. Core experiment wrapper
# ----------------------------------------------------------------------
class GNNExperiment:
    """
    End-to-end wrapper around:
       raw node/edge time-series  →  PyG snapshots  →  loaders  →  model
    """

    # ------------ constructor -------------------------------------------------
    def __init__(
        self,
        node_frames: Dict[str, pd.DataFrame],
        edge_frames: Dict[str, pd.DataFrame],
        graph: nx.DiGraph,
        cfg: Union[GNNConfig, str, Dict[str, Any]] = GNNConfig(),
    ):
        # raw inputs
        self.node_frames = node_frames     # {"Tokyo": df, ...}
        self.edge_frames = edge_frames     # {"tokyo-chubu": df, ...}
        self.graph       = graph

        # config
        self.cfg: GNNConfig = GNNConfig.load(cfg)

        # runtime placeholders
        self.reg_order:  List[str]     = []   # alphabetical node order
        self.edge_order: List[tuple]   = []   # (src_idx, dst_idx)
        self.snapshots   = None               # list[Data]

        self.train_dl = self.val_dl = self.test_dl = None

        self.model = None
        self.optimizer = None
        self.loss_fn = None
        self.history: Dict[str, list] = {}

        # device handling
        self.device = torch.device(
            "cuda" if (self.cfg.device == "cuda" and torch.cuda.is_available()) else "cpu"
        )

    # =========================== helpers =====================================
    @staticmethod
    def _edge_key(src: str, dst: str) -> str:
        """Convert (src, dst) into canonical 'src-dst' lowercase key."""
        return f"{src}-{dst}".lower()

    @staticmethod
    def _ensure_same_index(idxs: List[pd.DatetimeIndex]) -> List[pd.Timestamp]:
        """Return sorted intersection of all DatetimeIndex objects."""
        common = sorted(set.intersection(*(set(i) for i in idxs)))
        if not common:
            raise ValueError("No common timestamps across supplied DataFrames")
        return common

    # =================== 1. snapshot creation ================================
    def prepare_snapshots(self) -> "GNNExperiment":
        """
        Build `torch_geometric.data.Data` objects for each shared timestamp.
        Stores them in `self.snapshots` and returns self.
        """
        # --- deterministic node order ----------------------------------------
        self.reg_order = sorted(self.graph.nodes)
        node_pos = {n: i for i, n in enumerate(self.reg_order)}

        # --- deterministic edge order ----------------------------------------
        self.edge_order = [(node_pos[src], node_pos[dst])
                           for (src, dst) in self.graph.edges]

        # --- intersect timestamps --------------------------------------------
        node_idxs = [df.index for df in self.node_frames.values()]
        edge_idxs = [df.index for df in self.edge_frames.values()]
        ts_common = self._ensure_same_index(node_idxs + edge_idxs)

        # --- construct snapshot objects --------------------------------------
        snapshots: List[Data] = []
        for ts in ts_common:
            # ---- node matrix & targets
            feats, tgts = [], []
            for region in self.reg_order:
                row = self.node_frames[region].loc[ts]
                tgts.append(row[self.cfg.target_col])
                feats.append(row.drop(self.cfg.target_col).to_numpy(dtype=np.float32))
            x = torch.tensor(np.vstack(feats), dtype=torch.float32)
            y = torch.tensor(tgts, dtype=torch.float32)

            # ---- edge attribute matrix
            edge_rows = []
            for src_idx, dst_idx in self.edge_order:
                src_name, dst_name = self.reg_order[src_idx], self.reg_order[dst_idx]
                key = self._edge_key(src_name, dst_name)
                row = self.edge_frames[key].loc[ts]
                edge_rows.append(row.to_numpy(dtype=np.float32))
            edge_attr = torch.tensor(np.vstack(edge_rows), dtype=torch.float32)

            # ---- edge index tensor (shape [2, E])
            edge_index = torch.tensor(np.vstack(self.edge_order), dtype=torch.long)

            snapshots.append(
                Data(
                    x=x,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    snap_time=torch.tensor([pd.Timestamp(ts).value]),
                )
            )

        self.snapshots = snapshots
        return self

    # =================== 2. build loaders ====================================
    def build_loaders(self) -> "GNNExperiment":
        """
        Create `train_dl`, `val_dl`, `test_dl` according to `self.cfg`.
        Must be called **after** `prepare_snapshots`.
        """
        if self.snapshots is None:
            raise RuntimeError("Call prepare_snapshots() before build_loaders()")

        # ---- chronological sort (snap_time is 1-elem tensor) -----------------
        snaps_sorted = sorted(self.snapshots, key=lambda g: g.snap_time.item())

        if self.cfg.split_mode == "date":
            if self.cfg.cutoff_date is None:
                raise ValueError("`cutoff_date` must be set in cfg for date split")

            cutoff_int = pd.Timestamp(self.cfg.cutoff_date).value
            train_set = [g for g in snaps_sorted if g.snap_time.item() <= cutoff_int]
            holdout   = [g for g in snaps_sorted if g.snap_time.item() >  cutoff_int]

            if not holdout:
                raise ValueError("No snapshots after cutoff_date for val/test split")

            # simple 50-50 val/test split on holdout
            mid = len(holdout) // 2
            val_set, test_set = holdout[:mid], holdout[mid:]

        elif self.cfg.split_mode == "ratio":
            n_total = len(snaps_sorted)
            n_test  = int(n_total * self.cfg.test_ratio)
            n_val   = int(n_total * self.cfg.val_ratio)
            n_train = n_total - n_val - n_test

            train_set = snaps_sorted[:n_train]
            val_set   = snaps_sorted[n_train:n_train + n_val]
            test_set  = snaps_sorted[n_train + n_val:]

        else:
            raise ValueError(f"Unknown split_mode '{self.cfg.split_mode}'")

        # ---- DataLoaders ------------------------------------------------------
        self.train_dl = DataLoader(
            train_set,
            batch_size=self.cfg.batch_size,
            shuffle=self.cfg.shuffle_in_split,
        )
        self.val_dl = DataLoader(
            val_set,
            batch_size=self.cfg.batch_size,
            shuffle=False,
        )
        self.test_dl = DataLoader(
            test_set,
            batch_size=self.cfg.batch_size,
            shuffle=False,
        )

        return self

    # =================== placeholders for later ==============================
    def init_model(self) -> "GNNExperiment":
        """Instantiate the requested GNN backbone + head (TODO)."""
        # TODO
        return self

    def compile(self) -> "GNNExperiment":
        """Attach loss fn, optimiser, schedulers, etc. (TODO)."""
        # TODO
        return self

    def train(self, debug: bool = False):
        """Main training loop (TODO)."""
        # TODO
        pass

    def evaluate(self, split: str = "val") -> Dict[str, float]:
        """Evaluate on a chosen split (TODO)."""
        # TODO
        return {}

    def predict(self, node_frames_new, edge_frames_new, timestamps):
        """Predict on unseen timestamps (TODO)."""
        # TODO
        pass

    def plot_history(self, metric: str = "loss"):
        """Matplotlib learning-curve plot (TODO)."""
        # TODO
        pass

    def save(self, run_dir: Union[str, pathlib.Path]):
        """Save config, weights, scaler, etc. (TODO)."""
        # TODO
        pass

    @classmethod
    def load(cls, run_dir: Union[str, pathlib.Path]) -> "GNNExperiment":
        """Reload a previously saved experiment (TODO)."""
        # TODO
        pass


### Model zoo 

In [None]:
class GNNModelSimple(torch.nn.Module):
    def __init__(self, d_in, hidden, d_out):
        super().__init__()
        self.conv1 = GCNConv(d_in, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.fc    = torch.nn.Linear(hidden, d_out)
    def forward(self, data):
        x, ei, batch = data.x, data.edge_index, data.batch
        x = torch.relu(self.conv1(x, ei))
        x = torch.relu(self.conv2(x, ei))
        x = global_mean_pool(x, batch)
        return self.fc(x)

class GNNConv(torch.nn.Module):
    def __init__(self, d_in, hidden, d_out, layers):
        super().__init__()
        self.body = GCN(d_in, hidden, layers, norm='batch', act='relu')
        self.fc   = torch.nn.Linear(hidden, d_out)
    def forward(self, data):
        x = self.body(data.x, data.edge_index)
        return self.fc(x)

class GNNSage(torch.nn.Module):
    def __init__(self, d_in, hidden, d_out, layers):
        super().__init__()
        self.body = GraphSAGE(d_in, hidden, layers, norm='batch', act='relu')
        self.fc   = torch.nn.Linear(hidden, d_out)
    def forward(self, data):
        x = self.body(data.x, data.edge_index)
        return self.fc(x)

class GNNGAT(torch.nn.Module):
    def __init__(self, d_in, hidden, d_out, layers, heads):
        super().__init__()
        self.body = GAT(d_in, hidden, layers, heads=heads, norm='batch', act='relu')
        self.fc   = torch.nn.Linear(hidden, d_out)
    def forward(self, data):
        x = self.body(data.x, data.edge_index)
        return self.fc(x)

class GNNGAT2(torch.nn.Module):
    def __init__(self, d_in, hidden, d_out, layers, heads, edge_dim):
        super().__init__()
        self.conv1 = GATv2Conv(d_in, hidden, heads=heads, concat=True,  edge_dim=edge_dim)
        self.conv2 = GATv2Conv(hidden*heads, hidden, heads=1,  concat=False, edge_dim=edge_dim)
        self.fc    = torch.nn.Linear(hidden, d_out)
    def forward(self, data):
        x = torch.relu(self.conv1(data.x, data.edge_index, data.edge_attr))
        x = torch.relu(self.conv2(x,       data.edge_index, data.edge_attr))
        x = global_mean_pool(x, data.batch)
        return self.fc(x)

### Initialize Model 

In [None]:
# 3.  EXPERIMENT  (unchanged until indicated)
class GNNExperiment:
    # (constructor, helpers, prepare_snapshots, build_loaders as in previous post)
    #  …  previous code unchanged …

    # ---------- NEW: model factory ------------------------------------
    def init_model(self) -> "GNNExperiment":
        if self.snapshots is None:
            raise RuntimeError("Call prepare_snapshots() first.")

        d_in  = self.snapshots[0].x.shape[1]
        d_out = 1                               # can be generalised later
        edge_dim = (
            self.snapshots[0].edge_attr.shape[1]
            if self.snapshots[0].edge_attr is not None else None
        )

        name = self.cfg.model_name.lower()
        if name in {"simple", "baseline"}:
            model = GNNModelSimple(d_in, self.cfg.hidden_dim, d_out)
        elif name in {"gcn"}:
            model = GNNConv(d_in, self.cfg.hidden_dim, d_out, self.cfg.num_layers)
        elif name in {"graphsage", "sage"}:
            model = GNNSage(d_in, self.cfg.hidden_dim, d_out, self.cfg.num_layers)
        elif name == "gat":
            model = GNNGAT(d_in, self.cfg.hidden_dim, d_out,
                           self.cfg.num_layers, self.cfg.heads)
        elif name in {"gatv2", "gat2"}:
            if edge_dim is None:
                raise ValueError("GATv2 selected but snapshots contain no edge features.")
            model = GNNGAT2(d_in, self.cfg.hidden_dim, d_out,
                            self.cfg.num_layers, self.cfg.heads, edge_dim)
        else:
            raise ValueError(f"Unknown model_name '{self.cfg.model_name}'")

        self.model = model.to(self.device)
        return self

### Compile 

In [None]:
def compile(self) -> "GNNExperiment":
    """
    Attach loss function, optimiser (and later scheduler) to the experiment.
    Must be called **after** init_model().
    """
    if self.model is None:
        raise RuntimeError("Call init_model() before compile()")

    # ---------- LOSS ------------------------------------------------
    loss_name = self.cfg.loss_fn.lower()
    if loss_name == "mse":
        self.loss_fn = torch.nn.MSELoss()
    elif loss_name in {"mae", "l1"}:
        self.loss_fn = torch.nn.L1Loss()
    elif loss_name == "bce":
        pos_w = (
            torch.tensor(self.cfg.class_weights, dtype=torch.float32, device=self.device)
            if self.cfg.class_weights else None
        )
        self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_w)
    elif loss_name in {"cross_entropy", "ce"}:
        w = (
            torch.tensor(self.cfg.class_weights, dtype=torch.float32, device=self.device)
            if self.cfg.class_weights else None
        )
        self.loss_fn = torch.nn.CrossEntropyLoss(weight=w)
    else:
        raise ValueError(f"Unknown loss_fn '{self.cfg.loss_fn}'")

    # ---------- OPTIMISER ------------------------------------------
    opt_name = self.cfg.optimiser.lower()
    opt_kwargs = dict(lr=self.cfg.lr, **self.cfg.optimiser_kwargs)

    if opt_name == "adam":
        self.optimizer = torch.optim.Adam(self.model.parameters(), **opt_kwargs)
    elif opt_name == "adamw":
        self.optimizer = torch.optim.AdamW(self.model.parameters(), **opt_kwargs)
    elif opt_name == "sgd":
        self.optimizer = torch.optim.SGD(self.model.parameters(), **opt_kwargs)
    else:
        raise ValueError(f"Unknown optimiser '{self.cfg.optimiser}'")

    # schedulers or other callbacks could be attached here later
    return self

# ---------- train / evaluate / etc. to be filled later ----------
def train(self, debug: bool = False):
    pass
def evaluate(self, split: str = "val") -> Dict[str, float]:
    return {}
def predict(self, node_frames_new, edge_frames_new, timestamps):
    pass
def plot_history(self, metric: str = "loss"): pass
def save(self, run_dir: Union[str, pathlib.Path]): pass
@classmethod
def load(cls, run_dir: Union[str, pathlib.Path]) -> "GNNExperiment": pass

### Extra Loss functionality 

In [None]:
# 0.  SMALL UTILS
# ----------------------------------------------------------------------
def mean_absolute_error(preds: torch.Tensor, targets: torch.Tensor) -> float:
    """MAE helper (tensor → float CPU)."""
    return torch.mean(torch.abs(preds - targets)).item()


In [None]:
# ======================================================================
# gnn_experiment.py
# ======================================================================
from __future__ import annotations
import json, pathlib, numpy as np, pandas as pd, torch, networkx as nx
from dataclasses import dataclass, field, asdict
from typing import Dict, Any, Optional, Union, List
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    GCNConv, GATv2Conv, GCN, GraphSAGE, GAT, global_mean_pool
)

try:  import yaml
except ImportError: yaml = None


# ----------------------------------------------------------------------
# 0.  SMALL UTILS
# ----------------------------------------------------------------------
def mean_absolute_error(preds: torch.Tensor, targets: torch.Tensor) -> float:
    return torch.mean(torch.abs(preds - targets)).item()

def has_nan(t: torch.Tensor) -> bool:
    return torch.isnan(t).any() or torch.isinf(t).any()


# ----------------------------------------------------------------------
# 1.  CONFIG
# ----------------------------------------------------------------------
@dataclass
class GNNConfig:
    # --- task / model --------------------------------------------------
    task: str            = "node_reg"              # {"node_reg", "node_clf", "edge_clf"}
    model_name: str      = "gcn"                   # {"simple","gcn","graphsage","gat","gatv2"}
    num_layers: int      = 2
    hidden_dim: int      = 128
    heads: int           = 8
    norm: str            = "batch"

    # --- dataset & target ---------------------------------------------
    target_col: str      = "target"

    # --- splitting -----------------------------------------------------
    split_mode: str      = "date"                  # {"date","ratio"}
    cutoff_date: Optional[str] = None
    val_ratio: float     = 0.10
    test_ratio: float    = 0.10
    shuffle_in_split: bool = False

    # --- optimisation --------------------------------------------------
    batch_size: int      = 32
    lr: float            = 1e-3
    epochs: int          = 100
    patience: int        = 10
    grad_clip: float     = 1.0

    # --- loss / optimiser ---------------------------------------------
    loss_fn: str         = "mse"                   # {"mse","mae","bce","cross_entropy"}
    class_weights: Optional[list] = None
    optimiser: str       = "adam"                  # {"adam","adamw","sgd"}
    optimiser_kwargs: Dict[str, Any] = field(default_factory=dict)

    # --- misc ----------------------------------------------------------
    device: str          = "cuda"
    run_name: str        = "default_run"
    seed: int            = 42

    # ---------- helper constructor ----------
    @staticmethod
    def load(cfg: Union["GNNConfig", str, Dict[str, Any]]) -> "GNNConfig":
        if isinstance(cfg, GNNConfig):
            return cfg
        if isinstance(cfg, dict):
            return GNNConfig(**cfg)
        path = pathlib.Path(cfg)
        with open(path, "r", encoding="utf-8") as f:
            data = (
                json.load(f)
                if path.suffix == ".json"
                else yaml.safe_load(f) if yaml else
                (_ for _ in ()).throw(RuntimeError("PyYAML not installed"))
            )
        return GNNConfig(**data)

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


# ----------------------------------------------------------------------
# 2.  MODEL ZOO  (shortened comments)
# ----------------------------------------------------------------------
class GNNModelSimple(torch.nn.Module):
    def __init__(self, d_in, hidden, d_out):
        super().__init__()
        self.conv1 = GCNConv(d_in, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.fc    = torch.nn.Linear(hidden, d_out)
    def forward(self, data):
        x = torch.relu(self.conv1(data.x, data.edge_index))
        x = torch.relu(self.conv2(x,      data.edge_index))
        return self.fc(global_mean_pool(x, data.batch))

class GNNConv(torch.nn.Module):
    def __init__(self, d_in, hidden, d_out, layers):
        super().__init__()
        self.body = GCN(d_in, hidden, layers, norm='batch', act='relu')
        self.fc   = torch.nn.Linear(hidden, d_out)
    def forward(self, data):
        return self.fc(self.body(data.x, data.edge_index))

class GNNSage(torch.nn.Module):
    def __init__(self, d_in, hidden, d_out, layers):
        super().__init__()
        self.body = GraphSAGE(d_in, hidden, layers, norm='batch', act='relu')
        self.fc   = torch.nn.Linear(hidden, d_out)
    def forward(self, data):
        return self.fc(self.body(data.x, data.edge_index))

class GNNGAT(torch.nn.Module):
    def __init__(self, d_in, hidden, d_out, layers, heads):
        super().__init__()
        self.body = GAT(d_in, hidden, layers, heads=heads, norm='batch', act='relu')
        self.fc   = torch.nn.Linear(hidden, d_out)
    def forward(self, data):
        return self.fc(self.body(data.x, data.edge_index))

class GNNGAT2(torch.nn.Module):
    def __init__(self, d_in, hidden, d_out, layers, heads, edge_dim):
        super().__init__()
        self.conv1 = GATv2Conv(d_in, hidden, heads=heads, concat=True, edge_dim=edge_dim)
        self.conv2 = GATv2Conv(hidden*heads, hidden, heads=1, concat=False, edge_dim=edge_dim)
        self.fc    = torch.nn.Linear(hidden, d_out)
    def forward(self, data):
        x = torch.relu(self.conv1(data.x, data.edge_index, data.edge_attr))
        x = torch.relu(self.conv2(x,      data.edge_index, data.edge_attr))
        return self.fc(global_mean_pool(x, data.batch))


# ----------------------------------------------------------------------
# 3.  EXPERIMENT (prepare_snapshots + loaders + model/compile as before)
# ----------------------------------------------------------------------
class GNNExperiment:
    # ------- constructor & helpers (unchanged until train section) -----
    def __init__(self, node_frames, edge_frames, graph, cfg=GNNConfig()):
        self.node_frames, self.edge_frames, self.graph = node_frames, edge_frames, graph
        self.cfg: GNNConfig = GNNConfig.load(cfg)

        self.reg_order:  List[str]   = []
        self.edge_order: List[tuple] = []
        self.snapshots = None

        self.train_dl = self.val_dl = self.test_dl = None
        self.model = self.optimizer = self.loss_fn = None
        self.metric_fn = mean_absolute_error
        self.history: Dict[str, list] = {"train_loss": [], "val_loss": [], "val_metric": []}
        self.best_val_loss = float("inf")
        self.patience_counter = 0
        self.best_ckpt = None

        self.device = torch.device(
            "cuda" if (self.cfg.device == "cuda" and torch.cuda.is_available()) else "cpu"
        )

    # -- (prepare_snapshots & build_loaders identical to earlier version) --

    def _edge_key(self, s, d): return f"{s}-{d}".lower()
    @staticmethod
    def _ensure_same_index(idxs): return sorted(set.intersection(*(set(i) for i in idxs)))

    def prepare_snapshots(self):
        self.reg_order = sorted(self.graph.nodes)
        pos = {n: i for i, n in enumerate(self.reg_order)}
        self.edge_order = [(pos[s], pos[d]) for (s, d) in self.graph.edges]

        ts_common = self._ensure_same_index(
            [df.index for df in self.node_frames.values()]
            + [df.index for df in self.edge_frames.values()]
        )

        snaps = []
        for ts in ts_common:
            feats, tgts = [], []
            for region in self.reg_order:
                row = self.node_frames[region].loc[ts]
                tgts.append(row[self.cfg.target_col])
                feats.append(row.drop(self.cfg.target_col).to_numpy(dtype=np.float32))
            x = torch.tensor(np.vstack(feats), dtype=torch.float32)
            y = torch.tensor(tgts, dtype=torch.float32)

            edge_rows = []
            for s_i, d_i in self.edge_order:
                s, d = self.reg_order[s_i], self.reg_order[d_i]
                edge_rows.append(
                    self.edge_frames[self._edge_key(s, d)].loc[ts].to_numpy(np.float32)
                )
            edge_attr  = torch.tensor(np.vstack(edge_rows), dtype=torch.float32)
            edge_index = torch.tensor(np.vstack(self.edge_order), dtype=torch.long)

            snaps.append(
                Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
                     y=y, snap_time=torch.tensor([pd.Timestamp(ts).value]))
            )
        self.snapshots = snaps
        return self

    def build_loaders(self):
        if self.snapshots is None:
            raise RuntimeError("Call prepare_snapshots() first")

        snaps_sorted = sorted(self.snapshots, key=lambda g: g.snap_time.item())
        if self.cfg.split_mode == "date":
            cut = pd.Timestamp(self.cfg.cutoff_date).value
            train_set = [g for g in snaps_sorted if g.snap_time.item() <= cut]
            hold      = [g for g in snaps_sorted if g.snap_time.item() >  cut]
            mid = len(hold) // 2
            val_set, test_set = hold[:mid], hold[mid:]
        else:
            n = len(snaps_sorted)
            n_test = int(n * self.cfg.test_ratio)
            n_val  = int(n * self.cfg.val_ratio)
            n_train = n - n_val - n_test
            train_set = snaps_sorted[:n_train]
            val_set   = snaps_sorted[n_train:n_train+n_val]
            test_set  = snaps_sorted[n_train+n_val:]

        self.train_dl = DataLoader(train_set, batch_size=self.cfg.batch_size,
                                   shuffle=self.cfg.shuffle_in_split)
        self.val_dl   = DataLoader(val_set,   batch_size=self.cfg.batch_size)
        self.test_dl  = DataLoader(test_set,  batch_size=self.cfg.batch_size)
        return self

    def update_config(self, **kwargs):
        for k, v in kwargs.items():
            if not hasattr(self.cfg, k): raise AttributeError(k)
            setattr(self.cfg, k, v)
        return self

    def init_model(self):
        d_in  = self.snapshots[0].x.shape[1]
        d_out = 1
        edge_dim = self.snapshots[0].edge_attr.shape[1]

        name = self.cfg.model_name.lower()
        if name in {"simple", "baseline"}:
            self.model = GNNModelSimple(d_in, self.cfg.hidden_dim, d_out)
        elif name == "gcn":
            self.model = GNNConv(d_in, self.cfg.hidden_dim, d_out, self.cfg.num_layers)
        elif name in {"graphsage", "sage"}:
            self.model = GNNSage(d_in, self.cfg.hidden_dim, d_out, self.cfg.num_layers)
        elif name == "gat":
            self.model = GNNGAT(d_in, self.cfg.hidden_dim, d_out,
                                self.cfg.num_layers, self.cfg.heads)
        elif name in {"gatv2", "gat2"}:
            self.model = GNNGAT2(d_in, self.cfg.hidden_dim, d_out,
                                 self.cfg.num_layers, self.cfg.heads, edge_dim)
        else:
            raise ValueError(f"Unknown model '{self.cfg.model_name}'")

        self.model = self.model.to(self.device)
        return self

    def compile(self):
        if self.model is None:
            raise RuntimeError("Call init_model() before compile()")

        loss = self.cfg.loss_fn.lower()
        if loss == "mse":
            self.loss_fn = torch.nn.MSELoss()
        elif loss in {"mae", "l1"}:
            self.loss_fn = torch.nn.L1Loss()
        elif loss == "bce":
            w = (torch.tensor(self.cfg.class_weights, dtype=torch.float32, device=self.device)
                 if self.cfg.class_weights else None)
            self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=w)
        elif loss in {"cross_entropy", "ce"}:
            w = (torch.tensor(self.cfg.class_weights, dtype=torch.float32, device=self.device)
                 if self.cfg.class_weights else None)
            self.loss_fn = torch.nn.CrossEntropyLoss(weight=w)
        else:
            raise ValueError(f"Unknown loss_fn '{self.cfg.loss_fn}'")

        opt_name = self.cfg.optimiser.lower()
        opt_kwargs = dict(lr=self.cfg.lr, **self.cfg.optimiser_kwargs)
        if opt_name == "adam":
            self.optimizer = torch.optim.Adam(self.model.parameters(), **opt_kwargs)
        elif opt_name == "adamw":
            self.optimizer = torch.optim.AdamW(self.model.parameters(), **opt_kwargs)
        elif opt_name == "sgd":
            self.optimizer = torch.optim.SGD(self.model.parameters(), **opt_kwargs)
        else:
            raise ValueError(f"Unknown optimiser '{self.cfg.optimiser}'")

        return self

    # ==================================================================
    # ### 4. TRAIN / EVALUATE ###########################################
    # ==================================================================
    def _train_epoch(self, debug: bool = False) -> float:
        self.model.train()
        total = 0
        for step, data in enumerate(self.train_dl):
            data = data.to(self.device)
            data.x = data.x.float()
            if data.edge_attr is not None:
                data.edge_attr = data.edge_attr.float()

            self.optimizer.zero_grad()
            out = self.model(data)
            loss = self.loss_fn(out.squeeze(-1), data.y)
            if debug and (has_nan(loss) or has_nan(out)):
                raise RuntimeError(f"NaN detected at step {step}")

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
            self.optimizer.step()
            total += loss.item()
        return total / len(self.train_dl)

    @torch.no_grad()
    def _eval_loader(self, loader: DataLoader) -> tuple[float, float]:
        self.model.eval()
        total_loss = total_metric = 0
        for data in loader:
            data = data.to(self.device)
            data.x = data.x.float()
            if data.edge_attr is not None:
                data.edge_attr = data.edge_attr.float()
            out = self.model(data)
            loss = self.loss_fn(out.squeeze(-1), data.y)
            metric = self.metric_fn(out.squeeze(-1).cpu(), data.y.cpu())
            total_loss   += loss.item()
            total_metric += metric
        n = len(loader)
        return total_loss / n, total_metric / n

    def train(self, debug: bool = False):
        if any(v is None for v in (self.model, self.optimizer, self.loss_fn)):
            raise RuntimeError("Call init_model() and compile() before train()")

        for epoch in range(1, self.cfg.epochs + 1):
            tr_loss = self._train_epoch(debug)
            val_loss, val_metric = self._eval_loader(self.val_dl)

            self.history["train_loss"].append(tr_loss)
            self.history["val_loss"].append(val_loss)
            self.history["val_metric"].append(val_metric)

            print(f"[{epoch:03d}/{self.cfg.epochs}] "
                  f"train={tr_loss:.4f}  val={val_loss:.4f}  metric={val_metric:.4f}")

            # ----- early stopping -----
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.patience_counter = 0
                self.best_ckpt = {k: v.cpu() for k, v in self.model.state_dict().items()}
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.cfg.patience:
                    print(f"Early stopping at epoch {epoch}")
                    break
        return self

    def evaluate(self, split: str = "test") -> Dict[str, float]:
        loader = {"train": self.train_dl, "val": self.val_dl, "test": self.test_dl}.get(split)
        if loader is None:
            raise ValueError("split must be one of {'train','val','test'}")

        loss, metric = self._eval_loader(loader)
        return {"loss": loss, "metric": metric}

    # -------- stubs for predict / save / load / plot (later) ----------
    def predict(self, node_frames_new, edge_frames_new, timestamps): ...
    def plot_history(self, metric: str = "loss"): ...
    def save(self, run_dir): ...
    @classmethod
    def load(cls, run_dir): ...


# ======================================================================
# USAGE EXAMPLE (not executed here)
# ======================================================================
# exp = (GNNExperiment(nodes_dict, edges_dict, G, {"model_name":"gatv2", "cutoff_date":"2023-06-01"})
#        .prepare_snapshots()
#        .build_loaders()
#        .init_model()
#        .compile()
#        .train(debug=False))
#
# print(exp.evaluate("test"))


### Predictions and plotting 

In [None]:
# =====================================================================
# inside class GNNExperiment  (add/replace these two methods)
# =====================================================================
# --------------------------------------------------
# PREDICT on *unseen* data
# --------------------------------------------------
@torch.no_grad()
def predict(
    self,
    node_frames_new: Dict[str, pd.DataFrame],
    edge_frames_new: Dict[str, pd.DataFrame],
    timestamps: Optional[List[pd.Timestamp]] = None,
    return_df: bool = True,
):
    """
    Forward-pass the fitted model on future (or out-of-sample) snapshots.

    Args
    ----
    node_frames_new : dict[str → pd.DataFrame]
        Same structure / columns as training node_frames.
    edge_frames_new : dict[str → pd.DataFrame]
        Same structure / columns as training edge_frames.
    timestamps      : iterable of pandas Timestamps (optional)
        If None, use *intersection* of all node/edge frames.
    return_df       : if True, returns a tidy DataFrame; else raw torch.Tensor.

    Returns
    -------
    pd.DataFrame or torch.Tensor
    """
    if self.model is None:
        raise RuntimeError("Train or load a model before calling predict()")

    # 1. decide which timestamps to score ---------------------------------
    if timestamps is None:
        idxs = [df.index for df in node_frames_new.values()] + \
               [df.index for df in edge_frames_new.values()]
        timestamps = sorted(set.intersection(*(set(ix) for ix in idxs)))
    else:
        timestamps = [pd.Timestamp(ts) for ts in timestamps]

    # 2. build snapshots *in the same node/edge order* --------------------
    snaps = []
    for ts in timestamps:
        feats = []
        for region in self.reg_order:
            feats.append(
                node_frames_new[region].loc[ts]
                .drop(self.cfg.target_col)
                .to_numpy(dtype=np.float32)
            )
        x = torch.tensor(np.vstack(feats), dtype=torch.float32)

        edge_rows = []
        for s_idx, d_idx in self.edge_order:
            s, d = self.reg_order[s_idx], self.reg_order[d_idx]
            edge_rows.append(
                edge_frames_new[f"{s}-{d}".lower()].loc[ts].to_numpy(np.float32)
            )
        edge_attr  = torch.tensor(np.vstack(edge_rows), dtype=torch.float32)
        edge_index = torch.tensor(np.array(self.edge_order).T, dtype=torch.long)

        snaps.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr))

    loader = DataLoader(snaps, batch_size=self.cfg.batch_size, shuffle=False)

    # 3. forward pass ------------------------------------------------------
    self.model.eval()
    preds = []
    for data in loader:
        data = data.to(self.device)
        data.x = data.x.float()
        if data.edge_attr is not None:
            data.edge_attr = data.edge_attr.float()
        out = self.model(data).squeeze(-1).cpu()
        preds.append(out)

    y_hat = torch.cat(preds, dim=0)          # [len(timestamps)*N]
    y_hat = y_hat.reshape(len(timestamps), len(self.reg_order))

    if return_df:
        return pd.DataFrame(
            y_hat.numpy(),
            index=pd.to_datetime(timestamps),
            columns=self.reg_order,
        )
    return y_hat                                     # raw tensor


# --------------------------------------------------
# PLOT learning curves
# --------------------------------------------------
def plot_history(self, metric: str = "loss"):
    """
    Plot training vs validation curves.

    metric : "loss" (train/val) | "metric" (val MAE or CE etc.)
    """
    if not self.history["train_loss"]:
        raise RuntimeError("Nothing in history — did you call train()?")

    if metric == "loss":
        plt.figure(figsize=(6, 4))
        plt.plot(self.history["train_loss"], label="Train Loss")
        plt.plot(self.history["val_loss"],   label="Val Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training / Validation Loss")
        plt.legend()
        plt.grid(True)
        plt.show()

    elif metric in {"metric", "mae", "accuracy"}:
        if not self.history["val_metric"]:
            raise ValueError("Metric history empty; choose 'loss' instead.")
        plt.figure(figsize=(6, 4))
        plt.plot(self.history["val_metric"], label=f"Val {metric.upper()}")
        plt.xlabel("Epoch")
        plt.ylabel(metric.upper())
        plt.title(f"Validation {metric.upper()} Curve")
        plt.legend()
        plt.grid(True)
        plt.show()
    else:
        raise ValueError("metric must be 'loss' or 'metric'")

### Loading and prediciting 

In [None]:
# =====================================================================
# inside class GNNExperiment  (add or replace the stubs)
# =====================================================================
    # --------------------------------------------------
    # SAVE everything needed to reproduce / predict
    # --------------------------------------------------
    def save(self, run_dir: Union[str, pathlib.Path]):
        """
        Persist config, model weights, and run metadata into `run_dir/`.

        Creates:
            run_dir/
              ├─ cfg.json
              ├─ meta.json             (node & edge ordering, dims)
              ├─ model.pt              (best weights if available else current)
              └─ history.json          (training curves)
        """
        run_dir = pathlib.Path(run_dir)
        run_dir.mkdir(parents=True, exist_ok=True)

        # 1) config ------------------------------------------------------
        (run_dir / "cfg.json").write_text(
            json.dumps(self.cfg.to_dict(), indent=2)
        )

        # 2) metadata we need to rebuild the model without original data --
        meta = {
            "reg_order":  self.reg_order,
            "edge_order": self.edge_order,
            "input_dim":  int(self.snapshots[0].x.shape[1]) if self.snapshots else None,
            "edge_dim": (
                int(self.snapshots[0].edge_attr.shape[1])
                if self.snapshots and self.snapshots[0].edge_attr is not None
                else None
            ),
        }
        (run_dir / "meta.json").write_text(json.dumps(meta, indent=2))

        # 3) model weights ----------------------------------------------
        state = self.best_ckpt if self.best_ckpt is not None else self.model.state_dict()
        torch.save(state, run_dir / "model.pt")

        # 4) training curves --------------------------------------------
        (run_dir / "history.json").write_text(json.dumps(self.history, indent=2))

        print(f"✨  Run saved to: {run_dir.resolve()}")


    # --------------------------------------------------
    # LOAD a previously saved run
    # --------------------------------------------------
    @classmethod
    def load(
        cls,
        run_dir: Union[str, pathlib.Path],
        node_frames: Optional[Dict[str, pd.DataFrame]] = None,
        edge_frames: Optional[Dict[str, pd.DataFrame]] = None,
        graph: Optional[nx.DiGraph] = None,
    ) -> "GNNExperiment":
        """
        Recreate an experiment from disk.  If you plan to *predict*, you must
        also supply `node_frames`, `edge_frames`, and `graph` so the snapshot
        builder can run; otherwise they can be left None.

        Usage
        -----
        exp = GNNExperiment.load("runs/my_run", node_frames, edge_frames, G)
        ŷ = exp.predict(...)
        """
        run_dir = pathlib.Path(run_dir)
        cfg  = GNNConfig.load(json.loads((run_dir / "cfg.json").read_text()))
        meta = json.loads((run_dir / "meta.json").read_text())

        # --------------------------------------------------------------
        # 1. Create (possibly empty) experiment instance
        # --------------------------------------------------------------
        if node_frames is None:  node_frames = {}
        if edge_frames is None:  edge_frames = {}
        if graph is None:        graph = nx.DiGraph()

        exp = cls(node_frames, edge_frames, graph, cfg)

        # restore deterministic ordering
        exp.reg_order  = meta["reg_order"]
        exp.edge_order = [tuple(t) for t in meta["edge_order"]]

        # --------------------------------------------------------------
        # 2. Re-instantiate the model (input & edge dims come from meta)
        # --------------------------------------------------------------
        d_in, edge_dim = meta["input_dim"], meta["edge_dim"]
        name = cfg.model_name.lower()
        if name in {"simple", "baseline"}:
            model = GNNModelSimple(d_in, cfg.hidden_dim, 1)
        elif name == "gcn":
            model = GNNConv(d_in, cfg.hidden_dim, 1, cfg.num_layers)
        elif name in {"graphsage", "sage"}:
            model = GNNSage(d_in, cfg.hidden_dim, 1, cfg.num_layers)
        elif name == "gat":
            model = GNNGAT(d_in, cfg.hidden_dim, 1, cfg.num_layers, cfg.heads)
        elif name in {"gatv2", "gat2"}:
            model = GNNGAT2(d_in, cfg.hidden_dim, 1, cfg.num_layers, cfg.heads, edge_dim)
        else:
            raise ValueError(f"Unknown model '{cfg.model_name}'")

        exp.model = model.to(exp.device)

        # load weights
        state_dict = torch.load(run_dir / "model.pt", map_location=exp.device)
        exp.model.load_state_dict(state_dict)

        # training curves (optional)
        exp.history = json.loads((run_dir / "history.json").read_text())

        print(f"🔄  Loaded run from {run_dir.resolve()}")
        return exp


### Scaling Data 

In [None]:
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
import numpy as np

# -----------------------------------------------------------
# 1️⃣  split timestamps once so we reuse it everywhere
# -----------------------------------------------------------
cutoff = pd.Timestamp("2023-06-01")

train_idx = node_frames["Tokyo"].index <= cutoff
test_idx  = ~train_idx                # everything after cutoff

# -----------------------------------------------------------
# 2️⃣  one scaler *per region*
# -----------------------------------------------------------
y_scalers = {}          # region -> fitted scaler
for region, df in node_frames.items():
    # fit on TRAIN part only
    scaler = MinMaxScaler()
    scaler.fit(df.loc[train_idx, ["target"]].values)
    y_scalers[region] = scaler

    # overwrite the column in BOTH splits with scaled values
    df.loc[:, "target"] = scaler.transform(df[["target"]])

# -----------------------------------------------------------
# 3️⃣  build & train your GNNExperiment
# -----------------------------------------------------------
exp = (
    GNNExperiment(node_frames, edge_frames, G,
                  {"cutoff_date": "2023-06-01"})
      .prepare_snapshots()
      .build_loaders()
      .init_model()
      .compile()
      .train()
)

# -----------------------------------------------------------
# 4️⃣  get scaled predictions
# -----------------------------------------------------------
scaled_preds = exp.predict(node_frames, edge_frames)   # DataFrame, still scaled

# -----------------------------------------------------------
# 5️⃣  inverse-transform for human-readable numbers
# -----------------------------------------------------------
pred_orig_units = pd.DataFrame(index=scaled_preds.index, columns=scaled_preds.columns)
for region in scaled_preds.columns:
    pred_orig_units[region] = y_scalers[region].inverse_transform(
        scaled_preds[[region]].values
    )

# -----------------------------------------------------------
# 6️⃣  compute metrics in natural units
# -----------------------------------------------------------
y_true_orig = pd.DataFrame(index=pred_orig_units.index, columns=pred_orig_units.columns)
for region in pred_orig_units.columns:
    # remember: the node_frames we stored earlier now contain *scaled* y.
    # fetch the *unscaled* version from the original CSV / a backup copy
    # or just inverse-transform the scaled column the same way:
    y_true_orig[region] = y_scalers[region].inverse_transform(
        node_frames[region].loc[test_idx, ["target"]].values
    )

mae_per_region = (pred_orig_units - y_true_orig).abs().mean()
print(mae_per_region)
