# DATASCI 315, Group Work 12: Los Angeles Traffic Prediction with Graph Neural Networks

In this assignment, we'll use a graph neural network to perform traffic speed forecasting. Specifically, our goal is to predict the future speeds of vehicles on different road segments based on past traffic data collected from sensors. The model takes historical traffic speed observations from multiple locations over a period of time and infers the expected traffic speeds at those locations for upcoming time intervals (e.g., 15, 30, or 45 minutes ahead). This is a spatiotemporal prediction problem, as it requires understanding both how traffic changes over time and how it propagates through the road network. The predicted speeds can be used to improve traffic management, route planning, and overall efficiency in intelligent transportation systems.

**Dataset:** The LA traffic dataset we will analyze for this assignment was collected by the California Department of Transportation from 228 sensor stations across District 7 (Los Angeles area). Traffic speed data was recorded every 5 minutes on weekdays between May 1 and June 30, 2012. Each station produced 288 readings per day. The dataset captures both spatial and temporal aspects of traffic by representing the network as a graph where nodes are road sensors and edges reflect proximity.

**Method:** In this assignment you will use a graph neural network (GNN) called Spatial-Temporal Graph Attention Network (ST-GAT) to model traffic speed prediction. This GNN combines a Graph Attention Network (GAT) with a Long Short-Term Memory (LSTM) network to capture both spatial and temporal dependencies in the traffic data. The GAT component uses a multi-head attention mechanism to dynamically learn the importance of each road segment's connections, enabling the model to identify which nearby segments most influence traffic flow. To prepare the data for this structure, a "Speed2Vec" method is introduced to convert time-series speed readings into node features. These features are then passed through the GAT to capture spatial relationships and through the LSTM to model how traffic evolves over time. This hybrid architecture allows the model to make accurate traffic forecasts by leveraging both dynamic spatial interactions and sequential temporal patterns.

**Reference:** This dataset and method are introduced in "Spatial-Temporal Graph Attention Networks: A Deep Learning Approach for Traffic Forecasting" by Zhang, Lu, and Liu (2019).

**Instructions:** For this task, it would be helpful to select GPU as the runtime.

We will go through the following steps:

1. Installation and Setup
2. Creating a DataLoader
3. Constructing the Model
4. Developing Training and Evaluation Functions
5. Training the Model
6. Testing the Model

## Installation and setup

Confirm that you are on a GPU.
This should print `cuda`.


In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")

The traffic dataset is stored in the `data/` directory.

In [None]:
from pathlib import Path

DATA_DIR = Path("data")

## Creating a Dataloader
Now, we create a dataloader which will process data from `.csv` files into a PyTorch Geometric dataset.

In [None]:
from pathlib import Path
from shutil import copyfile

import pandas as pd
import torch
from torch_geometric.data import Data, InMemoryDataset


def distance_to_weight(distance_matrix, sigma2=0.1, epsilon=0.5, *, gat_version=False):
    """
    Given distances between all nodes, convert into a weight matrix.

    :param distance_matrix: Matrix of distances between nodes
    :param sigma2: User configurable parameter to adjust sparsity of matrix
    :param epsilon: User configurable parameter to adjust sparsity of matrix
    :param gat_version: If true, use 0/1 weights with self loops. Otherwise, use float
    :return: Adjacency weight matrix
    """
    num_nodes = distance_matrix.shape[0]
    normalized_distances = distance_matrix / 10000.0
    squared_distances = normalized_distances * normalized_distances
    non_self_mask = torch.ones([num_nodes, num_nodes]) - torch.eye(num_nodes)

    # Compute weights using Gaussian kernel (refer to Eq.10 in paper)
    adjacency_weights = (
        torch.exp(-squared_distances / sigma2)
        * (torch.exp(-squared_distances / sigma2) >= epsilon)
        * non_self_mask
    )

    # If using the GAT version, round to 0/1 and include self loops
    if gat_version:
        adjacency_weights[adjacency_weights > 0] = 1
        adjacency_weights += torch.eye(num_nodes)

    return adjacency_weights


class TrafficDataset(InMemoryDataset):
    """Dataset for Graph Neural Networks."""

    def __init__(
        self,
        config,
        adjacency_weights,
        root="",
        transform=None,
        pre_transform=None,
    ):
        self.config = config
        self.adjacency_weights = adjacency_weights
        super().__init__(root, transform, pre_transform)
        self.data, self.slices, self.n_node, self.mean, self.std_dev = torch.load(
            self.processed_paths[0], weights_only=False
        )

    @property
    def raw_file_names(self):
        return [str(Path(self.raw_dir) / "PeMSD7_V_228.csv")]

    @property
    def processed_file_names(self):
        return ["./data.pt"]

    def download(self):
        copyfile(
            DATA_DIR / "PeMSD7_V_228.csv",
            str(Path(self.raw_dir) / "PeMSD7_V_228.csv"),
        )

    def process(self):
        """Process the raw datasets into saved .pt dataset for later use.

        Note that any self.fields here wont exist if loading from the .pt file.
        """
        # Data Preprocessing and loading
        data = pd.read_csv(self.raw_file_names[0], header=None).values
        # Technically using the validation and test datasets here, but it's fine
        # Would normally get the mean and std_dev from a large dataset
        mean = torch.mean(data)
        std_dev = torch.std(data)
        data = z_score(data, torch.mean(data), torch.std(data))

        _, n_node = data.shape
        n_window = self.config["N_PRED"] + self.config["N_HIST"]

        # Manipulate nxn matrix into 2 x num_edges format
        edge_index = torch.zeros((2, n_node**2), dtype=torch.long)
        # Create an edge_attr matrix with our weights (num_edges x 1)
        edge_attr = torch.zeros((n_node**2, 1))
        num_edges = 0
        for i in range(n_node):
            for j in range(n_node):
                if self.adjacency_weights[i, j] != 0.0:
                    edge_index[0, num_edges] = i
                    edge_index[1, num_edges] = j
                    edge_attr[num_edges] = self.adjacency_weights[i, j]
                    num_edges += 1
        # Keep only the first num_edges entries
        edge_index = edge_index[:, :num_edges]
        edge_attr = edge_attr[:num_edges, :]

        sequences = []
        # T x F x N
        for day_idx in range(self.config["N_DAYS"]):
            for slot_idx in range(self.config["N_SLOT"]):
                # For each time point construct a different graph with Data object
                graph = Data()
                graph.__num_nodes__ = n_node

                graph.edge_index = edge_index
                graph.edge_attr = edge_attr

                # (F,N) switched to (N,F)
                start_idx = day_idx * self.config["N_DAY_SLOT"] + slot_idx
                end_idx = start_idx + n_window
                # [21, 228]
                full_window = torch.swapaxes(data[start_idx:end_idx, :], 0, 1)
                graph.x = torch.FloatTensor(full_window[:, 0 : self.config["N_HIST"]])
                graph.y = torch.FloatTensor(full_window[:, self.config["N_HIST"] : :])
                sequences += [graph]

        # Make the actual dataset
        data, slices = self.collate(sequences)
        torch.save((data, slices, n_node, mean, std_dev), self.processed_paths[0])


def get_splits(dataset: TrafficDataset, n_slot, splits):
    """Split data into train, val, and test subsets.

    :param dataset: TrafficDataset object to split
    :param n_slot: Number of possible sliding windows in a day
    :param splits: (train, val, test) ratios
    """
    split_train, split_val, _ = splits
    i = n_slot * split_train
    j = n_slot * split_val
    train = dataset[:i]
    val = dataset[i : i + j]
    test = dataset[i + j :]

    return train, val, test

## Build the Model

Using PyG's built-in layers, create a Spatio-Temporal Graph as presented in https://ieeexplore.ieee.org/document/8903252.

This model is a PyTorch model containing an initialization function for setting up the model architecture and a forward function for performing a forward pass of data through the model.

In [None]:
import torch
from torch.nn import functional
from torch_geometric.nn import GATConv


class StGat(torch.nn.Module):
    """Spatio-Temporal Graph Attention Network.

    As presented in https://ieeexplore.ieee.org/document/8903252
    """

    def __init__(self, in_channels, out_channels, n_nodes, heads=8, dropout=0.0):
        """Initialize the ST-GAT model.

        :param in_channels: Number of input channels
        :param out_channels: Number of output channels
        :param n_nodes: Number of nodes in the graph
        :param heads: Number of attention heads to use in graph
        :param dropout: Dropout probability on output of Graph Attention Network
        """
        super().__init__()
        self.n_pred = out_channels
        self.heads = heads
        self.dropout = dropout
        self.n_nodes = n_nodes

        lstm1_hidden_size = 32
        lstm2_hidden_size = 128

        # Single graph attentional layer with 8 attention heads
        self.gat = GATConv(
            in_channels=in_channels,
            out_channels=in_channels,
            heads=heads,
            dropout=0,
            concat=False,
        )

        # Add two LSTM layers
        self.lstm1 = torch.nn.LSTM(
            input_size=self.n_nodes,
            hidden_size=lstm1_hidden_size,
            num_layers=1,
        )
        for name, param in self.lstm1.named_parameters():
            if "bias" in name:
                torch.nn.init.constant_(param, 0.0)
            elif "weight" in name:
                torch.nn.init.xavier_uniform_(param)
        self.lstm2 = torch.nn.LSTM(
            input_size=lstm1_hidden_size,
            hidden_size=lstm2_hidden_size,
            num_layers=1,
        )
        for name, param in self.lstm2.named_parameters():
            if "bias" in name:
                torch.nn.init.constant_(param, 0.0)
            elif "weight" in name:
                torch.nn.init.xavier_uniform_(param)

        # Fully-connected neural network
        self.linear = torch.nn.Linear(lstm2_hidden_size, self.n_nodes * self.n_pred)
        torch.nn.init.xavier_uniform_(self.linear.weight)

    def forward(self, data, device):
        """Forward pass of the ST-GAT model.

        :param data: Data to make a pass on
        :param device: Device to operate on
        """
        x, edge_index = data.x, data.edge_index
        # Apply dropout
        x = torch.FloatTensor(x) if device == "cpu" else torch.cuda.FloatTensor(x)

        # GAT layer: output of gat: [11400, 12]
        x = self.gat(x, edge_index)
        x = functional.dropout(x, self.dropout, training=self.training)

        # RNN: 2 LSTM
        # [batchsize*n_nodes, seq_length] -> [batch_size, n_nodes, seq_length]
        batch_size = data.num_graphs
        n_node = int(data.num_nodes / batch_size)
        x = torch.reshape(x, (batch_size, n_node, data.num_features))
        # For lstm: x should be (seq_length, batch_size, n_nodes)
        # sequence length = 12, batch_size = 50, n_node = 228
        x = torch.movedim(x, 2, 0)
        # [12, 50, 228] -> [12, 50, 32]
        x, _ = self.lstm1(x)
        # [12, 50, 32] -> [12, 50, 128]
        x, _ = self.lstm2(x)

        # Output contains h_t for each timestep, only the last one has all inputs
        # [12, 50, 128] -> [50, 128]
        x = torch.squeeze(x[-1, :, :])
        # [50, 128] -> [50, 228*9]
        x = self.linear(x)

        # Now reshape into final output
        s = x.shape
        # [50, 228*9] -> [50, 228, 9]
        x = torch.reshape(x, (s[0], self.n_nodes, self.n_pred))
        # [50, 228, 9] ->  [11400, 9]
        return torch.reshape(x, (s[0] * self.n_nodes, self.n_pred))

## Create Train and Evaluation Functions

Create a train function which performs a forward and a backward pass using the model.

Create an evaluation function which performs only a forward pass using the model.

These functions will be used in various stages of overall model training and testing.

In [None]:
@torch.no_grad()
def eval_model(model, device, dataloader, eval_type=""):
    """Evaluate model on data.

    :param model: Model to evaluate
    :param device: Device to evaluate on
    :param dataloader: Data loader
    :param eval_type: Name of evaluation type, e.g. Train/Val/Test
    """
    model.eval()
    model.to(device)

    mae = 0
    rmse = 0
    mape = 0
    n = 0

    # Evaluate model on all data
    for i, raw_batch in enumerate(dataloader):
        current_batch = raw_batch.to(device)
        if current_batch.x.shape[0] == 1:
            pass
        else:
            pred = model(current_batch, device)
            truth = current_batch.y.view(pred.shape)
            if i == 0:
                y_pred = torch.zeros(len(dataloader), pred.shape[0], pred.shape[1])
                y_truth = torch.zeros(len(dataloader), pred.shape[0], pred.shape[1])
            truth = un_z_score(truth, dataloader.dataset.mean, dataloader.dataset.std_dev)
            pred = un_z_score(pred, dataloader.dataset.mean, dataloader.dataset.std_dev)
            y_pred[i, : pred.shape[0], :] = pred
            y_truth[i, : pred.shape[0], :] = truth
            rmse += calc_rmse(truth, pred)
            mae += calc_mae(truth, pred)
            mape += calc_mape(truth, pred)
            n += 1
    rmse, mae, mape = rmse / n, mae / n, mape / n

    print(f"{eval_type}, MAE: {mae}, RMSE: {rmse}, MAPE: {mape}")

    # Get the average score for each metric in each batch
    return rmse, mae, mape, y_pred, y_truth


def train(model, device, dataloader, optimizer, loss_fn, epoch):
    """Train model on data.

    :param model: Model to evaluate
    :param device: Device to evaluate on
    :param dataloader: Data loader
    :param optimizer: Optimizer to use
    :param loss_fn: Loss function
    :param epoch: Current epoch
    """
    model.train()
    for _, raw_batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}")):
        current_batch = raw_batch.to(device)
        optimizer.zero_grad()
        y_pred = torch.squeeze(model(current_batch, device))
        loss = loss_fn()(y_pred.float(), torch.squeeze(current_batch.y).float())
        writer.add_scalar("Loss/train", loss, epoch)
        loss.backward()
        optimizer.step()

    return loss

In order to evaluate the performance of the model, we need to define some evaluation metrics.

* The Z-score normalizes data using mean and standard deviation.
* MAPE is mean absolute percentage error.
* RMSE is root mean squared error.
* MAE is mean absolute error.

In [None]:
def z_score(x, mean, std):
    """Z-score normalization function: z = (X - mu) / sigma.

    :param x: torch array, input array to be normalized.
    :param mean: float, the value of mean.
    :param std: float, the value of standard deviation.
    :return: torch array, z-score normalized array.
    """
    return (x - mean) / std


def un_z_score(x_normed, mean, std):
    """Undo the Z-score calculation.

    :param x_normed: torch array, input array to be un-normalized.
    :param mean: float, the value of mean.
    :param std: float, the value of standard deviation.
    """
    return x_normed * std + mean


def calc_mape(ground_truth, prediction):
    """Mean absolute percentage error, given as a % (e.g. 99 -> 99%).

    :param ground_truth: torch array, ground truth.
    :param prediction: torch array, prediction.
    :return: torch scalar, MAPE averages on all elements of input.
    """
    return torch.mean(torch.abs(prediction - ground_truth) / (ground_truth + 1e-15) * 100)


def calc_rmse(ground_truth, prediction):
    """Root mean squared error.

    :param ground_truth: torch array, ground truth.
    :param prediction: torch array, prediction.
    :return: torch scalar, RMSE averages on all elements of input.
    """
    return torch.sqrt(torch.mean((prediction - ground_truth) ** 2))


def calc_mae(ground_truth, prediction):
    """Mean absolute error.

    :param ground_truth: torch array, ground truth.
    :param prediction: torch array, prediction.
    :return: torch scalar, MAE averages on all elements of input.
    """
    return torch.mean(torch.abs(prediction - ground_truth))

Now, let's put it all together. Let's use the `train` and `eval` functions along with the model and dataloadres to create a training function (`model_train`) and testing function (`model_test`).

We also build in tensorboard support for logging of the training metrics over time.


In [None]:
import time
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

# Make a tensorboard writer
writer = SummaryWriter()


def model_train(train_dataloader, val_dataloader, config, device):
    """Train the ST-GAT model. Evaluate on validation dataset as you go.

    :param train_dataloader: Data loader of training dataset
    :param val_dataloader: Dataloader of val dataset
    :param config: configuration to use
    :param device: Device to evaluate on
    """
    # Make the model. Each datapoint in the graph is 228x12: N x F
    model = StGat(
        in_channels=config["N_HIST"],
        out_channels=config["N_PRED"],
        n_nodes=config["N_NODE"],
        dropout=config["DROPOUT"],
    )
    optimizer = optim.Adam(
        model.parameters(),
        lr=config["INITIAL_LR"],
        weight_decay=config["WEIGHT_DECAY"],
    )
    loss_fn = torch.nn.MSELoss

    model.to(device)

    # For every epoch, train the model on training dataset
    # Evaluate model on validation dataset
    for epoch in range(config["EPOCHS"]):
        loss = train(model, device, train_dataloader, optimizer, loss_fn, epoch)
        print(f"Loss: {loss:.3f}")
        if epoch % 5 == 0:
            train_mae, train_rmse, train_mape, _, _ = eval_model(
                model, device, train_dataloader, "Train"
            )
            val_mae, val_rmse, val_mape, _, _ = eval_model(model, device, val_dataloader, "Valid")
            writer.add_scalar("MAE/train", train_mae, epoch)
            writer.add_scalar("RMSE/train", train_rmse, epoch)
            writer.add_scalar("MAPE/train", train_mape, epoch)
            writer.add_scalar("MAE/val", val_mae, epoch)
            writer.add_scalar("RMSE/val", val_rmse, epoch)
            writer.add_scalar("MAPE/val", val_mape, epoch)

    writer.flush()
    # Save the model
    timestr = time.strftime("%m-%d-%H%M%S")
    checkpoint_path = Path(config["CHECKPOINT_DIR"]) / f"model_{timestr}.pt"
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": loss,
        },
        checkpoint_path,
    )

    return model


def model_test(model, test_dataloader, device):
    """Test the ST-GAT model.

    :param model: Model to test
    :param test_dataloader: Data loader of test dataset
    :param device: Device to evaluate on
    """
    eval_model(model, device, test_dataloader, "Test")

---

**Problem 1:** Interpreting Training Metrics with TensorBoard

TensorBoard is a graphical tool to monitor your training process. Run the following code cell (which will initially be empty) and proceed to the training step below. After training is done, click the refresh button (in the top right corner) to view the training results.

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./runs

After training and refreshing the TensorBoard dashboard, answer the following questions in the markdown cell below:

1. At which epoch does the validation metric plateau in terms of MAE, MAPE, and RMSE?
2. If you observe signs of overfitting, at which epoch does it begin?

Attach a screenshot of the TensorBoard output to support your answer.

**Your answers here:**

> BEGIN SOLUTION

1. The validation metrics (MAE, MAPE, RMSE) plateau around epoch 40-50. After this point, the metrics show minimal improvement.

2. Signs of overfitting begin around epoch 45-50, where training loss continues to decrease but validation metrics start to level off or slightly increase.

(Note: Exact epochs may vary based on random initialization)

> END SOLUTION

In [None]:
# Mark as complete after writing your analysis
# BEGIN SOLUTION
problem_1_completed = True
# END SOLUTION

In [None]:
# Test assertions
assert problem_1_completed, "Please complete the analysis above and set problem_1_completed = True"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Manual review required for written answers
assert problem_1_completed is True, "problem_1_completed must be True"
# END HIDDEN TESTS

## Start Training

Now, create your dataloaders and start training!

In our default configuration, we train for 60 epochs with a batch size of 50. You can view your training progress in the TensorBoard above by clicking the "refresh" button to see new data. Training and validation performance are updated every 5 epochs.

In [None]:
from torch_geometric.loader import DataLoader

# Constant config to use throughout
config = {
    "BATCH_SIZE": 50,
    "EPOCHS": 60,
    "WEIGHT_DECAY": 5e-5,
    "INITIAL_LR": 3e-4,
    "CHECKPOINT_DIR": "./runs",
    "N_PRED": 9,
    "N_HIST": 12,
    "DROPOUT": 0.2,
    # Number of possible 5 minute measurements per day
    "N_DAY_SLOT": 288,
    # Number of days worth of data in the dataset
    "N_DAYS": 44,
    # If false, use GCN paper weight matrix, if true, use GAT paper weight matrix
    "USE_GAT_WEIGHTS": True,
    "N_NODE": 228,
}
# Number of possible windows in a day
config["N_SLOT"] = config["N_DAY_SLOT"] - (config["N_PRED"] + config["N_HIST"]) + 1

# Load the distance matrix and convert to adjacency weights
distance_matrix = pd.read_csv(DATA_DIR / "PeMSD7_W_228.csv", header=None).values
adjacency_weights = distance_to_weight(distance_matrix, gat_version=config["USE_GAT_WEIGHTS"])
dataset = TrafficDataset(config, adjacency_weights)

# Total of 44 days in the dataset: use 34 for training, 5 for val, 5 for test
d_train, d_val, d_test = get_splits(dataset, config["N_SLOT"], (34, 5, 5))
train_dataloader = DataLoader(d_train, batch_size=config["BATCH_SIZE"], shuffle=True)
val_dataloader = DataLoader(d_val, batch_size=config["BATCH_SIZE"], shuffle=True)
test_dataloader = DataLoader(d_test, batch_size=config["BATCH_SIZE"], shuffle=False)

# Get GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")

# Configure and train model
config["N_NODE"] = dataset.n_node
model = model_train(train_dataloader, val_dataloader, config, device)

## Understanding Graph Structure in PyTorch Geometric

The simplest graphs in PyTorch Geometric are made up of two components: nodes and edge index. Nodes (usually denoted by `x`) contain the features of the nodes, and `edge_index` is a 2-by-`number_of_edges` tensor in which the first row stores the starting node index of edges and the second row stores the destination node index of edges. Hence, each column contains the indices of the two nodes of an edge.

**Reference:** See the [PyTorch Geometric Data documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html) for more details.

---

**Problem 2: Graph Construction in PyTorch Geometric**

In this problem, you will practice constructing graphs using PyTorch Geometric's edge index format.

#### Part (a): Creating a Linear Graph

Create an `edge_index` tensor that represents the following undirected graph:

```
0 --- 1 --- 2 --- 3
```

where the numbers are the node indices.

**Hint:** For an undirected graph, you need to include edges in both directions (e.g., both 0->1 and 1->0).

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Data

# BEGIN SOLUTION
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]], dtype=torch.long)
# END SOLUTION
x = torch.tensor([[-1], [0], [1], [2]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
G = pyg_utils.to_networkx(data, to_undirected=True)

plt.figure(figsize=(3, 3))
nx.draw(G, with_labels=True, node_color="lightblue", font_weight="bold")
plt.show()

In [None]:
# Test assertions
assert edge_index.shape[0] == 2, "edge_index should have 2 rows"
assert data.num_nodes == 4, "Graph should have 4 nodes"
print("All tests passed!")

# BEGIN HIDDEN TESTS
assert edge_index.shape[1] == 6, "A linear graph 0-1-2-3 should have 6 directed edges"
assert data.num_edges == 6, "Graph should have 6 edges"
# Check that all expected edges exist
edges_set = set(zip(edge_index[0].tolist(), edge_index[1].tolist(), strict=False))
assert (0, 1) in edges_set, "Missing edge 0->1"
assert (1, 0) in edges_set, "Missing edge 1->0"
assert (1, 2) in edges_set, "Missing edge 1->2"
assert (2, 1) in edges_set, "Missing edge 2->1"
assert (2, 3) in edges_set, "Missing edge 2->3"
assert (3, 2) in edges_set, "Missing edge 3->2"
# END HIDDEN TESTS

#### Part (b): Creating a Triangle Graph

Create an `edge_index` tensor that represents an undirected triangle with three nodes (0, 1, 2) where each node is connected to every other node.

In [None]:
# BEGIN SOLUTION
edge_index = torch.tensor([[0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 0, 1]], dtype=torch.long)
# END SOLUTION
x = torch.tensor([[0], [1], [2]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
G = pyg_utils.to_networkx(data, to_undirected=True)

plt.figure(figsize=(3, 3))
nx.draw(G, with_labels=True, node_color="lightblue", font_weight="bold")
plt.show()

In [None]:
# Test assertions
assert edge_index.shape[0] == 2, "edge_index should have 2 rows"
assert edge_index.shape[1] == 6, "A triangle should have 6 directed edges (3 undirected edges)"
assert data.num_nodes == 3, "Graph should have 3 nodes"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Check that all expected edges exist for a complete triangle
edges_set = set(zip(edge_index[0].tolist(), edge_index[1].tolist(), strict=False))
assert (0, 1) in edges_set and (1, 0) in edges_set, "Missing edge between 0 and 1"
assert (0, 2) in edges_set and (2, 0) in edges_set, "Missing edge between 0 and 2"
assert (1, 2) in edges_set and (2, 1) in edges_set, "Missing edge between 1 and 2"
# END HIDDEN TESTS

#### Part (c): Batching Graphs as Disconnected Subgraphs

One way to batch multiple graphs is to create a single disconnected graph that contains all subgraphs.

Create an `edge_index` tensor that contains two disconnected subgraphs:
- A linear graph: `0 --- 1 --- 2 --- 3`
- A triangle graph: `4 --- 5 --- 6` (with node 4 connected to 5, 5 to 6, and 6 to 4)

Note that the triangle uses node indices 4, 5, 6 to avoid overlap with the linear graph.

In [None]:
# BEGIN SOLUTION
# Linear graph edges (nodes 0-3) + Triangle edges (nodes 4-6)
edge_index = torch.tensor(
    [[0, 1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6], [1, 0, 2, 1, 3, 2, 5, 6, 4, 6, 4, 5]],
    dtype=torch.long,
)
# END SOLUTION
x = torch.tensor([[0], [1], [2], [3], [4], [5], [6]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
G = pyg_utils.to_networkx(data, to_undirected=True)

plt.figure(figsize=(3, 3))
nx.draw(G, with_labels=True, node_color="lightblue", font_weight="bold")
plt.show()

In [None]:
# Test assertions
assert edge_index.shape[0] == 2, "edge_index should have 2 rows"
assert edge_index.shape[1] == 12, "Combined graph should have 12 directed edges"
assert data.num_nodes == 7, "Graph should have 7 nodes"
print("All tests passed!")

# BEGIN HIDDEN TESTS
edges_set = set(zip(edge_index[0].tolist(), edge_index[1].tolist(), strict=False))
# Check linear graph edges
assert (0, 1) in edges_set and (1, 0) in edges_set, "Missing edge in linear graph"
# Check triangle edges
assert (4, 5) in edges_set and (5, 4) in edges_set, "Missing edge 4-5 in triangle"
assert (5, 6) in edges_set and (6, 5) in edges_set, "Missing edge 5-6 in triangle"
# END HIDDEN TESTS

#### Part (d): Using Batch.from_data_list

The above operation can be done more easily using a convenience function provided by PyTorch Geometric.

Create two separate graphs, each with topology `0 --- 1 --- 2`, and combine them using `Batch.from_data_list()`.

**Reference:** See the [Batch documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Batch.html#torch_geometric.data.Batch.from_data_list) for details on how this function works.

In [None]:
from torch_geometric.data import Batch

# First graph: 0 --- 1 --- 2
# BEGIN SOLUTION
edge_index1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
# END SOLUTION
x1 = torch.tensor([[1], [2], [3]], dtype=torch.float)
data1 = Data(x=x1, edge_index=edge_index1)

# Second graph: 0 --- 1 --- 2
# BEGIN SOLUTION
edge_index2 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
# END SOLUTION
x2 = torch.tensor([[4], [5], [6]], dtype=torch.float)
data2 = Data(x=x2, edge_index=edge_index2)

# Combine with Batch.from_data_list
# BEGIN SOLUTION
batch = Batch.from_data_list([data1, data2])
# END SOLUTION

G = pyg_utils.to_networkx(batch, to_undirected=True)

plt.figure(figsize=(3, 3))
nx.draw(G, with_labels=True, node_color="lightblue", font_weight="bold")
plt.show()

In [None]:
# Test assertions
assert batch.num_graphs == 2, "Batch should contain 2 graphs"
assert batch.num_nodes == 6, "Batch should have 6 total nodes"
print("All tests passed!")

# BEGIN HIDDEN TESTS
assert edge_index1.shape[1] == 4, "First graph should have 4 directed edges"
assert edge_index2.shape[1] == 4, "Second graph should have 4 directed edges"
assert batch.batch is not None, "Batch should have batch attribute"
# END HIDDEN TESTS

## Test the model

Now that we have a trained model, we can test it on the test dataset and visualize its performance

In [None]:
def plot_prediction(_test_dataloader, y_pred, y_truth, node, config):
    """Plot predictions vs ground truth for a specific node."""
    # Calculate the truth
    s = y_truth.shape
    y_truth = y_truth.reshape(s[0], config["BATCH_SIZE"], config["N_NODE"], s[-1])
    # Just get the first prediction out for the nth node
    y_truth = y_truth[:, :, node, 0]
    # Flatten to get the predictions for entire test dataset
    y_truth = torch.flatten(y_truth)
    day0_truth = y_truth[: config["N_SLOT"]]

    # Calculate the predicted
    s = y_pred.shape
    y_pred = y_pred.reshape(s[0], config["BATCH_SIZE"], config["N_NODE"], s[-1])
    # Just get the first prediction out for the nth node
    y_pred = y_pred[:, :, node, 0]
    # Flatten to get the predictions for entire test dataset
    y_pred = torch.flatten(y_pred)
    # Just grab the first day
    day0_pred = y_pred[: config["N_SLOT"]]
    t = list(range(0, config["N_SLOT"] * 5, 5))
    plt.plot(t, day0_pred, label="ST-GAT")
    plt.plot(t, day0_truth, label="truth")
    plt.xlabel("Time (minutes)")
    plt.ylabel("Speed prediction")
    plt.title("Predictions of traffic over time")
    plt.legend()
    plt.savefig("predicted_times.png")
    plt.show()


_, _, _, y_pred, y_truth = eval_model(model, device, test_dataloader, "Test")
plot_prediction(test_dataloader, y_pred, y_truth, 0, config)

---

**Problem 3:** Analyzing Traffic Predictions Across Time

The variable `y_pred_reshape` contains predictions for 9 future time points (0, 5, 10, 15, 20, 25, 30, 35, 40 minutes ahead) based on the previous 12 time points for each of the 228 nodes.

For each node, determine which of the 8 time points (indices 0-7, corresponding to 0-35 minutes) has the lowest predicted speed (worst traffic). Then, create a histogram showing how many nodes have their worst traffic at each time point.

**Hint:** Use `torch.argmin()` to find the index of the minimum value along a specific dimension.

In [None]:
y_pred_reshape = y_pred.reshape(
    y_pred.shape[0], config["BATCH_SIZE"], config["N_NODE"], y_pred.shape[-1]
)[0][0]

fig, ax = plt.subplots(figsize=(4, 4))
t = torch.tensor(list(range(0, config["N_SLOT"] * 5, 5)))
t_ticks = torch.arange(len(t))
# BEGIN SOLUTION
# Find the time index with the minimum speed (worst traffic) for each node
# y_pred_reshape has shape [N_NODE, N_PRED] where N_PRED=9 (but we use first 8)
worst_times_index = torch.argmin(y_pred_reshape[:, :8], dim=1).cpu().numpy()
# END SOLUTION
ax.hist(worst_times_index, bins=torch.linspace(-0.5, 7.5, 9), edgecolor="black")
ax.set_xticks(torch.arange(8))
ax.set_xticklabels(torch.arange(8) * 5)
ax.set_xlabel("Time (minutes)")
ax.set_ylabel("Node counts")

plt.show()

In [None]:
# Test assertions
assert (
    len(worst_times_index) == config["N_NODE"]
), f"Should have {config['N_NODE']} predictions, one per node"
assert worst_times_index.min() >= 0, "Index should be non-negative"
assert worst_times_index.max() <= 7, "Index should be at most 7 (8 time points)"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Verify that worst_times_index represents argmin over time dimension
manual_check = torch.argmin(y_pred_reshape[:, :8], dim=1).cpu().numpy()
assert (worst_times_index == manual_check).all(), "Should be argmin over time dimension"
# END HIDDEN TESTS

## Exploring Graph Neural Network Properties: Equivariance

Graph neural networks enjoy a number of important properties that are natural to processing graph-structured data. First of all, the prediction results should be the same regardless of how we label the nodes of the graph. In our example, the predicted traffic of a node should be the same after relabeling the nodes. This property is called *equivariance*.

Suppose that we have nodes $1, 2, 3$ and relabeled them to $3, 1, 2$. The features of the node according to the original label are $x_i$ ($i=1,2,3$). The predictions are denoted by $f(x_i)$ ($i=1,2,3$). Let $\mathbf{x}$ be the original $(x_1, x_2, x_3)$ and $\mathbf{x'}$ be the reordered features $(x_3, x_1, x_2)$. Then by equivariance, the following should hold:

$$
\text{1st element of } f(\mathbf{x'}) = \text{3rd element of } f(\mathbf{x}) \\
\text{2nd element of } f(\mathbf{x'}) = \text{1st element of } f(\mathbf{x}) \\
\text{3rd element of } f(\mathbf{x'}) = \text{2nd element of } f(\mathbf{x})
$$

In other words, the function preserves the permutation of the samples.

We will confirm that the graph neural network we trained is *equivariant*.

---

**Problem 4:** Demonstrating GNN Equivariance

In this problem, you will verify that graph neural networks are equivariant to node permutations.

#### Part (a): Generating Random Permutations

Read the [`torch.randperm` documentation](https://pytorch.org/docs/stable/generated/torch.randperm.html) and use it to generate a random permutation of length 4.

In [None]:
# BEGIN SOLUTION
idx_perm = torch.randperm(4)
# END SOLUTION
idx_perm

In [None]:
# Test assertions
assert idx_perm.shape == torch.Size([4]), "Permutation should have length 4"
assert set(idx_perm.tolist()) == {
    0,
    1,
    2,
    3,
}, "Permutation should contain exactly 0, 1, 2, 3"
print("All tests passed!")

# BEGIN HIDDEN TESTS
assert idx_perm.dtype == torch.int64, "Permutation should be int64 dtype"
# END HIDDEN TESTS

#### Part (b): Understanding argsort

You can obtain the new index of the $i$-th element by using `argsort()` on `idx_perm`. Read the [`torch.argsort` documentation](https://pytorch.org/docs/stable/generated/torch.argsort.html) and verify this by applying `argsort` to your permutation from part (a).

**Example:** If `idx_perm = [2, 0, 1, 3]`, then element 0 moved to position 1, element 1 moved to position 2, element 2 moved to position 0, and element 3 stayed at position 3. So `argsort(idx_perm) = [1, 2, 0, 3]`.

In [None]:
# BEGIN SOLUTION
idx_new = torch.argsort(idx_perm)
# END SOLUTION
idx_new

In [None]:
# Test assertions
assert idx_new.shape == torch.Size([4]), "argsort result should have length 4"
assert set(idx_new.tolist()) == {
    0,
    1,
    2,
    3,
}, "argsort should contain exactly 0, 1, 2, 3"
# Verify argsort property: idx_perm[idx_new[i]] should give sorted order
assert (idx_perm[idx_new] == torch.arange(4)).all(), "argsort property should hold"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Additional verification of argsort
for idx in range(4):
    assert (
        idx_perm[idx_new[idx]] == idx
    ), f"argsort[{idx}] should point to where {idx} is in idx_perm"
# END HIDDEN TESTS

#### Part (c): Applying Permutations to Tensors

Using the permutation `idx_perm` generated above, permute the tensor `[1, 6, 7, 2]` by indexing with the permutation.

In [None]:
# BEGIN SOLUTION
result = torch.tensor([1, 6, 7, 2])[idx_perm]
# END SOLUTION
result

In [None]:
# Test assertions
original = torch.tensor([1, 6, 7, 2])
assert result.shape == original.shape, "Result should have same shape as original"
assert (result == original[idx_perm]).all(), "Result should be original indexed by idx_perm"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Verify permutation was applied correctly
for idx in range(4):
    assert (
        result[idx] == original[idx_perm[idx]]
    ), f"Position {idx} should have value from position idx_perm[{idx}]"
# END HIDDEN TESTS

#### Part (d): Getting Dataset Length

Now we are going to permute the test dataset `d_test`. First, get the length of the dataset and print it.

In [None]:
# BEGIN SOLUTION
len_test = len(d_test)
# END SOLUTION
print(len_test)

In [None]:
# Test assertions
assert len_test == 1340, f"Test dataset should have 1340 samples, got {len_test}"
print("All tests passed!")

# BEGIN HIDDEN TESTS
assert len_test > 0, "Test dataset should not be empty"
assert len_test == len(d_test), "len_test should equal len(d_test)"
# END HIDDEN TESTS

#### Part (e): Extracting a Batch

Obtain a batch from the test dataloader and compute the number of nodes in the batch.

**Hint:** Inspect the `eval` function in the earlier code and iterate through `test_dataloader`, calling `break` after the first iteration to get a single batch.

In [None]:
# BEGIN SOLUTION
# Obtain a batch from the test dataloader
for raw_batch in test_dataloader:
    batch = raw_batch.to(device)
    break
len_x = len(batch.x)
# END SOLUTION

In [None]:
# Test assertions
assert batch is not None, "batch should be defined"
assert len_x > 0, "len_x should be positive"
assert len_x == batch.x.shape[0], "len_x should equal batch.x.shape[0]"
print("All tests passed!")

# BEGIN HIDDEN TESTS
assert hasattr(batch, "edge_index"), "batch should have edge_index attribute"
assert batch.x.device.type == device, f"batch should be on {device}"
# END HIDDEN TESTS

#### Part (f): Demonstrating Equivariance

Feed the permuted dataset to the model and demonstrate equivariance by comparing the output from permuted input to the permuted output from non-permuted input.

First, we define a simple graph attention module which is the first layer of the full traffic prediction model.

In [None]:
model_gat = GATConv(
    in_channels=config["N_HIST"],
    out_channels=config["N_PRED"],
    heads=8,
    dropout=0,
)
model_gat.to("cuda")

Next, we define the index variables required to permute the model input and the outputs.
The two arrays will be used to permute nodes and their associated edge indexes.

`idx_perm` contains the permutation and `idx_new` stores the new position of the original indices.

In [None]:
idx_perm = torch.randperm(len_x).cuda()
idx_new = torch.argsort(idx_perm)

Compute the model output of the pre-permuted batch.

In [None]:
pred_preperm = model_gat(batch.x, batch.edge_index)
pred_preperm = pred_preperm[idx_perm]

Write code that permutes the batch data. You need to:
1. Permute the node features `batch_postperm.x` using `idx_perm`
2. Update the edge indices using `idx_new` (since edges point to node indices, which have moved)
3. Run the model on the permuted data
4. Compare with the pre-permuted output to verify equivariance

In [None]:
batch_postperm = batch.clone()
batch_postperm.x = batch.x[idx_perm]
# BEGIN SOLUTION
# Extract each row of edge_index, apply idx_new to map old indices to new positions
edge_index_new = torch.stack((idx_new[batch.edge_index[0]], idx_new[batch.edge_index[1]]))
# Update edge index of batch_postperm
batch_postperm.edge_index = edge_index_new
# Call model with permuted graph
pred_postperm = model_gat(batch_postperm.x, batch_postperm.edge_index)
# END SOLUTION
# Check equivariance
print(f"The two tensors are the same: {torch.allclose(pred_preperm, pred_postperm)}")

In [None]:
# Test assertions
assert torch.allclose(
    pred_preperm, pred_postperm
), "Model should be equivariant: permuted output should match"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Additional equivariance check
assert pred_postperm.shape == pred_preperm.shape, "Output shapes should match"
# Check that the edge indices were properly transformed
assert (
    batch_postperm.edge_index.shape == batch.edge_index.shape
), "Edge index shape should be preserved"
# END HIDDEN TESTS