In [25]:
import os
import re
import glob
import math
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

In [26]:
# translation and rotation invariance

def align_future(future, center, theta):
    xy = future - center
    c, s = np.cos(-theta), np.sin(-theta)
    x_new = xy[...,0]*c - xy[...,1]*s
    y_new = xy[...,0]*s + xy[...,1]*c
    return np.stack([x_new, y_new], axis=-1)

# converts to a relative coordinate system, so model can focus on the patterns
def invariance_transform(past, accel_dt=None):
    """
    Convert a scene to an ego-centric, velocity-aligned frame.

    Args
    ----
    past : (A, T, 6) float array
        [:,:,0:2] = x, y
        [:,:,2:4] = vx, vy
        [:,:,4]   = heading  (rad)
        [:,:,5]   = type_id  (int)
    accel_dt : float or None
        Sampling period in seconds.  If given, an acceleration
        channel is added (Δv / Δt) so output has 8 features.
        If None, acceleration is omitted and output has 6 features.

    Returns
    -------
    aligned : (A, T, 7 or 9) float array
    center  : (2,)            ego’s last (x,y) in world frame
    theta   : float           ego’s last heading (rad)
    """
    A, T, F = past.shape
    assert F == 6, f"expected feat_dim = 6, got {F}"

    pos     = past[..., 0:2]          # (A,T,2)
    vel     = past[..., 2:4]          # (A,T,2)
    heading = past[..., 4]            # (A,T)
    obj_id  = past[..., 5].astype(int)  # keep as float for stacking

    # --- translate so ego’s last position is origin ------------------
    center = pos[0, -1].copy()        # (2,)
    pos_t  = pos - center             # broadcasting (A,T,2) - (2,)

    # --- rotate so ego’s last heading is +X --------------------------
    theta = heading[0, -1]            # scalar
    c, s  = np.cos(-theta), np.sin(-theta)

    # Rotate vectors
    R = np.array([[c, -s],
                  [s,  c]])           # 2×2
    pos_r = pos_t @ R.T               # (A,T,2)
    vel_r = vel   @ R.T               # (A,T,2)

    # --- optional acceleration --------------------------------------
    if accel_dt is not None:
        inv_dt = 1.0 / accel_dt
        acc_r = np.zeros_like(vel_r)
        acc_r[:, 1:] = (vel_r[:, 1:] - vel_r[:, :-1]) * inv_dt
        features = 9
    else:
        features = 7

    # --- relative heading -------------------------------------------
    heading_rel = ((heading - theta + np.pi) % (2*np.pi)) - np.pi  # (A,T)
    heading_cos = np.cos(heading_rel)  # (A,T)
    heading_sin = np.sin(heading_rel)  # (A,T)
    heading_vec = np.stack([heading_cos, heading_sin], axis=-1)  # (A,T,2)

    # --- stack output -----------------------------------------------
    aligned = np.zeros((A, T, features), dtype=past.dtype)
    aligned[..., 0:2] = pos_r
    aligned[..., 2:4] = vel_r
    if accel_dt is not None:
        aligned[..., 4:6] = acc_r
        aligned[..., 6:8]   = heading_vec
        aligned[..., 8]   = obj_id
    else:
        aligned[..., 4:6]   = heading_vec
        aligned[..., 6]   = obj_id

    return aligned, center, theta

# batch inverse transform for use outside in prediction, we only care about position inverse rotation and translation
def inverse_transform(pred, centers, thetas):
    """
    Bring aligned predictions back to world coordinates.

    Args
    ----
    pred    : (..., T, 2)  aligned positions
    centers : (..., 2)     translation(s) subtracted in forward pass
    thetas  : (...)        rotation angle(s) (rad), same leading dims as centers

    Returns
    -------
    world : (..., T, 2)  positions in world frame
    """
    pred    = np.asarray(pred)
    centers = np.asarray(centers)
    thetas  = np.asarray(thetas)

    # Bring everything to shape (..., T, 2)
    # Allow leading batch dims of arbitrary rank
    # Broadcasting handles scalars automatically.
    c = np.cos(thetas)[..., None]      # (..., 1)
    s = np.sin(thetas)[..., None]      # (..., 1)

    # Rotate back
    x, y = pred[..., 0], pred[..., 1]
    x_w  = x * c - y * s
    y_w  = x * s + y * c
    world = np.stack([x_w, y_w], axis=-1)  # (..., T, 2)

    # Translate back
    world += centers[..., None, :]         # broadcast center over T

    return world

In [27]:
## test
A, T = 5, 10
past = np.random.randn(A, T, 6).astype(np.float32)
past[..., 5] = np.random.randint(0, 4, size=(A, T))  # random type IDs

aligned, center, theta = invariance_transform(past, accel_dt=0.1)
pred_world = inverse_transform(aligned[0, :, :2], center, theta)  # ego track back to world

assert np.allclose(pred_world, past[0, :, :2], atol=1e-5)
print("✓ forward + inverse round-trip OK")

✓ forward + inverse round-trip OK


In [28]:
def collate_fn(batch):
    l = len(batch[0])
    if l == 3:
        # train: (past,mask,future)
        pasts, masks, futures = zip(*batch)
        return (
            torch.stack(pasts),
            torch.stack(masks),
            torch.stack(futures),
        )
    elif l == 4:
        # test: (past,mask,center,theta)
        pasts, masks, centers, thetas = zip(*batch)
        return (
            torch.stack(pasts),
            torch.stack(masks),
            torch.stack(centers),      # shape (B,2)
            torch.stack(thetas),       # shape (B,)
        )
    else:
        raise ValueError(f"Unrecognized sample of length {l}")
    

class TrajectoryDataset(Dataset):
    def __init__(self, input_path=None, data=None, T_past=50, T_future=60, is_test=False):
        if data is not None:
            self.data = data
        else:
            npz = np.load(input_path)
            self.data = npz['data']
        self.T_past = T_past
        self.T_future = T_future
        self.is_test = is_test

        # Calculate normalization statistics from the past data
        self.calculate_normalization_stats()

        
    def calculate_normalization_stats(self):
        """Calculate mean and std for efficient normalization"""
        #align past data
        all_pos = []
        all_vel = []
        for scene in self.data:
            past = scene[:, :self.T_past, :].copy()
            past_aligned, _, _ = invariance_transform(past)

            # collect positions & velocities across all agents & all time-steps
            all_pos.append(past_aligned[..., :2].reshape(-1, 2))
            all_vel.append(past_aligned[..., 2:4].reshape(-1, 2))

        all_pos = np.concatenate(all_pos, axis=0)
        all_vel = np.concatenate(all_vel, axis=0)

        # # now compute statistics on the aligned data
        # self.pos_mean = all_pos.mean(axis=0)
        # self.pos_std  = np.maximum(all_pos.std(axis=0), 1e-6)
        # self.vel_mean = all_vel.mean(axis=0)
        # self.vel_std  = np.maximum(all_vel.std(axis=0), 1e-6)

        # Only consider non-zero values for position and velocity
        #positions = self.data[..., :2]  # x, y positions
        mask = np.abs(all_pos).sum(axis=-1) > 0
        
        if mask.sum() > 0:
            valid_positions = all_pos[mask]
            self.pos_mean = valid_positions.mean(axis=0)
            self.pos_std = valid_positions.std(axis=0)
            
            # Ensure std is not zero to avoid division by zero
            self.pos_std = np.maximum(self.pos_std, 1e-6)
        else:
            self.pos_mean = np.zeros(2)
            self.pos_std = np.ones(2)
            
        # Same for velocities
        #velocities = self.data[..., 2:4]  # vx, vy velocities
        mask = np.abs(all_vel).sum(axis=-1) > 0
        
        if mask.sum() > 0:
            valid_velocities = all_vel[mask]
            self.vel_mean = valid_velocities.mean(axis=0)
            self.vel_std = valid_velocities.std(axis=0)
            self.vel_std = np.maximum(self.vel_std, 1e-6)
        else:
            self.vel_mean = np.zeros(2)
            self.vel_std = np.ones(2)
            
    def normalize_features(self, features):
        """Normalize features efficiently"""
        normalized = features.copy()
        # Normalize positions (x, y)
        normalized[..., 0:2] = (features[..., 0:2] - self.pos_mean) / self.pos_std
        # Normalize velocities (vx, vy)
        normalized[..., 2:4] = (features[..., 2:4] - self.vel_mean) / self.vel_std
        # Normalize acceleration (ax, ay) # it's not present RN
        # if features.shape[-1] >= 6:  # If acceleration is present
        #     normalized[..., 4:6] = (features[..., 4:6] - self.vel_mean) / self.vel_std
        return normalized
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]  # (num_agents, T, features) per scene calculations
        
        # Extract past trajectory
        past = scene[:, :self.T_past, :].copy()  # (num_agents, T_past, features)

        
        #TODO not only is overwriting heading and object features, but also its calculating acceleration based on velocity we are about to transform

        # # If acceleration is not already in the data, calculate it
        # if past.shape[-1] < 6:
        #     # Original features are likely x, y, vx, vy
        #     # Calculate acceleration from velocity if not already present
        #     num_agents, T, _ = past.shape
            
        #     # Create a new array with extra space for acceleration
        #     past_with_accel = np.zeros((num_agents, T, 6))
            
        #     # Copy existing features
        #     past_with_accel[:, :, :past.shape[-1]] = past
            
        #     # If velocity exists, calculate acceleration as the derivative of velocity
        #     if past.shape[-1] >= 4:  # If we have velocity
        #         # Calculate acceleration (dvx, dvy) by differentiating velocity
        #         accel = np.zeros((num_agents, T, 2))
        #         accel[:, 1:, :] = past[:, 1:, 2:4] - past[:, :-1, 2:4]  # Simple finite difference
        #         past_with_accel[:, :, 4:6] = accel
            
        #     past = past_with_accel

        # Apply translation + rotation invariance per scene 
        # (shifts ego → origin & rotates so ego’s heading is +x)
        # note acceleration is included in the past data this time
        past_aligned, center, theta = invariance_transform(past)
        
        # Normalize features
        past_aligned_normalized = self.normalize_features(past_aligned)
        
        # Create mask for valid agents (based on position)
        mask = np.sum(np.abs(past[:, :, :2]), axis=(1, 2)) > 0
        
        # For training data, also extract and normalize future trajectory of ego vehicle
        if not self.is_test and scene.shape[1] >= self.T_past + self.T_future:
            future_raw = scene[0, self.T_past:self.T_past+self.T_future, :2]  # Ego vehicle future (x, y)
            # align future ego to the same reference frame, then normalize
            future_aligned = align_future(future_raw, center, theta)
            future_aligned_normalized = (future_aligned - self.pos_mean) / self.pos_std
            
            return (
                torch.tensor(past_aligned_normalized, dtype=torch.float32),
                torch.tensor(mask, dtype=torch.bool),
                torch.tensor(future_aligned_normalized, dtype=torch.float32)
            )
        
        # For test data, only return aligned and normalized past
        return (
            torch.tensor(past_aligned_normalized, dtype=torch.float32),
            torch.tensor(mask, dtype=torch.bool),
            torch.tensor(center, dtype=torch.float32),
            torch.tensor(theta, dtype=torch.float32)
        )
    
    def denormalize_prediction(self, prediction):
        """Convert normalized predictions back to original scale"""
        return prediction * self.pos_std + self.pos_mean


In [29]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class AgentTypeEmbedding(nn.Module):
    def __init__(self, num_types=10, d_model=128):
        super().__init__()
        self.embedding = nn.Embedding(num_types, d_model)
        
    def forward(self, x):
        # Use default type if type information is not available
        if x.shape[-1] == 6:  # If we only have x, y, vx, vy, ax, ay, heading, object type
            # Create default type tensor (all zeros)
            obj_type = torch.zeros(x.shape[:-1], dtype=torch.long, device=x.device)
        else:
            obj_type = x[..., -1].long()
        return self.embedding(obj_type)

class ImprovedTrajectoryTransformer(nn.Module):
    def __init__(self, feature_dim=6, d_model=256, nhead=8,
                 num_layers=4, dim_feedforward=512, 
                 T_past=50, T_future=60, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.T_past = T_past
        self.T_future = T_future
        
        # Feature embedding for positions, velocities, accelerations
        self.feature_embed = nn.Linear(feature_dim, d_model)
        
        # Object type embedding
        self.type_embedding = AgentTypeEmbedding(num_types=10, d_model=d_model)
        
        # Positional encoding for timesteps
        self.pos_encoding = PositionalEncoding(d_model)
        
        # Layer normalization
        self.norm = nn.LayerNorm(d_model)
        
        # Transformer encoder for temporal relations
        temporal_encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward, 
            dropout=dropout,
            batch_first=False
        )
        self.temporal_encoder = nn.TransformerEncoder(
            temporal_encoder_layer, 
            num_layers=num_layers//2
        )
        
        # Transformer encoder for social relations
        social_encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward, 
            dropout=dropout,
            batch_first=False
        )
        self.social_encoder = nn.TransformerEncoder(
            social_encoder_layer, 
            num_layers=num_layers//2
        )
        
        # Output MLP
        self.prediction_head = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, dim_feedforward // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward // 2, 2 * T_future)
        )
        
    def forward(self, past, agent_mask):
        B, N, T, F = past.shape  # Batch, Num_agents, Time, Features

        assert F >= 7, f"Expected at least 7 features, got {F}"
        
        # Embed all features directly
        features_flat = past.reshape(B * N * T, F)
        feature_embedding = self.feature_embed(features_flat) #project to higher space
        feature_embedding = feature_embedding.reshape(B, N, T, self.d_model)
        
        # Get object type embedding
        type_embedding = self.type_embedding(past)  # B, N, T, d_model
        
        # Combine embeddings
        combined_embedding = feature_embedding + type_embedding
        
        # Reshape for temporal transformer: (T, B*N, d_model)
        temporal_input = combined_embedding.permute(2, 0, 1, 3).reshape(T, B*N, self.d_model)
        
        # Add positional encoding
        temporal_input = self.pos_encoding(temporal_input)
        
        # Apply temporal transformer
        temporal_output = self.temporal_encoder(temporal_input)
        
        # Get the last temporal state for each agent
        agent_features = temporal_output[-1].reshape(B, N, self.d_model)  # B, N, d_model
        
        # Make sure there's at least one valid agent per batch
        if (~agent_mask).all(dim=1).any():
            fallback_mask = agent_mask.clone()
            fallback_mask[:, 0] = True  # At least use ego vehicle
            agent_mask = torch.where(agent_mask.sum(dim=1, keepdim=True) == 0, fallback_mask, agent_mask)
        
        # Prepare for social transformer: (N, B, d_model)
        social_input = agent_features.permute(1, 0, 2) #want agent features back in the first dim, not time
        
        # Apply social transformer with masking
        social_output = self.social_encoder(social_input, src_key_padding_mask=~agent_mask)
        
        # Extract ego vehicle embedding
        ego_embedding = social_output[0]  # B, d_model
        
        # Apply prediction head
        trajectory_flat = self.prediction_head(ego_embedding)  # B, 2*T_future
        
        # Reshape to (Batch, Time, XY)
        predictions = trajectory_flat.reshape(B, self.T_future, 2)
        
        return predictions

In [42]:
def train_epoch(model, dataloader, optimizer, device, clip_grad=.3):
    model.train()
    total_loss = 0.0
    criterion = nn.SmoothL1Loss()
    
    for batch in dataloader:
        past, mask, future = [x.to(device) for x in batch]
        
        optimizer.zero_grad()
        pred = model(past, mask)
        
        loss = criterion(pred, future)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
        
        optimizer.step()
        total_loss += loss.item() * past.size(0)
    
    return total_loss / len(dataloader.dataset)

def evaluate(model, val_loader, device):
    model.eval()
    total_loss = 0.0
    criterion = nn.MSELoss()
    
    with torch.no_grad():
        for batch in val_loader:
            past, mask, future = [x.to(device) for x in batch]
            pred = model(past, mask)
            loss = criterion(pred, future)
            total_loss += loss.item() * past.size(0)
    
    return total_loss / len(val_loader.dataset)

@torch.inference_mode() 
def  predict(model, test_loader, test_dataset, device):
    model.eval()
    model.to(device)
    all_preds = []
    
    for past, mask, centers, thetas in test_loader:
        past = past.to(device, non_blocking=True).float()
        mask = mask.to(device, non_blocking=True)
        
        #predict in normalized aligned space
        pred_norm = model(past, mask).cpu().numpy()

        #undo normalization (still aligned)
        pred_aligned = test_dataset.denormalize_prediction(pred_norm)

        #undo relative alignment -> output world coords
        pred_world = inverse_transform(
            pred_aligned, 
            centers.numpy(), 
            thetas.numpy()
        )
        
        all_preds.append(pred_world)
    
    return np.concatenate(all_preds, axis=0)

def get_latest_checkpoint(folder):
    files = glob.glob(os.path.join(folder, "ckpt_epoch_*.pt"))
    if not files:
        return None
    return max(files, key=lambda f: int(re.findall(r"ckpt_epoch_(\d+)", f)[0]))

In [35]:
train_input = 'data/train.npz'
test_input = 'data/test_input.npz'
output_csv = 'predictions.csv'
checkpoint_path = 'best_model.pt'
checkpoints_dir = 'checkpoints'

# Hyperparameters
batch_size = 64
lr = 1e-4
weight_decay = 1e-5
epochs = 1000
patience = 15  # Early stopping patience
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')

print(f"Using device: {device}")

Using device: cuda


In [36]:

# Load data
full_data = np.load(train_input)['data']

# Split into train and validation (7:3)
num_samples = len(full_data)
num_train = int(0.8 * num_samples)
perm = np.random.permutation(num_samples)
train_idx = perm[:num_train]
val_idx = perm[num_train:]

train_data = full_data[train_idx]
val_data = full_data[val_idx]

# Create datasets with normalization
train_ds = TrajectoryDataset(data=train_data)
val_ds = TrajectoryDataset(data=val_data)

# Create test dataset using the same normalization stats as training
test_ds = TrajectoryDataset(input_path=test_input, is_test=True)
# Copy normalization stats from train_ds
test_ds.pos_mean = train_ds.pos_mean
test_ds.pos_std = train_ds.pos_std
test_ds.vel_mean = train_ds.vel_mean
test_ds.vel_std = train_ds.vel_std

# Create data loaders
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [37]:
b = next(iter(train_loader))

print("First sample:", b[0])

First sample: tensor([[[[-1.4462e-02, -1.6054e-04,  1.3743e+00,  ...,  9.9999e-01,
           -5.2109e-03,  0.0000e+00],
          [-1.4305e-02, -1.6142e-04,  1.3743e+00,  ...,  9.9999e-01,
           -5.2162e-03,  0.0000e+00],
          [-1.4113e-02, -1.6244e-04,  1.3791e+00,  ...,  9.9999e-01,
           -5.2020e-03,  0.0000e+00],
          ...,
          [ 5.8963e-04, -1.3641e-04,  1.9003e+00,  ...,  1.0000e+00,
            6.1093e-04,  0.0000e+00],
          [ 9.6926e-04, -1.3563e-04,  1.9484e+00,  ...,  1.0000e+00,
            3.1902e-04,  0.0000e+00],
          [ 1.3501e-03, -1.3493e-04,  1.9449e+00,  ...,  1.0000e+00,
            0.0000e+00,  0.0000e+00]],

         [[ 1.2238e-02,  2.4020e-03, -1.5797e+00,  ..., -9.9999e-01,
            3.4182e-03,  4.0000e+00],
          [ 1.2084e-02,  2.3963e-03, -1.5661e+00,  ..., -1.0000e+00,
            2.6245e-03,  4.0000e+00],
          [ 1.1900e-02,  2.3905e-03, -1.5532e+00,  ..., -1.0000e+00,
            1.8207e-03,  4.0000e+00],
      

In [38]:
# Create model, optimizer, and scheduler
model = ImprovedTrajectoryTransformer(feature_dim=7, dropout=.3).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
warm_up_epochs = 5
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs - warm_up_epochs, eta_min=1e-6)
warm_up_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: (epoch + 1) / warm_up_epochs if epoch < warm_up_epochs else 1)
os.makedirs(checkpoints_dir, exist_ok=True)

# Training setup
start_epoch = 1
best_val_loss = float('inf')
no_improve_epochs = 0

# Try to load checkpoint
latest_ckpt = get_latest_checkpoint(checkpoints_dir)
if latest_ckpt:
    print(f"Loading checkpoint: {latest_ckpt}")
    ckpt = torch.load(latest_ckpt, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    start_epoch = ckpt['epoch'] + 1
    best_val_loss = ckpt.get('val_loss', float('inf'))
    print(f"✅ Resumed from epoch {start_epoch - 1} with val_loss={best_val_loss:.6f}")

writer = SummaryWriter(log_dir="runs/exp1")   # creates runs/exp1/*

# Training loop
print(f"Starting training from epoch {start_epoch}")
for epoch in range(start_epoch, epochs + 1):
    # Train for one epoch
    train_loss = train_epoch(model, train_loader, optimizer, device)
    
    # Evaluate on validation set
    val_loss = evaluate(model, val_loader, device)
    
    # Update learning rate
    if epoch <= warm_up_epochs:
        warm_up_scheduler.step()
    else:
        scheduler.step()
    
    # Print progress
    print(f"Epoch {epoch}/{epochs} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")
    writer.add_scalar("Loss/Train", train_loss, epoch)
    writer.add_scalar("Loss/Val",   val_loss,   epoch)
    
    # Save best model
    if val_loss <= best_val_loss and epoch > warm_up_epochs:
        best_val_loss = val_loss
        no_improve_epochs = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': best_val_loss
        }, checkpoint_path)
        print(f"✅ Best model saved at epoch {epoch} (val loss: {best_val_loss:.6f})")
    else:
        no_improve_epochs += 1
    
    # Save periodic checkpoint
    if epoch % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss
        }, f'{checkpoints_dir}/ckpt_epoch_{epoch:04d}.pt')
        print(f"🧪 Checkpoint saved at {checkpoints_dir}/ckpt_epoch_{epoch:04d}.pt")
    
    # Early stopping
    # if no_improve_epochs >= patience:
    #     print(f"Early stopping triggered after {epoch} epochs")
    #     break
writer.close()



Starting training from epoch 1
Epoch 1/1000 | Train Loss: 0.000842 | Val Loss: 0.000570
Epoch 2/1000 | Train Loss: 0.000284 | Val Loss: 0.000513
Epoch 3/1000 | Train Loss: 0.000182 | Val Loss: 0.000488
Epoch 4/1000 | Train Loss: 0.000131 | Val Loss: 0.000501
Epoch 5/1000 | Train Loss: 0.000095 | Val Loss: 0.000493
Epoch 6/1000 | Train Loss: 0.000071 | Val Loss: 0.000477
✅ Best model saved at epoch 6 (val loss: 0.000477)
Epoch 7/1000 | Train Loss: 0.000055 | Val Loss: 0.000473
✅ Best model saved at epoch 7 (val loss: 0.000473)
Epoch 8/1000 | Train Loss: 0.000045 | Val Loss: 0.000470
✅ Best model saved at epoch 8 (val loss: 0.000470)
Epoch 9/1000 | Train Loss: 0.000036 | Val Loss: 0.000480
Epoch 10/1000 | Train Loss: 0.000029 | Val Loss: 0.000480
🧪 Checkpoint saved at checkpoints/ckpt_epoch_0010.pt
Epoch 11/1000 | Train Loss: 0.000024 | Val Loss: 0.000472
Epoch 12/1000 | Train Loss: 0.000019 | Val Loss: 0.000479
Epoch 13/1000 | Train Loss: 0.000015 | Val Loss: 0.000487
Epoch 14/1000 | Tr

KeyboardInterrupt: 

In [44]:
# Load best model for prediction
print("Loading best model for prediction...")
model.load_state_dict(torch.load(checkpoint_path, map_location=device)['model_state_dict'])

# Generate predictions
print("Generating predictions...")
preds = predict(model, test_loader, test_ds, device)

# Flatten predictions to match submission format (2100*60, 2)
preds_flat = preds.reshape(-1, 2)

# Create ID column (required for submission)
ids = np.arange(len(preds_flat))

# Save predictions to CSV
output = np.column_stack((ids, preds_flat))
header = "index,x,y"
np.savetxt(output_csv, output, delimiter=',', header=header, comments='', fmt=['%d', '%.6f', '%.6f'])
print(f"Predictions saved to {output_csv}")

Loading best model for prediction...
Generating predictions...
Predictions saved to predictions.csv


In [68]:

# assume: TrajectoryDataset, invariance_transform, inverse_transform,
#         align_future, collate_fn, etc. are already defined above

ds = TrajectoryDataset(input_path='data/train.npz', is_test=False)

# pick one scene
idx = 0
scene = ds.data[idx]

# 1) PAST reconstruction test
# — slice ALL features, not just :2
raw_past_feats = scene[:, :ds.T_past, :].copy()           # (A, T_past, F>=4)
aligned_feats, center, theta = invariance_transform(raw_past_feats) #one aligned scene

# now inverse only the XY channels
recon_xy = inverse_transform(aligned_feats[..., :2], center, theta)
err_past = np.max(np.abs(recon_xy - raw_past_feats[..., :2]))
print(f'Past XY reconstruction max‐error: {err_past:.3e}')

if scene.shape[1] >= ds.T_past + ds.T_future:
    raw_fut = scene[0, ds.T_past:ds.T_past+ds.T_future, :2].copy()

    # 1) align → inverse (should reconstruct raw_fut up to numerical noise)
    fut_aln   = align_future(raw_fut, center, theta)
    recon_if  = inverse_transform(fut_aln, center, theta)
    err_if    = np.max(np.abs(recon_if - raw_fut))
    print(f'Future invariance max‐error: {err_if:.3e}')

    # 2) normalization → denormalization (should also be tiny)
    fut_norm  = (fut_aln - ds.pos_mean) / ds.pos_std
    fut_den   = ds.denormalize_prediction(fut_norm)
    err_scale = np.max(np.abs(fut_den - fut_aln))
    print(f'Future norm/denorm max‐error: {err_scale:.3e}')

# 3) BATCH shapes test
ds_test = TrajectoryDataset(input_path='data/train.npz', is_test=True)
loader  = DataLoader(ds_test, batch_size=8, collate_fn=collate_fn)
past_b, mask_b, centers_b, thetas_b = next(iter(loader))
print('Batch shapes:')
print('  past:',     past_b.shape)       # (8, A, T_past, F)
print('  mask:',     mask_b.shape)       # (8, A)
print('  centers:',  centers_b.shape)    # (8, 2)
print('  thetas:',   thetas_b.shape)     # (8,)

# quick pipeline shape‐check
pred_norm = np.random.randn(8, ds.T_future, 2)
pred_aln  = ds.denormalize_prediction(pred_norm)
pred_w    = inverse_transform(pred_aln, centers_b.numpy(), thetas_b.numpy())
print('  pred_world:', pred_w.shape)    # (8, T_future, 2)

Past XY reconstruction max‐error: 2.842e-14
Future invariance max‐error: 0.000e+00
Future norm/denorm max‐error: 7.105e-15
Batch shapes:
  past: torch.Size([8, 50, 50, 7])
  mask: torch.Size([8, 50])
  centers: torch.Size([8, 2])
  thetas: torch.Size([8])
  pred_world: (8, 60, 2)
