In [14]:
import torch
import torch.nn as nn
import math
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [15]:
def collate_fn(batch):
    if len(batch[0]) == 3:  # Training data with future
        pasts, masks, futures = zip(*batch)
        past = torch.stack(pasts)
        mask = torch.stack(masks)
        future = torch.stack(futures)
        return past, mask, future
    else:  # Test data without future
        pasts, masks = zip(*batch)
        past = torch.stack(pasts)
        mask = torch.stack(masks)
        return past, mask
    
class TrajectoryDataset(Dataset):
    def __init__(self, input_path=None, data=None, T_past=50, T_future=60, is_test=False):
        if data is not None:
            self.data = data
        else:
            npz = np.load(input_path)
            self.data = npz['data']
        self.T_past = T_past
        self.T_future = T_future
        self.is_test = is_test
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]  #(num_agents, T, 6)
        
        # past trajectory of all agents
        past = scene[:, :self.T_past, :]
        
        # create mask for valid agents, essentially checking if the agent has any past trajectory
        # no padding in the first two dimensions
        mask = np.sum(np.abs(past[..., :2]), axis=(1, 2)) > 0
        
        # for training data, also extract future trajectory of ego vehicle
        if not self.is_test and scene.shape[1] >= self.T_past + self.T_future:
            future = scene[0, self.T_past:self.T_past+self.T_future, :2]  # Ego vehicle future (x,y)
            return torch.tensor(past, dtype=torch.float32), torch.tensor(mask, dtype=torch.bool), torch.tensor(future, dtype=torch.float32)
        
        # for test data, only return past
        return torch.tensor(past, dtype=torch.float32), torch.tensor(mask, dtype=torch.bool)

In [16]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pos_enc = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        exponent = torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        exponent = torch.clamp(exponent, min=-50.0)  # Prevent underflow
        div_term = torch.exp(exponent)
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        pos_enc = pos_enc.unsqueeze(1)
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, x):
        seq_len = x.size(0)
        return x + self.pos_enc[:seq_len]

In [17]:
class TrajectoryTransformer(nn.Module):
    def __init__(self, feature_dim=6, d_model=128, nhead=8,
                 num_layers_temporal=2, num_layers_social=2,
                 dim_feedforward=256, T_past=50, T_future=60, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.T_past = T_past
        self.T_future = T_future

        self.input_embed = nn.Linear(feature_dim, d_model)
        self.time_pos_enc = PositionalEncoding(d_model, max_len=T_past)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.temporal_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers_temporal)

        social_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.social_encoder = nn.TransformerEncoder(social_layer, num_layers=num_layers_social)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, dim_feedforward // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward // 2, 2 * T_future)
        )

    def forward(self, past, agent_mask):
        B, N, T, F = past.shape
        x = past.view(B * N, T, F).permute(1, 0, 2)
        x = self.input_embed(x)
        x = x / (x.norm(dim=-1, keepdim=True) + 1e-6)  # Normalize
        x = x * math.sqrt(self.d_model)
        x = self.time_pos_enc(x)
        x = self.temporal_encoder(x)
        agent_feats = x[-1].view(B, N, self.d_model)

        if (~agent_mask).all(dim=1).any():
            fallback = agent_mask.clone()
            fallback[:, 0] = True
            agent_mask = torch.where(agent_mask.sum(dim=1, keepdim=True) == 0, fallback, agent_mask)

        scene = agent_feats.permute(1, 0, 2)
        scene = self.social_encoder(scene, src_key_padding_mask=~agent_mask)
        ego_embed = scene[0]
        out = self.mlp(ego_embed)
        preds = out.view(B, self.T_future, 2)
        return preds

In [18]:
def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    criterion = nn.MSELoss()
    
    for batch in dataloader:
        past, mask, future = [x.to(device) for x in batch]
        
        assert not torch.isnan(future).any(), "NaNs in target future"
        assert not torch.isinf(future).any(), "Infs in target future"
        
        optimizer.zero_grad()
        pred = model(past, mask)
        
        assert not torch.isnan(pred).any(), "NaNs in prediction"
        assert not torch.isinf(pred).any(), "Infs in prediction"
        
        loss = criterion(pred, future)
        # Calculate loss against ground truth future
        try:
            loss.backward()
        except RuntimeError as e:
            print("Backward failed with loss:", loss)
            for name, param in model.named_parameters():
                if torch.isnan(param).any() or torch.isinf(param).any():
                    print(f"Param {name} has NaN or inf")
            raise e

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item() * past.size(0)
        total_loss += loss.item() * past.size(0)
        
    return total_loss / len(dataloader.dataset)


def evaluate(model, val_loader, device):
    model.eval()
    total_loss = 0.0
    criterion = nn.MSELoss()
    
    with torch.no_grad():
        for batch in val_loader:
            past, mask, future = [x.to(device) for x in batch]
            pred = model(past, mask)
            loss = criterion(pred, future)
            total_loss += loss.item() * past.size(0)
            
    return total_loss / len(val_loader.dataset)


def predict(model, test_loader, device):
    model.eval()
    all_preds = []
    
    with torch.no_grad():
        for batch in test_loader:
            past, mask = [x.to(device) for x in batch]
            pred = model(past, mask)
            all_preds.append(pred.cpu().numpy())
            
    return np.concatenate(all_preds, axis=0)

In [19]:
train_input = 'data/train.npz'
test_input = 'data/test_input.npz'
output_csv = 'predictions.csv'

batch_size = 32
lr = 5e-4
epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [20]:
full_data = np.load(train_input)['data']

# Split into train and eval (7:3)
num_samples = len(full_data)
num_train = int(0.7 * num_samples)
perm = np.random.permutation(num_samples)
train_idx = perm[:num_train]
eval_idx = perm[num_train:]

train_data = full_data[train_idx]
eval_data = full_data[eval_idx]

train_ds = TrajectoryDataset(data=train_data)
eval_ds = TrajectoryDataset(data=eval_data)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(eval_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

test_ds = TrajectoryDataset(test_input)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

model = TrajectoryTransformer().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)



In [None]:
best_val_loss = float('inf')
for epoch in range(1, epochs+1):
    train_loss = train(model, train_loader, optimizer, device)
    val_loss = evaluate(model, eval_loader, device)
    
    print(f"Epoch {epoch}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pt')
        print(f"Saved best model with val loss: {best_val_loss:.4f}")

# Load best model for testing


Epoch 1/100, Train Loss: 18048059.9314, Val Loss: 5016223.8480
Saved best model with val loss: 5016223.8480
Epoch 2/100, Train Loss: 7529572.0177, Val Loss: 3274130.5310
Saved best model with val loss: 3274130.5310
Epoch 3/100, Train Loss: 6072606.9309, Val Loss: 2916850.4583
Saved best model with val loss: 2916850.4583
Epoch 4/100, Train Loss: 5782672.2326, Val Loss: 2898068.3980
Saved best model with val loss: 2898068.3980
Epoch 5/100, Train Loss: 5582924.3634, Val Loss: 2799922.7217
Saved best model with val loss: 2799922.7217
Epoch 6/100, Train Loss: 5280491.6400, Val Loss: 2516638.2267
Saved best model with val loss: 2516638.2267
Epoch 7/100, Train Loss: 4830185.0674, Val Loss: 2196735.8990
Saved best model with val loss: 2196735.8990
Epoch 8/100, Train Loss: 4386618.4777, Val Loss: 1935618.2093
Saved best model with val loss: 1935618.2093
Epoch 9/100, Train Loss: 4123113.7931, Val Loss: 1863087.8383
Saved best model with val loss: 1863087.8383
Epoch 10/100, Train Loss: 3982542.88

In [None]:
model.load_state_dict(torch.load('best_model.pt'))

<All keys matched successfully>

In [None]:
# Generate predictions on test set
test_preds = predict(model, test_loader, device)
B, T, D = test_preds.shape
flat_preds = test_preds.reshape(B * T, D)
np.savetxt(output_csv, flat_preds, delimiter=',')
print(f"Saved predictions to {output_csv}")

Saved predictions to predictions.csv
