## T-GCN

In [1]:
# Install torch geometric -- for pyg
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch-geometric

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m54.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m68.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m34.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import InMemoryDataset, Data
import pandas as pd
import numpy as np
import pickle

  import torch_geometric.typing
  import torch_geometric.typing


In [3]:
from google.colab import drive
drive.mount('/content/drive')
!ls /content/drive

Mounted at /content/drive
MyDrive  Shareddrives


In [4]:
class SkyNetDataset(InMemoryDataset):
  def __init__(self, root, transform=None, pre_transform=None):
    super().__init__(root, transform, pre_transform)
    self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

  @property
  def raw_file_names(self):
    return []

  @property
  def processed_file_names(self):
    return ['flight_graphs.pt']

  def process(self):
    pass

dataset = SkyNetDataset(root="/content/drive/Shareddrives/CS_224W_Project/data/data/skynet_clean_graphs")

In [5]:
# Helper functions
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

def calculate_laplacian_with_self_loop(matrix: torch.Tensor) -> torch.Tensor:
    """
    Normalized Laplacian with self-loops:
        L = D^{-1/2} (A + I) D^{-1/2}
    """
    device = matrix.device
    matrix = matrix + torch.eye(matrix.size(0), device=device)
    row_sum = matrix.sum(1)
    d_inv_sqrt = torch.pow(row_sum, -0.5).flatten()
    d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.0
    d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
    normalized_laplacian = (
        matrix.matmul(d_mat_inv_sqrt).transpose(0, 1).matmul(d_mat_inv_sqrt)
    )
    return normalized_laplacian


class TGraphConvolution(nn.Module):
    """
    Graph convolution used inside the T-GCN gates, but with multi-dimensional
    node features.

    Args:
        adj:           (N, N) adjacency matrix
        input_dim:     F  (node feature dim)
        num_gru_units: H  (hidden dim)
        output_dim:    O

    Inputs:
        inputs:       (B, N, F)
        hidden_state: (B, N*H) or (B, N, H)

    Output:
        (B, N*O)
    """
    def __init__(self, adj, input_dim: int, num_gru_units: int,
                 output_dim: int, bias: float = 0.0):
        super().__init__()
        self._input_dim = input_dim
        self._num_gru_units = num_gru_units
        self._output_dim = output_dim
        self._bias_init_value = bias

        self.register_buffer(
            "laplacian",
            calculate_laplacian_with_self_loop(torch.as_tensor(adj, dtype=torch.float32))
        )
        self.weights = nn.Parameter(
            torch.FloatTensor(self._input_dim + self._num_gru_units, self._output_dim)
        )
        self.biases = nn.Parameter(torch.FloatTensor(self._output_dim))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weights)
        nn.init.constant_(self.biases, self._bias_init_value)

    def forward(self, inputs, hidden_state):
        # inputs: (B, N, F)
        batch_size, num_nodes, in_dim = inputs.shape
        assert in_dim == self._input_dim

        # hidden_state -> (B, N, H)
        if hidden_state.dim() == 2:
            hidden_state = hidden_state.view(batch_size, num_nodes, self._num_gru_units)
        elif hidden_state.dim() == 3:
            assert hidden_state.shape[2] == self._num_gru_units
        else:
            raise ValueError(f"hidden_state must have dim 2 or 3, got {hidden_state.dim()}")

        # concat node features & hidden states
        concat = torch.cat((inputs, hidden_state), dim=2)   # (B, N, F+H)
        B, N, C = concat.shape

        # A * [x, h]
        concat = concat.permute(1, 2, 0).reshape(N, C * B)  # (N, C*B)
        a_times_concat = self.laplacian @ concat            # (N, C*B)
        a_times_concat = a_times_concat.view(N, C, B).permute(2, 0, 1)  # (B, N, C)
        a_times_concat = a_times_concat.reshape(B * N, C)   # (B*N, C)

        outputs = a_times_concat @ self.weights + self.biases   # (B*N, O)
        outputs = outputs.view(B, N, self._output_dim)          # (B, N, O)
        outputs = outputs.reshape(B, N * self._output_dim)      # (B, N*O)
        return outputs


class TGCNCell(nn.Module):
    """
    One T-GCN cell (one time step).

    Args:
        adj:        (N, N)
        input_dim:  F
        hidden_dim: H
    """
    def __init__(self, adj, input_dim: int, hidden_dim: int):
        super().__init__()
        self._input_dim = input_dim
        self._hidden_dim = hidden_dim
        self.register_buffer("adj", torch.as_tensor(adj, dtype=torch.float32))

        # gate [r, u]
        self.graph_conv1 = TGraphConvolution(
            self.adj, self._input_dim, self._hidden_dim, self._hidden_dim * 2, bias=1.0
        )
        # candidate c
        self.graph_conv2 = TGraphConvolution(
            self.adj, self._input_dim, self._hidden_dim, self._hidden_dim
        )

    def forward(self, inputs, hidden_state):
        """
        inputs:       (B, N, F)
        hidden_state: (B, N*H)
        returns:      (B, N*H), (B, N*H)
        """
        # [r, u] = sigmoid(A[x, h]W + b)  -> (B, N*2H)
        concatenation = torch.sigmoid(self.graph_conv1(inputs, hidden_state))
        r, u = torch.chunk(concatenation, chunks=2, dim=1)  # each (B, N*H)

        # candidate c = tanh(A[x, (r * h)]W + b)
        c = torch.tanh(self.graph_conv2(inputs, r * hidden_state))  # (B, N*H)

        # new hidden: u * h + (1 - u) * c
        new_hidden_state = u * hidden_state + (1.0 - u) * c
        return new_hidden_state, new_hidden_state


class TGCN(nn.Module):
    """
    Temporal Graph Convolutional Network over airports.

    Inputs:
        inputs: (B, T, N, F)  – batch, time, num_nodes, node_feat_dim
    Outputs:
        (B, N, H) – node embeddings at final time step
    """
    def __init__(self, adj, input_dim: int, hidden_dim: int, **kwargs):
        super().__init__()
        self._num_nodes = adj.shape[0]
        self._input_dim = input_dim
        self._hidden_dim = hidden_dim
        self.register_buffer("adj", torch.as_tensor(adj, dtype=torch.float32))
        self.tgcn_cell = TGCNCell(self.adj, self._input_dim, self._hidden_dim)

    def forward(self, inputs):
        batch_size, seq_len, num_nodes, in_dim = inputs.shape
        assert num_nodes == self._num_nodes
        assert in_dim == self._input_dim

        hidden_state = torch.zeros(
            batch_size,
            num_nodes * self._hidden_dim,
            device=inputs.device,
            dtype=inputs.dtype,
        )

        output = None
        for t in range(seq_len):
            x_t = inputs[:, t, :, :]          # (B, N, F)
            output, hidden_state = self.tgcn_cell(x_t, hidden_state)

        # output: (B, N*H) -> (B, N, H)
        last_output = output.view(batch_size, num_nodes, self._hidden_dim)
        return last_output

In [6]:
# Edge-level T-GCN delay model
class TGCNDelayModel(nn.Module):
    """
    Full model:
      - T-GCN over airport nodes: node weather+geo → time-aware embeddings
      - MLP over edges: [h_src || h_dst || edge_features] → delay (minutes).

    We use regression (MSE) here.
    """
    def __init__(self, adj, node_feat_dim: int, edge_feat_dim: int,
                 hidden_dim: int = 64, edge_hidden_dim: int = 64):
        super().__init__()
        self.num_nodes = adj.shape[0]

        # Temporal graph encoder over airports
        self.tgcn = TGCN(adj, node_feat_dim, hidden_dim)

        # Edge prediction head
        in_dim = 2 * hidden_dim + edge_feat_dim
        self.edge_mlp = nn.Sequential(
            nn.Linear(in_dim, edge_hidden_dim),
            nn.ReLU(),
            nn.Linear(edge_hidden_dim, 1),
        )

    def forward(self, node_seq, edge_index, edge_attr):
        """
        node_seq:   (B, T, N, F_node)
        edge_index: (B, 2, E)  (we assume B=1 with DataLoader(batch_size=1))
        edge_attr:  (B, E, F_edge)

        Returns:
            preds: (B, E)  – predicted delay (minutes, or normalized minutes)
        """
        batch_size = node_seq.shape[0]
        assert batch_size == 1, "TGCNDelayModel currently assumes batch_size=1."

        # Normalise edge_index shape: (B, 2, E) -> (2, E) (since B=1)
        if edge_index.dim() == 3:
            edge_index = edge_index[0]
        src, dst = edge_index         # (E,), (E,)

        # T-GCN over nodes
        node_embeddings = self.tgcn(node_seq)   # (B=1, N, H)

        # Gather embeddings for origin & destination
        h_src = node_embeddings[:, src, :]      # (B=1, E, H)
        h_dst = node_embeddings[:, dst, :]      # (B=1, E, H)

        # Normalise edge_attr shape
        if edge_attr.dim() == 2:
            edge_attr_expanded = edge_attr.unsqueeze(0).expand(batch_size, -1, -1)
        else:
            edge_attr_expanded = edge_attr      # (B=1, E, F_edge)

        edge_inputs = torch.cat([h_src, h_dst, edge_attr_expanded], dim=-1)  # (B=1, E, 2H+F_e)
        preds = self.edge_mlp(edge_inputs).squeeze(-1)                       # (B=1, E)
        return preds

In [7]:
# Static airport adjacency from PyG dataset
def build_adj_from_dataset(pyg_dataset):
    """
    Build a static adjacency over airports by combining edges
    from all snapshots in the dataset.

    Assumes that:
      - All graphs share the same num_nodes
      - Node indices are consistent across time.
    """
    num_nodes = pyg_dataset[0].num_nodes
    adj = torch.zeros(num_nodes, num_nodes, dtype=torch.float32)

    for data in pyg_dataset:
        ei = data.edge_index
        src = ei[0]
        dst = ei[1]
        adj[src, dst] = 1.0
        adj[dst, src] = 1.0  # undirected for Laplacian

    return adj

adj = build_adj_from_dataset(dataset)
print("Adjacency shape:", adj.shape)

Adjacency shape: torch.Size([107, 107])


In [8]:
# Sequence dataset for T-GCN
class SequenceSkyNetDataset(Dataset):
    """
    Wraps your PyG SkyNetDataset into sequences for T-GCN.

    Each item:
      node_seq:   (T, N, F_node)   – node features for T consecutive time steps
      edge_index: (2, E)           – edges of the last time step
      edge_attr:  (E, F_edge)      – edge features of the last time step
      y:          (E,)             – departure delay labels of the last time step
    """
    def __init__(self, base_dataset, history_len=4, require_edges=True):
        self.base = base_dataset
        self.history_len = history_len
        self.require_edges = require_edges

        # Sort indices by time or block_idx if available
        indices = list(range(len(base_dataset)))
        if hasattr(base_dataset[0], "block_idx"):
            indices = sorted(indices, key=lambda i: int(base_dataset[i].block_idx))

        self.sorted_indices = indices

        # Build list of valid target positions
        valid_positions = []
        for pos in range(history_len - 1, len(self.sorted_indices)):
            idx = self.sorted_indices[pos]
            data = self.base[idx]
            if require_edges and data.edge_index.size(1) == 0:
                continue
            valid_positions.append(pos)

        self.valid_positions = valid_positions

    def __len__(self):
        return len(self.valid_positions)

    def __getitem__(self, idx):
        """
        idx indexes into valid_positions, not directly into base_dataset.
        """
        pos = self.valid_positions[idx]
        # history positions for this target
        hist_positions = self.sorted_indices[pos - self.history_len + 1 : pos + 1]

        graphs = [self.base[i] for i in hist_positions]

        # Node sequence: (T, N, F_node)
        node_seq = torch.stack([g.x for g in graphs], dim=0)

        # Last graph provides edges, edge features, labels
        target = graphs[-1]
        edge_index = target.edge_index
        edge_attr = target.edge_attr
        y = target.y

        return node_seq, edge_index, edge_attr, y

In [9]:
def build_adj_from_seq_subset(seq_dataset, ds_indices):
    """
    Build adjacency using only the graphs that appear in the history windows
    for the given sequence indices (ds_indices are indices into seq_dataset).
    """
    base = seq_dataset.base
    num_nodes = base[0].num_nodes
    adj = torch.zeros(num_nodes, num_nodes, dtype=torch.float32)

    for ds_idx in ds_indices:
        pos = seq_dataset.valid_positions[ds_idx]
        # include the whole history window used for this sequence
        hist_positions = seq_dataset.sorted_indices[
            pos - seq_dataset.history_len + 1 : pos + 1
        ]
        for idx in hist_positions:
            g = base[idx]
            ei = g.edge_index
            src, dst = ei[0], ei[1]
            adj[src, dst] = 1.0
            adj[dst, src] = 1.0  # undirected
    return adj

In [10]:
# Build sequence dataset
history_len = 4  # 4 * 6h = 24h of history
seq_dataset = SequenceSkyNetDataset(dataset, history_len=history_len)

num_samples = len(seq_dataset)
train_end = int(0.7 * num_samples)
val_end   = int(0.85 * num_samples)

train_set = torch.utils.data.Subset(seq_dataset, range(0, train_end))
val_set   = torch.utils.data.Subset(seq_dataset, range(train_end, val_end))
test_set  = torch.utils.data.Subset(seq_dataset, range(val_end, num_samples))

train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=1, shuffle=False)
test_loader  = DataLoader(test_set,  batch_size=1, shuffle=False)

print(len(train_loader), len(val_loader), len(test_loader))

1019 219 219


In [11]:
# Instantiate T-GCN model
sample_graph = dataset[0]
node_feat_dim = sample_graph.x.size(1)         # should be 9: [lat, lon, 7 weather]
edge_feat_dim = sample_graph.edge_attr.size(1) # should be 13: flight features

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TGCNDelayModel(
    adj=adj,
    node_feat_dim=node_feat_dim,
    edge_feat_dim=edge_feat_dim,
    hidden_dim=64,
    edge_hidden_dim=64,
    # task="regression"
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
loss_fn = nn.MSELoss()  # MSE in minutes^2

In [12]:
# Regression metrics
def regression_metrics(preds, targets):
    """
    preds, targets: 1D tensors with all edges concatenated

    Returns:
      dict: MSE, MAE, RMSE
    """
    diff = preds - targets
    mse = torch.mean(diff ** 2).item()
    mae = torch.mean(diff.abs()).item()
    rmse = mse ** 0.5
    return {"MSE": mse, "MAE": mae, "RMSE": rmse}

In [13]:
@torch.no_grad()
def compute_y_stats_from_loader(loader, device):
    """
    Compute mean and std of the raw delay y from the training loader.
    """
    ys = []
    for _, _, _, y in loader:
        ys.append(y.view(-1).float())
    all_y = torch.cat(ys, dim=0).to(device)

    y_mean = all_y.mean()
    y_std = all_y.std()

    # avoid division by zero
    if y_std.item() < 1e-6:
        y_std = y_std.new_tensor(1.0)

    return y_mean, y_std

def train_one_epoch_norm(model, loader, optimizer, device, y_mean, y_std):
    """
    Train for one epoch using normalized targets:
      y_norm = (y - y_mean) / y_std
    """
    model.train()
    total_loss = 0.0
    total_edges = 0

    for node_seq, edge_index, edge_attr, y in loader:
        # node_seq:  (B=1, T, N, F_node)
        # edge_index:(B=1, 2, E)
        # edge_attr: (B=1, E, F_edge)
        # y:         (B=1, E)
        node_seq = node_seq.to(device)
        edge_index = edge_index.to(device)
        edge_attr = edge_attr.to(device)
        y = y.to(device)

        # normalize targets
        y_norm = (y - y_mean) / y_std

        optimizer.zero_grad()

        preds_norm = model(node_seq, edge_index, edge_attr)  # (B=1, E), in normalized space
        preds_norm = preds_norm.view_as(y_norm)

        loss = loss_fn(preds_norm, y_norm)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        num_edges = y.numel()
        total_loss += loss.item() * num_edges
        total_edges += num_edges

    # average MSE in normalized space
    return total_loss / max(total_edges, 1)


@torch.no_grad()
def evaluate_denorm(model, loader, device, y_mean, y_std):
    """
    Evaluate the model, denormalizing predictions back to minutes before
    computing metrics.
    """
    model.eval()
    all_preds = []
    all_targets = []

    for node_seq, edge_index, edge_attr, y in loader:
        node_seq = node_seq.to(device)
        edge_index = edge_index.to(device)
        edge_attr = edge_attr.to(device)
        y = y.to(device)  # raw minutes

        preds_norm = model(node_seq, edge_index, edge_attr)  # (B=1, E), normalized
        preds_norm = preds_norm.view(-1)

        # denormalize predictions back to minutes
        preds = preds_norm * y_std + y_mean
        targets = y.view(-1)  # already in minutes

        all_preds.append(preds.detach().cpu())
        all_targets.append(targets.detach().cpu())

    if not all_preds:
        return {"MSE": float("nan"), "MAE": float("nan"), "RMSE": float("nan")}

    all_preds = torch.cat(all_preds, dim=0)
    all_targets = torch.cat(all_targets, dim=0)

    return regression_metrics(all_preds, all_targets)

In [16]:
# Different training strategies - Comparing prediction results
# (1) Trained on first 9 months and tested on last 3 months
# (2) Trained on first 3 weeks of each month and tested on last week of each month
from torch.utils.data import Subset

DATASET_START = pd.Timestamp("2013-01-01 00:00:00")
BLOCK_HOURS = 6  # 6-hour blocks

def block_to_timestamp(block_idx: int) -> pd.Timestamp:
    return DATASET_START + pd.Timedelta(hours=BLOCK_HOURS * int(block_idx))

def block_to_calendar_info(block_idx: int):
    """
    Returns:
      year, month, week_of_month (0..3), month_index

    week_of_month definition:
      week 0: days  1- 7
      week 1: days  8-14
      week 2: days 15-21
      week 3: days 22-end (the "last week" of the month)
    """
    ts = block_to_timestamp(block_idx)
    year = ts.year
    month = ts.month

    week_idx_raw = (ts.day - 1) // 7
    week_of_month = min(week_idx_raw, 3)  # clamp so "last week" is always 3

    month_index = year * 12 + (month - 1)
    return year, month, week_of_month, month_index

# Build metadata for each sequence sample in seq_dataset

sample_meta = []

for ds_idx in range(len(seq_dataset)):
    # seq_dataset maps indices -> positions in the underlying PyG dataset
    pos = seq_dataset.valid_positions[ds_idx]
    graph_idx = seq_dataset.sorted_indices[pos]
    g = dataset[graph_idx]

    if not hasattr(g, "block_idx"):
        raise AttributeError(
            "Data objects must have a `block_idx` attribute to build calendar-based splits."
        )

    block_idx = int(g.block_idx)
    year, month, week_of_month, month_index = block_to_calendar_info(block_idx)

    sample_meta.append(
        {
            "ds_idx": ds_idx,          # index in seq_dataset
            "graph_idx": graph_idx,    # index in original dataset
            "block_idx": block_idx,
            "year": year,
            "month": month,
            "week_of_month": week_of_month,
            "month_index": month_index,
        }
    )

unique_months = sorted({m["month_index"] for m in sample_meta})

def month_index_to_str(midx: int) -> str:
    y = midx // 12
    m = midx % 12 + 1
    return f"{y}-{m:02d}"

print("Months present in dataset (YYYY-MM):")
print([month_index_to_str(m) for m in unique_months])

# Strategy (1): first 9 months vs last 3 months

n_months = len(unique_months)
if n_months < 12:
    print(
        f"[WARN] Only {n_months} unique months found; "
        "experiment (1) assumes at least 12. "
        "Falling back to first 75% of months for train, last 25% for test."
    )

if n_months >= 12:
    # earliest 9 months for train, latest 3 for test
    train_months_exp1 = set(unique_months[:9])
    test_months_exp1  = set(unique_months[-3:])
else:
    split = max(1, int(0.75 * n_months))
    train_months_exp1 = set(unique_months[:split])
    test_months_exp1  = set(unique_months[split:])

exp1_train_indices = [m["ds_idx"] for m in sample_meta if m["month_index"] in train_months_exp1]
exp1_test_indices  = [m["ds_idx"] for m in sample_meta if m["month_index"] in test_months_exp1]

print(f"Experiment 1 - #train sequences: {len(exp1_train_indices)}, #test sequences: {len(exp1_test_indices)}")

# Strategy (2): first 3 weeks vs last week of each month

exp2_train_indices = [m["ds_idx"] for m in sample_meta if m["week_of_month"] < 3]
exp2_test_indices  = [m["ds_idx"] for m in sample_meta if m["week_of_month"] == 3]

print(f"Experiment 2 - #train sequences: {len(exp2_train_indices)}, #test sequences: {len(exp2_test_indices)}")

adj_exp1 = build_adj_from_seq_subset(seq_dataset, exp1_train_indices)
adj_exp2 = build_adj_from_seq_subset(seq_dataset, exp2_train_indices)

print("Adjacency (Exp1) shape:", adj_exp1.shape)
print("Adjacency (Exp2) shape:", adj_exp2.shape)

# DataLoaders for both strategies

batch_size = 1  # same as your original T-GCN training

exp1_train_loader = DataLoader(Subset(seq_dataset, exp1_train_indices),
                               batch_size=batch_size, shuffle=True)
exp1_test_loader  = DataLoader(Subset(seq_dataset, exp1_test_indices),
                               batch_size=batch_size, shuffle=False)

exp2_train_loader = DataLoader(Subset(seq_dataset, exp2_train_indices),
                               batch_size=batch_size, shuffle=True)
exp2_test_loader  = DataLoader(Subset(seq_dataset, exp2_test_indices),
                               batch_size=batch_size, shuffle=False)

print("DataLoaders for both experiments are ready.")

Months present in dataset (YYYY-MM):
['2013-01', '2013-02', '2013-03', '2013-04', '2013-05', '2013-06', '2013-07', '2013-08', '2013-09', '2013-10', '2013-11', '2013-12']
Experiment 1 - #train sequences: 1089, #test sequences: 368
Experiment 2 - #train sequences: 1005, #test sequences: 452
Adjacency (Exp1) shape: torch.Size([107, 107])
Adjacency (Exp2) shape: torch.Size([107, 107])
DataLoaders for both experiments are ready.


In [17]:
# Run the two T-GCN experiments (model selection based on MAE)

import copy

loss_fn = nn.MSELoss()

def make_tgcn_model(adj_matrix):
    """Create a new T-GCN delay model with the given adjacency."""
    model = TGCNDelayModel(
        adj=adj_matrix,
        node_feat_dim=node_feat_dim,
        edge_feat_dim=edge_feat_dim,
        hidden_dim=128,
        edge_hidden_dim=64,
    ).to(device)
    return model

def run_experiment(train_loader, test_loader, adj_matrix, num_epochs=50, label=""):
    # 1) compute normalization stats only from training data
    y_mean, y_std = compute_y_stats_from_loader(train_loader, device)
    print(
        f"\n{label}\n"
        f"Target normalization: mean={y_mean.item():.3f}, std={y_std.item():.3f}"
    )

    model = make_tgcn_model(adj_matrix)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)

    best_mae = float("inf")
    best_epoch = -1
    best_metrics = None
    best_train_loss = None  # normalized MSE
    best_state_dict = None

    for epoch in range(1, num_epochs + 1):
        # train on normalized y
        train_loss = train_one_epoch_norm(model, train_loader, optimizer, device, y_mean, y_std)

        # evaluate in original units (minutes)
        test_metrics = evaluate_denorm(model, test_loader, device, y_mean, y_std)

        # model selection based on MAE in minutes
        is_best = test_metrics["MAE"] < best_mae
        if is_best:
            best_mae = test_metrics["MAE"]
            best_epoch = epoch
            best_metrics = test_metrics
            best_train_loss = train_loss
            best_state_dict = copy.deepcopy(model.state_dict())

        print(
            f"Epoch {epoch:03d} | "
            f"Train (norm MSE): {train_loss:.4f} | "
            f"Test MSE: {test_metrics['MSE']:.3f}, "
            f"MAE: {test_metrics['MAE']:.3f}, "
            f"RMSE: {test_metrics['RMSE']:.3f}"
        )

    # Restore best model weights (based on MAE)
    if best_state_dict is not None:
        model.load_state_dict(best_state_dict)

    # Print stats for the best epoch
    print(
        f"\nBest {label} epoch: {best_epoch:03d} | "
        f"Train (norm MSE): {best_train_loss:.4f} | "
        f"Test MSE: {best_metrics['MSE']:.3f}, "
        f"MAE: {best_metrics['MAE']:.3f}, "
        f"RMSE: {best_metrics['RMSE']:.3f}"
    )

    # Return model and metrics in original units
    return model, best_metrics

# Number of epochs for the experiments
exp_num_epochs = 100

# (1) Trained on first 9 months and tested on last 3 months
model_exp1, metrics_exp1 = run_experiment(
    exp1_train_loader,
    exp1_test_loader,
    adj_exp1,
    num_epochs=exp_num_epochs,
    label="Exp1: first 9 months vs last 3 months",
)

# (2) Trained on first 3 weeks of each month and tested on last week of each month
model_exp2, metrics_exp2 = run_experiment(
    exp2_train_loader,
    exp2_test_loader,
    adj_exp2,
    num_epochs=exp_num_epochs,
    label="Exp2: first 3 weeks vs last week of each month",
)

print("\nSummary comparison (best MAE over epochs, in minutes)")
print("Exp1 best metrics:", metrics_exp1)
print("Exp2 best metrics:", metrics_exp2)


Exp1: first 9 months vs last 3 months
Target normalization: mean=13.384, std=41.558
Epoch 001 | Train (norm MSE): 1.2785 | Test MSE: 2996.770, MAE: 37.626, RMSE: 54.743
Epoch 002 | Train (norm MSE): 1.2333 | Test MSE: 1658.841, MAE: 33.491, RMSE: 40.729
Epoch 003 | Train (norm MSE): 1.2459 | Test MSE: 1107.792, MAE: 19.139, RMSE: 33.284
Epoch 004 | Train (norm MSE): 1.1525 | Test MSE: 1153.684, MAE: 16.790, RMSE: 33.966
Epoch 005 | Train (norm MSE): 1.1356 | Test MSE: 2523.551, MAE: 41.367, RMSE: 50.235
Epoch 006 | Train (norm MSE): 1.1473 | Test MSE: 1139.415, MAE: 16.216, RMSE: 33.755
Epoch 007 | Train (norm MSE): 1.1119 | Test MSE: 1357.001, MAE: 27.617, RMSE: 36.837
Epoch 008 | Train (norm MSE): 1.1129 | Test MSE: 2120.133, MAE: 37.882, RMSE: 46.045
Epoch 009 | Train (norm MSE): 1.1229 | Test MSE: 1154.454, MAE: 22.071, RMSE: 33.977
Epoch 010 | Train (norm MSE): 1.0861 | Test MSE: 1514.777, MAE: 22.129, RMSE: 38.920
Epoch 011 | Train (norm MSE): 1.0753 | Test MSE: 1146.785, MAE: 2