# Training SpatioTemporal TCN Autoencoder

This notebook shows how to train and evaluate the **SpatioTemporal TCN Autoencoder** on the freeway dataset, mirroring the baseline notebooks.

In [1]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from models import SpatioTemporalTCNAutoencoder
from parameters import STAEParameters, TrainingParameters
from training import train_sttcn_ae, test_sttcn_ae, compute_anomaly_threshold_sttcn_ae

from datautils import (
    get_full_data, normalize_data, label_anomalies,
    generate_edges
)
from torch_geometric.data import Data as PyGData
from tqdm import tqdm




  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from metrics import calculate_accuracy, crash_detection_delay, calculate_tp_fp, find_thresholds, find_delays, find_percent, discrete_fp_delays, generate_anomaly_labels, calculate_auc

### Time Series Sequencing

In [3]:
from torch_geometric.data import Data
from datautils import generate_edges, temporalize_sequence
def sequence_sttcn_ae(df, timesteps, hide_anomalies=True):
    """
    Build sliding-window sequences for the SpatioTemporal TCN Autoencoder using the same
    column conventions as the baseline notebooks (unix_time, milemarker, occ/speed/volume).
    
    Returns:
        temporal_windows: list of windows; each window is a list[PyG Data] of length=timesteps
        kept_indices: indices into the unique unix_time array kept (useful for mapping labels)
    """
    static_edges = generate_edges(milemarkers=list(range(49)))   # same graph used by baselines
    unique_times = np.unique(df['unix_time'])
    sequence = []
    kept_indices = []

    for idx, t in enumerate(tqdm(unique_times)):
        if hide_anomalies:
            contains_anom = bool(np.any(df.loc[df['unix_time']==t, 'anomaly'].values))
            if contains_anom:
                continue

        kept_indices.append(idx)
        xt = df.loc[df['unix_time']==t].sort_values('milemarker')[['occ','speed','volume']].to_numpy()
        x_tensor = torch.tensor(xt, dtype=torch.float32)
        graph_t = Data(x=x_tensor, edge_index=static_edges)  
        sequence.append(graph_t)

    temporal_windows = temporalize_sequence(graph_sequence=sequence, timesteps=timesteps)
    return temporal_windows, kept_indices


### Loading Hyperparameters

In [None]:

stae_params = STAEParameters(
    num_features=3,
    latent_dim=64,
    gcn_hidden_dim=128,
    lstm_hidden_dim=64,      # not used by TCN, but present in dataclass
    lstm_num_layers=1,       # not used by TCN
    dropout=0.02165472020554443
)

training_params = TrainingParameters(
    learning_rate=0.0023750472284281726,
    batch_size=1,
    timesteps=4,
    n_epochs=5
)

mse_weights = [1,1,1]  

### Prepare Train Data

In [5]:
train_df, test_df, valid_days = get_full_data()
train_df = normalize_data(train_df)
train_df = label_anomalies(train_df)

train_sequence, kept_train = sequence_sttcn_ae(train_df, training_params.timesteps)
len(train_sequence), train_sequence[0][0] if len(train_sequence)>0 else None

100%|██████████| 13440/13440 [00:29<00:00, 453.24it/s]


(6417, Data(x=[196, 3], edge_index=[2, 1832]))

### Model Training or Loading

In [6]:
model, losses = train_sttcn_ae(stae_params, training_params, train_sequence, mse_weights=mse_weights, verbose=True)
print('Trained with', len(losses), 'steps. Last 10:', np.mean(losses[-10:]) if len(losses)>10 else np.mean(losses))

  0%|          | 0/5 [00:00<?, ?it/s][W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.




 20%|██        | 1/5 [06:38<26:32, 398.08s/it]

Epoch 0: last 100 avg loss 0.02576844731345773


 40%|████      | 2/5 [13:06<19:36, 392.30s/it]

Epoch 1: last 100 avg loss 0.024698820868507027


 60%|██████    | 3/5 [19:25<12:52, 386.14s/it]

Epoch 2: last 100 avg loss 0.02578314106911421


 80%|████████  | 4/5 [25:56<06:28, 388.03s/it]

Epoch 3: last 100 avg loss 0.025039336839690805


100%|██████████| 5/5 [32:34<00:00, 390.98s/it]

Epoch 4: last 100 avg loss 0.026594945825636387
Trained with 32085 steps. Last 10: 0.028189396020025015





### Computing Node-level Anomaly Threshold

In [7]:
thresh = compute_anomaly_threshold_sttcn_ae(train_sequence, model, mse_weights, method='max')
thresh[:5] if hasattr(thresh,'__len__') else thresh

100%|██████████| 6417/6417 [03:20<00:00, 32.04it/s]


array([0.13820803, 0.37218288, 0.29510638, 0.40703198, 0.1363681 ],
      dtype=float32)

### Loading Test Data

In [8]:
_, df_test_data, _ = get_full_data()

In [9]:
test_data = normalize_data(df_test_data)
test_data = label_anomalies(test_data)
test_data, kept_test_indices = sequence_sttcn_ae(test_data, training_params.timesteps, hide_anomalies=False)

100%|██████████| 4800/4800 [00:04<00:00, 1058.24it/s]


In [10]:
test_errors, test_recons_speeds, test_true_speeds = test_sttcn_ae(test_data, mse_weights, model, verbose=True)

100%|██████████| 4781/4781 [02:28<00:00, 32.22it/s]


In [11]:
test_errors.shape

(4781, 196)

In [12]:
test_sequence, kept_test = sequence_sttcn_ae(test_df, training_params.timesteps, hide_anomalies=False)
N = test_errors.shape[1]
W = test_errors.shape[0]
T = training_params.timesteps

anomaly_labels_all = generate_anomaly_labels(test_df, kept_test)

start = (T - 1) * N
stop  = start + W * N
anomaly_labels_windows = anomaly_labels_all[start:stop]

assert len(anomaly_labels_windows) == W * N, (len(anomaly_labels_windows), W*N)

auc = calculate_auc(test_errors, anomaly_labels_windows)
print("AUC:", auc)

  0%|          | 0/4800 [00:00<?, ?it/s]

100%|██████████| 4800/4800 [00:04<00:00, 1082.78it/s]


AUC: 0.6093723009703806


In [19]:
np.mean(test_errors)

0.015201215

In [None]:

W, N = test_errors.shape          # windows x nodes
T = training_params.timesteps

start_nodes = (T - 1) * N
stop_nodes  = start_nodes + W * N
anomaly_labels = generate_anomaly_labels(df_test_data, kept_test_indices)
anomaly_labels_win = anomaly_labels[start_nodes:stop_nodes]
assert anomaly_labels_win.shape[0] == W * N

crash_reported_all = df_test_data['crash_record'].to_numpy()[0::N][kept_test_indices]  
crash_reported_win = crash_reported_all[(T - 1):(T - 1 + W)]
assert crash_reported_win.shape[0] == W

delay_results = np.array(find_delays(thresh, test_errors, anomaly_labels_win, crash_reported_win))


100%|██████████| 1000/1000 [01:17<00:00, 12.94it/s]
100%|██████████| 98/98 [00:00<00:00, 103.89it/s]


In [16]:
discrete_fp_delays(thresh, test_errors, anomaly_labels_win, crash_reported_win)

100%|██████████| 1000/1000 [01:17<00:00, 12.92it/s]


Found FPR of 0.009807928075194115 for 0.01
Found FPR of 0.025337147527584796 for 0.025
Found FPR of 0.049856967715570084 for 0.05
Found FPR of 0.10012259910093993 for 0.1
Found FPR of 0.19493257049448304 for 0.2
FPR 1% gives mean delay of -4.055555555555555 +/- 9.334655990936987 while missing 0.25%.
FPR 2.5% gives mean delay of -6.833333333333333 +/- 8.797095480264431 while missing 0.25%.
FPR 5% gives mean delay of -9.444444444444445 +/- 6.512570939875709 while missing 0.25%.
FPR 10% gives mean delay of -8.318181818181818 +/- 6.81333160820063 while missing 0.08333333333333337%.
FPR 20% gives mean delay of -12.416666666666666 +/- 4.517712056143266 while missing 0.0%.
