In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

# ===== Dataset =====
class TrafficDataset(Dataset):
    def __init__(self, csv_path, window=4):
        df = pd.read_csv(csv_path)
        df["timestamp"] = pd.to_datetime(df["timestamp"])
        df = df.sort_values(["timestamp", "segment_id"])

        self.segment_ids = sorted(df["segment_id"].unique())
        self.timestamps = sorted(df["timestamp"].unique())

        # Normalize features
        scaler = MinMaxScaler()
        df[["vehicle_count", "avg_speed"]] = scaler.fit_transform(df[["vehicle_count", "avg_speed"]])

        self.X, self.y = [], []
        for i in range(len(self.timestamps) - window - 1):
            x_window = []
            for j in range(window):
                temp = df[df["timestamp"] == self.timestamps[i + j]].sort_values("segment_id")
                x_window.append(temp[["vehicle_count", "avg_speed"]].values)
            x_tensor = torch.tensor(np.stack(x_window), dtype=torch.float32)  # [T, N, F]

            current = df[df["timestamp"] == self.timestamps[i + window]].sort_values("segment_id")["avg_speed"].values
            prev = df[df["timestamp"] == self.timestamps[i + window - 1]].sort_values("segment_id")["avg_speed"].values
            y_tensor = torch.tensor(current - prev, dtype=torch.float32)  # shape: [N]

            self.X.append(x_tensor)                 # [T, N, F]
            self.y.append(y_tensor.unsqueeze(0))    # [1, N]

        self.X = torch.stack(self.X)      # [B, T, N, F]
        self.y = torch.cat(self.y, dim=0) # [B, N]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


# ===== Refined Model =====
class GNNTransformerRefined(nn.Module):
    def __init__(self, num_nodes, feature_dim=2, hidden_dim=64, window=4):
        super().__init__()
        self.gcn1 = GCNConv(feature_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.window = window
        self.ln = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.2)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=4),
            num_layers=2
        )
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index):
        B, T, N, feat_dim = x.shape

        # Flatten for GCN input
        x = x.view(B * T * N, feat_dim)
        x = self.gcn1(x, edge_index)
        x = torch.relu(x)
        x = self.gcn2(x, edge_index)
        x = self.ln(x)
        x = self.dropout(x)

        # Reshape back to [B, T, N, hidden]
        x = x.view(B, T, N, -1).permute(0, 2, 1, 3)  # [B, N, T, H]
        x = x.reshape(B * N, T, -1)  # Flatten to [B*N, T, H] for transformer

        x = self.transformer(x.permute(1, 0, 2))  # [T, B*N, H]
        x = x.mean(dim=0)  # [B*N, H]
        x = self.fc(x)     # [B*N, 1]
        x = x.view(B, N)   # Final output: [B, N]

        return x


# ===== Graph builder =====
def build_edge_index(num_segments):
    edges = []
    for i in range(num_segments - 1):
        edges.append([i, i + 1])
        edges.append([i + 1, i])
    return torch.tensor(edges, dtype=torch.long).t().contiguous()

# ===== Training =====
def train_model():
    dataset = TrafficDataset("synthetic_traffic_data.csv")
    loader = DataLoader(dataset, batch_size=16, shuffle=True)
    model = GNNTransformerRefined(num_nodes=len(dataset.segment_ids))
    edge_index = build_edge_index(len(dataset.segment_ids))

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.HuberLoss()

    for epoch in range(30):
        model.train()
        total_loss = 0
        for x_batch, y_batch in loader:
            y_pred = model(x_batch, edge_index)
            loss = loss_fn(y_pred, y_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1} | Loss: {total_loss / len(loader):.4f}")

    torch.save(model.state_dict(), "gnn_transformer_refined.pt")
    print("Refined model saved.")

if __name__ == "__main__":
    train_model()




Epoch 1 | Loss: 0.1676
Epoch 2 | Loss: 0.0252
Epoch 3 | Loss: 0.0241
Epoch 4 | Loss: 0.0227
Epoch 5 | Loss: 0.0224
Epoch 6 | Loss: 0.0223
Epoch 7 | Loss: 0.0224
Epoch 8 | Loss: 0.0226
Epoch 9 | Loss: 0.0220
Epoch 10 | Loss: 0.0226
Epoch 11 | Loss: 0.0223
Epoch 12 | Loss: 0.0225
Epoch 13 | Loss: 0.0222
Epoch 14 | Loss: 0.0220
Epoch 15 | Loss: 0.0222
Epoch 16 | Loss: 0.0220
Epoch 17 | Loss: 0.0220
Epoch 18 | Loss: 0.0220
Epoch 19 | Loss: 0.0222
Epoch 20 | Loss: 0.0223
Epoch 21 | Loss: 0.0221
Epoch 22 | Loss: 0.0220
Epoch 23 | Loss: 0.0218
Epoch 24 | Loss: 0.0219
Epoch 25 | Loss: 0.0219
Epoch 26 | Loss: 0.0219
Epoch 27 | Loss: 0.0223
Epoch 28 | Loss: 0.0221
Epoch 29 | Loss: 0.0218
Epoch 30 | Loss: 0.0218
Refined model saved.
