In [None]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import warnings
import scipy.sparse as sp
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Layer, Dropout, Lambda, LayerNormalization, GlobalAveragePooling1D
from tensorflow.keras.layers import Add
from tensorflow.keras.models import Sequential
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error

warnings.filterwarnings('ignore', category=UserWarning)

# -------------------- USER PATHS: change if needed --------------------
TRAFFIC_H5 = 'data/METR-LA.h5'
ADJ_PKL = 'data/adj_METR-LA.pkl'

# ---------------------------------------------------------------------

# -------------------- USER SETTINGS --------------------
N_TIMESTEPS = 12         # Lookback window
FORECAST_HORIZON = 1     # Number of steps to predict
NUM_EPOCHS = 36         # Increased epochs for better training
BATCH_SIZE = 128
D_MODEL = 64             # Dimensionality of the model
NUM_HEADS = 4            # Number of attention heads in the Transformer
DFF = 128                # Dimensionality of the feed-forward network in Transformer
# -------------------------------------------------------

def load_and_clean_data(h5_path):
    """Loads and preprocesses the traffic data from an HDF5 file."""
    print("1) Loading traffic data from HDF5 and cleaning data...")
    try:
        traffic_df = pd.read_hdf(h5_path, 'df')
        if traffic_df.shape[1] > 1 and traffic_df.iloc[:, 0].dtype == 'object':
            traffic_df = traffic_df.iloc[:, 1:]
        traffic_df = traffic_df.apply(pd.to_numeric, errors='coerce')
        if traffic_df.isna().any().any():
            traffic_df = traffic_df.fillna(method='ffill').fillna(method='bfill')
        traffic_data = traffic_df.values.astype(np.float32)
        N_SAMPLES, N_SENSORS = traffic_data.shape
        print(f"Loaded traffic data with shape: {traffic_data.shape}")
        return traffic_data, N_SENSORS
    except FileNotFoundError:
        print(f"Error: The file at {h5_path} was not found.")
        return None, None
    except Exception as e:
        print(f"An error occurred: {e}")
        return None, None

def robust_adjacency_loader(adj_path, n_sensors, traffic_matrix):
    """
    Robustly loads the adjacency matrix and provides a specific error if it fails.
    """
    print("2) Attempting to load adjacency matrix from file...")
    
    # --- Step 1: Try to load the pickle file ---
    try:
        with open(adj_path, 'rb') as f:
            # The 'latin1' encoding is important for files created with older Python versions.
            adj_loaded = pickle.load(f, encoding='latin1')
            
    except Exception as e:
        # If loading fails for ANY reason, print the specific error.
        print(f"\n--- ERROR ---")
        print(f"An explicit error occurred while trying to load the pickle file: {e}")
        print("Moving to correlation fallback as a result.")
        print("---------------\n")
        
        # Since loading failed, go directly to the fallback.
        corr = np.corrcoef(traffic_matrix.T)
        np.fill_diagonal(corr, 0)
        adj = (np.abs(corr) >= 0.30).astype(np.float32)
        print(f"Generated adjacency with shape: {adj.shape}")
        return adj

    # --- Step 2: If loading succeeded, try to validate the contents ---
    print("Successfully loaded pickle file. Now validating contents...")
    
    def valid_mat(m):
        try:
            m = np.asarray(m, dtype=np.float32)
            if m.ndim == 2 and m.shape == (n_sensors, n_sensors):
                return m
        except Exception:
            return None
        return None

    # Try common formats within the loaded object
    if isinstance(adj_loaded, np.ndarray):
        m = valid_mat(adj_loaded)
        if m is not None:
            print("Found valid matrix in numpy.ndarray.")
            return m
            
    if sp.issparse(adj_loaded):
        m = valid_mat(adj_loaded.toarray())
        if m is not None:
            print("Found valid matrix in scipy.sparse object.")
            return m
            
    if isinstance(adj_loaded, dict):
        for key in ['adj', 'adj_mx', 'adj_matrix', 'adjacency', 'A']:
            if key in adj_loaded:
                m = valid_mat(adj_loaded[key])
                if m is not None:
                    print(f"Found valid matrix in dictionary with key: '{key}'.")
                    return m
                    
    if isinstance(adj_loaded, (list, tuple)):
        for item in adj_loaded:
            m = valid_mat(item)
            if m is not None:
                print("Found valid matrix as an item in a list/tuple.")
                return m

    #  3: If validation fails, use the fallback
    print("\n--- WARNING ---")
    print("Successfully loaded pickle file, but could not find a valid matrix inside.")
    print("Moving to correlation fallback.")
    print("---------------\n")
    
    corr = np.corrcoef(traffic_matrix.T)
    np.fill_diagonal(corr, 0)
    adj = (np.abs(corr) >= 0.30).astype(np.float32)
    print(f"Generated adjacency with shape: {adj.shape}")
    return adj

def create_spatiotemporal_sequences(data, adj_matrix, n_steps, horizon):
    """Creates sequences for the Spatiotemporal Transformer."""
    print("3) Creating spatiotemporal sequences...")
    X_traffic, y = [], []
    for i in range(len(data) - n_steps - horizon + 1):
        X_traffic.append(data[i : i + n_steps, :])
        y.append(data[i + n_steps : i + n_steps + horizon, :])
    X_adj = np.tile(adj_matrix, (len(X_traffic), 1, 1))
    X_traffic = np.array(X_traffic)
    y = np.array(y).reshape(-1, data.shape[1] * horizon)
    print(f"X_traffic shape: {X_traffic.shape}, X_adj shape: {X_adj.shape}, y shape: {y.shape}")
    return X_traffic, X_adj, y

#  Custom Transformer Layers (from your second script)
class MultiHeadAttention(Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        self.wq = Dense(d_model)
        self.wk = Dense(d_model)
        self.wv = Dense(d_model)
        self.dense = Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]
        q = self.split_heads(self.wq(q), batch_size)
        k = self.split_heads(self.wk(k), batch_size)
        v = self.split_heads(self.wv(v), batch_size)
        
        matmul_qk = tf.matmul(q, k, transpose_b=True)
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
        
        if mask is not None:
            scaled_attention_logits += mask
            
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        output = tf.matmul(attention_weights, v)
        output = tf.transpose(output, perm=[0, 2, 1, 3])
        output = tf.reshape(output, (batch_size, -1, self.d_model))
        return self.dense(output), attention_weights

class TransformerBlock(Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = Sequential([
            Dense(dff, activation='relu'),
            Dense(d_model)
        ])
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
    
    def call(self, x, training, mask):
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)
        
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)
        
        return out2

# --- Build and Train Combined Model ---
def build_and_train_graph_transformer(X_train, y_train, adj_mx, n_sensors, n_timesteps, horizon):
    """
    Builds and trains the combined Spatiotemporal Graph Transformer model.
    """
    print("\n4) Building and training Spatiotemporal Graph Transformer model...")
    
    # Define the two inputs for the model
    traffic_input = Input(shape=(n_timesteps, n_sensors), name='traffic_input')
    adj_input = Input(shape=(n_sensors, n_sensors), name='adj_input')

    # --- Step 1: Spatial Encoding (GNN Layer) ---
    # This layer uses the adjacency matrix to mix neighbor information.
    # It performs the graph convolution from your first script.
    graph_conv_output = Lambda(lambda x: tf.einsum('ijk,ikl->ijl', x[0], x[1]),
                               output_shape=(n_timesteps, n_sensors),
                               name='graph_convolution')([traffic_input, adj_input])
    
    # A Dense layer projects the raw sensor data into the model's main dimension (D_MODEL)
    # This allows the model to learn a richer representation for each sensor.
    projected_input = Dense(D_MODEL, name='input_projection')(graph_conv_output)

    # --- Step 2: Temporal Encoding (Transformer Block) ---
    # The spatially-aware data is now fed into the Transformer to learn temporal patterns.
    transformer_block = TransformerBlock(D_MODEL, NUM_HEADS, DFF)
    transformer_output = transformer_block(projected_input, training=True, mask=None)

    # --- Step 3: Final Prediction Head ---
    # We use GlobalAveragePooling1D to aggregate the features over the time dimension.
    # This creates a single feature vector that summarizes the entire time sequence.
    pooled_output = GlobalAveragePooling1D()(transformer_output)
    
    # A final Dense layer makes the forecast for all sensors.
    output_layer = Dense(n_sensors * horizon, activation='relu', name='output_layer')(pooled_output)

    # Create the complete model
    model = Model(inputs=[traffic_input, adj_input], outputs=output_layer)
    model.compile(optimizer='adam', loss='mse')
    
    # Print a summary of the model architecture
    model.summary()
    
    # Train the model
    model.fit([X_train, adj_mx], y_train, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=1)
    
    return model

# --- Main Pipeline Execution ---
if __name__ == "__main__":
    traffic_data, N_SENSORS = load_and_clean_data(TRAFFIC_H5)
    if traffic_data is None:
        exit()

    adj_mx = robust_adjacency_loader(ADJ_PKL, N_SENSORS, traffic_data)
    
    scaler = MinMaxScaler(feature_range=(0, 1))
    scaled_data = scaler.fit_transform(traffic_data)
    X_traffic, X_adj, y = create_spatiotemporal_sequences(scaled_data, adj_mx, N_TIMESTEPS, FORECAST_HORIZON)

    # Use a sequential split for time series data to prevent data leakage
    print("5) Using the full dataset for final training...")
    
    # Note: We pass the full X_adj here, not a split version
    model = build_and_train_graph_transformer(X_traffic, y, X_adj, N_SENSORS, N_TIMESTEPS, FORECAST_HORIZON)
    
    print("✅ Final model trained on all data. Saving...")
    model.save('final_graph_transformer_model.h5')
        
    # Visualization
   

2025-09-26 12:05:12.183356: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758888312.556627      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758888312.660891      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


1) Loading traffic data from HDF5 and cleaning data...
Loaded traffic data with shape: (34272, 207)
2) Attempting to load adjacency matrix from file...
Successfully loaded pickle file. Now validating contents...
Found valid matrix as an item in a list/tuple.
3) Creating spatiotemporal sequences...
X_traffic shape: (34260, 12, 207), X_adj shape: (34260, 207, 207), y shape: (34260, 207)
5) Using the full dataset for final training...

4) Building and training Spatiotemporal Graph Transformer model...


I0000 00:00:1758888333.687495      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1758888333.688359      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


Epoch 1/36


I0000 00:00:1758888357.238404      68 service.cc:148] XLA service 0x7eb048108440 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1758888357.240268      68 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1758888357.240291      68 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1758888357.854421      68 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m 13/268[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m3s[0m 14ms/step - loss: 0.3670

I0000 00:00:1758888361.534098      68 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m268/268[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 26ms/step - loss: 0.1703
Epoch 2/36
[1m268/268[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 13ms/step - loss: 0.0820
Epoch 3/36
[1m268/268[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 12ms/step - loss: 0.0687
Epoch 4/36
[1m268/268[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 13ms/step - loss: 0.0522
Epoch 5/36
[1m268/268[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 13ms/step - loss: 0.0469
Epoch 6/36
[1m268/268[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 13ms/step - loss: 0.0428
Epoch 7/36
[1m268/268[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 13ms/step - loss: 0.0410
Epoch 8/36
[1m268/268[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 13ms/step - loss: 0.0410
Epoch 9/36
[1m268/268[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 13ms/step - loss: 0.0375
Epoch 10/36
[1m268/268[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 13ms/step - lo