In [1]:
import os
import re
import glob
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
from datetime import datetime
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.amp import autocast, GradScaler


In [2]:
def collate_fn(batch):
    if len(batch[0]) == 3:  # Training data with future
        pasts, masks, futures = zip(*batch)
        past = torch.stack(pasts)
        mask = torch.stack(masks)
        future = torch.stack(futures)
        return past, mask, future
    else:  # Test data without future
        pasts, masks = zip(*batch)
        past = torch.stack(pasts)
        mask = torch.stack(masks)
        return past, mask

class TrajectoryDataset(Dataset):
    def __init__(self, input_path=None, data=None, T_past=50, T_future=60,
                 is_test=False, pos_mean=None, pos_std=None, vel_mean=None, vel_std=None):
        if data is not None:
            self.data = data.copy()
        else:
            npz = np.load(input_path)
            self.data = npz['data']
        self.T_past = T_past
        self.T_future = T_future
        self.is_test = is_test
        # Normalization stats
        self.pos_mean = np.array(pos_mean, dtype=np.float32) if pos_mean is not None else None
        self.pos_std  = np.array(pos_std,  dtype=np.float32) if pos_std  is not None else None
        self.vel_mean = np.array(vel_mean, dtype=np.float32) if vel_mean is not None else None
        self.vel_std  = np.array(vel_std,  dtype=np.float32) if vel_std  is not None else None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx].astype(np.float32)  # (num_agents, T, 6)
        # Apply global normalization
        if self.pos_mean is not None:
            scene[..., :2] = (scene[..., :2] - self.pos_mean) / self.pos_std
            scene[..., 2:4] = (scene[..., 2:4] - self.vel_mean) / self.vel_std
        past = scene[:, :self.T_past, :]
        mask = np.sum(np.abs(past[..., :2]), axis=(1, 2)) > 0
        if not self.is_test and scene.shape[1] >= self.T_past + self.T_future:
            future = scene[0, self.T_past:self.T_past + self.T_future, :2]
            return (
                torch.tensor(past, dtype=torch.float32),
                torch.tensor(mask, dtype=torch.bool),
                torch.tensor(future, dtype=torch.float32)
            )
        return (
            torch.tensor(past, dtype=torch.float32),
            torch.tensor(mask, dtype=torch.bool)
        )

class AugmentedTrajectoryDataset(Dataset):
    def __init__(self, input_path=None, data=None, T_past=50, T_future=60, is_test=False,
                 augment_prob=0.7, noise_scale=0.03, rotation_max=20, flip_prob=0.4):
        if data is not None:
            self.data = data
        else:
            npz = np.load(input_path)
            self.data = npz['data']
        self.T_past = T_past
        self.T_future = T_future
        self.is_test = is_test
        
        # Augmentation parameters
        self.augment_prob = augment_prob
        self.noise_scale = noise_scale
        self.rotation_max = rotation_max
        self.flip_prob = flip_prob

    def __len__(self):
        return len(self.data)
    
    def _apply_augmentations(self, past, future=None):
        """Apply augmentations to the trajectory data"""
        # Skip augmentation during testing or with probability (1-augment_prob)
        if self.is_test or np.random.random() > self.augment_prob:
            return past, future
        
        # Create copies to avoid modifying original data
        past_aug = past.copy()
        future_aug = future.copy() if future is not None else None
        
        # 1. Random rotation
        if np.random.random() < 0.7:  # 70% chance of rotation
            angle = np.random.uniform(-self.rotation_max, self.rotation_max)
            angle_rad = np.radians(angle)
            rot_matrix = np.array([
                [np.cos(angle_rad), -np.sin(angle_rad)],
                [np.sin(angle_rad), np.cos(angle_rad)]
            ])
            
            # Apply rotation to positions
            for i in range(past_aug.shape[0]):
                for t in range(past_aug.shape[1]):
                    past_aug[i, t, :2] = np.dot(rot_matrix, past_aug[i, t, :2])
                    if past_aug.shape[2] > 2:  # If we have velocity data
                        past_aug[i, t, 2:4] = np.dot(rot_matrix, past_aug[i, t, 2:4])
            
            # Apply same rotation to future if it exists
            if future_aug is not None:
                for t in range(future_aug.shape[0]):
                    future_aug[t, :2] = np.dot(rot_matrix, future_aug[t, :2])
        
        # 2. Add Gaussian noise to positions
        if np.random.random() < 0.6:  # 60% chance of adding noise
            noise = np.random.normal(0, self.noise_scale, past_aug[..., :2].shape)
            past_aug[..., :2] += noise
            
            # Update velocities to match the noisy positions
            if past_aug.shape[2] > 2:
                for i in range(past_aug.shape[0]):
                    for t in range(1, past_aug.shape[1]):
                        past_aug[i, t, 2:4] = past_aug[i, t, :2] - past_aug[i, t-1, :2]
        
        # 3. Random horizontal flipping
        if np.random.random() < self.flip_prob:
            # Flip x-coordinates
            past_aug[..., 0] = -past_aug[..., 0]
            # Flip x-velocities
            if past_aug.shape[2] > 2:
                past_aug[..., 2] = -past_aug[..., 2]
            
            # Flip future trajectory if it exists
            if future_aug is not None:
                future_aug[..., 0] = -future_aug[..., 0]
        
        return past_aug, future_aug

    def __getitem__(self, idx):
        scene = self.data[idx]  # (num_agents, T, 6)
        past = scene[:, :self.T_past, :]
        mask = np.sum(np.abs(past[..., :2]), axis=(1, 2)) > 0
        
        if not self.is_test and scene.shape[1] >= self.T_past + self.T_future:
            future = scene[0, self.T_past:self.T_past + self.T_future, :2]
            
            # Apply augmentations
            past_aug, future_aug = self._apply_augmentations(past, future)
            
            return (
                torch.tensor(past_aug, dtype=torch.float32),
                torch.tensor(mask, dtype=torch.bool),
                torch.tensor(future_aug, dtype=torch.float32)
            )
        
        # For test data, no augmentation
        if self.is_test:
            return (
                torch.tensor(past, dtype=torch.float32),
                torch.tensor(mask, dtype=torch.bool)
            )
            
        return (
            torch.tensor(past, dtype=torch.float32),
            torch.tensor(mask, dtype=torch.bool)
        )

In [3]:
class BernsteinLayer(nn.Module):
    def __init__(self, n_degree, T=60):
        super().__init__()
        self.n_degree = n_degree
        self.T = T
        # Generate Bernstein basis functions
        time_points = torch.linspace(0, 1, T)
        bern = torch.zeros(n_degree + 1, T)
        for i in range(n_degree + 1):
            bern[i] = math.comb(n_degree, i) * (time_points ** i) * ((1 - time_points) ** (n_degree - i))
        self.register_buffer('bernstein_values', bern)
        # Calculate pseudoinverse for control-point regression
        self.register_buffer('pinv', torch.pinverse(bern.T))
        
        # Log the shapes for debugging
        print(f"Bernstein basis shape: {bern.shape}, Pinv shape: {self.pinv.shape}")

    def forward(self, control_points):
        """Convert control points to trajectory using Bernstein polynomials"""
        # control_points: (B, n+1, 2)
        B, n_cp, D = control_points.shape
        if n_cp != self.n_degree + 1:
            raise ValueError(f"Expected {self.n_degree+1} control points, got {n_cp}")
            
        cp = control_points.transpose(1, 2)  # (B,2,n+1)
        traj = torch.matmul(cp, self.bernstein_values)  # (B,2,T)
        return traj.transpose(1, 2)  # (B,T,2)

    def inverse(self, trajectories):
        """Convert trajectory to control points using pseudoinverse"""
        # trajectories: (B,T,2) -> (B,n+1,2)
        B, T, D = trajectories.shape
        
        # Verify input dimensions match expected dimensions
        if T != self.T:
            raise ValueError(f"Expected trajectories with length {self.T}, but got {T}")
            
        # Use einsum for efficient batch matrix multiplication
        try:
            # pinv: (T,n+1)
             return torch.einsum('nt,btd->bnd', self.pinv, trajectories)
        except RuntimeError as e:
            print(f"Error in Bernstein inverse: Pinv shape={self.pinv.shape}, Traj shape={trajectories.shape}")
            raise e

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pos_enc = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pos_enc', pos_enc.unsqueeze(0))

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pos_enc[:, :seq_len, :]

In [5]:
class SymmetricAttention(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=256, dropout=0.1):
        super().__init__()
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.rpe_processor = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )

    def forward(self, source, target, rpe=None):
        B, Ns, D = source.shape
        _, Nt, _ = target.shape
        query = self.query_proj(target)
        key = self.key_proj(source)
        value = self.value_proj(source)
        if rpe is not None:
            # rpe: (B, N_target, N_source, D)
            proc = self.rpe_processor(rpe)                         # (B, N_target, N_source, D)
            key_expanded = key.unsqueeze(1)                        # (B, 1, N_source, D)
            enhanced = key_expanded + proc                         # broadcasts to (B, N_target, N_source, D)
            outputs = []
            for i in range(Nt):
                q = query[:, i:i+1]                                # (B, 1, D)
                k = enhanced[:, i, :, :]                           # (B, N_source, D)
                v = value                                          # (B, N_source, D)
                out, _ = self.multihead_attn(q, k, v)              # now all 3-D
                outputs.append(out)
            attn_output = torch.cat(outputs, dim=1)                # (B, N_target, D)
        else:
            attn_output, _ = self.multihead_attn(query, key, value)
        tgt2 = self.norm1(target + self.dropout(attn_output))
        ff_out = self.ff(tgt2)
        return self.norm2(tgt2 + self.dropout(ff_out))

In [6]:
class SIMPLModel(nn.Module):
    def __init__(self, feature_dim=6, d_model=128, nhead=8,
                 num_layers_temporal=2, num_layers_social=2,
                 dim_feedforward=256, T_past=50, T_future=60,
                 polynomial_degree=5, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.T_past = T_past
        self.T_future = T_future
        self.polynomial_degree = polynomial_degree

        # Feature normalization layers
        self.position_norm = nn.LayerNorm([2])
        self.velocity_norm = nn.LayerNorm([2])
        self.other_norm = nn.LayerNorm([feature_dim-4]) if feature_dim>4 else None

        # Input embedding
        self.input_embed = nn.Linear(feature_dim, d_model)
        self.time_pos_enc = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                batch_first=True
            ),
            num_layers=1
        )

        # Temporal encoders
        self.temporal_encoders = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                batch_first=True
            )
            for _ in range(num_layers_temporal)
        ])

        # Relative position embedding
        self.rpe_generator = nn.Sequential(
            nn.Linear(feature_dim, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model)
        )

        # Social attention layers
        self.social_encoders = nn.ModuleList([
            nn.MultiheadAttention(
                embed_dim=d_model,
                num_heads=nhead,
                dropout=dropout,
                batch_first=True
            )
            for _ in range(num_layers_social)
        ])
        
        # Social layer norms
        self.social_norms = nn.ModuleList([
            nn.LayerNorm(d_model)
            for _ in range(num_layers_social)
        ])

        # Control point predictor (with proper initialization)
        self.control_point_predictor = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.LayerNorm(dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, dim_feedforward),
            nn.LayerNorm(dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, 2*(polynomial_degree+1))
        )

        # Create Bernstein layer with matching dimension
        self.bernstein_layer = BernsteinLayer(polynomial_degree, T_future)
        
        # Skip connection
        self.temporal_skip = nn.Linear(d_model, d_model)
        
        # Initialize weights with improved method
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize model weights with Xavier/Kaiming initialization"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # Use Kaiming/He initialization for ReLU layers
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def compute_relative_position_embedding(self, past, mask):
        """Generate relative position embedding between agents"""
        B, N, T, F = past.shape
        # Get last position of each agent
        last = past[:, :, -1, :]
        # Expand ego vehicle position
        ego = last[:, 0:1, :].expand(-1, N, -1)
        # Calculate relative positions and apply mask
        rpe_feats = (ego - last) * mask.unsqueeze(-1).float()
        return self.rpe_generator(rpe_feats)

    def forward(self, past, agent_mask, return_cps=False):
        """Forward pass of the model"""
        B, N, T, F = past.shape
        
        # Ensure ego vehicle is always included in mask
        agent_mask = agent_mask.clone()
        agent_mask[:, 0] = True
        
        # Normalize features
        x = past.reshape(B*N*T, F)
        pos = x[:, :2]; pos = self.position_norm(pos)
        vel = x[:, 2:4]; vel = self.velocity_norm(vel)
        if F > 4 and self.other_norm is not None:
            oth = x[:, 4:]; oth = self.other_norm(oth)
            x_norm = torch.cat([pos, vel, oth], dim=1)
        else:
            x_norm = torch.cat([pos, vel], dim=1)
        
        # Reshape for temporal processing
        x = x_norm.reshape(B*N, T, F)
        
        # Apply input embedding
        x = self.input_embed(x)
        
        # Apply layer normalization to stabilize training
        x = x / (x.norm(dim=-1, keepdim=True) + 1e-6) * math.sqrt(self.d_model)
        
        # Apply positional encoding
        x = self.time_pos_enc(x)
        
        # Store features for skip connection
        temp_feats = x.clone()
        
        # Apply temporal encoders
        for layer in self.temporal_encoders:
            x = layer(x)
        
        # Extract feature from last timestep
        x = x[:, -1, :].reshape(B, N, self.d_model)
        
        # Compute relative position embedding
        rpe = self.compute_relative_position_embedding(past, agent_mask)
        
        # Extract ego features and prepare for social attention
        ego_feats = x[:, 0:1, :]  # (B, 1, D)
        
        # Apply social attention layers
        for i, (attn, norm) in enumerate(zip(self.social_encoders, self.social_norms)):
            # Create key padding mask for attention
            key_padding_mask = ~agent_mask  # (B, N)
            
            # Apply attention
            attn_output, _ = attn(
                query=ego_feats,  # (B, 1, D)
                key=x,            # (B, N, D)
                value=x,          # (B, N, D)
                key_padding_mask=key_padding_mask
            )
            
            # Apply normalization and residual connection
            ego_feats = norm(ego_feats + attn_output)
        
        # Extract final ego embedding
        ego_embed = ego_feats.squeeze(1)  # (B, D)
        
        # Apply skip connection from temporal features
        temp_skip = self.temporal_skip(temp_feats.reshape(B*N, T, self.d_model)[:, -1, :])
        temp_skip = temp_skip.reshape(B, N, self.d_model)[:, 0, :]
        ego_embed = ego_embed + temp_skip
        
        # Predict control points
        cps_flat = self.control_point_predictor(ego_embed)
        cps = cps_flat.reshape(B, self.polynomial_degree+1, 2)
        
        # Convert control points to trajectory
        traj = self.bernstein_layer(cps)
        
        if return_cps:
            return traj, cps
        return traj

In [7]:
def configure_optimizer(model, base_lr=3e-5):
    backbone_params, predictor_params, embedding_params, bernstein_params = [], [], [], []
    for name, param in model.named_parameters():
        if 'control_point_predictor' in name:
            predictor_params.append(param)
        elif 'bernstein_layer' in name:
            bernstein_params.append(param)
        elif 'input_embed' in name:
            embedding_params.append(param)
        else:
            backbone_params.append(param)
    return torch.optim.AdamW([
        {'params': backbone_params, 'lr': base_lr,    'weight_decay':1e-5},
        {'params': predictor_params,'lr': base_lr*1.5,'weight_decay':2e-5},
        {'params': embedding_params,'lr': base_lr*1.2,'weight_decay':1e-5},
        {'params': bernstein_params,'lr': base_lr*0.5,'weight_decay':1e-6},
    ])

In [8]:
def train(model, dataloader, optimizer, device, epoch,
          lr_scheduler=None, writer=None, global_step=0, lambda_cp=0.1):
    model.train()
    criterion = nn.SmoothL1Loss()
    scaler = GradScaler()
    total_loss = 0.0
    total_traj_loss = 0.0
    total_cp_loss = 0.0
    batch_count = 0

    # Implement curriculum learning for trajectory horizon
    if epoch <= 10:
        H = 10
    elif epoch <= 20:
        H = 20
    elif epoch <= 40:
        H = 40
    else:
        H = model.T_future

    for batch_idx, batch in enumerate(dataloader):
        past, mask, future = [x.to(device) for x in batch]
        
        # Skip problematic batches
        if torch.isnan(past).any() or torch.isnan(future).any():
            print(f"Warning: NaN values detected in batch {batch_idx}, skipping")
            continue
            
        optimizer.zero_grad()
        
        try:
            with autocast(device.type):
                # Forward pass with control points
                pred, cps_pred = model(past, mask, return_cps=True)
                
                # Trajectory loss over current horizon H
                loss_traj = criterion(pred[:, :H, :], future[:, :H, :])
                
                # Check if we have full future trajectory for CP loss
                if future.size(1) == model.T_future:
                    # Calculate control points for ground truth trajectory
                    cp_true = model.bernstein_layer.inverse(future)
                    loss_cp = F.mse_loss(cps_pred, cp_true)
                else:
                    loss_cp = torch.tensor(0.0, device=pred.device)
                
                # Combined loss with weighting
                loss = loss_traj + lambda_cp * loss_cp
                
            # Check for NaN loss
            if torch.isnan(loss):
                print(f"Warning: NaN loss detected in batch {batch_idx}, skipping")
                continue
                
            # Gradient scaling for mixed precision
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Step optimizer with scaling
            scaler.step(optimizer)
            scaler.update()
            
            # Update learning rate if scheduler provided
            if lr_scheduler is not None:
                lr_scheduler.step()
                
            # Log metrics
            if writer is not None and batch_idx % 20 == 0:
                writer.add_scalar('train/batch_loss', loss.item(), global_step)
                writer.add_scalar('train/batch_traj_loss', loss_traj.item(), global_step)
                if loss_cp.item() > 0:
                    writer.add_scalar('train/batch_cp_loss', loss_cp.item(), global_step)
                writer.add_scalar('train/batch_horizon', H, global_step)
                lr = optimizer.param_groups[0]['lr']
                writer.add_scalar('train/learning_rate', lr, global_step)
                
            # Update counters
            global_step += 1
            total_loss += loss.item()
            total_traj_loss += loss_traj.item()
            total_cp_loss += loss_cp.item() if loss_cp.item() > 0 else 0
            batch_count += 1
            
        except RuntimeError as e:
            print(f"Error in batch {batch_idx}: {e}")
            # Print tensor shapes for debugging
            print(f"past: {past.shape}, mask: {mask.shape}, future: {future.shape}")
            if "CUDA out of memory" in str(e):
                # Try to recover from OOM error
                torch.cuda.empty_cache()
                continue
            else:
                # For other errors, we might want to raise to debug
                raise e

    # Calculate average losses
    avg_loss = total_loss / max(batch_count, 1)
    avg_traj_loss = total_traj_loss / max(batch_count, 1)
    avg_cp_loss = total_cp_loss / max(batch_count, 1)
    
    # Log epoch metrics
    if writer is not None:
        writer.add_scalar('train/epoch_loss', avg_loss, epoch)
        writer.add_scalar('train/epoch_traj_loss', avg_traj_loss, epoch)
        writer.add_scalar('train/epoch_cp_loss', avg_cp_loss, epoch)
    
    return model, global_step, avg_loss

# Evaluation function with tensorboard logging
def evaluate(model, val_loader, device, writer=None, global_step=None):
    model.eval()
    total_loss = 0.0
    mse_criterion = nn.MSELoss(reduction='none')
    
    all_errors = []
    all_ades = []
    all_fdes = []
    
    eval_start_time = time.time()
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            past, mask, future = [x.to(device) for x in batch]
            
            # No need for autocast during evaluation
            pred = model(past, mask)
            
            # Calculate MSE loss per time step and sample
            mse = mse_criterion(pred, future)  # (B, T, 2)
            
            # Calculate ADE (Average Displacement Error)
            ade = torch.sqrt(mse.sum(dim=2)).mean(dim=1)  # (B,)
            all_ades.extend(ade.cpu().numpy())
            
            # Calculate FDE (Final Displacement Error)
            fde = torch.sqrt(mse[:, -1].sum(dim=1))  # (B,)
            all_fdes.extend(fde.cpu().numpy())
            
            # Store errors for reporting
            for i in range(len(ade)):
                all_errors.append({
                    'ade': ade[i].item(),
                    'fde': fde[i].item()
                })
            
            # Accumulate MSE loss
            batch_loss = mse.mean()
            total_loss += batch_loss.item() * past.size(0)
    
    # Calculate overall metrics
    avg_loss = total_loss / len(val_loader.dataset)
    avg_ade = np.mean(all_ades)
    avg_fde = np.mean(all_fdes)
    
    # Calculate additional metrics - percentiles
    ade_50 = np.percentile(all_ades, 50)  # median
    ade_90 = np.percentile(all_ades, 90)  # 90th percentile
    fde_50 = np.percentile(all_fdes, 50)  # median
    fde_90 = np.percentile(all_fdes, 90)  # 90th percentile
    
    eval_time = time.time() - eval_start_time
    
    # Log metrics to tensorboard
    if writer is not None and global_step is not None:
        writer.add_scalar('val/loss', avg_loss, global_step)
        writer.add_scalar('val/ade_mean', avg_ade, global_step)
        writer.add_scalar('val/fde_mean', avg_fde, global_step)
        writer.add_scalar('val/ade_50', ade_50, global_step)
        writer.add_scalar('val/ade_90', ade_90, global_step)
        writer.add_scalar('val/fde_50', fde_50, global_step)
        writer.add_scalar('val/fde_90', fde_90, global_step)
        writer.add_scalar('val/eval_time', eval_time, global_step)
        
        # Add histograms of ADE and FDE
        writer.add_histogram('val/ade_dist', np.array(all_ades), global_step)
        writer.add_histogram('val/fde_dist', np.array(all_fdes), global_step)
    
    return {
        'loss': avg_loss,
        'ade': avg_ade,
        'fde': avg_fde,
        'ade_50': ade_50,
        'ade_90': ade_90,
        'fde_50': fde_50,
        'fde_90': fde_90,
        'eval_time': eval_time,
        'detailed_errors': all_errors
    }

# Prediction function with optional tensorboard visualizations
def predict(model, test_loader, device, writer=None, visualize_samples=False):
    model.eval()
    all_preds = []
    inference_start_time = time.time()
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            past, mask = [x.to(device) for x in batch]
            
            # Forward pass
            pred = model(past, mask)
            all_preds.append(pred.cpu().numpy())
            
            # Visualize sample predictions (only if requested and writer is provided)
            if writer is not None and visualize_samples and batch_idx < 10:
                # In a real implementation, you would generate and save figures
                # writer.add_figure(f'test/trajectory_{batch_idx}', fig, 0)
                pass
    
    # Concatenate all predictions
    predictions = np.concatenate(all_preds, axis=0)
    
    # Log inference statistics if writer is provided
    if writer is not None:
        inference_time = time.time() - inference_start_time
        avg_time_per_sample = inference_time / len(predictions)
        writer.add_text('inference_stats', 
                      f"Total inference time: {inference_time:.2f}s, "
                      f"Samples: {len(predictions)}, "
                      f"Avg time per sample: {avg_time_per_sample*1000:.2f}ms")
        
        # Add histogram of prediction coordinates
        writer.add_histogram('test/pred_x', predictions[:, :, 0].flatten(), 0)
        writer.add_histogram('test/pred_y', predictions[:, :, 1].flatten(), 0)
    
    return predictions

In [9]:
train_input = 'data/train.npz'
test_input = 'data/test_input.npz'
output_csv = 'predictions.csv'
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_dir = f"runs/simpl_{timestamp}"
ckpt_dir = os.path.join("checkpoints", f"simpl_{timestamp}")
os.makedirs(ckpt_dir, exist_ok=True)
best_ckpt_path = os.path.join(ckpt_dir, "best_model.pt")

In [10]:
batch_size = 128
base_lr = 3e-5  # Lower learning rate for stability
epochs = 200
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [11]:
# Create tensorboard writer
writer = SummaryWriter(log_dir)

# Log hyperparameters
hparams = {
    'batch_size': batch_size,
    'base_learning_rate': base_lr,
    'epochs': epochs,
    'model_type': 'SIMPL',
    'd_model': 128,
    'nhead': 8,
    'num_layers_temporal': 3, 
    'num_layers_social': 2,
    'polynomial_degree': 5,  
    'dropout': 0.2,
    'weight_decay': 1e-5,
}
writer.add_text('hyperparameters', str(hparams))

In [12]:
print("Loading data...")
full_data = np.load(train_input)['data']
print(f"Full data shape: {full_data.shape}")
# Split into train and eval (7:3)
num_samples = len(full_data)
num_train = int(0.7 * num_samples)
perm = np.random.permutation(num_samples)
train_idx = perm[:num_train]
eval_idx = perm[num_train:]

train_data = full_data[train_idx]
eval_data = full_data[eval_idx]

# Calculate normalization statistics
# Only consider non-zero values for better statistics
mask = (full_data[..., :2] != 0).any(axis=-1)
pos_all = full_data[..., :2][mask]
vel_all = full_data[..., 2:4][mask]

pos_mean = pos_all.mean(axis=0)
pos_std = pos_all.std(axis=0) + 1e-6  # Add epsilon to avoid division by zero
vel_mean = vel_all.mean(axis=0)
vel_std = vel_all.std(axis=0) + 1e-6

print(f"Position mean: {pos_mean}, std: {pos_std}")
print(f"Velocity mean: {vel_mean}, std: {vel_std}")

train_ds = TrajectoryDataset(
    data=train_data,
    pos_mean=pos_mean,
    pos_std=pos_std,
    vel_mean=vel_mean,
    vel_std=vel_std)

eval_ds = TrajectoryDataset(
    data=eval_data,
    pos_mean=pos_mean,
    pos_std=pos_std,
    vel_mean=vel_mean,
    vel_std=vel_std
)
test_ds = TrajectoryDataset(
    input_path=test_input,
    is_test=True,
    pos_mean=pos_mean,
    pos_std=pos_std,
    vel_mean=vel_mean,
    vel_std=vel_std
)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(eval_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# Log dataset information
writer.add_text('dataset_info', f"Train samples: {len(train_ds)}, Eval samples: {len(eval_ds)}, Test samples: {len(test_ds)}")


Loading data...
Full data shape: (10000, 50, 110, 6)
Position mean: [2830.48216443 1083.72035943], std: [3191.10358747 1637.55340973]
Velocity mean: [-0.03593408 -0.02212871], std: [3.89127287 3.42483996]


In [13]:
print(f"Creating model on device: {device}")

# Important: use polynomial_degree=5 (not 6) to ensure compatibility
polynomial_degree = 5

model = SIMPLModel(
    feature_dim=6,
    d_model=128,
    nhead=8,
    num_layers_temporal=3,  # Reduced from 4
    num_layers_social=2,
    dim_feedforward=256,
    T_past=50,
    T_future=60,
    polynomial_degree=polynomial_degree,  # Match BernsteinLayer
    dropout=0.2
).to(device)

# Log model architecture and parameters
writer.add_text('model_architecture', str(model))
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
writer.add_text('model_params', f"Total trainable parameters: {total_params:,}")


Creating model on device: cuda
Bernstein basis shape: torch.Size([6, 60]), Pinv shape: torch.Size([6, 60])


In [14]:
optimizer = configure_optimizer(model, base_lr=base_lr)

# Create learning rate scheduler with warmup
from torch.optim.lr_scheduler import OneCycleLR

# One-cycle learning rate scheduler
lr_scheduler = OneCycleLR(
    optimizer,
    max_lr=[base_lr*3, base_lr*4.5, base_lr*3.6, base_lr*1.5],  # Scale for each param group
    steps_per_epoch=len(train_loader),
    epochs=epochs,
    pct_start=0.1,  # Spend 10% of time warming up
    div_factor=25,  # Initial LR is max_lr/25
    final_div_factor=1000,  # Final LR is max_lr/1000
    anneal_strategy='cos'
)

In [15]:
start_epoch   = 1
best_val_loss = float('inf')
global_step   = 0

if os.path.exists(best_ckpt_path):
    ckpt = torch.load(best_ckpt_path, map_location=device, weights_only=False)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    start_epoch   = ckpt['epoch'] + 1
    best_val_loss = ckpt.get('val_loss', best_val_loss)
    global_step   = ckpt.get('global_step', 0)
    hparams = ckpt.get('hparams', hparams)
    print(f"▶ Resuming from epoch {ckpt['epoch']} (val_loss={best_val_loss:.6f})")

# Record training start time
training_start_time = time.time()

# Log the training loop start
writer.add_text('training_info', f"Training started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

▶ Resuming from epoch 63 (val_loss=0.870934)


In [16]:
def save_checkpoint(path, epoch, model, optimizer, val_loss, val_metrics, global_step, hparams):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_metrics': val_metrics,
        'global_step': global_step,
        'hparams': hparams
    }, path)

In [17]:
best_val_metrics = {}
try:
    # Initialize early stopping counter
    patience = 50
    early_stop_counter = 0
    
    for epoch in range(start_epoch, epochs + 1):
        epoch_start_time = time.time()
        
        # Train for one epoch
        model, global_step, train_loss = train(
            model, train_loader, optimizer, device,
            epoch=epoch,
            lr_scheduler=lr_scheduler,
            writer=writer,
            global_step=global_step
        )
        
        # Evaluate on validation set
        val_metrics = evaluate(
            model, eval_loader, device, 
            writer=writer, global_step=global_step
        )
        val_loss = val_metrics['loss']
        
        epoch_time = time.time() - epoch_start_time
        
        # Print epoch summary
        print(f"Epoch {epoch}/{epochs} | "
              f"Train Loss: {train_loss:.6f} | "
              f"Val Loss: {val_loss:.6f} | "
              f"Val ADE: {val_metrics['ade']:.4f} | "
              f"Val FDE: {val_metrics['fde']:.4f} | "
              f"Time: {epoch_time:.2f}s")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_metrics = val_metrics
            save_checkpoint(
                best_ckpt_path, epoch, model, optimizer,
                val_loss, val_metrics, global_step, hparams
            )
            print(f"✅ Best model saved at epoch {epoch} (val loss: {best_val_loss:.6f})")
            writer.add_text('checkpoints', f"New best model at epoch {epoch} with val_loss: {best_val_loss:.6f}")
            # Reset early stopping counter
            early_stop_counter = 0
        else:
            # Increment early stopping counter
            early_stop_counter += 1
            print(f"No improvement for {early_stop_counter} epochs. Best val_loss: {best_val_loss:.6f}")
        
        # Save checkpoint every 20 epochs
        if epoch % 20 == 0:
            cur_ckpt_path = os.path.join(ckpt_dir, f"epoch_{epoch}.pt")
            save_checkpoint(
                cur_ckpt_path, epoch, model, optimizer,
                val_loss, val_metrics, global_step, hparams
            )
            print(f"🧪 Checkpoint saved at {cur_ckpt_path}")
            writer.add_text('checkpoints', f"Periodic checkpoint at epoch {epoch}")
        
        # Early stopping check
        if early_stop_counter >= patience:
            print(f"Early stopping triggered. No improvement for {patience} epochs.")
            writer.add_text('training_info', f"Early stopping at epoch {epoch}")
            break
            
except KeyboardInterrupt:
    print("Training interrupted by user")
    writer.add_text('training_info', f"Training interrupted at epoch {epoch}")

except Exception as e:
    print(f"Error during training: {e}")
    import traceback
    traceback.print_exc()
    writer.add_text('training_info', f"Training crashed with error: {str(e)}")

finally:
    # Calculate total training time
    total_training_time = time.time() - training_start_time
    hours, remainder = divmod(total_training_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    time_str = f"{int(hours)}h {int(minutes)}m {int(seconds)}s"
    
    print(f"Total training time: {time_str}")
    writer.add_text('training_info', f"Training completed/interrupted after {time_str}")
    
    # Save final model if different from best model
    if not os.path.exists(os.path.join(ckpt_dir, "final_model.pt")):
        final_path = os.path.join(ckpt_dir, "final_model.pt")
        # use best_val_loss/metrics if val_loss isn’t in scope
        vloss = locals().get('val_loss', best_val_loss)
        vmetrics = locals().get('val_metrics', best_val_metrics)
        try:
            save_checkpoint(final_path, epoch, model, optimizer,
                            vloss, vmetrics, global_step, hparams)
            print(f"Final model saved at {final_path}")
        except Exception as e:
            print(f"Failed to save final model: {e}")

Epoch 64/200 | Train Loss: 0.436582 | Val Loss: 0.867255 | Val ADE: 1.0623 | Val FDE: 1.0625 | Time: 22.57s
✅ Best model saved at epoch 64 (val loss: 0.867255)
Epoch 65/200 | Train Loss: 0.435433 | Val Loss: 0.866996 | Val ADE: 1.0638 | Val FDE: 1.0640 | Time: 20.79s
✅ Best model saved at epoch 65 (val loss: 0.866996)
Epoch 66/200 | Train Loss: 0.433848 | Val Loss: 0.866936 | Val ADE: 1.0632 | Val FDE: 1.0635 | Time: 20.87s
✅ Best model saved at epoch 66 (val loss: 0.866936)
Epoch 67/200 | Train Loss: 0.433048 | Val Loss: 0.866704 | Val ADE: 1.0638 | Val FDE: 1.0643 | Time: 20.90s
✅ Best model saved at epoch 67 (val loss: 0.866704)
Epoch 68/200 | Train Loss: 0.434678 | Val Loss: 0.865675 | Val ADE: 1.0625 | Val FDE: 1.0629 | Time: 20.87s
✅ Best model saved at epoch 68 (val loss: 0.865675)
Epoch 69/200 | Train Loss: 0.433444 | Val Loss: 0.867023 | Val ADE: 1.0641 | Val FDE: 1.0645 | Time: 21.19s
No improvement for 1 epochs. Best val_loss: 0.865675
Epoch 70/200 | Train Loss: 0.436056 | V

In [None]:
print("Loading best model for prediction...")
try:
    ckpt = torch.load(best_ckpt_path, map_location=device,weights_only=False)
    model.load_state_dict(ckpt['model_state_dict'])
    print(f"Loaded model from epoch {ckpt['epoch']} with val_loss={ckpt['val_loss']:.6f}")
except Exception as e:
    print(f"Failed to load best model: {e}")
    print("Using current model instead")

# Generate predictions
print("Generating predictions on test set...")
preds = predict(model, test_loader, device, writer=writer)

# Denormalize predictions if normalization was applied
if hasattr(test_ds, 'pos_mean') and test_ds.pos_mean is not None:
    print("Denormalizing predictions...")
    preds_denorm = preds * test_ds.pos_std + test_ds.pos_mean
else:
    preds_denorm = preds

# Reshape to format expected by the competition
preds_flat = preds_denorm.reshape(-1, 2)

# Create submission file
print(f"Creating submission file: {output_csv}")
import pandas as pd
df_preds = pd.DataFrame(preds_flat, columns=['x', 'y'])
df_preds.index.name = 'ID'
df_preds.to_csv(output_csv)
print(f"Predictions saved to {output_csv}")

# Close tensorboard writer
writer.close()
print("Training completed!")

Loading best model for prediction...
Loaded model from epoch 118 with val_loss=0.861630
Generating predictions on test set...
Denormalizing predictions...
Creating submission file: predictions.csv
Predictions saved to predictions.csv
Training completed!


: 