# Exploring Convolutional Social Pooling Model

This notebook allows you to interactively explore the different components of the vehicle trajectory prediction model.

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import sys
import os

# Add parent directory to path to import our modules
sys.path.append(os.path.dirname(os.getcwd()))

from models.encoder_decoder import LSTMEncoder, LSTMDecoder, EncoderDecoder
from models.social_pooling import ConvolutionalSocialPooling, SpatialGrid
from models.trajectory_model import TrajectoryPredictionModel, MultiModalLoss
from data.ngsim_dataset import NGSIMDataset, get_ngsim_dataloaders, collate_fn

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


## 1. Load and Explore the Dataset

In [2]:
# Check for NGSIM data files
ngsim_files = [
    'data/raw/trajectories-0750am-0805am.txt',  # US-101
    'data/raw/trajectories-0400-0415.txt',       # I-80
    'data/raw/sample_trajectory.txt'             # Sample data
]

data_file = None
for file in ngsim_files:
    if os.path.exists(file):
        data_file = file
        print(f"Found data file: {file}")
        break

if data_file is None:
    # Create more comprehensive sample data for demonstration
    print("No NGSIM data found. Creating demonstration dataset...")
    os.makedirs('data/raw', exist_ok=True)
    data_file = 'data/raw/demo_trajectories.txt'
    
    import random
    random.seed(42)
    
    sample_lines = []
    vehicle_id = 1
    frame_offset = 0
    
    # Generate data for multiple vehicles with realistic highway behavior
    for scenario in range(5):  # 5 different traffic scenarios
        base_time = 1118847200.0 + scenario * 100
        
        # Generate 10-15 vehicles per scenario
        num_vehicles = random.randint(10, 15)
        
        for v in range(num_vehicles):
            # Each vehicle has 50-100 frames of data
            num_frames = random.randint(50, 100)
            
            # Initial position and velocity
            init_x = random.uniform(0, 50)
            init_y = random.choice([28.0, 35.0, 42.0])  # 3 lanes
            init_vel = random.uniform(25, 35)  # 25-35 ft/s
            
            # Behavior type: straight, lane change, acceleration
            behavior = random.choice(['straight', 'lane_change', 'accelerate'])
            
            for frame in range(num_frames):
                frame_id = frame_offset + frame + 1
                time = base_time + frame * 0.1
                
                if behavior == 'straight':
                    x = init_x + frame * init_vel * 0.1
                    y = init_y + random.gauss(0, 0.02)  # Small noise
                    vel = init_vel + random.gauss(0, 0.1)
                    acc = random.gauss(0, 0.2)
                
                elif behavior == 'lane_change':
                    x = init_x + frame * init_vel * 0.1
                    # Smooth lane change using sigmoid
                    if 20 <= frame <= 40:
                        t = (frame - 20) / 20
                        y = init_y + 7.0 * (3 * t**2 - 2 * t**3)  # Smooth transition
                    else:
                        y = init_y if frame < 20 else init_y + 7.0
                    vel = init_vel + random.gauss(0, 0.1)
                    acc = random.gauss(0, 0.2)
                
                else:  # accelerate
                    acc_rate = 0.5 if frame < 30 else -0.3
                    vel = init_vel + frame * acc_rate * 0.1
                    vel = max(20, min(40, vel))  # Clamp velocity
                    x = init_x + sum([(init_vel + i * acc_rate * 0.1) * 0.1 for i in range(frame)])
                    y = init_y + random.gauss(0, 0.02)
                    acc = acc_rate + random.gauss(0, 0.1)
                
                # Global coordinates (arbitrary offset)
                global_x = 6042300 + x
                global_y = 1873340 + y
                
                # Lane ID based on y position
                lane_id = 1 if y < 31.5 else (2 if y < 38.5 else 3)
                
                # Vehicle dimensions
                v_length = random.choice([14.9, 15.2, 16.0])
                v_width = random.choice([6.6, 6.8, 7.0])
                
                line = f"{vehicle_id},{frame_id},{num_frames},{time:.1f},{x:.3f},{y:.3f},"
                line += f"{global_x:.3f},{global_y:.3f},{v_length},{v_width},2,{vel:.2f},{acc:.2f},"
                line += f"{lane_id},0,0,0.0,0.0"
                sample_lines.append(line)
            
            vehicle_id += 1
        
        frame_offset += 100
    
    sample_data = '\n'.join(sample_lines)
    
    with open(data_file, 'w') as f:
        f.write(sample_data)
    
    print(f"Created demonstration dataset with {vehicle_id-1} vehicles")
    print(f"Total frames: {len(sample_lines)}")
    print(f"Saved to: {data_file}")

# Load dataset
print(f"\nLoading dataset from: {data_file}")
dataset = NGSIMDataset(
    data_path=data_file,
    hist_len=30,  # Standard history length
    pred_len=50,   # Standard prediction length
    skip=2,        # Skip every other frame for more samples
    train=True
)

print(f"Dataset size: {len(dataset)} samples")
print(f"History length: {dataset.hist_len} frames")
print(f"Prediction length: {dataset.pred_len} frames")

Found data file: data/raw/sample_trajectory.txt

Loading dataset from: data/raw/sample_trajectory.txt
Dataset size: 0 samples
History length: 30 frames
Prediction length: 50 frames


In [None]:
# Get a sample from the dataset
if len(dataset) > 0:
    sample = dataset[0]
    print("Sample keys:", sample.keys())
    print(f"History shape: {sample['hist'].shape}")
    print(f"Future shape: {sample['fut'].shape}")
    print(f"History velocity shape: {sample['hist_vel'].shape}")
    print(f"Number of neighbors: {len(sample['neighbors'])}")
    
    # Visualize the trajectory
    plt.figure(figsize=(10, 6))
    hist = sample['hist'].numpy()
    fut = sample['fut'].numpy()
    
    plt.plot(hist[:, 0], hist[:, 1], 'b-o', label='History', markersize=8)
    plt.plot(fut[:, 0], fut[:, 1], 'r-s', label='Future (Ground Truth)', markersize=8)
    plt.xlabel('X position (feet)')
    plt.ylabel('Y position (feet)')
    plt.title('Sample Vehicle Trajectory')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.axis('equal')
    plt.show()
else:
    print("No samples in dataset. Please provide a larger data file.")

## 2. Explore the LSTM Encoder

In [None]:
# Create and test the encoder
encoder = LSTMEncoder(input_dim=2, hidden_dim=128, num_layers=1).to(device)

# Create dummy input
batch_size = 4
seq_len = 30
dummy_hist = torch.randn(batch_size, seq_len, 2).to(device)

# Forward pass
encoder_output, (hidden, cell) = encoder(dummy_hist)

print(f"Input shape: {dummy_hist.shape}")
print(f"Encoder output shape: {encoder_output.shape}")
print(f"Hidden state shape: {hidden.shape}")
print(f"Cell state shape: {cell.shape}")

# Visualize encoder outputs
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Plot hidden state activations
axes[0].imshow(hidden[0].detach().cpu().numpy(), aspect='auto', cmap='coolwarm')
axes[0].set_title('Hidden State Activations')
axes[0].set_xlabel('Hidden Dimension')
axes[0].set_ylabel('Batch Sample')
axes[0].colorbar = plt.colorbar(axes[0].images[0], ax=axes[0])

# Plot temporal evolution of encoder output
sample_output = encoder_output[0].detach().cpu().numpy()
axes[1].imshow(sample_output.T, aspect='auto', cmap='viridis')
axes[1].set_title('Encoder Output Over Time (Sample 0)')
axes[1].set_xlabel('Time Step')
axes[1].set_ylabel('Hidden Dimension')
axes[1].colorbar = plt.colorbar(axes[1].images[0], ax=axes[1])

plt.tight_layout()
plt.show()

## 3. Explore the Convolutional Social Pooling Layer

In [None]:
# Create social pooling layer
social_pooling = ConvolutionalSocialPooling(
    encoder_dim=128,
    grid_size=(13, 3),
    soc_conv_depth=64,
    conv_3x1_depth=16,
    conv_1x1_depth=32
).to(device)

# Create spatial grid
spatial_grid = SpatialGrid(grid_size=(13, 3), grid_extent=(90.0, 21.0))

print("Social Pooling Configuration:")
print(f"  Grid size: {social_pooling.grid_size}")
print(f"  Grid extent: {spatial_grid.grid_extent} feet")
print(f"  Cell size: {spatial_grid.cell_size} feet")
print(f"  Output dimension: {social_pooling.fc_dim}")

In [None]:
# Visualize the spatial grid
fig, ax = plt.subplots(figsize=(10, 4))

# Draw grid
grid_extent = spatial_grid.grid_extent
grid_size = spatial_grid.grid_size

# Create grid lines
for i in range(grid_size[0] + 1):
    x = -grid_extent[0]/2 + i * spatial_grid.cell_size[0]
    ax.axvline(x, color='gray', linestyle='-', alpha=0.3)

for j in range(grid_size[1] + 1):
    y = -grid_extent[1]/2 + j * spatial_grid.cell_size[1]
    ax.axhline(y, color='gray', linestyle='-', alpha=0.3)

# Mark ego vehicle
ax.scatter(0, 0, color='red', s=200, marker='s', label='Ego Vehicle', zorder=5)

# Add some example neighbor vehicles
neighbor_positions = [
    (15, 0),    # Vehicle ahead in same lane
    (30, 0),    # Vehicle further ahead
    (10, 7),    # Vehicle in left lane
    (-15, 0),   # Vehicle behind
    (20, -7),   # Vehicle in right lane
]

for pos in neighbor_positions:
    ax.scatter(pos[0], pos[1], color='blue', s=150, marker='o', alpha=0.7)

ax.set_xlim(-grid_extent[0]/2, grid_extent[0]/2)
ax.set_ylim(-grid_extent[1]/2, grid_extent[1]/2)
ax.set_xlabel('Longitudinal Distance (feet)')
ax.set_ylabel('Lateral Distance (feet)')
ax.set_title('Spatial Grid for Social Pooling (13x3 cells)')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_aspect('equal')

plt.tight_layout()
plt.show()

In [None]:
# Test social pooling with dummy data
batch_size = 2
num_vehicles = 5

# Create dummy encoder outputs for all vehicles
dummy_encoder_outputs = torch.randn(num_vehicles, 128, 30).to(device)

# Create dummy neighbor indices (vehicle 0 has vehicles 1,2,3 as neighbors)
neighbor_indices = [
    [1, 2, 3],  # Neighbors for vehicle 0
    [0, 2, 4],  # Neighbors for vehicle 1
]

# Create dummy grid positions
dummy_positions = torch.tensor([
    [6, 1],   # Vehicle 0 position in grid
    [8, 1],   # Vehicle 1
    [10, 0],  # Vehicle 2
    [5, 2],   # Vehicle 3
    [3, 1],   # Vehicle 4
]).to(device)

# Forward pass through social pooling
social_features = social_pooling(
    dummy_encoder_outputs,
    neighbor_indices,
    dummy_positions
)

print(f"Social features shape: {social_features.shape}")
print(f"Social features stats:")
print(f"  Mean: {social_features.mean().item():.4f}")
print(f"  Std: {social_features.std().item():.4f}")
print(f"  Min: {social_features.min().item():.4f}")
print(f"  Max: {social_features.max().item():.4f}")

## 4. Explore the LSTM Decoder

In [None]:
# Create decoder
decoder = LSTMDecoder(
    output_dim=2,
    hidden_dim=128,
    num_layers=1,
    num_modes=6
).to(device)

# Use outputs from encoder
pred_len = 50
batch_size = 2

# Dummy social features and hidden states
dummy_social = torch.randn(batch_size, 64).to(device)
dummy_hidden = (torch.randn(1, batch_size, 128).to(device),
                torch.randn(1, batch_size, 128).to(device))

# Forward pass
predictions, mode_probs = decoder(dummy_social, dummy_hidden, pred_len)

print(f"Predictions shape: {predictions.shape}")
print(f"  (batch_size, pred_len, num_modes, output_dim)")
print(f"Mode probabilities shape: {mode_probs.shape}")
print(f"  (batch_size, pred_len, num_modes)")

# Visualize mode probabilities over time
sample_probs = mode_probs[0].detach().cpu().numpy()

plt.figure(figsize=(12, 4))
plt.imshow(sample_probs.T, aspect='auto', cmap='hot', interpolation='nearest')
plt.colorbar(label='Probability')
plt.xlabel('Time Step')
plt.ylabel('Mode')
plt.title('Mode Probabilities Over Prediction Horizon')
plt.yticks(range(6), [f'Mode {i+1}' for i in range(6)])
plt.show()

In [None]:
# Visualize multi-modal predictions
sample_pred = predictions[0].detach().cpu().numpy()
sample_probs = mode_probs[0].mean(dim=0).detach().cpu().numpy()

plt.figure(figsize=(10, 8))

colors = plt.cm.rainbow(np.linspace(0, 1, 6))

for mode in range(6):
    trajectory = sample_pred[:, mode, :]
    plt.plot(trajectory[:, 0], trajectory[:, 1], 
             color=colors[mode], linewidth=2,
             label=f'Mode {mode+1} (p={sample_probs[mode]:.3f})',
             alpha=0.5 + 0.5 * sample_probs[mode])
    plt.scatter(trajectory[-1, 0], trajectory[-1, 1],
                color=colors[mode], s=100, marker='*')

plt.scatter(0, 0, color='red', s=200, marker='s', label='Start Position', zorder=5)
plt.xlabel('X position (feet)')
plt.ylabel('Y position (feet)')
plt.title('Multi-Modal Trajectory Predictions')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()

## 5. Complete Model Pipeline

In [None]:
# Create complete model
model = TrajectoryPredictionModel(
    input_dim=2,
    output_dim=2,
    encoder_dim=128,
    decoder_dim=128,
    num_layers=1,
    num_modes=6,
    grid_size=(13, 3),
    soc_conv_depth=64
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Print model architecture
print("\nModel Architecture:")
for name, module in model.named_children():
    print(f"  {name}: {module.__class__.__name__}")

In [None]:
# Test with a batch from dataloader
from torch.utils.data import DataLoader

if len(dataset) > 0:
    dataloader = DataLoader(
        dataset,
        batch_size=min(4, len(dataset)),
        shuffle=False,
        collate_fn=collate_fn
    )
    
    # Get a batch
    batch = next(iter(dataloader))
    
    hist = batch['hist'].to(device)
    fut = batch['fut'].to(device)
    neighbors = batch['neighbors']
    
    print(f"Batch history shape: {hist.shape}")
    print(f"Batch future shape: {fut.shape}")
    
    # Forward pass
    with torch.no_grad():
        predictions, mode_probs = model(hist, neighbors, pred_len=fut.size(1))
    
    print(f"\nModel output:")
    print(f"  Predictions shape: {predictions.shape}")
    print(f"  Mode probabilities shape: {mode_probs.shape}")
else:
    print("Not enough data for testing. Please provide a larger dataset.")

## 6. Loss Function Analysis

In [None]:
# Create loss function
criterion = MultiModalLoss(num_modes=6, regression_loss_weight=5.0)

if len(dataset) > 0 and 'predictions' in locals():
    # Calculate loss
    loss, metrics = criterion(predictions, mode_probs, fut)
    
    print("Loss Components:")
    print(f"  Total Loss: {loss.item():.4f}")
    print(f"  Classification Loss: {metrics['classification_loss']:.4f}")
    print(f"  Regression Loss: {metrics['regression_loss']:.4f}")
    print(f"\nMetrics:")
    print(f"  Min ADE: {metrics['min_ade']:.4f} feet")
    print(f"  Min FDE: {metrics['min_fde']:.4f} feet")
else:
    print("Please run the previous cells to generate predictions first.")

## 7. Gradient Flow Visualization

In [None]:
# Perform a backward pass to visualize gradients
if len(dataset) > 0 and 'loss' in locals():
    # Clear previous gradients
    model.zero_grad()
    
    # Backward pass
    loss.backward()
    
    # Collect gradient statistics
    grad_stats = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_stats[name] = {
                'mean': param.grad.mean().item(),
                'std': param.grad.std().item(),
                'max': param.grad.max().item(),
                'min': param.grad.min().item()
            }
    
    # Plot gradient magnitudes
    fig, ax = plt.subplots(figsize=(12, 6))
    
    names = list(grad_stats.keys())
    means = [abs(grad_stats[n]['mean']) for n in names]
    
    bars = ax.bar(range(len(names)), means)
    ax.set_xticks(range(len(names)))
    ax.set_xticklabels(names, rotation=45, ha='right')
    ax.set_ylabel('Absolute Mean Gradient')
    ax.set_title('Gradient Magnitudes Across Model Parameters')
    ax.set_yscale('log')
    
    plt.tight_layout()
    plt.show()
    
    print("Parameters with largest gradients:")
    sorted_grads = sorted(grad_stats.items(), key=lambda x: abs(x[1]['mean']), reverse=True)
    for name, stats in sorted_grads[:5]:
        print(f"  {name}: mean={stats['mean']:.6f}, std={stats['std']:.6f}")
else:
    print("Please run the previous cells to calculate loss first.")

## 8. Interactive Trajectory Prediction

In [None]:
def predict_trajectory(model, history_trajectory, neighbor_positions=None):
    """
    Make predictions for a custom trajectory
    
    Args:
        model: Trained model
        history_trajectory: numpy array of shape (hist_len, 2)
        neighbor_positions: list of (x, y) positions for neighbor vehicles
    """
    # Convert to tensor
    hist = torch.FloatTensor(history_trajectory).unsqueeze(0).to(device)
    
    # Create dummy neighbors
    if neighbor_positions is None:
        neighbors = [[]]
    else:
        neighbors = [[{'relative_pos': pos, 'velocity': 30.0} for pos in neighbor_positions]]
    
    # Predict
    with torch.no_grad():
        predictions, mode_probs = model(hist, neighbors, pred_len=50)
    
    return predictions[0].cpu().numpy(), mode_probs[0].cpu().numpy()

# Create a sample trajectory
t = np.linspace(0, 2, 30)
x = t * 15  # Moving forward at 15 ft/s
y = np.sin(t * 2) * 2  # Slight lane change motion
history = np.stack([x, y], axis=1)

# Predict future
pred, probs = predict_trajectory(model, history)

# Visualize
plt.figure(figsize=(12, 6))
plt.plot(history[:, 0], history[:, 1], 'b-o', linewidth=2, label='History', markersize=4)

colors = plt.cm.rainbow(np.linspace(0, 1, 6))
avg_probs = probs.mean(axis=0)

for mode in range(6):
    traj = pred[:, mode, :]
    plt.plot(traj[:, 0] + history[-1, 0], traj[:, 1] + history[-1, 1],
             color=colors[mode], alpha=0.3 + 0.7 * avg_probs[mode],
             linewidth=1.5, label=f'Mode {mode+1} (p={avg_probs[mode]:.3f})')

plt.xlabel('X position (feet)')
plt.ylabel('Y position (feet)')
plt.title('Interactive Trajectory Prediction')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()

## 9. Model Parameter Distribution

In [None]:
# Analyze parameter distributions
param_stats = {}
for name, param in model.named_parameters():
    param_np = param.detach().cpu().numpy().flatten()
    param_stats[name] = {
        'mean': param_np.mean(),
        'std': param_np.std(),
        'min': param_np.min(),
        'max': param_np.max(),
        'values': param_np
    }

# Plot histograms for select layers
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

select_params = [
    'encoder_decoder.encoder.lstm.weight_ih_l0',
    'encoder_decoder.decoder.lstm.weight_ih_l0',
    'social_pooling.conv_3x1.weight',
    'social_pooling.conv_1x1.weight',
    'social_pooling.fc.weight',
    'encoder_decoder.decoder.output_layer.weight'
]

for idx, param_name in enumerate(select_params):
    if param_name in param_stats:
        values = param_stats[param_name]['values']
        axes[idx].hist(values, bins=50, alpha=0.7, color='blue', edgecolor='black')
        axes[idx].set_title(param_name.split('.')[-2] + '.' + param_name.split('.')[-1])
        axes[idx].set_xlabel('Parameter Value')
        axes[idx].set_ylabel('Count')
        axes[idx].axvline(values.mean(), color='red', linestyle='--', label=f'Mean: {values.mean():.3f}')
        axes[idx].legend()

plt.suptitle('Parameter Distributions Across Model Layers', fontsize=14)
plt.tight_layout()
plt.show()

## 10. Memory and Computation Analysis

In [None]:
# Analyze model memory usage
def get_model_size(model):
    param_size = 0
    buffer_size = 0
    
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_mb = (param_size + buffer_size) / 1024 / 1024
    return size_mb

model_size = get_model_size(model)
print(f"Model size: {model_size:.2f} MB")

# Analyze layer-wise parameter counts
layer_params = {}
for name, module in model.named_modules():
    if len(list(module.children())) == 0:  # Leaf module
        params = sum(p.numel() for p in module.parameters())
        if params > 0:
            layer_params[name] = params

# Sort by parameter count
sorted_layers = sorted(layer_params.items(), key=lambda x: x[1], reverse=True)

print("\nTop 10 layers by parameter count:")
for name, count in sorted_layers[:10]:
    print(f"  {name}: {count:,} parameters")

# Visualize parameter distribution
fig, ax = plt.subplots(figsize=(10, 6))
names = [n.split('.')[-1] for n, _ in sorted_layers[:15]]
counts = [c for _, c in sorted_layers[:15]]

bars = ax.bar(range(len(names)), counts)
ax.set_xticks(range(len(names)))
ax.set_xticklabels(names, rotation=45, ha='right')
ax.set_ylabel('Number of Parameters')
ax.set_title('Parameter Count by Layer (Top 15)')
ax.set_yscale('log')

plt.tight_layout()
plt.show()