In [21]:
import os
import re
import glob
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
from datetime import datetime

In [22]:
def collate_fn(batch):
    if len(batch[0]) == 3:  # Training data with future
        pasts, masks, futures = zip(*batch)
        past = torch.stack(pasts)
        mask = torch.stack(masks)
        future = torch.stack(futures)
        return past, mask, future
    else:  # Test data without future
        pasts, masks = zip(*batch)
        past = torch.stack(pasts)
        mask = torch.stack(masks)
        return past, mask


class TrajectoryDataset(Dataset):
    def __init__(self, input_path=None, data=None, T_past=50, T_future=60, is_test=False):
        if data is not None:
            self.data = data
        else:
            npz = np.load(input_path)
            self.data = npz['data']
        self.T_past = T_past
        self.T_future = T_future
        self.is_test = is_test

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]  # (num_agents, T, 6)
        past = scene[:, :self.T_past, :]
        mask = np.sum(np.abs(past[..., :2]), axis=(1, 2)) > 0
        if not self.is_test and scene.shape[1] >= self.T_past + self.T_future:
            future = scene[0, self.T_past:self.T_past + self.T_future, :2]
            return (
                torch.tensor(past, dtype=torch.float32),
                torch.tensor(mask, dtype=torch.bool),
                torch.tensor(future, dtype=torch.float32)
            )
        return (
            torch.tensor(past, dtype=torch.float32),
            torch.tensor(mask, dtype=torch.bool)
        )

In [23]:
class BernsteinLayer(nn.Module):
    def __init__(self, n_degree, T=60):
        super().__init__()
        self.n_degree = n_degree
        self.T = T
        self._precompute_bernstein_values()

    def _precompute_bernstein_values(self):
        time_points = torch.linspace(0, 1, self.T)
        bern = torch.zeros(self.n_degree + 1, self.T)
        for i in range(self.n_degree + 1):
            binom = math.comb(self.n_degree, i)
            bern[i] = binom * (time_points ** i) * ((1 - time_points) ** (self.n_degree - i))
        # register as (n+1, T)
        self.register_buffer('bernstein_values', bern)

    def forward(self, control_points):
        # control_points: (B, n+1, 2)
        cp = control_points.transpose(1, 2)  # (B, 2, n+1)
        # matmul broadcasts: (B,2,n+1) @ (n+1,T) -> (B,2,T)
        traj = torch.matmul(cp, self.bernstein_values)  # (B,2,T)
        trajectories = traj.transpose(1, 2)  # (B,T,2)
        return trajectories

In [24]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pos_enc = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model)
        )
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        # shape (1, max_len, d_model)
        pos_enc = pos_enc.unsqueeze(0)
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        seq_len = x.size(1)
        return x + self.pos_enc[:, :seq_len, :]

In [25]:
class SymmetricAttention(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=256, dropout=0.1):
        super().__init__()
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.rpe_processor = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )

    def forward(self, source, target, rpe=None):
        B, N_source, D = source.shape
        _, N_target, _ = target.shape
        query = self.query_proj(target)
        key = self.key_proj(source)
        value = self.value_proj(source)
        if rpe is not None:
            proc = self.rpe_processor(rpe)
            enhanced = key.unsqueeze(1).repeat(1, N_target, 1, 1) + proc
            outputs = []
            for i in range(N_target):
                q = query[:, i:i+1, :]
                k = enhanced[:, i, :, :]
                v = value
                out, _ = self.multihead_attn(q, k, v)
                outputs.append(out)
            attn_output = torch.cat(outputs, dim=1)
        else:
            attn_output, _ = self.multihead_attn(query, key, value)
        target = self.norm1(target + self.dropout(attn_output))
        ff_out = self.ff(target)
        return self.norm2(target + self.dropout(ff_out))

In [26]:
class SIMPLModel(nn.Module):
    def __init__(self, feature_dim=6, d_model=128, nhead=8,
                 num_layers_temporal=2, num_layers_social=2,
                 dim_feedforward=256, T_past=50, T_future=60,
                 polynomial_degree=5, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.T_past = T_past
        self.T_future = T_future
        self.polynomial_degree = polynomial_degree
        self.input_embed = nn.Linear(feature_dim, d_model)
        self.time_pos_enc = PositionalEncoding(d_model, max_len=T_past)
        self.temporal_encoders = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
            for _ in range(num_layers_temporal)
        ])
        self.rpe_generator = nn.Sequential(
            nn.Linear(feature_dim, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, d_model)
        )
        self.social_encoders = nn.ModuleList([
            SymmetricAttention(d_model, nhead, dim_feedforward, dropout)
            for _ in range(num_layers_social)
        ])
        self.control_point_predictor = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, dim_feedforward // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward // 2, 2 * (polynomial_degree + 1))
        )
        self.bernstein_layer = BernsteinLayer(polynomial_degree, T_future)

    def compute_relative_position_embedding(self, past, mask):
        B, N, T, F = past.shape
        last = past[:, :, -1, :]
        ego = last[:, 0:1, :].expand(-1, N, -1)
        rpe_feats = ego - last
        rpe_feats = rpe_feats * mask.unsqueeze(-1).float()
        return self.rpe_generator(rpe_feats)

    def forward(self, past, agent_mask):
        B, N, T, F = past.shape
        # always include ego agent
        agent_mask = agent_mask.clone()
        agent_mask[:, 0] = True
        # temporal embedding
        x = past.view(B * N, T, F)
        x = self.input_embed(x)
        x = x / (x.norm(dim=-1, keepdim=True) + 1e-6) * math.sqrt(self.d_model)
        x = self.time_pos_enc(x)
        for layer in self.temporal_encoders:
            x = layer(x)
        # pool final state
        x = x[:, -1, :]
        agent_feats = x.view(B, N, self.d_model)
        # social interaction
        rpe = self.compute_relative_position_embedding(past, agent_mask)
        ego_feats = agent_feats[:, 0:1, :]
        others = agent_feats
        rpe_mat = rpe.unsqueeze(1)
        for layer in self.social_encoders:
            ego_feats = layer(others, ego_feats, rpe_mat)
        ego_embed = ego_feats.squeeze(1)
        cps_flat = self.control_point_predictor(ego_embed)
        cps = cps_flat.view(B, self.polynomial_degree + 1, 2)
        return self.bernstein_layer(cps)

In [27]:
def train(model, dataloader, optimizer, device, num_epochs=10, lr_scheduler=None, writer=None, global_step=0):
    model.train()
    position_criterion = nn.SmoothL1Loss()
    
    epoch_start_time = time.time()
    total_loss = 0.0
    num_batches = 0
    
    for batch_idx, batch in enumerate(dataloader):
        past, mask, future = [x.to(device) for x in batch]
        
        # Check for NaNs or Infs in inputs
        if torch.isnan(past).any() or torch.isinf(past).any():
            print("Warning: NaN or Inf detected in past input. Skipping batch.")
            continue
        
        if torch.isnan(future).any() or torch.isinf(future).any():
            print("Warning: NaN or Inf detected in future target. Skipping batch.")
            continue
        
        optimizer.zero_grad()
        
        # Forward pass
        pred = model(past, mask)
        
        # Check for NaNs or Infs in predictions
        if torch.isnan(pred).any() or torch.isinf(pred).any():
            print("Warning: NaN or Inf detected in predictions. Skipping batch.")
            continue
        
        # Position loss (smooth L1)
        loss = position_criterion(pred, future)
        
        # Calculate ADE for monitoring
        with torch.no_grad():
            mse = torch.pow(pred - future, 2).sum(dim=2)  # (B, T)
            ade = torch.sqrt(mse).mean(dim=1).mean()  # scalar
            fde = torch.sqrt(mse[:, -1]).mean()  # scalar
        
        # Backward pass with gradient clipping
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        batch_loss = loss.item()
        total_loss += batch_loss * past.size(0)
        num_batches += 1
        
        # Log metrics to tensorboard every 20 batches
        if writer is not None and batch_idx % 20 == 0:
            writer.add_scalar('train/batch_loss', batch_loss, global_step)
            writer.add_scalar('train/batch_ade', ade.item(), global_step)
            writer.add_scalar('train/batch_fde', fde.item(), global_step)
            writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step)
            
            
            # Log histograms of model weights and gradients
            if batch_idx % 200 == 0:
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        writer.add_histogram(f'weights/{name}', param.data, global_step)
                        if param.grad is not None:
                            writer.add_histogram(f'gradients/{name}', param.grad, global_step)
        
        global_step += 1
    
    # Step learning rate scheduler if provided
    if lr_scheduler is not None:
        lr_scheduler.step()
    
    epoch_loss = total_loss / len(dataloader.dataset)
    epoch_time = time.time() - epoch_start_time
    
    if writer is not None:
        writer.add_scalar('train/epoch_loss', epoch_loss, global_step)
        writer.add_scalar('train/epoch_time', epoch_time, global_step)
    
    print(f"Training - Loss: {epoch_loss:.6f}, Time: {epoch_time:.2f}s")
    
    return model, global_step

# Evaluation function with tensorboard logging
def evaluate(model, val_loader, device, writer=None, global_step=None):
    model.eval()
    total_loss = 0.0
    mse_criterion = nn.MSELoss(reduction='none')
    
    all_errors = []
    all_ades = []
    all_fdes = []
    
    eval_start_time = time.time()
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            past, mask, future = [x.to(device) for x in batch]
            pred = model(past, mask)
            
            # Calculate MSE loss per time step and sample
            mse = mse_criterion(pred, future)  # (B, T, 2)
            
            # Calculate ADE (Average Displacement Error)
            ade = torch.sqrt(mse.sum(dim=2)).mean(dim=1)  # (B,)
            all_ades.extend(ade.cpu().numpy())
            
            # Calculate FDE (Final Displacement Error)
            fde = torch.sqrt(mse[:, -1].sum(dim=1))  # (B,)
            all_fdes.extend(fde.cpu().numpy())
            
            # Store errors for reporting
            for i in range(len(ade)):
                all_errors.append({
                    'ade': ade[i].item(),
                    'fde': fde[i].item()
                })
            
            # Accumulate MSE loss
            batch_loss = mse.mean()
            total_loss += batch_loss.item() * past.size(0)
            
            # Log sample trajectories periodically
            if writer is not None and batch_idx % 50 == 0 and batch_idx < 150:
                # Plot sample trajectories for validation
                for i in range(min(2, past.size(0))):
                    fig_path = f'val_sample_{i}_step_{global_step}.png'
                    # In a real implementation, you'd visualize trajectories here
                    # writer.add_figure(f'val/trajectory_{i}', fig, global_step)
    
    # Calculate overall metrics
    avg_loss = total_loss / len(val_loader.dataset)
    avg_ade = np.mean(all_ades)
    avg_fde = np.mean(all_fdes)
    
    # Calculate additional metrics - percentiles
    ade_50 = np.percentile(all_ades, 50)  # median
    ade_90 = np.percentile(all_ades, 90)  # 90th percentile
    fde_50 = np.percentile(all_fdes, 50)  # median
    fde_90 = np.percentile(all_fdes, 90)  # 90th percentile
    
    eval_time = time.time() - eval_start_time
    
    # Log metrics to tensorboard
    if writer is not None and global_step is not None:
        writer.add_scalar('val/loss', avg_loss, global_step)
        writer.add_scalar('val/ade_mean', avg_ade, global_step)
        writer.add_scalar('val/fde_mean', avg_fde, global_step)
        writer.add_scalar('val/ade_50', ade_50, global_step)
        writer.add_scalar('val/ade_90', ade_90, global_step)
        writer.add_scalar('val/fde_50', fde_50, global_step)
        writer.add_scalar('val/fde_90', fde_90, global_step)
        writer.add_scalar('val/eval_time', eval_time, global_step)
        
        # Add histograms of ADE and FDE
        writer.add_histogram('val/ade_dist', np.array(all_ades), global_step)
        writer.add_histogram('val/fde_dist', np.array(all_fdes), global_step)
    
    return {
        'loss': avg_loss,
        'ade': avg_ade,
        'fde': avg_fde,
        'ade_50': ade_50,
        'ade_90': ade_90,
        'fde_50': fde_50,
        'fde_90': fde_90,
        'eval_time': eval_time,
        'detailed_errors': all_errors
    }

# Prediction function with optional tensorboard visualizations
def predict(model, test_loader, device, writer=None, visualize_samples=False):
    model.eval()
    all_preds = []
    inference_start_time = time.time()
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            past, mask = [x.to(device) for x in batch]
            
            # Forward pass
            pred = model(past, mask)
            all_preds.append(pred.cpu().numpy())
            
            # Visualize sample predictions (only if requested and writer is provided)
            if writer is not None and visualize_samples and batch_idx < 10:
                # In a real implementation, you would generate and save figures
                # writer.add_figure(f'test/trajectory_{batch_idx}', fig, 0)
                pass
    
    # Concatenate all predictions
    predictions = np.concatenate(all_preds, axis=0)
    
    # Log inference statistics if writer is provided
    if writer is not None:
        inference_time = time.time() - inference_start_time
        avg_time_per_sample = inference_time / len(predictions)
        writer.add_text('inference_stats', 
                      f"Total inference time: {inference_time:.2f}s, "
                      f"Samples: {len(predictions)}, "
                      f"Avg time per sample: {avg_time_per_sample*1000:.2f}ms")
        
        # Add histogram of prediction coordinates
        writer.add_histogram('test/pred_x', predictions[:, :, 0].flatten(), 0)
        writer.add_histogram('test/pred_y', predictions[:, :, 1].flatten(), 0)
    
    return predictions

In [28]:
train_input = 'data/train.npz'
test_input = 'data/test_input.npz'
output_csv = 'predictions.csv'
checkpoint_path = 'simpl_checkpoint.pt'
log_dir = 'runs/simpl_' + datetime.now().strftime('%Y%m%d_%H%M%S')

In [29]:
# Training hyperparameters
batch_size = 64
lr = 1e-4
epochs = 1000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [30]:
# Create tensorboard writer
writer = SummaryWriter(log_dir)

# Log hyperparameters
hparams = {
    'batch_size': batch_size,
    'learning_rate': lr,
    'epochs': epochs,
    'model_type': 'SIMPL',
    'd_model': 128,
    'nhead': 8,
    'num_layers_temporal': 2,
    'num_layers_social': 2,
    'polynomial_degree': 5,
    'dropout': 0.1,
    'weight_decay': 1e-5,
}
writer.add_text('hyperparameters', str(hparams))

In [31]:
# Data preparation
print("Loading data...")
full_data = np.load(train_input)['data']

# Split into train and eval (7:3)
num_samples = len(full_data)
num_train = int(0.7 * num_samples)
perm = np.random.permutation(num_samples)
train_idx = perm[:num_train]
eval_idx = perm[num_train:]

train_data = full_data[train_idx]
eval_data = full_data[eval_idx]

train_ds = TrajectoryDataset(data=train_data)
eval_ds = TrajectoryDataset(data=eval_data)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(eval_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

test_ds = TrajectoryDataset(test_input, is_test=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# Log dataset information
writer.add_text('dataset_info', f"Train samples: {len(train_ds)}, Eval samples: {len(eval_ds)}, Test samples: {len(test_ds)}")


Loading data...


In [None]:
# Create SIMPL model
print(f"Creating model on device: {device}")
model = SIMPLModel(
    feature_dim=6,
    d_model=128,
    nhead=8,
    num_layers_temporal=2,
    num_layers_social=2,
    dim_feedforward=256,
    T_past=50,
    T_future=60,
    polynomial_degree=5,
    dropout=0.1
).to(device)

# Log model architecture and parameters
writer.add_text('model_architecture', str(model))
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
writer.add_text('model_params', f"Total trainable parameters: {total_params}")

# Add model graph to tensorboard
dummy_input = (
    torch.zeros(2, 50, 50, 6, device=device),  # past
    torch.ones(2, 50, dtype=torch.bool, device=device)  # mask
)
try:
    writer.add_graph(model, dummy_input)
except RuntimeError as e:
    print("⚠️ Skipping add_graph():", e)

Creating model on device: cpu
Failed to add model graph to TensorBoard: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
	Graph diff:
		  graph(%self.1 : __torch__.SIMPLModel,
		        %past : Tensor,
		        %agent_mask.1 : Tensor):
		    %bernstein_layer : __torch__.BernsteinLayer = prim::GetAttr[name="bernstein_layer"](%self.1)
		    %control_point_predictor : __torch__.torch.nn.modules.container.Sequential = prim::GetAttr[name="control_point_predictor"](%self.1)
		    %social_encoders : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="social_encoders"](%self.1)
		    %_1.13 : __torch__.SymmetricAttention = prim::GetAttr[name="1"](%social_encoders)
		    %social_encoders.1 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="social_encoders"](%self.1)
		    %_0.9 : __torch__.SymmetricAttention = prim::GetAttr[name="0"](%social_encoders.1)
		    %rpe_generator : __torch__.torch.nn.modules.container.Sequential = prim::

In [33]:
# Optimizer and learning rate scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=lr/10)


In [None]:
# Training
print("Starting training...")
os.makedirs("checkpoints", exist_ok=True)
start_epoch = 1
best_val_loss = float('inf')
global_step = 0

# Load checkpoint if exists
if os.path.exists(checkpoint_path):
    ckpt = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    start_epoch = ckpt['epoch'] + 1
    best_val_loss = ckpt.get('val_loss', float('inf'))
    global_step = ckpt.get('global_step', 0)
    print(f"✅ Resumed from {checkpoint_path} (epoch {start_epoch - 1})")

# Record training start time
training_start_time = time.time()

# Log the training loop start
writer.add_text('training_info', f"Training started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# Training loop
try:
    for epoch in range(start_epoch, epochs + 1):
        epoch_start_time = time.time()
        
        # Train for one epoch
        model, global_step = train(
            model, train_loader, optimizer, device, 
            num_epochs=1, lr_scheduler=lr_scheduler,
            writer=writer, global_step=global_step
        )
        
        # Evaluate on validation set
        val_metrics = evaluate(
            model, eval_loader, device, 
            writer=writer, global_step=global_step
        )
        val_loss = val_metrics['loss']
        
        epoch_time = time.time() - epoch_start_time
        
        print(f"Epoch {epoch}/{epochs} | "
                f"Train Loss: {val_metrics['loss']:.6f} | "
                f"Val ADE: {val_metrics['ade']:.4f} | "
                f"Val FDE: {val_metrics['fde']:.4f} | "
                f"Time: {epoch_time:.2f}s")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss,
                'val_metrics': val_metrics,
                'global_step': global_step,
                'hparams': hparams
            }, checkpoint_path)
            print(f"✅ Best model saved at epoch {epoch} (val loss: {best_val_loss:.6f})")
            writer.add_text('checkpoints', f"New best model at epoch {epoch} with val_loss: {best_val_loss:.6f}")
        
        # Save checkpoint every 50 epochs
        if epoch % 50 == 0:
            checkpoint_file = f'checkpoints/simpl_ckpt_epoch_{epoch:04d}.pt'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_metrics': val_metrics,
                'global_step': global_step,
                'hparams': hparams
            }, checkpoint_file)
            print(f"🧪 Checkpoint saved at {checkpoint_file}")
            writer.add_text('checkpoints', f"Periodic checkpoint at epoch {epoch}")
        
        # Early stopping check - if no improvement for 100 epochs
        if epoch > start_epoch + 100 and val_loss > best_val_loss * 0.99:
            print(f"Early stopping triggered. No significant improvement for 100 epochs.")
            writer.add_text('training_info', f"Early stopping at epoch {epoch}")
            break

except KeyboardInterrupt:
    print("Training interrupted by user")
    writer.add_text('training_info', f"Training interrupted at epoch {epoch}")

finally:
    # Calculate total training time
    total_training_time = time.time() - training_start_time
    print(f"Total training time: {total_training_time:.2f}s")
    writer.add_text('training_info', f"Training completed/interrupted after {total_training_time:.2f}s")

# Generate predictions using best model
print("Generating predictions using best model...")
model.load_state_dict(torch.load(checkpoint_path, map_location=device)['model_state_dict'])
preds = predict(model, test_loader, device)

# Save predictions
np.savetxt(output_csv, preds.reshape(-1, 2), delimiter=',')
print(f"Predictions saved to {output_csv}")

# Close tensorboard writer
writer.close()

Starting training...
