In [28]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'src'))

v_dataset = pd.read_csv(r"../dataset/V_small_4.csv")
w_dataset = pd.read_csv(r"../dataset/W_small_4.csv")
# Convert main traffic DataFrame to numpy: [time_steps, num_nodes]
data_np = v_dataset.values

# Define historical and prediction window lengths
n_his, n_pred = 12, 3  # e.g., use past 12 steps to predict next 3

from dataloader import STGCNDataset
stgcn_dataset = STGCNDataset(data_np, n_his, n_pred)
from torch.utils.data import DataLoader
stgcn_loader = DataLoader(stgcn_dataset, batch_size=64, shuffle=True)

# Load adjacency matrix (defines graph connectivity)
w = w_dataset.values
adj = torch.from_numpy(w).float()  # shape: [num_nodes, num_nodes

# Inspect shapes
x, y = stgcn_dataset[0]
print(f"Input shape: {x.shape}")   # Expected: [n_his, num_nodes, 1]
print(f"Target shape: {y.shape}")  # Expected: [n_pred, num_nodes, 1]
print(f"Adjacency shape: {adj.shape}")  # Expected: [num_nodes, num_nodes]

Input shape: torch.Size([12, 4])
Target shape: torch.Size([3, 4])
Adjacency shape: torch.Size([3, 4])


In [21]:
class Temporal(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(Temporal, self).__init__()
        # kernel_size -> Size of 1D temporal kernel
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(1, kernel_size))
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=(1, kernel_size))
        self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=(1, kernel_size))

    def forward(self, x):
        
        # convert to NCHW
        x = x.permute(0, 3, 2, 1)  # [batch_size, in_channels, num_nodes, n_his]

        h = self.conv1(x) + torch.sigmoid(self.conv2(x))
        h = F.relu(h)

        # Recurrent connection
        out = h + self.conv3(x)

        # convert back to NHWC
        out = out.permute(0, 3, 2, 1)  # [batch_size, n_his_out, num_nodes, out_channels]

        return out


In [26]:
for batch_x, batch_y in stgcn_loader:
    print(f"Batch input shape: {batch_x.shape}")
    print(f"Batch target shape: {batch_y.shape}")
    
    # The dataloader should give us data in shape [batch_size, n_his, num_nodes, 1]
    break

# Create temporal conv layer
# For STGCN, we typically expect:
# - in_channels: number of features (1 for traffic flow)
# - out_channels: number of output features
# - kernel_size: temporal window size

temporal_layer = Temporal(in_channels=1, out_channels=64, kernel_size=3)

# Test with a small batch
try:
    # Ensure input has the right shape: [batch_size, n_his, num_nodes, in_channels]
    if len(batch_x.shape) == 3:
        batch_x = batch_x.unsqueeze(-1)
    
    print(f"Input to temporal layer: {batch_x.shape}")
    output = temporal_layer(batch_x)
    print(f"Output from temporal layer: {output.shape}")
    print("✓ Temporal convolution works!")
    
except Exception as e:
    print(f"✗ Error in temporal convolution: {e}")
    print(f"Input shape was: {batch_x.shape}")

Batch input shape: torch.Size([64, 12, 4])
Batch target shape: torch.Size([64, 3, 4])
Input to temporal layer: torch.Size([64, 12, 4, 1])
Output from temporal layer: torch.Size([64, 10, 4, 64])
✓ Temporal convolution works!
