| Model “block”                         | Import (PyG / PyG-Temporal unless noted)                                                                        | Edge-attr aware?                                | Big selling point                                                              | Use when …                                                          |
| ------------------------------------- | --------------------------------------------------------------------------------------------------------------- | ----------------------------------------------- | ------------------------------------------------------------------------------ | ------------------------------------------------------------------- |
| **TGCN**                              | `from torch_geometric_temporal.nn.recurrent import TGCN`                                                        | **No** (only adjacency)                         | Gated RNN that embeds `A` inside GRU gates.                                    | You want a light baseline; <1 M params.                             |
| **STGCN (sandwich)**                  | `from torch_geometric_temporal.nn.stgcn import STGCNBlock`                                                      | **No**                                          | GCN layer between two 1-D temporal convs – the classic traffic-forecast block. | Medium horizon (1–48 h), cheap to run.                              |
| **Graph WaveNet**                     | `from torch_geometric_temporal.nn.conv import GraphWaveNet`                                                     | **Yes** (`edge_index`, optional `edge_weight`)  | Dilated causal CNN ⇄ diffusion graph conv; handles 100+ look-back steps.       | You need long receptive field or multihorizon output.               |
| **NNConv + GRU**                      | `python<br>from torch_geometric.nn import NNConv<br>from torch_geometric_temporal.nn.recurrent import GConvGRU` | **Yes** (`edge_attr`)                           | Edge-conditioned weights: tie-line capacity/utilisation gates the message.     | Capacities or dynamic line states matter.                           |
| **DCRNN**                             | `from torch_geometric_temporal.nn.recurrent import DCRNN`                                                       | **No** (diffusion uses adjacency)               | Diffusion conv inside GRU; softer than STGCN for directed graphs.              | You have directed edges & want diffusion bias.                      |
| **ASTGNN / SAN (Adaptive Sp-T Attn)** | `from torch_geometric_temporal.nn.attention import ASTGNNBlock`                                                 | **Yes** (learns adaptive adjacency)             | Learns **extra** soft links via attention on top of physical graph.            | Hidden couplings (e.g. fuel-price link) aren’t in your 9×9 `A`.     |
| **GAT-GRU hybrid**                    | `from torch_geometric.nn import GATv2Conv` + wrap in GRU                                                        | **Optional** (`edge_attr` via concat in `attn`) | Attention tells the model which neighbour to trust at each step.               | Graph is tiny → attention cost trivial; you crave interpretability. |
| **Temporal Graph Network (TGN)**      | *external repo* `pip install pytorch-tgn` → `from tgn import TemporalFusion`                                    | **Yes** (event stream)                          | Continuous-time memory; handles irregular events.                              | You have true event log, not fixed snapshots.                       |
| **Transformer in PyG 2.4**            | `python<br>from torch_geometric.nn.models import GraphTransformer`                                              | **Optional** (`edge_attr` ↦ bias)               | Full-attention over nodes per snapshot, then stack with 1-D temporal SA.       | GPU okay, you want SOTA accuracy / explainability.                  |
| **DGL-Spatio-Temporal modules**       | `from dgl.nn.pytorch import STConv, TGATConv …`                                                                 | **Yes**                                         | Same ideas but DGL backend (better for billion-edge later).                    | You ever migrate to huge graphs; code still looks like PyTorch.     |


Quick pick-list
Start simple: TGCN or STGCNBlock stack.

Edge capacities matter? → switch to NNConv + GRU.

>30-step history or multi-horizon? → Graph WaveNet or ASTGNN.

Irregular event timestamps? → external TGN package.

Want explainable attention & you have GPU slack? → GraphTransformer (PyG 2.4).

In [None]:
# jpx_stgnn.py  ─── Minimal end-to-end ST-GNN for nine-region price forecasts
# ----------------------------------------------------------
import numpy as np, pandas as pd, torch, torch.nn as nn
from torch.utils.data import DataLoader
from torch_geometric_temporal.signal import StaticGraphTemporalSignal
from torch_geometric_temporal.nn.recurrent import TGCN


# ────────────────────────────────────────────────────────────
# 1. Pandas-dict  →  PyG-Temporal dataset
# ────────────────────────────────────────────────────────────
def build_temporal_graph_dataset(region_dfs: dict[str, pd.DataFrame],
                                 adj: np.ndarray) -> StaticGraphTemporalSignal:
    regions   = sorted(region_dfs)
    timestamps = sorted(
        set.intersection(*(set(df.index) for df in region_dfs.values()))
    )

    feat_seq, tgt_seq = [], []
    for ts in timestamps:
        feats, tgts = [], []
        for r in regions:
            row  = region_dfs[r].loc[ts]
            tgts.append(row["target"])
            feats.append(row.drop("target").to_numpy())
        feat_seq.append(torch.tensor(np.vstack(feats), dtype=torch.float32))
        tgt_seq .append(torch.tensor(tgts,          dtype=torch.float32))

    edge_index  = torch.tensor(np.vstack(np.nonzero(adj)), dtype=torch.long)
    edge_weight = torch.tensor(adj[np.nonzero(adj)],        dtype=torch.float32).view(-1)

    return StaticGraphTemporalSignal(edge_index=edge_index,
                                     edge_weight=edge_weight,
                                     features=feat_seq,
                                     targets=tgt_seq)


# ────────────────────────────────────────────────────────────
# 2. Dataset slicing helpers
# ────────────────────────────────────────────────────────────
def slice_region_dfs(region_dfs, idx):
    return {n: df.loc[idx] for n, df in region_dfs.items()}

def make_static_loaders(region_dfs, adj, train_end, batch=1):
    all_idx   = sorted(region_dfs[next(iter(region_dfs))].index)
    train_idx = pd.Index([t for t in all_idx if t <= pd.Timestamp(train_end)])
    test_idx  = pd.Index([t for t in all_idx if t >  pd.Timestamp(train_end)])

    train_ds = build_temporal_graph_dataset(slice_region_dfs(region_dfs, train_idx), adj)
    test_ds  = build_temporal_graph_dataset(slice_region_dfs(region_dfs,  test_idx), adj)

    dl_args = dict(batch_size=batch, shuffle=False)
    return DataLoader(train_ds, **dl_args), DataLoader(test_ds, **dl_args)


def walk_forward_splits(region_dfs, adj,
                        start_train: str,
                        step_days=30, horizon_days=30, batch=1):
    all_idx = sorted(region_dfs[next(iter(region_dfs))].index)
    t_k = pd.Timestamp(start_train)
    last_trainable = all_idx[-1] - pd.Timedelta(days=horizon_days)

    while t_k <= last_trainable:
        train_idx = pd.Index([t for t in all_idx if t <= t_k])
        test_idx  = pd.Index([t for t in all_idx
                              if t_k < t <= t_k + pd.Timedelta(days=horizon_days)])

        train_dl = DataLoader(
            build_temporal_graph_dataset(slice_region_dfs(region_dfs, train_idx), adj),
            batch_size=batch, shuffle=False)
        test_dl  = DataLoader(
            build_temporal_graph_dataset(slice_region_dfs(region_dfs, test_idx),  adj),
            batch_size=batch, shuffle=False)

        yield train_dl, test_dl, t_k
        t_k += pd.Timedelta(days=step_days)


# ────────────────────────────────────────────────────────────
# 3. Model: one-layer TGCN + linear head
#    (swap self.core for NNConv, GraphWaveNet, etc.)
# ────────────────────────────────────────────────────────────
class PriceTGCN(nn.Module):
    def __init__(self, num_nodes, in_feats, hidden=64):
        super().__init__()
        self.core = TGCN(num_nodes=num_nodes,
                         in_channels=in_feats,
                         out_channels=hidden)
        self.head = nn.Linear(hidden, 1)  # scalar per node

    def forward(self, x, edge_index, edge_weight):
        h = self.core(x, edge_index, edge_weight)
        return self.head(h).squeeze(-1)   # → (N,)


# ────────────────────────────────────────────────────────────
# 4. Training & evaluation utilities
# ────────────────────────────────────────────────────────────
def run_epoch(model, loader, opt=None):
    train_mode = opt is not None
    model.train() if train_mode else model.eval()
    loss_fn = nn.L1Loss()
    total = 0.0

    for X_t, y_t in loader:
        X_t, y_t = X_t.squeeze(0), y_t.squeeze(0)
        if train_mode:
            opt.zero_grad()
        y_hat = model(X_t, loader.dataset.edge_index, loader.dataset.edge_weight)
        loss  = loss_fn(y_hat, y_t)
        if train_mode:
            loss.backward(); opt.step()
        total += loss.item()

    return total / len(loader)


# ────────────────────────────────────────────────────────────
# 5A. STATIC-SPLIT experiment  (comment out if not needed)
# ────────────────────────────────────────────────────────────
train_dl, test_dl = make_static_loaders(region_dfs, adjacency,
                                        train_end="2024-06-30")

model = PriceTGCN(num_nodes=adjacency.shape[0],
                  in_feats=train_dl.dataset.features[0].shape[1]).to("cpu")
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    mae = run_epoch(model, train_dl, opt)
    print(f"[Static] epoch {epoch:02d} train-MAE = {mae:.4f}")

print("Static-split test MAE:", run_epoch(model, test_dl))


# ────────────────────────────────────────────────────────────
# 5B. WALK-FORWARD back-test  (comment out if not needed)
# ────────────────────────────────────────────────────────────
walk_mae = []
for train_dl, test_dl, t_k in walk_forward_splits(region_dfs, adjacency,
                                                 start_train="2021-01-01",
                                                 step_days=30, horizon_days=30):
    model = PriceTGCN(num_nodes=adjacency.shape[0],
                      in_feats=train_dl.dataset.features[0].shape[1]).to("cpu")
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)

    for _ in range(5):                          # 5 epochs per window
        run_epoch(model, train_dl, opt)

    mae = run_epoch(model, test_dl)
    walk_mae.append((t_k.strftime("%Y-%m-%d"), mae))
    print(f"[Walk] window end {t_k.date()}  test-MAE = {mae:.4f}")

print("Walk-forward MAE series:", walk_mae)
