In [None]:
# ======================================================================
# gnn_experiment.py
# ======================================================================
from __future__ import annotations

import json
import pathlib
from dataclasses import dataclass, field, asdict
from typing import Dict, Any, Optional, Union, List

import networkx as nx
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

# ------------------------------------------------------------- optional
try:
    import yaml
except ImportError:   # YAML is optional; JSON works out-of-the-box
    yaml = None


# ----------------------------------------------------------------------
# 1. Configuration object
# ----------------------------------------------------------------------
@dataclass
class GNNConfig:
    # --- task / model --------------------------------------------------
    task: str            = "node_reg"               # {"node_reg", "node_clf", "edge_clf"}
    model_name: str      = "gcn"                    # {"gcn", "graphsage", "gat", "gatv2"}
    num_layers: int      = 2
    hidden_dim: int      = 128
    heads: int           = 8                        # for attention models
    norm: str            = "batch"                  # {"batch", "layer", None}

    # --- dataset & target ---------------------------------------------
    target_col: str      = "target"                 # node column to predict

    # --- data splitting -----------------------------------------------
    split_mode: str      = "date"                   # {"date", "ratio"}
    cutoff_date: Optional[str] = None               # required if split_mode == "date"
    val_ratio: float     = 0.10                     # used if split_mode == "ratio"
    test_ratio: float    = 0.10                     # used if split_mode == "ratio"
    shuffle_in_split: bool = False                  # shuffle batches inside DataLoader?

    # --- optimisation --------------------------------------------------
    batch_size: int      = 32
    lr: float            = 1e-3
    epochs: int          = 100
    patience: int        = 10
    grad_clip: float     = 1.0

    # --- loss / optimiser ---------------------------------------------
    loss_fn: str         = "mse"                    # {"mse", "mae", "bce", "cross_entropy"}
    class_weights: Optional[list] = None            # for BCE / CE
    optimiser: str       = "adam"                   # {"adam", "adamw", "sgd"}
    optimiser_kwargs: Dict[str, Any] = field(default_factory=dict)

    # --- misc ----------------------------------------------------------
    device: str          = "cuda"                   # fallback to cpu if unavailable
    run_name: str        = "default_run"
    seed: int            = 42

    # ------------------------------------------------------------------
    # helper constructor
    @staticmethod
    def load(cfg: Union["GNNConfig", str, Dict[str, Any]]) -> "GNNConfig":
        """
        Accept a GNNConfig, a dict, or a path to JSON/YAML and return a GNNConfig.
        """
        if isinstance(cfg, GNNConfig):
            return cfg
        if isinstance(cfg, dict):
            return GNNConfig(**cfg)
        if isinstance(cfg, (str, pathlib.Path)):
            path = pathlib.Path(cfg)
            with open(path, "r", encoding="utf-8") as f:
                if path.suffix == ".json":
                    data = json.load(f)
                elif path.suffix in {".yml", ".yaml"} and yaml is not None:
                    data = yaml.safe_load(f)
                else:
                    raise ValueError("Unsupported file type for config path")
            return GNNConfig(**data)
        raise TypeError(f"Unsupported cfg type: {type(cfg)}")

    # pretty-print helper
    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


# ----------------------------------------------------------------------
# 2. Core experiment wrapper
# ----------------------------------------------------------------------
class GNNExperiment:
    """
    End-to-end wrapper around:
       raw node/edge time-series  →  PyG snapshots  →  loaders  →  model
    """

    # ------------ constructor -------------------------------------------------
    def __init__(
        self,
        node_frames: Dict[str, pd.DataFrame],
        edge_frames: Dict[str, pd.DataFrame],
        graph: nx.DiGraph,
        cfg: Union[GNNConfig, str, Dict[str, Any]] = GNNConfig(),
    ):
        # raw inputs
        self.node_frames = node_frames     # {"Tokyo": df, ...}
        self.edge_frames = edge_frames     # {"tokyo-chubu": df, ...}
        self.graph       = graph

        # config
        self.cfg: GNNConfig = GNNConfig.load(cfg)

        # runtime placeholders
        self.reg_order:  List[str]     = []   # alphabetical node order
        self.edge_order: List[tuple]   = []   # (src_idx, dst_idx)
        self.snapshots   = None               # list[Data]

        self.train_dl = self.val_dl = self.test_dl = None

        self.model = None
        self.optimizer = None
        self.loss_fn = None
        self.history: Dict[str, list] = {}

        # device handling
        self.device = torch.device(
            "cuda" if (self.cfg.device == "cuda" and torch.cuda.is_available()) else "cpu"
        )

    # =========================== helpers =====================================
    @staticmethod
    def _edge_key(src: str, dst: str) -> str:
        """Convert (src, dst) into canonical 'src-dst' lowercase key."""
        return f"{src}-{dst}".lower()

    @staticmethod
    def _ensure_same_index(idxs: List[pd.DatetimeIndex]) -> List[pd.Timestamp]:
        """Return sorted intersection of all DatetimeIndex objects."""
        common = sorted(set.intersection(*(set(i) for i in idxs)))
        if not common:
            raise ValueError("No common timestamps across supplied DataFrames")
        return common

    # =================== 1. snapshot creation ================================
    def prepare_snapshots(self) -> "GNNExperiment":
        """
        Build `torch_geometric.data.Data` objects for each shared timestamp.
        Stores them in `self.snapshots` and returns self.
        """
        # --- deterministic node order ----------------------------------------
        self.reg_order = sorted(self.graph.nodes)
        node_pos = {n: i for i, n in enumerate(self.reg_order)}

        # --- deterministic edge order ----------------------------------------
        self.edge_order = [(node_pos[src], node_pos[dst])
                           for (src, dst) in self.graph.edges]

        # --- intersect timestamps --------------------------------------------
        node_idxs = [df.index for df in self.node_frames.values()]
        edge_idxs = [df.index for df in self.edge_frames.values()]
        ts_common = self._ensure_same_index(node_idxs + edge_idxs)

        # --- construct snapshot objects --------------------------------------
        snapshots: List[Data] = []
        for ts in ts_common:
            # ---- node matrix & targets
            feats, tgts = [], []
            for region in self.reg_order:
                row = self.node_frames[region].loc[ts]
                tgts.append(row[self.cfg.target_col])
                feats.append(row.drop(self.cfg.target_col).to_numpy(dtype=np.float32))
            x = torch.tensor(np.vstack(feats), dtype=torch.float32)
            y = torch.tensor(tgts, dtype=torch.float32)

            # ---- edge attribute matrix
            edge_rows = []
            for src_idx, dst_idx in self.edge_order:
                src_name, dst_name = self.reg_order[src_idx], self.reg_order[dst_idx]
                key = self._edge_key(src_name, dst_name)
                row = self.edge_frames[key].loc[ts]
                edge_rows.append(row.to_numpy(dtype=np.float32))
            edge_attr = torch.tensor(np.vstack(edge_rows), dtype=torch.float32)

            # ---- edge index tensor (shape [2, E])
            edge_index = torch.tensor(np.vstack(self.edge_order), dtype=torch.long)

            snapshots.append(
                Data(
                    x=x,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    snap_time=torch.tensor([pd.Timestamp(ts).value]),
                )
            )

        self.snapshots = snapshots
        return self

    # =================== 2. build loaders ====================================
    def build_loaders(self) -> "GNNExperiment":
        """
        Create `train_dl`, `val_dl`, `test_dl` according to `self.cfg`.
        Must be called **after** `prepare_snapshots`.
        """
        if self.snapshots is None:
            raise RuntimeError("Call prepare_snapshots() before build_loaders()")

        # ---- chronological sort (snap_time is 1-elem tensor) -----------------
        snaps_sorted = sorted(self.snapshots, key=lambda g: g.snap_time.item())

        if self.cfg.split_mode == "date":
            if self.cfg.cutoff_date is None:
                raise ValueError("`cutoff_date` must be set in cfg for date split")

            cutoff_int = pd.Timestamp(self.cfg.cutoff_date).value
            train_set = [g for g in snaps_sorted if g.snap_time.item() <= cutoff_int]
            holdout   = [g for g in snaps_sorted if g.snap_time.item() >  cutoff_int]

            if not holdout:
                raise ValueError("No snapshots after cutoff_date for val/test split")

            # simple 50-50 val/test split on holdout
            mid = len(holdout) // 2
            val_set, test_set = holdout[:mid], holdout[mid:]

        elif self.cfg.split_mode == "ratio":
            n_total = len(snaps_sorted)
            n_test  = int(n_total * self.cfg.test_ratio)
            n_val   = int(n_total * self.cfg.val_ratio)
            n_train = n_total - n_val - n_test

            train_set = snaps_sorted[:n_train]
            val_set   = snaps_sorted[n_train:n_train + n_val]
            test_set  = snaps_sorted[n_train + n_val:]

        else:
            raise ValueError(f"Unknown split_mode '{self.cfg.split_mode}'")

        # ---- DataLoaders ------------------------------------------------------
        self.train_dl = DataLoader(
            train_set,
            batch_size=self.cfg.batch_size,
            shuffle=self.cfg.shuffle_in_split,
        )
        self.val_dl = DataLoader(
            val_set,
            batch_size=self.cfg.batch_size,
            shuffle=False,
        )
        self.test_dl = DataLoader(
            test_set,
            batch_size=self.cfg.batch_size,
            shuffle=False,
        )

        return self

    # =================== placeholders for later ==============================
    def init_model(self) -> "GNNExperiment":
        """Instantiate the requested GNN backbone + head (TODO)."""
        # TODO
        return self

    def compile(self) -> "GNNExperiment":
        """Attach loss fn, optimiser, schedulers, etc. (TODO)."""
        # TODO
        return self

    def train(self, debug: bool = False):
        """Main training loop (TODO)."""
        # TODO
        pass

    def evaluate(self, split: str = "val") -> Dict[str, float]:
        """Evaluate on a chosen split (TODO)."""
        # TODO
        return {}

    def predict(self, node_frames_new, edge_frames_new, timestamps):
        """Predict on unseen timestamps (TODO)."""
        # TODO
        pass

    def plot_history(self, metric: str = "loss"):
        """Matplotlib learning-curve plot (TODO)."""
        # TODO
        pass

    def save(self, run_dir: Union[str, pathlib.Path]):
        """Save config, weights, scaler, etc. (TODO)."""
        # TODO
        pass

    @classmethod
    def load(cls, run_dir: Union[str, pathlib.Path]) -> "GNNExperiment":
        """Reload a previously saved experiment (TODO)."""
        # TODO
        pass
