<center>
    <h1>Verbal Explanation of Spatial Temporal GNNs for Traffic Forecasting</h1>
    <h2>[TEST] Error of the STGNN by event kind on Metr-LA</h2>
</center>

---

In this notebook the predictions of the *STGNN* on the *Metr-LA* dataset are tested on different event kinds.

In [1]:
import sys
import os

# Set the main path in the root folder of the project.
sys.path.append(os.path.join('..'))

In [2]:
# Settings for autoreloading.
%load_ext autoreload
%autoreload 2

In [3]:
from src.utils.seed import set_random_seed

# Set the random seed for deterministic operations.
SEED = 42
set_random_seed(SEED)

In [4]:
import torch

# Set the device for training and querying the model.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'The selected device is: "{DEVICE}"')

The selected device is: "cuda"


# 1 Loading the Data
In this section the data is loaded.

In [5]:
import os

BASE_DATA_DIR = os.path.join('..', 'data', 'metr-la')

In [6]:
import pickle
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [7]:
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.data.data_extraction import get_adjacency_matrix

# Get the adjacency matrix
adj_matrix_structure = get_adjacency_matrix(
    os.path.join(BASE_DATA_DIR, 'raw', 'adj_mx_metr_la.pkl'))

# Get the header of the adjacency matrix, the node indices and the
# matrix itself.
header, node_ids_dict, adj_matrix = adj_matrix_structure

# Get the STGNN and load the checkpoints.
spatial_temporal_gnn = SpatialTemporalGNN(9, 1, 12, 12, adj_matrix, DEVICE, 64)

stgnn_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                      'st_gnn_metr_la.pth')

stgnn_checkpoints = torch.load(stgnn_checkpoints_path)
spatial_temporal_gnn.load_state_dict(stgnn_checkpoints['model_state_dict'])

# Set the model in evaluation mode.
spatial_temporal_gnn.eval();

In [8]:
from src.data.data_extraction import get_locations_dataframe

# Get the dataframe containing the latitude and longitude of each sensor.
locations_df = get_locations_dataframe(
    os.path.join(BASE_DATA_DIR, 'raw', 'graph_sensor_locations_metr_la.csv'),
    has_header=True)

In [9]:
# Get the node positions dictionary.
node_pos_dict = { i: id for id, i in node_ids_dict.items() }

In [10]:
import os
import numpy as np
from src.spatial_temporal_gnn.prediction import predict

# Get the data and the values predicted by the STGNN.
x_train = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_train.npy'))
y_train = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'y_train.npy'))
x_val = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_val.npy'))
y_val = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'y_val.npy'))
x_test = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_test.npy'))
y_test = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'y_test.npy'))

In [11]:
from src.data.dataloaders import get_dataloader

train_dataloader = get_dataloader(x_train, y_train, batch_size=64,
                                  shuffle=True)
val_dataloader = get_dataloader(x_val, y_val, batch_size=64,
                                shuffle=False)
test_dataloader = get_dataloader(x_test, y_test, batch_size=64,
                                 shuffle=False)

# 2 Predictions
In this section the predictions of the STGNN are performed considering different event kinds.

In [16]:
from src.utils.config import CONGESTION_THRESHOLD_MPH, SEVERE_CONGESTION_THRESHOLD_MPH
from typing import Tuple, Optional
from src.data.data_processing import Scaler
from src.spatial_temporal_gnn.metrics import MAE, RMSE, MAPE
from torch.utils.data import DataLoader


def validate(
    model: SpatialTemporalGNN, val_dataloader: DataLoader, scaler: Scaler,
    device: str
    ) -> Tuple[float, float, float]:
    """
    Calculate MAE, RMSE and MAPE scores for The Spatial-Temporal GNN
    on the validation set.

    Arguments
    ---------
    model : SpatialTemporalGNN
        The spatial temporal graph neural network to be trained.
    val_dataloader : DataLoader
        The data loader for the validation set.
    scaler : Scaler
        The scaler used to scale the input and output data.
    device : str
        The device to run the model on.
    n_timestamps_to_predict : int, optional
        Number of timestamps to predict. If None, predict all
        the timestamps. By default None.

    Returns
    -------
    float
        The average MAE score of the predictions on the
        validation set.
    float
        The average RMSE score of the predictions on the
        validation set.
    float
        The average MAPE score of the predictions on the
        validation set.
    """
    torch.cuda.empty_cache()

    # Initialize the validation criterions.
    mae_criterion = MAE()
    rmse_criterion = RMSE()
    mape_criterion = MAPE()

    # Inizialize running errors.
    running_val_mae_severe_congestions = 0.
    running_val_mae_congestions = 0.
    running_val_mae_free_flows = 0.
    running_val_rmse_severe_congestions = 0.
    running_val_rmse_congestions = 0.
    running_val_rmse_free_flows = 0.
    running_val_mape_severe_congestions = 0.
    running_val_mape_congestions = 0.
    running_val_mape_free_flows = 0.
    
    # Initialize counts.
    severe_congestions_count = 0
    congestions_count = 0
    free_flows_count = 0

    with torch.no_grad():
        for _, (x, y) in enumerate(val_dataloader):
            # Get the data.
            x = x.type(torch.float32).to(device=device)
            y = y.type(torch.float32).to(device=device)
            
            y_severe_congestions = y.clone()
            y_severe_congestions[(y > SEVERE_CONGESTION_THRESHOLD_MPH)] = 0.
            y_congestions = y.clone()
            y_congestions[(y <= SEVERE_CONGESTION_THRESHOLD_MPH) | (y > CONGESTION_THRESHOLD_MPH)] = 0.
            y_free_flows = y.clone()
            y_free_flows[y <= CONGESTION_THRESHOLD_MPH] = 0.
            
            #print(y_severe_congestions.shape)
            #print(y_congestions.shape)
            #print(y_free_flows.shape)

            # Scale the input data.
            x = scaler.scale(x)

            # Compute the output.
            y_pred = model(x)

            # Un-scale the predictions.
            y_pred = scaler.un_scale(y_pred)

            # Get the prediction errors and update the running errors.
            if len(y_severe_congestions[y_severe_congestions > 0]) > 0:
                severe_congestions_count += 1
                mae_severe_congestion = mae_criterion(y_pred, y_severe_congestions)
                rmse_severe_congestion = rmse_criterion(y_pred, y_severe_congestions)
                mape_severe_congestion = mape_criterion(y_pred, y_severe_congestions)
                
                running_val_mae_severe_congestions += mae_severe_congestion.item()
                running_val_rmse_severe_congestions += rmse_severe_congestion.item()
                running_val_mape_severe_congestions += mape_severe_congestion.item()

            if len(y_congestions[y_congestions > 0]) > 0:
                congestions_count += 1
                mae_congestions = mae_criterion(y_pred, y_congestions)
                rmse_congestions = rmse_criterion(y_pred, y_congestions)
                mape_congestions = mape_criterion(y_pred, y_congestions)
                
                running_val_mae_congestions += mae_congestions.item()
                running_val_rmse_congestions += rmse_congestions.item()
                running_val_mape_congestions += mape_congestions.item()

            if len(y_free_flows[y_free_flows > 0]) > 0:
                free_flows_count += 1
                mae_free_flows = mae_criterion(y_pred, y_free_flows)
                rmse_free_flows = rmse_criterion(y_pred, y_free_flows)
                mape_free_flows = mape_criterion(y_pred, y_free_flows)
            
                running_val_mae_free_flows += mae_free_flows.item()
                running_val_rmse_free_flows += rmse_free_flows.item()
                running_val_mape_free_flows += mape_free_flows.item()

    # Remove unused tensors from gpu memory.
    torch.cuda.empty_cache()

    return (running_val_mae_severe_congestions / severe_congestions_count, \
           running_val_mae_congestions / congestions_count, \
           running_val_mae_free_flows / free_flows_count), \
           (running_val_rmse_severe_congestions / severe_congestions_count, \
           running_val_rmse_congestions / congestions_count, \
           running_val_rmse_free_flows / free_flows_count), \
           (running_val_mape_severe_congestions / severe_congestions_count, \
           running_val_mape_congestions / congestions_count, \
           running_val_mape_free_flows / free_flows_count)
                  

In [17]:
mae, rmse, mape = validate(
    spatial_temporal_gnn, test_dataloader, scaler,
    DEVICE)

print('Results on the test set on severe congestions:',
      f'MAE {mae[0]:.3g} - RMSE {rmse[0]:.3g}',
      f'- MAPE {mape[0] * 100.:.3g} %')

print('Results on the test set on congestions:',
      f'MAE {mae[1]:.3g} - RMSE {rmse[1]:.3g}',
      f'- MAPE {mape[1] * 100.:.3g} %')

print('Results on the test set on free-flows:',
      f'MAE {mae[2]:.3g} - RMSE {rmse[2]:.3g}',
      f'- MAPE {mape[2] * 100.:.3g} %')


Results on the test set on severe congestions: MAE 12.2 - RMSE 17.9 - MAPE 65.1 %
Results on the test set on congestions: MAE 5.21 - RMSE 7.73 - MAPE 10.6 %
Results on the test set on free-flows: MAE 1.96 - RMSE 4.16 - MAPE 3.02 %
