idea: fine tune on test scenario before doing prediction 

In [1]:
import tensorflow as tf
tf.config.list_physical_devices('GPU')



[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [2]:
# load data
import numpy as np
train_file = np.load('data/train.npz')
train_data = train_file['data']
print("train_data's shape", train_data.shape)
test_file = np.load('data/test_input.npz')
test_data = test_file['data']
print("test_data's shape", test_data.shape)


train_data's shape (10000, 50, 110, 6)
test_data's shape (2100, 50, 50, 6)


In [3]:
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense, Input, RepeatVector, TimeDistributed, Dropout

In [17]:
import pickle

def save_model(model, filepath='lstm_2.pkl'):
    """Save model and scaler together in a pickle file"""
    model_json = model.to_json()
    model_weights = model.get_weights()
    data = {
        'model_json': model_json,
        'model_weights': model_weights,
    }
    with open(filepath, 'wb') as f:
        pickle.dump(data, f)
    print(f"Model saved to {filepath}")

def load_model(filepath='lstm_2.pkl'):
    """Load model and scaler from pickle file"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    
    # Reconstruct model
    model = tf.keras.models.model_from_json(data['model_json'])
    model.set_weights(data['model_weights'])
    model.compile(optimizer='adam', loss='mse')
    
    return model

In [37]:
from tensorflow.keras.layers import (
    Input, LSTM, Dense, Dropout, RepeatVector, TimeDistributed, 
    Concatenate, Activation, Dot, Layer, BatchNormalization, 
    LayerNormalization, Add
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import Orthogonal, GlorotUniform
from tensorflow.keras.regularizers import l2
import tensorflow.keras.backend as K

class ScaleLayer(Layer):
    def __init__(self, scale_factor, **kwargs):
        super().__init__(**kwargs)
        self.scale_factor = scale_factor
    
    def call(self, x):
        return x / self.scale_factor

class MaxSubtractLayer(Layer):
    def call(self, x):
        return x - K.max(x, axis=-1, keepdims=True)

def create_lstm_encoder_decoder(input_dim, output_dim, timesteps_in, timesteps_out, 
                               lstm_units=128, num_layers=2, loss_fn='mse', lr=0.001):
    
    # GRADIENT MITIGATION STRATEGY 1: Keep proper weight initialization (critical for gradients)
    lstm_init = Orthogonal(gain=1.0)
    dense_init = GlorotUniform()
    
    # Encoder with minimal regularization for overfitting
    encoder_inputs = Input(shape=(timesteps_in, input_dim))
    x = encoder_inputs
    
    # REMOVED: Layer normalization for overfitting experiment
    # x = LayerNormalization()(x)
    
    encoder_states = []
    for i in range(num_layers):
        # STRATEGY 3: Keep residual connections (help gradients, don't prevent overfitting)
        if i > 0 and x.shape[-1] == lstm_units:
            residual = x
        else:
            residual = None
            
        lstm_out = LSTM(
            lstm_units, 
            return_sequences=True,
            recurrent_initializer=lstm_init,
            kernel_initializer=lstm_init,
            # REMOVED: All dropout for overfitting
            # recurrent_dropout=0.1,
            # dropout=0.1,
            # REMOVED: L2 regularization for overfitting
            # kernel_regularizer=l2(1e-4),
            # recurrent_regularizer=l2(1e-4)
        )(x)
        
        # REMOVED: Layer normalization for overfitting
        # lstm_out = LayerNormalization()(lstm_out)
        
        # Keep residual connection (helps gradients)
        if residual is not None:
            x = Add()([lstm_out, residual])
        else:
            x = lstm_out
            
        encoder_states.append(x)
    
    encoder_outputs = x  # (batch, timesteps_in, lstm_units)
    
    # Get final state for decoder initialization  
    encoder_state = LSTM(
        lstm_units, 
        return_state=True,
        recurrent_initializer=lstm_init,
        kernel_initializer=lstm_init,
        # REMOVED: Regularization for overfitting
        # kernel_regularizer=l2(1e-4),
        # recurrent_regularizer=l2(1e-4)
    )(encoder_outputs)
    _, state_h, state_c = encoder_state

    # Decoder with minimal regularization
    decoder_input = RepeatVector(timesteps_out)(state_h)
    
    decoder_outputs, _, _ = LSTM(
        lstm_units, 
        return_sequences=True, 
        return_state=True,
        recurrent_initializer=lstm_init,
        kernel_initializer=lstm_init,
        # REMOVED: All dropout and regularization
        # recurrent_dropout=0.1,
        # dropout=0.1,
        # kernel_regularizer=l2(1e-4),
        # recurrent_regularizer=l2(1e-4)
    )(decoder_input, initial_state=[state_h, state_c])
    
    # REMOVED: Layer normalization
    # decoder_outputs = LayerNormalization()(decoder_outputs)
    
    # KEEP: Scaled attention (essential for numerical stability, not regularization)
    scale_factor = (lstm_units ** 0.5)
    attention_scores = Dot(axes=[2, 2])([decoder_outputs, encoder_outputs])
    attention_scores = ScaleLayer(scale_factor)(attention_scores)
    
    # KEEP: Temperature scaling and max subtraction (numerical stability, not regularization)
    attention_scores = ScaleLayer(2.0)(attention_scores)  # Temperature scaling
    attention_scores = MaxSubtractLayer()(attention_scores)
    attention_weights = Activation('softmax')(attention_scores)
    
    # Apply attention weights
    attention_context = Dot(axes=[2, 1])([attention_weights, encoder_outputs])
    
    # Combine context with decoder outputs
    combined = Concatenate()([attention_context, decoder_outputs])
    # REMOVED: Layer normalization
    # combined = LayerNormalization()(combined)
    
    # Output layers without regularization
    x = TimeDistributed(Dense(
        256, 
        activation='relu',
        kernel_initializer=dense_init,
        # REMOVED: L2 regularization
        # kernel_regularizer=l2(1e-4)
    ))(decoder_outputs)
    # REMOVED: Normalization and dropout
    # x = TimeDistributed(LayerNormalization())(x)
    # x = TimeDistributed(Dropout(0.1))(x)
    
    x = TimeDistributed(Dense(
        64, 
        activation='relu',
        kernel_initializer=dense_init,
        # REMOVED: L2 regularization
        # kernel_regularizer=l2(1e-4)
    ))(x)
    # REMOVED: Normalization and dropout
    # x = TimeDistributed(LayerNormalization())(x)
    # x = TimeDistributed(Dropout(0.1))(x)
    
    # Final output layer
    outputs = TimeDistributed(Dense(
        output_dim, 
        activation='linear',
        kernel_initializer=dense_init
    ))(x)

    # KEEP: Aggressive gradient clipping (essential for gradient stability)
    model = Model(encoder_inputs, outputs)
    model.compile(
        optimizer=Adam(
            learning_rate=lr,
            clipnorm=0.1,       # Even more aggressive clipping for stability without normalization
            beta_1=0.9,         
            beta_2=0.999,       
            epsilon=1e-7        
        ),
        loss=loss_fn,
        metrics=['mae']
    )

    return model

In [19]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import Callback

class GradientMonitoringCallback(Callback):
    def __init__(self, clip_min=1e-4, clip_max=1e2, monitor_frequency=3):
        """
        Monitor gradient norms during training
        
        Args:
            clip_min: Minimum threshold for gradient norms
            clip_max: Maximum threshold for gradient norms  
            monitor_frequency: How often to check gradients (every N batches)
        """
        print(f"🔧 GradientMonitoringCallback initialized with clip_min={clip_min}, clip_max={clip_max}, monitor_freq={monitor_frequency}")
        self.clip_min = clip_min
        self.clip_max = clip_max
        self.monitor_frequency = monitor_frequency
        self.batch_count = 0
        self.total_calls = 0
        self.gradient_checks = 0
        self.fallback_calls = 0
        
    def on_train_begin(self, logs=None):
        print("🚀 GradientMonitoringCallback: Training started!")
        self.batch_count = 0
        self.total_calls = 0
        self.gradient_checks = 0
        self.fallback_calls = 0
        
    # def on_epoch_begin(self, epoch, logs=None):
    #     print(f"📍 GradientMonitoringCallback: Starting epoch {epoch + 1}")
        
    # def on_train_batch_begin(self, batch, logs=None):
    #     # Just to prove we're being called
    #     if batch % 50 == 0:  # Print every 50 batches to avoid spam
    #         print(f"⚡ GradientMonitoringCallback: Batch {batch} starting")
        
    def on_train_batch_end(self, batch, logs=None):
        self.batch_count += 1
        self.total_calls += 1
        
        # Print every time to show we're being called
        # if batch % 50 == 0:  # Print every 50 batches
            # print(f"📊 GradientMonitoringCallback: Batch {batch} ended (total calls: {self.total_calls})")
        
        # Only monitor every N batches to avoid performance overhead
        if self.batch_count % self.monitor_frequency != 0:
            return
            
        # print(f"🔍 GradientMonitoringCallback: Checking gradients at batch {batch} (check #{self.gradient_checks + 1})")
        
        # Get gradients from the optimizer's current state
        try:
            # Access the model's optimizer to get gradient information
            optimizer = self.model.optimizer
            print(f"   📋 Optimizer type: {type(optimizer).__name__}")
            
            # Get trainable variables
            trainable_vars = self.model.trainable_variables
            print(f"   📈 Number of trainable variables: {len(trainable_vars)}")
            
            if hasattr(optimizer, 'get_gradients'):
                print("   ✅ Optimizer has get_gradients method")
                # For some optimizers, we can access gradients directly
                grads = optimizer.get_gradients(self.model.total_loss, trainable_vars)
                print(f"   📊 Retrieved {len([g for g in grads if g is not None])} gradients")
            else:
                print("   ❌ Optimizer doesn't have get_gradients, using variable norms")
                # Alternative approach: check the current variable states
                grad_norms = []
                for i, var in enumerate(trainable_vars):
                    if var is not None:
                        var_norm = tf.norm(var)
                        grad_norms.append(var_norm)
                        if i < 3:  # Print first 3 for debugging
                            print(f"      Variable {i} norm: {float(var_norm.numpy()):.2e}")
                
                self._check_norms(grad_norms, "Variable")
                self.gradient_checks += 1
                return
                
            # Compute gradient norms
            grad_norms = []
            for i, grad in enumerate(grads):
                if grad is not None:
                    grad_norm = tf.norm(grad)
                    grad_norms.append(grad_norm)
                    if i < 3:  # Print first 3 for debugging
                        print(f"      Gradient {i} norm: {float(grad_norm.numpy()):.2e}")
                    
            print(f"   ✅ Computed {len(grad_norms)} gradient norms")
            self._check_norms(grad_norms, "Gradient")
            self.gradient_checks += 1
            
        except Exception as e:
            print(f"   ❌ Exception in gradient monitoring: {str(e)}")
            self.fallback_calls += 1
            # Fallback: just monitor the loss for signs of instability
            print('   🔄 Fallback: monitoring loss only')
            if logs:
                loss_value = logs.get('loss', 0)
                print(f"   📉 Current loss: {loss_value:.2e}")
                if np.isnan(loss_value) or np.isinf(loss_value):
                    print(f"   ⚠️  WARNING: Loss became {loss_value} at batch {batch}")
                elif loss_value > 1e6:
                    print(f"   ⚠️  WARNING: Very large loss {loss_value:.2e} at batch {batch}")
    
    def _check_norms(self, norms, norm_type="Gradient"):
        """Check if norms are within acceptable range"""
        print(f"   🔬 Checking {len(norms)} {norm_type.lower()} norms...")
        warnings = 0
        
        for idx, norm in enumerate(norms):
            try:
                norm_value = float(norm.numpy()) if hasattr(norm, 'numpy') else float(norm)
                
                if norm_value > self.clip_max:
                    print(f"   ⚠️  WARNING: {norm_type} norm {norm_value:.2e} is too large (layer {idx})")
                    warnings += 1
                elif norm_value < self.clip_min:
                    print(f"   ⚠️  WARNING: {norm_type} norm {norm_value:.2e} is too small (layer {idx})")
                    warnings += 1
                elif np.isnan(norm_value) or np.isinf(norm_value):
                    print(f"   ⚠️  WARNING: {norm_type} norm is {norm_value} (layer {idx})")
                    warnings += 1
                    
            except Exception as e:
                print(f"   ❌ Cannot convert norm to float for layer {idx}: {str(e)}")
                continue
                
        if warnings == 0:
            print(f"   ✅ All {norm_type.lower()} norms are within acceptable range")
        else:
            print(f"   ⚠️  Found {warnings} norm warnings")
    
    # def on_epoch_end(self, epoch, logs=None):
    #     """Print summary at end of each epoch"""
    #     print(f"📈 GradientMonitoringCallback: Epoch {epoch + 1} completed")
    #     print(f"   📊 Total batch calls: {self.total_calls}")
    #     print(f"   🔍 Gradient checks performed: {self.gradient_checks}")
    #     print(f"   🔄 Fallback calls: {self.fallback_calls}")
    #     
    #     if logs:
    #         loss = logs.get('loss', 0)
    #         val_loss = logs.get('val_loss', 0)
    #         print(f"   📉 Final epoch loss: {loss:.2e}")
    #         if val_loss:
    #             print(f"   📉 Final epoch val_loss: {val_loss:.2e}")
    #         
    #         if np.isnan(loss) or np.isinf(loss):
    #             print(f"   ⚠️  WARNING: Training loss became unstable: {loss}")
    #         if val_loss and (np.isnan(val_loss) or np.isinf(val_loss)):
    #             print(f"   ⚠️  WARNING: Validation loss became unstable: {val_loss}")
        
    def on_train_end(self, logs=None):
        print("🏁 GradientMonitoringCallback: Training completed!")
        print(f"   📊 Final stats - Total calls: {self.total_calls}, Gradient checks: {self.gradient_checks}, Fallbacks: {self.fallback_calls}")
        
        if self.total_calls == 0:
            print("   ❌ ERROR: Callback was never called! Check if it's properly added to callbacks list.")
        elif self.gradient_checks == 0 and self.fallback_calls == 0:
            print("   ⚠️  WARNING: No gradient monitoring was performed. Check monitor_frequency setting.")
        else:
            print("   ✅ Gradient monitoring completed successfully!")

In [20]:
from keras.src.callbacks import LearningRateScheduler, EarlyStopping, Callback
from keras.src.optimizers import Adam
from keras import Model
import numpy as np


def exponential_decay_schedule(epoch, lr):
    decay_rate = 0.9
    decay_steps = 5
    if epoch % decay_steps == 0 and epoch:
        print('Learning rate update:', lr * decay_rate)
        return lr * decay_rate
    return lr


# Custom callback to monitor LR and stop training
class LRThresholdCallback(Callback):
    def __init__(self, threshold=9e-5):
        super().__init__()
        self.threshold = threshold
        self.should_stop = False

    def on_epoch_end(self, epoch, logs=None):
        lr = float(self.model.optimizer.learning_rate.numpy())
        if lr < self.threshold:
            print(f"\nLearning rate {lr:.6f} < threshold {self.threshold}, moving to next phase.")
            self.model.stop_training = True



In [21]:
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.models import save_model

class SaveBestModelCallback(Callback):
    def __init__(self, save_path='best_model.keras', monitor='val_loss'):
        super().__init__()
        self.best = float('inf')
        self.monitor = monitor
        self.save_path = save_path

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get(self.monitor)
        if current is not None and current < self.best:
            self.best = current
            print(f"\nNew best {self.monitor}: {current:.6f}. Saving model...")
            save_model(self.model, 'lstm_2.pkl')


In [36]:
def train_model(train_data, batch_size=32, validation_split=0.2, Tobs=50, Tpred=60, epochs1=50, epochs2=50):
    n_scenarios = train_data.shape[0]
    n_agents = train_data.shape[1]
    X_train_raw = []
    y_train_deltas = []

    # Counters for pruning reasons
    pruned_zero_frame = 0
    pruned_observed_or_future_zero = 0
    total_agents = n_scenarios * n_agents
    
    for i in range(n_scenarios):
        for agent_id in range(n_agents):
            agent_data = train_data[i, agent_id, :, :]  # shape (110, 6)
    
            # Skip if any time step in the observation+prediction window is all zeros
            if np.any(np.all(agent_data[:Tobs + Tpred] == 0, axis=1)):
                pruned_zero_frame += 1
                continue
    
            observed = agent_data[:Tobs]         # shape (Tobs, 6)
            future = agent_data[Tobs:Tobs + Tpred, :2]  # only position_x, position_y
            last_obs_pos = observed[-1, :2]
    
            # Skip if observed or future window contains any full-zero frame
            if np.any(np.all(observed == 0, axis=1)) or np.any(np.all(future == 0, axis=1)):
                pruned_observed_or_future_zero += 1
                continue
    
            # Compute deltas
            delta = np.diff(np.vstack([last_obs_pos, future]), axis=0)  # shape (Tpred, 2)
    
            X_train_raw.append(observed)
            y_train_deltas.append(delta)
    
    # Print pruning summary
    print(f"Total agents: {total_agents}")
    print(f"Pruned due to zero frame in Tobs+Tpred: {pruned_zero_frame}")
    print(f"Pruned due to zero frame in observed or future window: {pruned_observed_or_future_zero}")
    print(f"Remaining valid agents: {len(X_train_raw)}")
    
    
    X_train = np.array(X_train_raw)     # shape (N_valid, Tobs, 6)
    y_train = np.array(y_train_deltas)  # shape (N_valid, Tpred, 2)
    
    
    print(f"ex. y_train {y_train[0]}")


    print(f"Training on {X_train.shape[0]} valid agent trajectories.")
    print(f"Input shape: {X_train.shape}, Delta Output shape: {y_train.shape}")
    
    # --- Normalize Input and Output ---
    X_mean = X_train.mean(axis=(0, 1), keepdims=True)  # shape: (1, 1, 6)
    X_std = X_train.std(axis=(0, 1), keepdims=True) + 1e-8

    y_mean = y_train.mean(axis=(0, 1), keepdims=True)  # shape: (1, 1, 2)
    y_std = y_train.std(axis=(0, 1), keepdims=True) + 1e-8

    X_std = np.where(X_std < 1e-6, 1.0, X_std)
    y_std = np.where(y_std < 1e-6, 1.0, y_std)

    X_train = (X_train - X_mean) / X_std
    y_train = (y_train - y_mean) / y_std 
    
    print("X_train NaNs:", np.isnan(X_train).sum())
    print("y_train NaNs:", np.isnan(y_train).sum())

    print("Any std == 0?", np.any(X_std == 0), np.any(y_std == 0))
    
    X_mean, X_std, y_mean, y_std = None, None, None, None
    
    # print(X_train[:2])
    # print(y_train[:2])
    
    model = create_lstm_encoder_decoder(
        input_dim=X_train.shape[-1],
        output_dim=2,
        timesteps_in=Tobs,
        timesteps_out=Tpred,
        loss_fn='mse',
        lr=0.001
    )
    
    gradient_monitoring_callback = GradientMonitoringCallback(clip_min=1e-4, clip_max=1e2)
    
    save_best_callback = SaveBestModelCallback(save_path='lstm2', monitor='val_loss')



    phase1_callbacks = [
        # LearningRateScheduler(exponential_decay_schedule),
        # EarlyStopping(patience=4, restore_best_weights=True, monitor='val_loss'),
        LRThresholdCallback(threshold=9e-5),
        gradient_monitoring_callback,
        save_best_callback
    ]

    print("\n--- Phase 1: Training ---")
    model.fit(
        X_train, y_train,
        epochs=epochs1,
        batch_size=batch_size,
        validation_split=validation_split,
        callbacks=phase1_callbacks,
        verbose=1
    )

    print("\n--- Phase 2: Fine-tuning ---")
    model.compile(
        optimizer=Adam(
            learning_rate=1e-4,
            clipnorm=0.25,      # More aggressive clipping
            # clipvalue=0.5,      # Also clip individual gradients
            beta_1=0.9,         # Standard momentum
            beta_2=0.999,       # Standard RMSprop decay
            epsilon=1e-7        # Smaller epsilon for stability
        ),
        loss='mse',
        metrics=['mae']
    )
    phase2_callbacks = [
        # LearningRateScheduler(exponential_decay_schedule),
        # EarlyStopping(patience=3, restore_best_weights=True, monitor='val_loss'), 
        gradient_monitoring_callback
    ]
    
    model.fit(
        X_train, y_train,
        epochs=epochs2,
        batch_size=batch_size,
        validation_split=validation_split,
        callbacks=phase2_callbacks,
        verbose=1
    )
    
    print(f"X_mean:{X_mean}, X_std:{X_std}, y_mean:{y_mean}, y_std:{y_std}")

    # Return model and normalization parameters
    return model, X_mean, X_std, y_mean, y_std

In [9]:
def plot_mae_by_timestep(y_true, y_pred):
    """
    Visualize MAE across timesteps in the prediction horizon.
    
    Args:
        y_true (np.ndarray): shape (N, Tpred, 2)
        y_pred (np.ndarray): shape (N, Tpred, 2)
    """
    mae_per_timestep = np.mean(np.abs(y_true - y_pred), axis=(0, 2))  # shape (Tpred,)
    
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10, 4))
    plt.plot(mae_per_timestep, label='MAE per Timestep')
    plt.xlabel('Timestep')
    plt.ylabel('MAE (meters)')
    plt.title('Mean Absolute Error Over Prediction Horizon')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()


In [10]:
def reconstruct_absolute_positions(pred_deltas, last_observed_positions):
    """
    Reconstruct absolute predicted positions by adding deltas to the last observed position.

    Args:
        pred_deltas: np.ndarray of shape (N, Tpred, 2)
        last_observed_positions: np.ndarray of shape (N, 2)

    Returns:
        np.ndarray of shape (N, Tpred, 2)
    """
    return last_observed_positions[:, None, :] + np.cumsum(pred_deltas, axis=1)

In [11]:
# # figure out stats
# def stats():
#     n_scenarios = train_data.shape[0]
#     X_train_raw = []
#     y_train_deltas = []
#     
#     for i in range(n_scenarios):
#         ego_data = train_data[i, 0, :, :]
#         if np.all(ego_data == 0):
#             continue
#     
#         observed = ego_data[:Tobs]            # shape (50, 6)
#         future = ego_data[Tobs:Tobs+Tpred, :2]
#         last_obs_pos = observed[-1, :2]
#     
#         if np.any(np.all(observed == 0, axis=1)) or np.any(np.all(future == 0, axis=1)):
#             continue
#     
#         # Compute deltas w.r.t. previous future timestep
#         delta = np.diff(np.vstack([last_obs_pos, future]), axis=0)  # (60, 2)
#     
#         X_train_raw.append(observed)
#         y_train_deltas.append(delta)
#     
#     
#     X_train = np.array(X_train_raw)
#     y_train = np.array(y_train_deltas)
#     
#     print(f"{X_train.shape[0]} valid sequences.")
#     print(f"Input shape: {X_train.shape}, Delta Output shape: {y_train.shape}")
#     
#     # --- Normalize Input and Output ---
#     X_mean = X_train.mean(axis=(0, 1), keepdims=True)  # shape: (1, 1, 6)
#     X_std = X_train.std(axis=(0, 1), keepdims=True) + 1e-8
#     
#     y_mean = y_train.mean(axis=(0, 1), keepdims=True)  # shape: (1, 1, 2)
#     y_std = y_train.std(axis=(0, 1), keepdims=True) + 1e-8
#     
#     X_train = (X_train - X_mean) / X_std
#     y_train = (y_train - y_mean) / y_std
#     return X_mean, X_std, y_mean, y_std
# X_mean, X_std, y_mean, y_std = stats()

In [12]:
def forecast_positions(scenario_data, Tobs, Tpred, model, X_mean=None, X_std=None, y_mean=None, y_std=None):
    """
    Use LSTM model to forecast future deltas and reconstruct absolute positions.
    Applies normalization only if statistics are provided.

    Args:
        scenario_data (numpy.ndarray): Shape (agents, time_steps, dimensions)
        Tobs (int): Number of observed time steps
        Tpred (int): Number of future time steps to predict
        model (Model): Trained LSTM model
        X_mean, X_std: Optional normalization stats for input
        y_mean, y_std: Optional normalization stats for output

    Returns:
        numpy.ndarray: Predicted absolute positions of shape (agents, Tpred, 2)
    """
    agents, _, _ = scenario_data.shape
    predicted_positions = np.zeros((agents, Tpred, 2))

    for agent_idx in range(agents):
        agent_data = scenario_data[agent_idx, :Tobs, :]  # shape (Tobs, 6)

        # Skip if fully padded
        if np.all(agent_data == 0):
            continue

        X_pred = np.expand_dims(agent_data, axis=0)  # shape (1, Tobs, 6)

        # Normalize if stats are provided
        if X_mean is not None and X_std is not None:
            X_pred = (X_pred - X_mean) / X_std

        # Predict deltas (normalized or raw)
        pred_deltas = model.predict(X_pred, verbose=0)  # shape (1, Tpred, 2)
        
        print("pred deltas")
        print(pred_deltas[:,:5])

        # Denormalize if stats are provided
        if y_mean is not None and y_std is not None:
            pred_deltas = pred_deltas * y_std + y_mean

        # Reconstruct absolute positions
        last_pos = agent_data[Tobs - 1, :2]  # shape (2,)
        abs_positions = reconstruct_absolute_positions(
            pred_deltas=pred_deltas,
            last_observed_positions=np.expand_dims(last_pos, axis=0)
        )[0]

        predicted_positions[agent_idx] = abs_positions

    return predicted_positions

In [13]:
def finetune_forecast_positions(scenario_data, Tobs, Tpred, model, 
                                 X_mean=None, X_std=None, y_mean=None, y_std=None, 
                                 epochs=3, lr=1e-4):
    """
    Fine-tune on valid agents from a scenario, then forecast their future positions.

    Returns:
        numpy.ndarray: Predicted absolute positions of shape (agents, Tpred, 2)
    """
    import copy
    from tensorflow.keras.optimizers import Adam

    agents, total_steps, _ = scenario_data.shape
    assert total_steps >= Tobs + Tpred, "Not enough time steps for observation + prediction"

    # Prepare fine-tuning data
    X_finetune = []
    y_finetune = []

    for agent_idx in range(agents):
        agent_traj = scenario_data[agent_idx, :, :]  # shape (time_steps, 6)
        
        segment = agent_traj[:Tobs + Tpred]
        if np.any(np.all(segment == 0, axis=1)):
            continue

        observed = segment[:Tobs]            # (Tobs, 6)
        future = segment[Tobs:Tobs+Tpred, :2]
        last_obs_pos = observed[-1, :2]

        if np.any(np.all(observed == 0, axis=1)) or np.any(np.all(future == 0, axis=1)):
            continue

        delta = np.diff(np.vstack([last_obs_pos, future]), axis=0)  # (Tpred, 2)

        X_finetune.append(observed)
        y_finetune.append(delta)

    if len(X_finetune) == 0:
        print("No valid agents found for fine-tuning.")
        return np.zeros((agents, Tpred, 2))

    X_finetune = np.array(X_finetune)
    y_finetune = np.array(y_finetune)

    # Normalize if stats provided
    if X_mean is not None and X_std is not None:
        X_finetune = (X_finetune - X_mean) / X_std

    if y_mean is not None and y_std is not None:
        y_finetune = (y_finetune - y_mean) / y_std

    # Clone and compile the model to avoid modifying the original
    model_finetune = copy.deepcopy(model)
    model_finetune.compile(optimizer=Adam(learning_rate=lr), loss='mse') #todo: fix optimizer 

    # Fine-tune
    model_finetune.fit(X_finetune, y_finetune, epochs=epochs, verbose=0)

    # Predict for each agent
    predicted_positions = np.zeros((agents, Tpred, 2))

    for agent_idx in range(agents):
        agent_obs = scenario_data[agent_idx, :Tobs, :]

        if np.any(np.all(agent_obs == 0, axis=1)):
            continue

        X_pred = np.expand_dims(agent_obs, axis=0)

        if X_mean is not None and X_std is not None:
            X_pred = (X_pred - X_mean) / X_std

        pred_deltas = model_finetune.predict(X_pred, verbose=0)

        if y_mean is not None and y_std is not None:
            pred_deltas = pred_deltas * y_std + y_mean

        last_pos = agent_obs[Tobs - 1, :2]
        abs_positions = reconstruct_absolute_positions(
            pred_deltas=pred_deltas,
            last_observed_positions=np.expand_dims(last_pos, axis=0)
        )[0]

        predicted_positions[agent_idx] = abs_positions

    return predicted_positions

In [14]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation

def make_gif(data_matrix1, data_matrix2, name='comparison'):
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.animation as animation

    cmap1 = plt.cm.get_cmap('viridis', 50)
    cmap2 = plt.cm.get_cmap('plasma', 50)

    assert data_matrix1.shape[1] == data_matrix2.shape[1], "Both matrices must have same number of timesteps"
    timesteps = data_matrix1.shape[1]

    fig, axes = plt.subplots(1, 2, figsize=(18, 9))
    ax1, ax2 = axes

    def update(frame):
        for ax in axes:
            ax.clear()

        for i in range(data_matrix1.shape[0]):
            for (data_matrix, ax, cmap) in [(data_matrix1, ax1, cmap1), (data_matrix2, ax2, cmap2)]:
                x = data_matrix[i, frame, 0]
                y = data_matrix[i, frame, 1]
                if x != 0 and y != 0:
                    xs = data_matrix[i, :frame+1, 0]
                    ys = data_matrix[i, :frame+1, 1]
                    mask = (xs != 0) & (ys != 0)
                    xs = xs[mask]
                    ys = ys[mask]
                    if len(xs) > 0 and len(ys) > 0:
                        color = cmap(i)
                        ax.plot(xs, ys, alpha=0.9, color=color)
                        ax.scatter(x, y, s=80, color=color)

        # Plot ego vehicle (index 0) on both
        ax1.plot(data_matrix1[0, :frame, 0], data_matrix1[0, :frame, 1], color='tab:orange', label='Ego Vehicle')
        ax1.scatter(data_matrix1[0, frame, 0], data_matrix1[0, frame, 1], s=80, color='tab:orange')
        ax1.set_title('Prediction')

        ax2.plot(data_matrix2[0, :frame, 0], data_matrix2[0, :frame, 1], color='tab:orange', label='Ego Vehicle')
        ax2.scatter(data_matrix2[0, frame, 0], data_matrix2[0, frame, 1], s=80, color='tab:orange')
        ax2.set_title('Actual')

        for ax, data_matrix in zip(axes, [data_matrix1, data_matrix2]):
            ax.set_xlim(data_matrix[:, :, 0][data_matrix[:, :, 0] != 0].min() - 10,
                        data_matrix[:, :, 0][data_matrix[:, :, 0] != 0].max() + 10)
            ax.set_ylim(data_matrix[:, :, 1][data_matrix[:, :, 1] != 0].min() - 10,
                        data_matrix[:, :, 1][data_matrix[:, :, 1] != 0].max() + 10)
            ax.legend()
            ax.set_xlabel('X')
            ax.set_ylabel('Y')

        # Compute MSE over non-zero entries up to current frame
        mask = (data_matrix2[:, :frame+1, :] != 0) & (data_matrix1[:, :frame+1, :] != 0)
        mse = np.mean((data_matrix1[:, :frame+1, :][mask] - data_matrix2[:, :frame+1, :][mask]) ** 2)

        fig.suptitle(f"Timestep {frame} - MSE: {mse:.4f}", fontsize=16)
        return ax1.collections + ax1.lines + ax2.collections + ax2.lines

    anim = animation.FuncAnimation(fig, update, frames=list(range(0, timesteps, 3)), interval=100, blit=True)
    anim.save(f'trajectory_visualization_{name}.gif', writer='pillow')
    plt.close()


In [38]:
# overfit small
model, X_mean, X_std, y_mean, y_std = train_model(train_data[:3,...],epochs1=300, epochs2=300, validation_split=0)

Total agents: 150
Pruned due to zero frame in Tobs+Tpred: 132
Pruned due to zero frame in observed or future window: 0
Remaining valid agents: 18
Training on 18 valid agent trajectories.
Input shape: (18, 50, 6), Delta Output shape: (18, 60, 2)
X_train NaNs: 0
y_train NaNs: 0
Any std == 0? False False
🔧 GradientMonitoringCallback initialized with clip_min=0.0001, clip_max=100.0, monitor_freq=3

--- Phase 1: Training ---
🚀 GradientMonitoringCallback: Training started!
Epoch 1/300
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 9s/step - loss: 1.0058 - mae: 0.6692
Epoch 2/300
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 194ms/step - loss: 0.8362 - mae: 0.6214
Epoch 3/300
   📋 Optimizer type: Adam
   📈 Number of trainable variables: 18
   ❌ Optimizer doesn't have get_gradients, using variable norms
      Variable 0 norm: 2.45e+00
      Variable 1 norm: 1.13e+01
      Variable 2 norm: 1.13e+01
   🔬 Checking 18 variable norms...
   ✅ All variable norms are wit

In [41]:
# visualize regular prediction

# model = load_model()

# Parameters
Tobs = 50
Tpred = 60

data = train_data[1]

# Select a test scenario (can use any valid index)
test_scenario = data.copy()  # shape (agents, time_steps, features)


# Forecast future positions
predicted_positions = forecast_positions(test_scenario, Tobs, Tpred, model, X_mean, X_std, y_mean, y_std)

# Create combined matrix of past observed + predicted for ego agent (agent 0)
ego_past = test_scenario[0, :Tobs, :2]               # shape (Tobs, 2)
ego_future = predicted_positions[0]                  # shape (Tpred, 2)

print(ego_future[:5])
print(test_scenario[0, Tobs:Tobs+5, :2])
ego_full = np.concatenate([ego_past, ego_future], axis=0)  # shape (Tobs + Tpred, 2)

# Create updated scenario with predicted ego and original others
updated_scenario = test_scenario.copy()
updated_scenario[0, :Tobs+Tpred, :2] = ego_full  # Replace ego trajectory

# Visualize
make_gif(updated_scenario, data, name='lstm2')

pred deltas
[[[ 0.6200853  -1.4685496 ]
  [ 0.6323009  -1.6576537 ]
  [ 0.5777124  -1.6909549 ]
  [ 0.64099866 -1.8064063 ]
  [ 0.82729065 -1.7919458 ]]]
pred deltas
[[[ 0.62008464 -1.4685491 ]
  [ 0.6322974  -1.6576431 ]
  [ 0.57770175 -1.690933  ]
  [ 0.6409859  -1.8063818 ]
  [ 0.8272821  -1.7919288 ]]]
pred deltas
[[[ 0.62008655 -1.46855   ]
  [ 0.632305   -1.6576673 ]
  [ 0.5777259  -1.6909825 ]
  [ 0.64101493 -1.8064376 ]
  [ 0.827302   -1.7919668 ]]]
pred deltas
[[[ 0.6200856  -1.4685488 ]
  [ 0.63230103 -1.6576543 ]
  [ 0.5777122  -1.6909554 ]
  [ 0.6409991  -1.8064065 ]
  [ 0.82729053 -1.7919452 ]]]
pred deltas
[[[ 0.62008476 -1.468549  ]
  [ 0.63229895 -1.6576478 ]
  [ 0.5777066  -1.6909428 ]
  [ 0.6409916  -1.8063927 ]
  [ 0.82728565 -1.7919362 ]]]
pred deltas
[[[ 0.91566324 -1.9718213 ]
  [ 0.78065526 -1.9445078 ]
  [ 0.5907213  -1.898335  ]
  [ 0.6161543  -1.6380386 ]
  [ 0.56206137 -1.2853904 ]]]
pred deltas
[[[ 0.8650169  -1.8976893 ]
  [ 0.68931925 -1.851831  ]
  [ 0.55

  cmap1 = plt.cm.get_cmap('viridis', 50)
  cmap2 = plt.cm.get_cmap('plasma', 50)


In [None]:
# Train the model
model, X_mean, X_std, y_mean, y_std = train_model(train_data,epochs1=2, epochs2=0)

# Save the model 
save_model(model)

In [None]:
# visualize regular prediction

# model = load_model()

# Parameters
Tobs = 50
Tpred = 60

data = train_data[0]

# Select a test scenario (can use any valid index)
test_scenario = data.copy()  # shape (agents, time_steps, features)


# Forecast future positions
predicted_positions = forecast_positions(test_scenario, Tobs, Tpred, model, X_mean, X_std, y_mean, y_std)

# Create combined matrix of past observed + predicted for ego agent (agent 0)
ego_past = test_scenario[0, :Tobs, :2]               # shape (Tobs, 2)
ego_future = predicted_positions[0]                  # shape (Tpred, 2)

print(ego_future[:5])
ego_full = np.concatenate([ego_past, ego_future], axis=0)  # shape (Tobs + Tpred, 2)

# Create updated scenario with predicted ego and original others
updated_scenario = test_scenario.copy()
updated_scenario[0, :Tobs+Tpred, :2] = ego_full  # Replace ego trajectory

# Visualize
make_gif(updated_scenario, data, name='lstm2')

In [None]:
# visualize prediction

# model = load_model()

# Parameters
Tobs = 50
Tpred = 60

data = train_data[0]

# Select a test scenario (can use any valid index)
test_scenario = data.copy()  # shape (agents, time_steps, features)


# Forecast future positions
predicted_positions = finetune_forecast_positions(test_scenario, Tobs, Tpred, model, X_mean, X_std, y_mean, y_std)

# Create combined matrix of past observed + predicted for ego agent (agent 0)
ego_past = test_scenario[0, :Tobs, :2]               # shape (Tobs, 2)
ego_future = predicted_positions[0]                  # shape (Tpred, 2)
ego_full = np.concatenate([ego_past, ego_future], axis=0)  # shape (Tobs + Tpred, 2)

# Create updated scenario with predicted ego and original others
updated_scenario = test_scenario.copy()
updated_scenario[0, :Tobs+Tpred, :2] = ego_full  # Replace ego trajectory

# Visualize
make_gif(updated_scenario, data, name='lstm2')

In [None]:
from sklearn.metrics import mean_squared_error


def evaluate_mse(train_data, model, Tobs=50, Tpred=60):
    """
    Computes LSTM prediction for ego agent and evaluates MSE with progress reporting.
    """
    N = train_data.shape[0]
    mse_list = []
    valid_scenarios = 0
    
    print(f"Evaluating {N} scenarios...")
    
    # Progress reporting variables
    report_interval = max(1, N // 10)  # Report at 10% intervals
    
    for i in range(N):
        # Progress reporting
        if i % report_interval == 0 or i == N-1:
            print(f"Processing scenario {i+1}/{N} ({(i+1)/N*100:.1f}%)")
        
        scenario_data = train_data[i]
        ego_agent_data = scenario_data[0]
        ground_truth = ego_agent_data[Tobs:Tobs+Tpred, :2]
        
        # Skip if ground truth contains all zeros (padded)
        if np.all(ground_truth == 0):
            continue
            
        valid_scenarios += 1
        
        # Forecast future positions
        predicted_positions = forecast_positions(
            ego_agent_data[np.newaxis, :, :],
            Tobs, Tpred, model, X_mean, X_std, y_mean, y_std
        )
        
        # Compute MSE
        mse = mean_squared_error(ground_truth, predicted_positions[0])
        mse_list.append(mse)
        
        # Occasional MSE reporting
        if i % report_interval == 0:
            print(f"  Current scenario MSE: {mse:.4f}")
    
    # Final results
    if mse_list:
        overall_mse = np.mean(mse_list)
        print(f"Evaluation complete: {valid_scenarios} valid scenarios")
        print(f"Mean Squared Error (MSE): {overall_mse:.4f}")
        print(f"Min MSE: {np.min(mse_list):.4f}, Max MSE: {np.max(mse_list):.4f}")
        return overall_mse
    else:
        print("No valid scenarios for evaluation.")
        return None

In [None]:
# Evaluate on training data
evaluate_mse(train_data, model)

In [ ]:
import pandas as pd
import numpy as np

def generate_submission(data, output_csv, Tobs=50, Tpred=60):
    """
    Applies forecasting and generates a submission CSV with format:
    index,x,y where index is auto-generated and matches submission key.
    
    Args:
        data (np.ndarray): Test data of shape (num_scenarios, 50, 50, 6).
        output_csv (str): Output CSV file path.
        Tobs (int): Observed time steps (default 50).
        Tpred (int): Prediction time steps (default 60).
    """

    predictions = []

    for i in range(data.shape[0]):
        scenario_data = data[i]            # Shape: (50, 50, 6)
        ego_agent_data = scenario_data[0]  # Shape: (50, 6)

        # Predict future positions for the ego agent
        predicted_positions = finetune_forecast_positions(
            ego_agent_data[np.newaxis, :, :], Tobs, Tpred, model
        )  # Shape: (1, 60, 2)

        # Append 60 predictions (x, y) for this scenario
        predictions.extend(predicted_positions[0])  # Shape: (60, 2)

    # Create DataFrame without explicit ID
    submission_df = pd.DataFrame(predictions, columns=["x", "y"])
    submission_df.index.name = 'index'  # Match Kaggle format

    # Save CSV with index
    submission_df.to_csv(output_csv)
    print(f"Submission file '{output_csv}' saved with shape {submission_df.shape}")

generate_submission(test_data, 'lstm_submission.csv')