# Other experiments/models ran

Other experiments that we ran based on how functional our baseline models were. None of these recieved great final MAE's so detailed explaination is not provided in the blog post.

## Setup

In [None]:
# Install torch geometric -- for pyg
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install -q torch-geometric

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import InMemoryDataset, Data
import pandas as pd
import numpy as np
import pickle

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!ls /content/drive

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## Load Dataset

Replace dataset_dest with the folder where you have your skynet graphs

In [None]:
class SkyNetDataset(InMemoryDataset):
  def __init__(self, root, transform=None, pre_transform=None):
    super().__init__(root, transform, pre_transform)
    self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

  @property
  def raw_file_names(self):
    return []

  @property
  def processed_file_names(self):
    return ['flight_graphs.pt']

  def process(self):
    pass

dataset_dest = "/content/drive/Shareddrives/CS_224W_Project/data/data/skynet_clean_graphs"
dataset = SkyNetDataset(root=dataset_dest)

## TGN

### Formulate dataset for TGN

In [None]:
# Sequence dataset for TGN
# Same as dataset for TGCN
class SequenceSkyNetDataset(Dataset):
    """
    Wraps your PyG SkyNetDataset into sequences for T-GCN.

    Each item:
      node_seq:   (T, N, F_node)   – node features for T consecutive time steps
      edge_index: (2, E)           – edges of the last time step
      edge_attr:  (E, F_edge)      – edge features of the last time step
      y:          (E,)             – departure delay labels of the last time step
    """
    def __init__(self, base_dataset, history_len=4, require_edges=True, binary_threshold=15):
        self.base = base_dataset
        self.history_len = history_len
        self.require_edges = require_edges
        self.binary_threshold = binary_threshold

        # Compute edge_attr normalization stats
        all_edge_attrs = []
        for i in range(len(base_dataset)):
            if base_dataset[i].edge_attr.numel() > 0:
                all_edge_attrs.append(base_dataset[i].edge_attr)

        all_edge_attrs = torch.cat(all_edge_attrs, dim=0)
        self.edge_mean = all_edge_attrs.mean(dim=0)
        self.edge_std = all_edge_attrs.std(dim=0)
        self.edge_std[self.edge_std < 1e-6] = 1.0


        # Sort indices by time or block_idx if available
        indices = list(range(len(base_dataset)))
        if hasattr(base_dataset[0], "block_idx"):
            indices = sorted(indices, key=lambda i: int(base_dataset[i].block_idx))

        self.sorted_indices = indices

        # Build list of valid target positions
        valid_positions = []
        for pos in range(history_len - 1, len(self.sorted_indices)):
            idx = self.sorted_indices[pos]
            data = self.base[idx]
            if require_edges and data.edge_index.size(1) == 0:
                continue
            valid_positions.append(pos)

        self.valid_positions = valid_positions

    def __len__(self):
        return len(self.valid_positions)

    def __getitem__(self, idx):
        """
        idx indexes into valid_positions, not directly into base_dataset.
        """
        pos = self.valid_positions[idx]
        # history positions for this target
        hist_positions = self.sorted_indices[pos - self.history_len + 1 : pos + 1]

        graphs = [self.base[i] for i in hist_positions]

        # Node sequence: (T, N, F_node)
        node_seq = torch.stack([g.x for g in graphs], dim=0)

        # Last graph provides edges, edge features, labels
        target = graphs[-1]
        edge_index = target.edge_index
        edge_attr = (target.edge_attr - self.edge_mean) / self.edge_std
        y = target.y

        return node_seq, edge_index, edge_attr, y

In [None]:
# Build sequence dataset
history_len = 8  # 8 * 6h = 48h of history
seq_dataset = SequenceSkyNetDataset(dataset, history_len=history_len)

num_samples = len(seq_dataset)
train_end = int(0.7 * num_samples)
val_end   = int(0.85 * num_samples)

train_set = torch.utils.data.Subset(seq_dataset, range(0, train_end))
val_set   = torch.utils.data.Subset(seq_dataset, range(train_end, val_end))
test_set  = torch.utils.data.Subset(seq_dataset, range(val_end, num_samples))

train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=1, shuffle=False)
test_loader  = DataLoader(test_set,  batch_size=1, shuffle=False)

print(len(train_loader), len(val_loader), len(test_loader))

In [None]:
# Dataset wrapper for TGN
class TGNFlightDataset:
    """
    Converts SkynetDataset to TGN format.

    TGN expects:
    - Chronologically ordered interactions (flights)
    - Each interaction: (src, dst, time, edge_features)
    """
    def __init__(self, subset_or_dataset):
        # Unwrap if it's a Subset
        if isinstance(subset_or_dataset, torch.utils.data.Subset):
            self.seq_dataset = subset_or_dataset.dataset
            self.indices = subset_or_dataset.indices
        else:
            self.seq_dataset = subset_or_dataset
            self.indices = list(range(len(subset_or_dataset)))

        self.interactions = []

        # Extract all flights chronologically
        print("Building TGN dataset...")
        for subset_idx in self.indices:
            node_seq, edge_index, edge_attr, y = self.seq_dataset[subset_idx]

            pos = self.seq_dataset.valid_positions[subset_idx]
            graph_idx = self.seq_dataset.sorted_indices[pos]
            graph = self.seq_dataset.base[graph_idx]

            timestamp = int(graph.block_idx) if hasattr(graph, 'block_idx') else int(subset_idx)

            # Get node features (use last time step)
            node_features = node_seq[-1]

            # Handle edge_index shape
            if edge_index.dim() == 3:
                edge_index = edge_index[0]
            src = edge_index[0]
            dst = edge_index[1]

            # Handle y and edge_attr shapes
            if y.dim() == 2:
                y = y[0]
            if edge_attr.dim() == 3:
                edge_attr = edge_attr[0]

            for e in range(len(src)):
                self.interactions.append({
                    'src': src[e].item(),
                    'dst': dst[e].item(),
                    'time': timestamp,
                    'edge_attr': edge_attr[e],
                    'label': y[e].item(),
                    'node_features': node_features,
                })

    def get_temporal_batches(self, batch_size=200):
        """
        Return batches of interactions in chronological order
        """
        batches = []
        for i in range(0, len(self.interactions), batch_size):
            batch = self.interactions[i:i+batch_size]

            src = torch.tensor([x['src'] for x in batch], dtype=torch.long)
            dst = torch.tensor([x['dst'] for x in batch], dtype=torch.long)
            t = torch.tensor([x['time'] for x in batch], dtype=torch.long)
            edge_attr = torch.stack([x['edge_attr'] for x in batch])
            labels = torch.tensor([x['label'] for x in batch], dtype=torch.float)
            node_features = batch[0]['node_features']  # Assume same for batch

            batches.append((src, dst, t, edge_attr, labels, node_features))

        return batches

In [None]:
print("\n=== Training TGN ===")
tgn_train_dataset = TGNFlightDataset(train_set)
tgn_test_dataset = TGNFlightDataset(test_set)

### Model formulation

In [None]:
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
)

# manual rewrite of TimeEncoder cause I cant figure out how to import it
class TimeEncoder(nn.Module):
    """
    Time encoding module for TGN.
    Encodes time as a learnable function.
    """
    def __init__(self, out_channels):
        super().__init__()
        self.out_channels = out_channels
        self.lin = nn.Linear(1, out_channels)

    def forward(self, t):
        """
        Args:
            t: [E] timestamps
        Returns:
            [E, out_channels] time encodings
        """
        return self.lin(t.view(-1, 1))

import copy

class FlightTGN(nn.Module):
    """
    Temporal Graph Network for flight delay prediction.

    Args:
        num_nodes: Number of airports
        raw_msg_dim: Dimension of raw messages (node_dim + edge_dim + time_dim)
        memory_dim: Dimension of memory embeddings
        time_dim: Dimension of time encodings
        embedding_dim: Final embedding dimension
        edge_feat_dim: Dimension of edge features
        node_feat_dim: Dimension of node features
    """
    def __init__(self, num_nodes, memory_dim=100, time_dim=100,
                 embedding_dim=100, edge_feat_dim=8, node_feat_dim=9):
        super().__init__()

        self.num_nodes = num_nodes
        self.memory_dim = memory_dim
        self.time_dim = time_dim
        self.embedding_dim = embedding_dim

        # Message dimension: node features + edge features + time encoding
        raw_msg_dim = node_feat_dim + edge_feat_dim + time_dim

        # TGN Memory module - tracks interaction history for each node
        self.memory = TGNMemory(
            num_nodes=num_nodes,
            raw_msg_dim=edge_feat_dim + node_feat_dim + time_dim,
            memory_dim=memory_dim,
            time_dim=time_dim,
            message_module=IdentityMessage(
                raw_msg_dim=edge_feat_dim + node_feat_dim + time_dim,
                memory_dim=memory_dim,
                time_dim=time_dim
            ),
            aggregator_module=LastAggregator(),
        )


        self.time_encoder = TimeEncoder(time_dim)

        # Edge embedding module - combines src memory, dst memory, edge features
        self.edge_encoder = nn.Sequential(
            nn.Linear(2 * memory_dim + edge_feat_dim + time_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim),
        )

        # Prediction head - regression (delay in minutes)
        self.predictor = nn.Sequential(
            nn.Linear(embedding_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )

    def forward(self, src, dst, t, edge_attr, node_features):
        """
        Args:
            src: Source node indices [E]
            dst: Destination node indices [E]
            t: Timestamps [E] (in seconds or block indices)
            edge_attr: Edge features [E, edge_feat_dim]
            node_features: Node features [N, node_feat_dim] (optional, can be None)

        Returns:
            predictions: [E] - predicted delay in minutes
        """
        src_memory, _ = self.memory(src)  # [E, memory_dim]
        dst_memory, _ = self.memory(dst)  # [E, memory_dim]

        time_encoding = self.time_encoder(t.float())  # [E, time_dim]

        if edge_attr.dim() == 1:
            edge_attr = edge_attr.unsqueeze(0)

        edge_emb = torch.cat([src_memory, dst_memory, edge_attr, time_encoding], dim=-1)
        edge_emb = self.edge_encoder(edge_emb)

        predictions = self.predictor(edge_emb).squeeze(-1)

        return predictions


    def update_memory(self, src, dst, t, edge_attr, node_features):
        """Update memory after processing a batch of interactions"""
        time_encoding = self.time_encoder(t.float())

        if node_features is not None:
            src_node_feat = node_features[src]
            dst_node_feat = node_features[dst]
        else:
            src_node_feat = torch.zeros(len(src), 1, device=src.device)
            dst_node_feat = torch.zeros(len(dst), 1, device=dst.device)

        raw_msg = torch.cat([src_node_feat, edge_attr, time_encoding], dim=-1)
        self.memory.update_state(src, dst, t, raw_msg)

    def reset_memory(self):
        """Reset memory between epochs or train/test"""
        self.memory.reset_state()



def train_tgn(model, train_dataset, optimizer, device, normalize=True):
    """Train TGN for one epoch"""
    model.train()
    model.reset_memory()

    batches = train_dataset.get_temporal_batches(batch_size=200)

    if normalize:
        all_labels = torch.cat([b[4] for b in batches])
        y_mean = all_labels.mean()
        y_std = all_labels.std()
        y_std = y_std if y_std > 1e-6 else torch.tensor(1.0)

    total_loss = 0.0
    total_samples = 0

    for src, dst, t, edge_attr, labels, node_features in batches:
        src = src.to(device)
        dst = dst.to(device)
        t = t.to(device)
        edge_attr = edge_attr.to(device)
        labels = labels.to(device)
        node_features = node_features.to(device)

        if normalize:
            labels_norm = (labels - y_mean.to(device)) / y_std.to(device)
        else:
            labels_norm = labels

        optimizer.zero_grad()
        preds = model(src, dst, t, edge_attr, node_features)

        loss = nn.MSELoss()(preds, labels_norm)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        with torch.no_grad():
            model.update_memory(src, dst, t, edge_attr, node_features)

        total_loss += loss.item() * len(labels)
        total_samples += len(labels)

    return total_loss / total_samples, y_mean if normalize else None, y_std if normalize else None


@torch.no_grad()
def evaluate_tgn(model, test_dataset, device, y_mean=None, y_std=None):
    """Evaluate TGN"""
    model.eval()
    model.reset_memory()

    batches = test_dataset.get_temporal_batches(batch_size=200)

    all_preds = []
    all_labels = []

    for src, dst, t, edge_attr, labels, node_features in batches:
        src = src.to(device)
        dst = dst.to(device)
        t = t.to(device)
        edge_attr = edge_attr.to(device)
        labels = labels.to(device)
        node_features = node_features.to(device)

        preds = model(src, dst, t, edge_attr, node_features)

        if y_mean is not None and y_std is not None:
            preds = preds * y_std.to(device) + y_mean.to(device)

        model.update_memory(src, dst, t, edge_attr, node_features)

        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    mse = ((all_preds - all_labels) ** 2).mean().item()
    mae = (all_preds - all_labels).abs().mean().item()
    rmse = mse ** 0.5

    return {"MSE": mse, "MAE": mae, "RMSE": rmse}

In [None]:
# Initialize model
num_nodes = dataset[0].num_nodes
model = FlightTGN(
    num_nodes=num_nodes,
    memory_dim=100,
    time_dim=100,
    embedding_dim=100,
    edge_feat_dim=8,
    node_feat_dim=9
).to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [None]:
# Training loop
best_mae = float('inf')
best_metrics = None
best_state = None

for epoch in range(1, 51):
    train_loss, y_mean, y_std = train_tgn(model, tgn_train_dataset, optimizer, device)
    test_metrics = evaluate_tgn(model, tgn_test_dataset, device, y_mean, y_std)

    if test_metrics['MAE'] < best_mae:
        best_mae = test_metrics['MAE']
        best_metrics = test_metrics
        best_state = copy.deepcopy(model.state_dict())

    print(f"Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | "
          f"Test MAE: {test_metrics['MAE']:.2f}, RMSE: {test_metrics['RMSE']:.2f}")

print(f"\nBest Test Metrics: {best_metrics}")

## GCN + temporal features

### Dataset

In [None]:
class SkyNetDataset(InMemoryDataset):
    def __init__(self, root, mode='month_split', transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.mode = mode

        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)
        data_list = [self.get(i) for i in range(len(self))]

        new_list = []
        for data in data_list:

            # extract month/day from Python datetime
            t = data.time
            month = t.month
            day = t.day

            num_edges = data.edge_index.size(1)

            # MONTH-SPLIT (whole graph train/test)
            if self.mode == 'month_split':
                if 1 <= month <= 9:
                    train_mask = torch.ones(num_edges, dtype=torch.bool)
                    test_mask  = torch.zeros(num_edges, dtype=torch.bool)
                else:
                    train_mask = torch.zeros(num_edges, dtype=torch.bool)
                    test_mask  = torch.ones(num_edges, dtype=torch.bool)

            # WEEK-SPLIT (edge-level)
            elif self.mode == 'week_split':
                # weeks 1–3 = days 1–21 → train
                # week 4    = days ≥22   → test
                train_mask = torch.tensor([day <= 21]*num_edges, dtype=torch.bool)
                test_mask  = ~train_mask

            else:
                raise ValueError("mode must be 'month_split' or 'week_split'")

            data.train_edge_mask = train_mask
            data.test_edge_mask  = test_mask
            new_list.append(data)

        self.data, self.slices = self.collate(new_list)

    @property
    def raw_file_names(self): return []
    @property
    def processed_file_names(self): return ['flight_graphs.pt']
    def process(self): pass

In [None]:
dataset_location = '/content/drive/Shareddrives/CS_224W_Project/data/data/skynet_clean_graphs'
dataset = SkyNetDataset(
    root=dataset_location,
    mode='month_split'
)

In [None]:
# Split graphs by month: Jan–Sep -> train, Oct–Dec -> test.
# Each graph corresponds to a 6-hour window, so this enforces
# a *temporal* train/test split rather than randomizing.
train_graphs = [g for g in dataset if g.time.month <= 9]   # Jan–Sep
test_graphs  = [g for g in dataset if g.time.month >= 10]  # Oct–Dec

def merge_graphs(graph_list):
    """
    Merge multiple PyG Data objects into a single large graph.

    We do this so the GCN trains on one big connected component rather
    than many tiny ones. Edge indices must be shifted so that nodes
    from consecutive graphs don't overlap.
    """
    x_list = []
    ei_list = []
    ea_list = []
    y_list = []

    cumulative_nodes = 0

    for g in graph_list:
        N = g.num_nodes

        # append node features
        x_list.append(g.x)

        # shift edge indices by cumulative node count
        ei_list.append(g.edge_index + cumulative_nodes)

        # append edges + labels
        ea_list.append(g.edge_attr)
        y_list.append(g.y)

        cumulative_nodes += N

    # concatenate all
    x = torch.cat(x_list, dim=0)
    edge_index = torch.cat(ei_list, dim=1)
    edge_attr  = torch.cat(ea_list, dim=0)
    y          = torch.cat(y_list, dim=0)

    # masks (all train for train graph, all test for test graph)
    num_edges = edge_index.size(1)

    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=y,
        train_edge_mask=torch.ones(num_edges, dtype=torch.bool),
        test_edge_mask=torch.zeros(num_edges, dtype=torch.bool)
    )
    data.num_nodes = x.size(0)
    return data

In [None]:
def add_temporal_features(graph):
    """
    Add 6 temporal features to edge_attr based on graph.time

    Input: graph with edge_attr [num_edges, 8]
    Output: graph with edge_attr [num_edges, 14]
    """
    timestamp = graph.time
    num_edges = graph.edge_attr.size(0)

    # Extract temporal features from timestamp
    dow = timestamp.weekday()
    dow_sin = np.sin(2 * np.pi * dow / 7)
    dow_cos = np.cos(2 * np.pi * dow / 7)

    hour = timestamp.hour
    hour_sin = np.sin(2 * np.pi * hour / 24)
    hour_cos = np.cos(2 * np.pi * hour / 24)

    month = timestamp.month
    month_sin = np.sin(2 * np.pi * month / 12)
    month_cos = np.cos(2 * np.pi * month / 12)

    # Create temporal feature tensor (repeat for all edges in this graph)
    temporal_features = torch.tensor([
        dow_sin, dow_cos,
        hour_sin, hour_cos,
        month_sin, month_cos
    ], dtype=torch.float).unsqueeze(0).repeat(num_edges, 1)

    # Concatenate with existing edge features
    graph.edge_attr = torch.cat([graph.edge_attr, temporal_features], dim=1)

    return graph

In [None]:
# Split
train_graphs = [g for g in dataset if g.time.month <= 9]
test_graphs = [g for g in dataset if g.time.month >= 10]

# Add temporal features to all graphs
train_graphs = [add_temporal_features(g) for g in train_graphs]
test_graphs = [add_temporal_features(g) for g in test_graphs]

print(f"New edge features: {train_graphs[0].edge_attr.shape}")  # Should be [num_edges, 14]

# Normalize: both x and y features
all_train_x = torch.cat([g.x for g in train_graphs], dim=0)
x_mean = all_train_x.mean(dim=0, keepdim=True)
x_std = all_train_x.std(dim=0, keepdim=True) + 1e-8

all_train_edge = torch.cat([g.edge_attr for g in train_graphs], dim=0)
edge_mean = all_train_edge.mean(dim=0, keepdim=True)
edge_std = all_train_edge.std(dim=0, keepdim=True) + 1e-8

all_train_y = torch.cat([g.y for g in train_graphs])
y_mean = all_train_y.mean()
y_std = all_train_y.std()

for g in train_graphs + test_graphs:
    g.x = (g.x - x_mean) / x_std
    g.edge_attr = (g.edge_attr - edge_mean) / edge_std
    g.y = (g.y - y_mean) / y_std

print(f"Y: mean={y_mean:.2f}, std={y_std:.2f}")

# Merge graphs
print("Merging graphs...")
train_graph = merge_graphs(train_graphs)
test_graph = merge_graphs(test_graphs)

test_graph.train_edge_mask[:] = False
test_graph.test_edge_mask[:] = True

print(f"Train graph: {train_graph}")
print(f"Test graph: {test_graph}")

train_loader = DataLoader([train_graph], batch_size=1)
test_loader = DataLoader([test_graph], batch_size=1)

### Model

Same gcn used for main tests

In [None]:
# ===========================================================
# GCN model with BN + Dropout + Edge MLP (kept simple)
# ===========================================================
class EdgeRegressionGCN3(nn.Module):
    def __init__(self, node_in, edge_in, hid=128, dropout=0.3):
        super().__init__()
        self.dropout = dropout

        self.conv1 = GCNConv(node_in, hid)
        self.conv2 = GCNConv(hid, hid)
        self.conv3 = GCNConv(hid, hid)

        self.bn1 = nn.BatchNorm1d(hid)
        self.bn2 = nn.BatchNorm1d(hid)
        self.bn3 = nn.BatchNorm1d(hid)

        self.edge_mlp = nn.Sequential(
            nn.Linear(hid*2 + edge_in, hid),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hid, hid//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hid//2, 1)
        )

    def forward(self, x, edge_index, edge_attr):
        z = F.relu(self.bn1(self.conv1(x, edge_index)))
        z = F.dropout(z, p=self.dropout, training=self.training)

        z = F.relu(self.bn2(self.conv2(z, edge_index)))
        z = F.dropout(z, p=self.dropout, training=self.training)

        z = F.relu(self.bn3(self.conv3(z, edge_index)))
        z = F.dropout(z, p=self.dropout, training=self.training)

        src, dst = edge_index
        e = torch.cat([z[src], z[dst], edge_attr], dim=1)
        return self.edge_mlp(e).squeeze(-1)

In [None]:
# Model + optimizer
model = EdgeRegressionGCN3(
    node_in=train_graph.x.size(1),
    edge_in=train_graph.edge_attr.size(1),
    hid=256,
    dropout=0.3
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

In [None]:
# Train and evaluate
trained_model, loss_history = train(model, train_loader, optimizer, criterion, epochs=70)
results = evaluate(trained_model, test_loader)

print("\n=== FINAL TEST RESULTS ===")
print("RMSE:", results["rmse"])
print("MAE :", results["mae"])

## GCN + memory

### Dataset

In [None]:
train_graphs = [g for g in dataset if g.time.month <= 9]   # Jan-Sep
test_graphs = [g for g in dataset if g.time.month >= 10]   # Oct-Dec
train_graphs = sorted(train_graphs, key=lambda g: g.time)
test_graphs = sorted(test_graphs, key=lambda g: g.time)

# NORMALIZE
# Node features
all_train_x = torch.cat([g.x for g in train_graphs], dim=0)
x_mean = all_train_x.mean(dim=0, keepdim=True)
x_std = all_train_x.std(dim=0, keepdim=True) + 1e-8

for g in train_graphs + test_graphs:
    g.x = (g.x - x_mean) / x_std

# Edge features
all_train_edge = torch.cat([g.edge_attr for g in train_graphs], dim=0)
edge_mean = all_train_edge.mean(dim=0, keepdim=True)
edge_std = all_train_edge.std(dim=0, keepdim=True) + 1e-8

for g in train_graphs + test_graphs:
    g.edge_attr = (g.edge_attr - edge_mean) / edge_std

# Targets
all_train_y = torch.cat([g.y for g in train_graphs])
y_mean = all_train_y.mean()
y_std = all_train_y.std()

for g in train_graphs + test_graphs:
    g.y = (g.y - y_mean) / y_std

print(f"Y normalization: mean={y_mean:.2f}, std={y_std:.2f}")

train_graph = merge_graphs(train_graphs)
test_graph = merge_graphs(test_graphs)

train_loader = DataLoader([train_graph], batch_size=1)
test_loader = DataLoader([test_graph], batch_size=1)


### Model

In [None]:
# ===========================================================
# GCN model with BN + Dropout + Edge MLP
# Added memory and different edges for message passing vs loss
# ===========================================================
class EdgeRegressionGCNWithMemory(nn.Module):
    def __init__(self, node_in, edge_in, num_airports, hid=128, dropout=0.3, memory_dim=64):
        super().__init__()
        self.dropout = dropout
        self.memory_dim = memory_dim
        self.num_airports = num_airports

        # Airport memory (persistent state)
        self.memory = nn.Parameter(torch.zeros(num_airports, memory_dim))

        # GCN layers (now take node features + memory)
        self.conv1 = GCNConv(node_in + memory_dim, hid)
        self.conv2 = GCNConv(hid, hid)
        self.conv3 = GCNConv(hid, hid)

        self.bn1 = nn.BatchNorm1d(hid)
        self.bn2 = nn.BatchNorm1d(hid)
        self.bn3 = nn.BatchNorm1d(hid)

        # Edge prediction MLP
        self.edge_mlp = nn.Sequential(
            nn.Linear(hid*2 + edge_in, hid),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hid, hid//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hid//2, 1)
        )

        # Project node embeddings to memory dimension
        self.z_to_mem = nn.Linear(hid, memory_dim)

        # Memory update GRU
        self.memory_update = nn.GRUCell(memory_dim, memory_dim)

    def forward(self, x, edge_index, edge_attr, airport_ids):
        node_memory = self.memory[airport_ids]
        x_with_mem = torch.cat([x, node_memory], dim=1)

        # GCN layers
        z = F.relu(self.bn1(self.conv1(x_with_mem, edge_index)))
        z = F.dropout(z, p=self.dropout, training=self.training)

        z = F.relu(self.bn2(self.conv2(z, edge_index)))
        z = F.dropout(z, p=self.dropout, training=self.training)

        z = F.relu(self.bn3(self.conv3(z, edge_index)))
        z = F.dropout(z, p=self.dropout, training=self.training)

        # Edge prediction
        src, dst = edge_index
        e = torch.cat([z[src], z[dst], edge_attr], dim=1)
        pred = self.edge_mlp(e).squeeze(-1)

        # Update memory (no gradients, after prediction)
        with torch.no_grad():
            z_proj = self.z_to_mem(z)  # [num_nodes, memory_dim]

            new_memory = torch.zeros_like(self.memory)
            counts = torch.zeros(self.num_airports, device=x.device)

            new_memory.index_add_(0, airport_ids, z_proj)
            counts.index_add_(0, airport_ids, torch.ones(len(airport_ids), device=x.device))

            new_memory = new_memory / (counts.unsqueeze(1) + 1e-8)
            self.memory.data = self.memory_update(new_memory, self.memory)

        return pred

    def reset_memory(self):
        self.memory.data.zero_()

In [None]:
def train_with_memory(model, loader, optimizer, criterion, epochs=40):
    model.train()
    loss_history = []

    for epoch in range(epochs):
        model.reset_memory()  # Reset at start of epoch
        total_loss = 0

        for batch in loader:
            batch = batch.to(device)
            optimizer.zero_grad()

            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.airport_ids)
            loss = criterion(pred[batch.train_edge_mask], batch.y[batch.train_edge_mask])

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        loss_history.append(total_loss)
        print(f"Epoch {epoch+1:03d} | Loss = {total_loss:.4f}")

    return model, loss_history


In [None]:
# ===========================================================
# Evaluation function
# ===========================================================
@torch.no_grad()
def evaluate_with_memory(model, loader, y_mean, y_std):
    model.eval()
    model.reset_memory()  # Reset memory for test set

    total_mse = 0
    total_mae = 0
    total_count = 0

    for batch in loader:
        batch = batch.to(device)

        pred_norm = model(batch.x, batch.edge_index, batch.edge_attr, batch.airport_ids)

        # Denormalize
        pred = pred_norm * y_std + y_mean
        target = batch.y * y_std + y_mean

        mask = batch.test_edge_mask
        p = pred[mask]
        t = target[mask]

        total_mse += F.mse_loss(p, t, reduction='sum').item()
        total_mae += F.l1_loss(p, t, reduction='sum').item()
        total_count += t.numel()

    return {
        "mse": total_mse / total_count,
        "rmse": (total_mse / total_count) ** 0.5,
        "mae": total_mae / total_count
    }


In [None]:
num_airports = train_graphs[0].num_nodes
print(f"\nNumber of airports: {num_airports}")

model = EdgeRegressionGCNWithMemory(
    node_in=train_graph.x.size(1),
    edge_in=train_graph.edge_attr.size(1),
    num_airports=num_airports,
    hid=256,
    dropout=0.3,
    memory_dim=64
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

In [None]:
model, loss_history = train_with_memory(model, train_loader, optimizer, criterion, epochs=100)

In [None]:
results = evaluate_with_memory(model, test_loader, y_mean, y_std)
print("\n=== FINAL TEST RESULTS ===")
print(f"RMSE: {results['rmse']:.2f} minutes")
print(f"MAE : {results['mae']:.2f} minutes")