# Robust Traffic Prediction LSTM (Anti-Mean Bias)
This version implements advanced strategies to overcome the 'mean-prediction' problem and capture traffic spikes.

### Architectual & Loss Fixes:
1. **Weighted MSE Loss:** Penalizes errors on higher traffic counts more heavily, forcing the model to deviate from the mean.
2. **Residual Bi-LSTM:** Adds skip connections between layers to preserve the original signal across the deep network.
3. **Feature Scaling:** Uses `StandardScaler` to handle the relative variance of different directions.
4. **Attention Head:** Refined Attention mechanism to isolate specific temporal triggers for traffic surges.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import joblib
import os

torch.manual_seed(42)
np.random.seed(42)

## 1. Feature Engineering & Variance Enhancement

In [None]:
df = pd.read_csv('../lstm_training_data.csv')

# 1. Cyclical Time Features
df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)

# 2. Local Trends (Velocity of traffic change)
for d in ['North', 'South', 'East', 'West']:
    df[f'{d}_trend'] = df[d].diff().fillna(0)
    df[f'{d}_roll_10'] = df[d].rolling(window=10).mean().fillna(method='bfill')

feature_cols = [
    'North', 'South', 'East', 'West', 
    'hour_sin', 'hour_cos', 'day_of_week',
    'North_trend', 'South_trend', 'East_trend', 'West_trend',
    'North_roll_10', 'South_roll_10', 'East_roll_10', 'West_roll_10'
]
target_cols = ['target_North', 'target_South', 'target_East', 'target_West']

# Normalize features and targets separately
scaler_x = StandardScaler()
scaler_y = StandardScaler()

split_idx = int(len(df) * 0.8)
X_train = scaler_x.fit_transform(df.iloc[:split_idx][feature_cols])
Y_train = scaler_y.fit_transform(df.iloc[:split_idx][target_cols])
X_test = scaler_x.transform(df.iloc[split_idx:][feature_cols])
Y_test = scaler_y.transform(df.iloc[split_idx:][target_cols])

joblib.dump(scaler_x, 'scaler_x.pkl')
joblib.dump(scaler_y, 'scaler_y.pkl')

In [None]:
class TrafficDataset(Dataset):
    def __init__(self, x, y, lookback=45):
        self.x = torch.FloatTensor(x)
        self.y = torch.FloatTensor(y)
        self.lookback = lookback

    def __len__(self):
        return len(self.x) - self.lookback

    def __getitem__(self, idx):
        return self.x[idx : idx + self.lookback], self.y[idx + self.lookback]

train_ds = TrafficDataset(X_train, Y_train)
test_ds = TrafficDataset(X_test, Y_test)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

## 2. Model: Residual Attention-LSTM

In [None]:
class ResidualLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, 
                            batch_first=True, bidirectional=True, dropout=0.2)
        
        # Attention to weigh important timesteps
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, output_dim)
        )

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        
        # Apply Attention
        attn_weights = torch.softmax(self.attention(lstm_out), dim=1)
        context = torch.sum(attn_weights * lstm_out, dim=1)
        
        return self.fc(context)

## 3. Weighted Loss Function
This is the most critical part for fixing the 'flat line' prediction. We penalize errors more if the actual value is higher.

In [None]:
def weighted_mse_loss(preds, targets):
    # Standard MSE
    loss = (preds - targets) ** 2
    
    # Increase weight for points where target is > 0.5 standard deviations from mean
    weights = torch.where(torch.abs(targets) > 0.5, 4.0, 1.0)
    return (loss * weights).mean()

## 4. Training Loop

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResidualLSTM(len(feature_cols), 128, len(target_cols)).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)

best_loss = float('inf')
for epoch in range(50):
    model.train()
    total_train = 0
    for bx, by in train_loader:
        bx, by = bx.to(device), by.to(device)
        optimizer.zero_grad()
        pred = model(bx)
        loss = weighted_mse_loss(pred, by)
        loss.backward()
        optimizer.step()
        total_train += loss.item()
    
    model.eval()
    total_val = 0
    with torch.no_grad():
        for bx, by in test_loader:
            bx, by = bx.to(device), by.to(device)
            total_val += weighted_mse_loss(model(bx), by).item()
    
    val_loss = total_val/len(test_loader)
    scheduler.step(val_loss)
    
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        
    if epoch % 5 == 0:
        print(f"Epoch {epoch} | Train Loss: {total_train/len(train_loader):.4f} | Val Loss: {val_loss:.4f}")

## 5. Verification
Checking if the model now tracks the spikes.

In [None]:
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
x_batch, y_batch = next(iter(test_loader))
with torch.no_grad():
    preds = model(x_batch.to(device)).cpu().numpy()

preds_rescaled = scaler_y.inverse_transform(preds)
actual_rescaled = scaler_y.inverse_transform(y_batch.numpy())

plt.figure(figsize=(15, 5))
plt.plot(actual_rescaled[:200, 0], label='Actual (North)', alpha=0.8, color='blue')
plt.plot(preds_rescaled[:200, 0], label='Predicted (North)', linestyle='--', color='red')
plt.title('Improved Traffic Prediction with Spike Weighting')
plt.legend()
plt.show()