In [1]:


import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch.utils.data import Dataset, DataLoader
import scipy.sparse as sp
import pickle
import time
from tqdm import tqdm
import matplotlib.pyplot as plt

class DataLoader_METR_LA:
    """
    Optimized DataLoader for the METR-LA traffic speed dataset
    Loads continuous time series where each time step's output is the input for the next time step
    """

    def __init__(self, data_path, batch_size, test_batch_size=None,
                 seq_length=12, normalize=True):
        self.data_path = data_path
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size or batch_size
        self.seq_length = seq_length
        self.normalize = normalize

        self.load_data()

    def load_data(self):
        # Load speed data
        df = pd.read_hdf(os.path.join(self.data_path, 'METR-LA.h5'))
        values = df.values  # shape: [time, nodes]
        self.num_nodes = values.shape[1]
        print(f"[Data] Loaded speed matrix: {values.shape}")

        # Load adjacency matrix
        adj_file = 'adj_METR-LA.pkl'
        if not os.path.exists(os.path.join(self.data_path, adj_file)):
            adj_file = 'adj_mx.pkl'

        print(f"[Graph] Loading adjacency matrix from: {adj_file}")
        with open(os.path.join(self.data_path, adj_file), 'rb') as f:
            adj_data = pickle.load(f, encoding='latin1')

        # Handle different formats of adjacency matrix data
        if isinstance(adj_data, (list, tuple)) and len(adj_data) == 3:
            _, _, adj_mx = adj_data
        else:
            adj_mx = adj_data[-1] if isinstance(adj_data, (list, tuple)) else adj_data

        # Ensure the adjacency matrix is a 2D dense matrix
        adj_mx = np.array(adj_mx)
        if adj_mx.ndim > 2:
            adj_mx = adj_mx[0]
        print(f"[Graph] Adjacency shape: {adj_mx.shape}")

        # Convert to PyTorch Geometric format
        adj_coo = sp.coo_matrix(adj_mx)
        indices = np.vstack((adj_coo.row, adj_coo.col))
        self.edge_index = torch.LongTensor(indices)
        self.edge_weight = torch.FloatTensor(adj_coo.data)

        # Data normalization
        if self.normalize:
            self.mean = np.mean(values)
            self.std = np.std(values)
            values = (values - self.mean) / self.std
        else:
            self.mean, self.std = 0, 1

        # Dataset split (70/10/20)
        num_samples = len(values)
        num_train = int(num_samples * 0.7)
        num_test = int(num_samples * 0.2)
        num_val = num_samples - num_train - num_test

        train_data = values[:num_train]
        val_data = values[num_train:num_train + num_val]
        test_data = values[num_train + num_val:]

        print(f"[Split] Train: {train_data.shape}, Val: {val_data.shape}, Test: {test_data.shape}")

        # Create PyTorch datasets and data loaders
        self.train_dataset = METR_LA_Dataset(train_data, self.seq_length)
        self.val_dataset = METR_LA_Dataset(val_data, self.seq_length)
        self.test_dataset = METR_LA_Dataset(test_data, self.seq_length)

        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size,
                                       shuffle=True, drop_last=True)
        self.val_loader = DataLoader(self.val_dataset, batch_size=self.test_batch_size,
                                     shuffle=False, drop_last=True)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.test_batch_size,
                                      shuffle=False, drop_last=True)

    def get_loaders(self):
        return (self.train_loader, self.val_loader, self.test_loader,
                self.edge_index, self.edge_weight, self.mean, self.std)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

class METR_LA_Dataset(Dataset):
    """
    Optimized METR-LA dataset class
    Each sample is a continuous window of length seq_length, including input and target
    x1, x2, ..., x_{seq_length} as input
    x2, x3, ..., x_{seq_length+1} as target
    """
    def __init__(self, data, seq_length):
        self.data = data
        self.seq_length = seq_length
        # We need one more time step than the sequence length (for the label)
        self.num_samples = len(data) - seq_length

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Get window of length seq_length+1
        time_window = self.data[idx : idx + self.seq_length + 1]

        # Swap time and node dimensions to form [num_nodes, total_time_steps]
        time_window = np.transpose(time_window, (1, 0))

        # Separate input sequence and target sequence (target is input shifted by one step)
        # Input: x1, x2, ..., x_{seq_length}
        x_seq = time_window[:, :-1]
        # Target: x2, x3, ..., x_{seq_length+1}
        y_target = time_window[:, 1:]

        # Add feature dimension, convert to tensor
        x_seq = torch.FloatTensor(x_seq).unsqueeze(-1)  # [num_nodes, seq_length, 1]
        y_target = torch.FloatTensor(y_target)          # [num_nodes, seq_length]

        return x_seq, y_target


In [3]:

def mse_loss(preds, labels):
    """Mean squared error loss"""
    return torch.mean((preds - labels) ** 2)

def mae_loss(preds, labels):
    """Mean absolute error loss"""
    return torch.mean(torch.abs(preds - labels))

def rmse_loss(preds, labels):
    """Root mean squared error loss"""
    return torch.sqrt(mse_loss(preds, labels))

def metric(preds, labels):
    """Calculate multiple evaluation metrics"""
    mse_val = mse_loss(preds, labels).item()
    mae_val = mae_loss(preds, labels).item()
    rmse_val = rmse_loss(preds, labels).item()
    return mse_val, mae_val, rmse_val


def train(model, train_loader, optimizer, criterion, device, edge_index, edge_weight, epoch):
    """Train model for one epoch"""
    model.train()
    total_loss = 0

    with tqdm(train_loader, desc=f"Epoch {epoch}") as t:
        for i, (x_seq, y_target) in enumerate(t):
            optimizer.zero_grad()

            # Prepare input data
            x_seq = x_seq.to(device)        # [batch_size, num_nodes, seq_length, 1]
            y_target = y_target.to(device)  # [batch_size, num_nodes, seq_length]
            edge_index = edge_index.to(device)
            if edge_weight is not None:
                edge_weight = edge_weight.to(device)

            # Forward pass
            y_pred = model(x_seq, edge_index, edge_weight)  # [batch_size, num_nodes, seq_length]

            # Calculate loss
            loss = criterion(y_pred, y_target)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Update progress bar
            total_loss += loss.item()
            t.set_postfix(loss=total_loss / (i + 1))

    return total_loss / len(train_loader)


def evaluate(model, data_loader, device, edge_index, edge_weight):
    """Evaluate model on validation or test set"""
    model.eval()
    mse_sum, mae_sum, rmse_sum = 0, 0, 0

    with torch.no_grad():
        for x_seq, y_target in data_loader:
            # Prepare input data
            x_seq = x_seq.to(device)        # [batch_size, num_nodes, seq_length, 1]
            y_target = y_target.to(device)  # [batch_size, num_nodes, seq_length]
            edge_index = edge_index.to(device)
            if edge_weight is not None:
                edge_weight = edge_weight.to(device)

            # Forward pass
            y_pred = model(x_seq, edge_index, edge_weight)  # [batch_size, num_nodes, seq_length]

            # Calculate evaluation metrics
            mse_val, mae_val, rmse_val = metric(y_pred, y_target)
            mse_sum += mse_val
            mae_sum += mae_val
            rmse_sum += rmse_val

    num_batches = len(data_loader)
    return (
        mse_sum / num_batches,
        mae_sum / num_batches,
        rmse_sum / num_batches
    )


def evaluate_future_steps(model, data_loader, device, edge_index, edge_weight, future_steps=3):
    """Evaluate multi-step future prediction"""
    model.eval()
    mse_sum, mae_sum, rmse_sum = 0, 0, 0

    with torch.no_grad():
        for x_seq, y_target in data_loader:
            # Prepare input data
            x_seq = x_seq.to(device)        # [batch_size, num_nodes, seq_length, 1]
            y_target = y_target.to(device)  # [batch_size, num_nodes, seq_length]
            edge_index = edge_index.to(device)
            if edge_weight is not None:
                edge_weight = edge_weight.to(device)

            # Use x_seq as historical data, predict future_steps steps ahead
            y_pred = model.predict_future_steps(
                x_seq, edge_index, edge_weight, future_steps
            )  # [batch_size, num_nodes, future_steps]

            # Get the true future values for the next future_steps from complete target
            true_future = y_target[:, :, :future_steps]

            # Calculate evaluation metrics
            mse_val, mae_val, rmse_val = metric(y_pred, true_future)
            mse_sum += mse_val
            mae_sum += mae_val
            rmse_sum += rmse_val

    num_batches = len(data_loader)
    return (
        mse_sum / num_batches,
        mae_sum / num_batches,
        rmse_sum / num_batches
    )



In [6]:

class GNNLayer(nn.Module):
    """
    Graph Neural Network layer for spatial feature extraction
    """
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.1, use_gat=False):
        super(GNNLayer, self).__init__()

        self.use_gat = use_gat

        if use_gat:
            # GAT version
            self.conv = GATConv(in_channels, out_channels, heads=1, dropout=dropout)
        else:
            # GCN version
            self.conv = GCNConv(in_channels, out_channels)

        self.norm = nn.BatchNorm1d(out_channels)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x, edge_index, edge_weight=None):
        """
        x: node features [batch_size*num_nodes, in_channels]
        edge_index: edge indices [2, num_edges]
        edge_weight: edge weights [num_edges]
        """
        if self.use_gat:
            # GAT doesn't use edge_weight
            x = self.conv(x, edge_index)
        else:
            # GCN uses edge_weight
            x = self.conv(x, edge_index, edge_weight)

        x = self.norm(x)
        x = self.activation(x)
        x = self.dropout(x)

        return x


class RNNBlock(nn.Module):
    """
    RNN/GRU/LSTM block for temporal feature extraction and hidden state propagation
    """
    def __init__(self, input_dim, hidden_dim, rnn_type='gru', dropout=0.1):
        super(RNNBlock, self).__init__()

        self.hidden_dim = hidden_dim
        self.rnn_type = rnn_type.lower()

        # Choose RNN type
        if self.rnn_type == 'gru':
            self.rnn_cell = nn.GRUCell(input_dim, hidden_dim)
        elif self.rnn_type == 'lstm':
            self.rnn_cell = nn.LSTMCell(input_dim, hidden_dim)
        else:
            self.rnn_cell = nn.RNNCell(input_dim, hidden_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, h_prev):
        """
        x: current time step features [batch_size*num_nodes, input_dim]
        h_prev: previous time step hidden state
               - GRU/RNN: [batch_size*num_nodes, hidden_dim]
               - LSTM: tuple of two tensors
        """
        # Handle hidden state based on RNN type
        if self.rnn_type == 'lstm':
            h_prev_h, h_prev_c = h_prev
            h_new, c_new = self.rnn_cell(x, (h_prev_h, h_prev_c))
            h_new = self.dropout(h_new)
            return (h_new, c_new)
        else:
            h_new = self.rnn_cell(x, h_prev)
            h_new = self.dropout(h_new)
            return h_new


class MLPPredictor(nn.Module):
    """
    Multi-layer perceptron predictor for generating predictions from RNN output
    """
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.1):
        super(MLPPredictor, self).__init__()

        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        # Input layer
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        self.norms.append(nn.LayerNorm(hidden_dim))

        # Hidden layers
        for _ in range(num_layers - 2):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.norms.append(nn.LayerNorm(hidden_dim))

        # Output layer
        self.layers.append(nn.Linear(hidden_dim, output_dim))

        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        """
        x: RNN output [batch_size*num_nodes, input_dim]
        """
        # Process hidden layers
        for i, (layer, norm) in enumerate(zip(self.layers[:-1], self.norms)):
            x = layer(x)
            x = norm(x)
            x = self.activation(x)
            x = self.dropout(x)

        # Output layer
        x = self.layers[-1](x)
        return x


class STGNN(nn.Module):
    """
    Spatio-Temporal Graph Neural Network, strictly implemented according to the architecture
    For time series x1, x2, ..., xT:
    1. Time step t: xt → GNN → RNN(ht-1) → ht → MLP → prediction x(t+1)
    2. Hidden state of each time step is passed to the next time step
    """
    def __init__(self,
                 num_nodes,
                 node_features,
                 hidden_dim,
                 seq_length,
                 gnn_type='gcn',
                 rnn_type='gru',
                 dropout=0.1):
        super(STGNN, self).__init__()

        self.num_nodes = num_nodes
        self.node_features = node_features
        self.hidden_dim = hidden_dim
        self.seq_length = seq_length
        self.gnn_type = gnn_type.lower()
        self.rnn_type = rnn_type.lower()

        # GNN layer - for spatial feature extraction
        self.gnn = GNNLayer(
            in_channels=node_features,
            hidden_channels=hidden_dim,
            out_channels=hidden_dim,
            dropout=dropout,
            use_gat=(gnn_type == 'gat')
        )

        # RNN/GRU/LSTM block - for temporal feature extraction
        self.rnn_block = RNNBlock(
            input_dim=hidden_dim,
            hidden_dim=hidden_dim,
            rnn_type=rnn_type,
            dropout=dropout
        )

        # MLP predictor - to predict next time step based on hidden state
        self.predictor = MLPPredictor(
            input_dim=hidden_dim,
            hidden_dim=hidden_dim,
            output_dim=1,  # single-step prediction
            num_layers=2,
            dropout=dropout
        )

    def _init_hidden_state(self, batch_size, device):
        """Initialize RNN hidden state"""
        if self.rnn_type == 'lstm':
            return (torch.zeros(batch_size * self.num_nodes, self.hidden_dim).to(device),
                    torch.zeros(batch_size * self.num_nodes, self.hidden_dim).to(device))
        else:
            return torch.zeros(batch_size * self.num_nodes, self.hidden_dim).to(device)

    def forward(self, x, edge_index, edge_weight=None):
        """
        Training mode: train on the entire input sequence, predict each subsequent time step

        x: input sequence [batch_size, num_nodes, seq_length, node_features]
        edge_index: graph edge index [2, num_edges]
        edge_weight: edge weights [num_edges]

        returns: predictions for all time steps [batch_size, num_nodes, seq_length]
        """
        batch_size = x.size(0)
        device = x.device

        # Store predictions for all time steps
        predictions = []

        # Initialize RNN hidden state
        h = self._init_hidden_state(batch_size, device)

        # Process each time step
        for t in range(self.seq_length):
            # Extract current time step graph data [batch_size, num_nodes, node_features]
            x_t = x[:, :, t, :]

            # Reshape to [batch_size * num_nodes, node_features] for GNN
            x_t_flat = x_t.reshape(batch_size * self.num_nodes, -1)

            # a. Spatial feature extraction through GNN (corresponding to GNN layers in architecture)
            gnn_out = self.gnn(x_t_flat, edge_index, edge_weight)  # [batch_size * num_nodes, hidden_dim]

            # b. Temporal feature extraction through RNN (corresponding to RNN/GRU/LSTM block in architecture)
            h = self.rnn_block(gnn_out, h)

            # c. Predict next time step based on current hidden state (corresponding to MLP in architecture)
            if self.rnn_type == 'lstm':
                pred_input = h[0]  # For LSTM, use h not c
            else:
                pred_input = h

            # Generate prediction through MLP
            pred = self.predictor(pred_input)  # [batch_size * num_nodes, 1]
            pred = pred.reshape(batch_size, self.num_nodes)  # [batch_size, num_nodes]

            # Store current time step prediction
            predictions.append(pred)

        # Stack all time step predictions [batch_size, num_nodes, seq_length]
        predictions = torch.stack(predictions, dim=2)

        return predictions

    def predict_future_steps(self, x_history, edge_index, edge_weight=None, future_steps=3):
        """
        Test mode: autoregressive prediction of multiple future steps using historical data

        x_history: historical data [batch_size, num_nodes, history_length, node_features]
        edge_index: graph edge index [2, num_edges]
        edge_weight: edge weights [num_edges]
        future_steps: number of future steps to predict

        returns: predictions for future time steps [batch_size, num_nodes, future_steps]
        """
        batch_size = x_history.size(0)
        device = x_history.device

        # First establish hidden state using historical data
        h = self._init_hidden_state(batch_size, device)

        # Process all historical time steps
        history_length = x_history.size(2)
        for t in range(history_length):
            x_t = x_history[:, :, t, :]
            x_t_flat = x_t.reshape(batch_size * self.num_nodes, -1)
            gnn_out = self.gnn(x_t_flat, edge_index, edge_weight)
            h = self.rnn_block(gnn_out, h)

        # Start predicting future steps from the last historical point
        future_predictions = []
        current_input = x_history[:, :, -1, :]  # Last historical point

        # Autoregressive prediction of future steps
        for _ in range(future_steps):
            # 1. Prepare current input
            current_input_flat = current_input.reshape(batch_size * self.num_nodes, -1)

            # 2. Process through GNN
            gnn_out = self.gnn(current_input_flat, edge_index, edge_weight)

            # 3. Update hidden state
            h = self.rnn_block(gnn_out, h)

            # 4. Generate prediction
            if self.rnn_type == 'lstm':
                pred_input = h[0]
            else:
                pred_input = h

            pred = self.predictor(pred_input)  # [batch_size * num_nodes, 1]
            pred = pred.reshape(batch_size, self.num_nodes)  # [batch_size, num_nodes]

            # 5. Save prediction
            future_predictions.append(pred)

            # 6. Use current prediction as input for next step
            current_input = pred.unsqueeze(-1)  # [batch_size, num_nodes, 1]

        # Stack all future predictions [batch_size, num_nodes, future_steps]
        future_predictions = torch.stack(future_predictions, dim=2)

        return future_predictions


In [7]:

def mse_loss(preds, labels):
    """Mean squared error loss"""
    return torch.mean((preds - labels) ** 2)

def mae_loss(preds, labels):
    """Mean absolute error loss"""
    return torch.mean(torch.abs(preds - labels))

def rmse_loss(preds, labels):
    """Root mean squared error loss"""
    return torch.sqrt(mse_loss(preds, labels))

def metric(preds, labels):
    """Calculate multiple evaluation metrics"""
    mse_val = mse_loss(preds, labels).item()
    mae_val = mae_loss(preds, labels).item()
    rmse_val = rmse_loss(preds, labels).item()
    return mse_val, mae_val, rmse_val


def train(model, train_loader, optimizer, criterion, device, edge_index, edge_weight, epoch):
    """Train model for one epoch"""
    model.train()
    total_loss = 0

    with tqdm(train_loader, desc=f"Epoch {epoch}") as t:
        for i, (x_seq, y_target) in enumerate(t):
            optimizer.zero_grad()

            # Prepare input data
            x_seq = x_seq.to(device)        # [batch_size, num_nodes, seq_length, 1]
            y_target = y_target.to(device)  # [batch_size, num_nodes, seq_length]
            edge_index = edge_index.to(device)
            if edge_weight is not None:
                edge_weight = edge_weight.to(device)

            # Forward pass
            y_pred = model(x_seq, edge_index, edge_weight)  # [batch_size, num_nodes, seq_length]

            # Calculate loss
            loss = criterion(y_pred, y_target)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Update progress bar
            total_loss += loss.item()
            t.set_postfix(loss=total_loss / (i + 1))

    return total_loss / len(train_loader)


def evaluate(model, data_loader, device, edge_index, edge_weight):
    """Evaluate model on validation or test set"""
    model.eval()
    mse_sum, mae_sum, rmse_sum = 0, 0, 0

    with torch.no_grad():
        for x_seq, y_target in data_loader:
            # Prepare input data
            x_seq = x_seq.to(device)        # [batch_size, num_nodes, seq_length, 1]
            y_target = y_target.to(device)  # [batch_size, num_nodes, seq_length]
            edge_index = edge_index.to(device)
            if edge_weight is not None:
                edge_weight = edge_weight.to(device)

            # Forward pass
            y_pred = model(x_seq, edge_index, edge_weight)  # [batch_size, num_nodes, seq_length]

            # Calculate evaluation metrics
            mse_val, mae_val, rmse_val = metric(y_pred, y_target)
            mse_sum += mse_val
            mae_sum += mae_val
            rmse_sum += rmse_val

    num_batches = len(data_loader)
    return (
        mse_sum / num_batches,
        mae_sum / num_batches,
        rmse_sum / num_batches
    )


def evaluate_future_steps(model, data_loader, device, edge_index, edge_weight, future_steps=3):
    """Evaluate multi-step future prediction"""
    model.eval()
    mse_sum, mae_sum, rmse_sum = 0, 0, 0

    with torch.no_grad():
        for x_seq, y_target in data_loader:
            # Prepare input data
            x_seq = x_seq.to(device)        # [batch_size, num_nodes, seq_length, 1]
            y_target = y_target.to(device)  # [batch_size, num_nodes, seq_length]
            edge_index = edge_index.to(device)
            if edge_weight is not None:
                edge_weight = edge_weight.to(device)

            # Use x_seq as historical data, predict future_steps steps ahead
            y_pred = model.predict_future_steps(
                x_seq, edge_index, edge_weight, future_steps
            )  # [batch_size, num_nodes, future_steps]

            # Get the true future values for the next future_steps from complete target
            true_future = y_target[:, :, :future_steps]

            # Calculate evaluation metrics
            mse_val, mae_val, rmse_val = metric(y_pred, true_future)
            mse_sum += mse_val
            mae_sum += mae_val
            rmse_sum += rmse_val

    num_batches = len(data_loader)
    return (
        mse_sum / num_batches,
        mae_sum / num_batches,
        rmse_sum / num_batches
    )



In [8]:

def train_stgnn(
    data_path='.',
    batch_size=64,
    epochs=100,
    lr=0.001,
    seq_length=12,
    future_steps=3,  # Predict future steps during testing
    hidden_dim=64,
    gnn_type='gcn',  # 'gcn' or 'gat'
    rnn_type='gru',  # 'rnn', 'gru', or 'lstm'
    dropout=0.1,
    use_cuda=True,
    seed=42,
    save_model=True,
    save_path='models'
):
    """Complete training pipeline for STGNN"""
    # Set random seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device('cuda' if use_cuda and torch.cuda.is_available() else 'cpu')
    if device.type == 'cuda':
        torch.cuda.manual_seed(seed)

    print(f"Training device: {device}")
    print(f"Model configuration: GNN={gnn_type}, RNN={rnn_type}, hidden_dim={hidden_dim}")

    # Load dataset
    data_loader = DataLoader_METR_LA(
        data_path, batch_size, seq_length=seq_length
    )
    train_loader, val_loader, test_loader, edge_index, edge_weight, mean, std = data_loader.get_loaders()

    # Initialize model
    model = STGNN(
        num_nodes=207,        # METR-LA has 207 sensors
        node_features=1,      # Input feature dimension
        hidden_dim=hidden_dim,
        seq_length=seq_length,
        gnn_type=gnn_type,
        rnn_type=rnn_type,
        dropout=dropout
    ).to(device)

    # Set optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Use MAE as the main loss function
    criterion = mae_loss

    # Create model save directory
    if save_model and not os.path.exists(save_path):
        os.makedirs(save_path)

    # Training loop variables
    best_val_mae = float('inf')
    best_model_state = None
    train_losses = []
    val_losses = []
    val_metrics = []

    print("Starting training...")
    for epoch in range(1, epochs + 1):
        # Train for one epoch
        train_loss = train(
            model, train_loader, optimizer, criterion,
            device, edge_index, edge_weight, epoch
        )
        train_losses.append(train_loss)

        # Evaluate on validation set
        val_mse, val_mae, val_rmse = evaluate(
            model, val_loader, device, edge_index, edge_weight
        )
        val_losses.append(val_mae)
        val_metrics.append((val_mse, val_mae, val_rmse))

        # Adjust learning rate based on validation MAE
        scheduler.step(val_mae)

        print(f"Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Val MSE: {val_mse:.4f} | MAE: {val_mae:.4f} | RMSE: {val_rmse:.4f}")

        # Evaluate multi-step future prediction every 5 epochs
        if epoch % 5 == 0:
            future_mse, future_mae, future_rmse = evaluate_future_steps(
                model, val_loader, device, edge_index, edge_weight, future_steps
            )
            print(f"Future {future_steps} steps prediction: MSE: {future_mse:.4f} | MAE: {future_mae:.4f} | RMSE: {future_rmse:.4f}")

        # Save the best model based on validation MAE
        if val_mae < best_val_mae:
            best_val_mae = val_mae
            if save_model:
                best_model_state = model.state_dict().copy()
                torch.save(
                    best_model_state,
                    os.path.join(save_path, f'best_stgnn_{gnn_type}_{rnn_type}.pth')
                )
                print(f"New best model saved! (Val MAE: {val_mae:.4f})")

    # Plot training progress
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss (MAE)')
    plt.plot(val_losses, label='Val MAE')
    plt.xlabel('Epoch')
    plt.ylabel('Loss / MAE')
    plt.title(f'Training Progress ({gnn_type.upper()}-{rnn_type.upper()})')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot([m[0] for m in val_metrics], label='Val MSE')
    plt.plot([m[1] for m in val_metrics], label='Val MAE')
    plt.plot([m[2] for m in val_metrics], label='Val RMSE')
    plt.xlabel('Epoch')
    plt.ylabel('Metric Value')
    plt.title('Validation Metrics')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    if save_model:
        plt.savefig(os.path.join(save_path, f'stgnn_{gnn_type}_{rnn_type}_training.png'))
    plt.show()

    # Load the best saved model for testing
    if save_model and best_model_state is not None:
        model.load_state_dict(best_model_state)
        print("Loaded best model for testing")

    # Evaluate one-step prediction on test set
    print("\nEvaluating next-step prediction on test set...")
    test_mse, test_mae, test_rmse = evaluate(
        model, test_loader, device, edge_index, edge_weight
    )
    print(f"[STGNN {gnn_type.upper()}-{rnn_type.upper()}] Test Results (Next Step) → MSE: {test_mse:.4f}, MAE: {test_mae:.4f}, RMSE: {test_rmse:.4f}")

    # Evaluate multi-step future prediction on test set
    print(f"\nEvaluating future {future_steps} steps prediction on test set...")
    future_mse, future_mae, future_rmse = evaluate_future_steps(
        model, test_loader, device, edge_index, edge_weight, future_steps
    )
    print(f"[STGNN {gnn_type.upper()}-{rnn_type.upper()}] Test Results (Future {future_steps} Steps) → MSE: {future_mse:.4f}, MAE: {future_mae:.4f}, RMSE: {future_rmse:.4f}")

    return model, (test_mse, test_mae, test_rmse), (future_mse, future_mae, future_rmse)




In [9]:
# Train model and get necessary variables for visualization
model, (test_mse, test_mae, test_rmse), (future_mse, future_mae, future_rmse) = train_stgnn(
    data_path='./',
    epochs=10,
    batch_size=64,
    hidden_dim=64,
    seq_length=12,
    future_steps=1,
    gnn_type='gcn',
    rnn_type='gru',
    dropout=0.1
)

Training device: cuda
Model configuration: GNN=gcn, RNN=gru, hidden_dim=64
[Data] Loaded speed matrix: (34272, 207)
[Graph] Loading adjacency matrix from: adj_mx.pkl
[Graph] Adjacency shape: (207, 207)
[Split] Train: (23990, 207), Val: (3428, 207), Test: (6854, 207)
Starting training...


Epoch 1:  36%|███▌      | 134/374 [00:11<00:21, 11.31it/s, loss=0.246]


KeyboardInterrupt: 