# End-to-End Big Data Bowl Heliocentricity Transformer

This notebook provides a complete pipeline for:
1. **Data Processing**: Load and transform raw CSV data into ML-ready format
2. **Caching**: Save/load processed tensors to avoid reprocessing
3. **Model Training**: Train the Heliocentricity Transformer with CVAE
4. **Evaluation**: Evaluate predictions with metadata tracking
5. **Heliocentricity**: Calculate Heliocentricity scores with play/player cross-referencing

## Key Features
- **Automatic caching**: Processed data saved to `dataset/processed/processed_data.pt`
- **Metadata tracking**: Every prediction linked to game_id, play_id, and player_ids
- **Pretrained weights**: Model weights saved to `dataset/pretrained/best_heliocentricity_model.pt`
- **Cross-referencing**: Easy lookup of predictions by play and player

## Notebook Structure
1. Imports and hyperparameters
2. Data preprocessing functions
3. Model architecture (HeliocentricityTransformer)
4. Loss and inference functions
5. Custom dataset with padding and metadata
6. Data loading (with caching)
7. Training function
8. Evaluation function (with metadata)
9. Training execution
10. Evaluation execution
11. Heliocentricity calculation utilities
12. Heliocentricity computation and analysis

In [1]:
# === End-to-End Big Data Bowl Heliocentricity Transformer ===
# This notebook processes raw data, trains a model, and evaluates predictions

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import Adam
from sklearn.metrics import mean_squared_error

# --- Hyperparameters ---
# Data dimensions from Big Data Bowl
T_HIST = 25         # Number of historical frames (max in dataset)
T_PRED = 25         # Number of frames to predict (max in dataset)
N_AGENTS = 9        # Actual number of agents per frame in data
D_AGENT = 33        # Agent features: player_height, player_weight, s, a, dir, o, x_rel, y_rel + one-hot encoded position/side/role
D_GLOBAL = 18       # Global features: down, yards_to_go + one-hot encoded dropback_type, team_coverage_type

# Model architecture hyperparameters
D_MODEL = 128       # Transformer Embedding Dimension
D_LATENT = 32       # Latent variable Z dimension
N_HEADS = 8         # Transformer Heads
N_LAYERS = 3        # Transformer Encoder Layers
KL_BETA = 0.01      # KL Loss Weight (needs tuning/annealing)

# Training hyperparameters
NUM_EPOCHS = 10
BATCH_SIZE = 64
LEARNING_RATE = 1e-4

# File paths
PROCESSED_DATA_PATH = Path('dataset/processed/processed_data.pt')
PRETRAINED_WEIGHTS_PATH = Path('dataset/pretrained/best_heliocentricity_model.pt')

In [2]:
# === Data Preprocessing Functions ===

def standardize(df: pd.DataFrame) -> pd.DataFrame:
    """
    Create a direction-invariant view of all plays.
    
    Returns a new DataFrame where:
    - x_rel=0 is at the line of scrimmage (offense behind at negative x_rel, defense ahead at positive x_rel)
    - All plays show offense driving toward increasing x (left to right / bottom to top)
    - 'left' plays are flipped since they drive toward decreasing x
    - Orientation and direction angles are properly adjusted
    
    Original DataFrame is not modified.
    """
    # Create a copy to avoid modifying original
    df_rel = df.copy()
    
    # Determine which plays need flipping
    is_left = df_rel['play_direction'] == 'left'
    
    # For left plays, flip x coordinates (mirror horizontally)
    df_rel.loc[is_left, 'x'] = 120 - df_rel.loc[is_left, 'x']
    if 'ball_land_x' in df_rel.columns:
        df_rel.loc[is_left, 'ball_land_x'] = 120 - df_rel.loc[is_left, 'ball_land_x']
    
    # For left plays, flip y coordinates (mirror vertically)
    df_rel.loc[is_left, 'y'] = 53.3 - df_rel.loc[is_left, 'y']
    if 'ball_land_y' in df_rel.columns:
        df_rel.loc[is_left, 'ball_land_y'] = 53.3 - df_rel.loc[is_left, 'ball_land_y']
    
    # For left plays, flip orientation and direction angles
    df_rel.loc[is_left, 'o'] = df_rel.loc[is_left, 'o'] - 180
    df_rel.loc[is_left, 'dir'] = df_rel.loc[is_left, 'dir'] - 180
    
    # Normalize angles to [0, 360) range
    df_rel.loc[is_left, 'o'] = df_rel.loc[is_left, 'o'] % 360
    df_rel.loc[is_left, 'dir'] = df_rel.loc[is_left, 'dir'] % 360
    
    # Flip the absolute_yardline_number for left plays
    df_rel.loc[is_left, 'absolute_yardline_number'] = 120 - df_rel.loc[is_left, 'absolute_yardline_number']
    
    # Make x relative to line of scrimmage (LOS at x=0)
    df_rel['x_rel'] = df_rel['x'] - df_rel['absolute_yardline_number']
    if 'ball_land_x' in df_rel.columns:
        df_rel['ball_land_x_rel'] = df_rel['ball_land_x'] - df_rel['absolute_yardline_number']
    
    # Make y relative to center of field
    df_rel['y_rel'] = df_rel['y'] - 26.65
    if 'ball_land_y' in df_rel.columns:
        df_rel['ball_land_y_rel'] = df_rel['ball_land_y'] - 26.65
    
    # Add distance to ball landing spot
    if 'ball_land_x' in df_rel.columns and 'ball_land_y' in df_rel.columns:
        df_rel['dist_to_ball'] = np.sqrt(
            (df_rel['x'] - df_rel['ball_land_x'])**2 + 
            (df_rel['y'] - df_rel['ball_land_y'])**2
        )
    
    return df_rel


def height_to_inches(height_str):
    """Convert height string like '6-2' to inches (74)"""
    if pd.isna(height_str):
        return None
    feet, inches = height_str.split('-')
    return int(feet) * 12 + int(inches)


def process_raw_data():
    """
    Process raw CSV data into PyTorch tensors with play/player metadata.
    Returns dictionary with tensors and metadata for cross-referencing predictions.
    """
    print("=" * 60)
    print("PROCESSING RAW DATA")
    print("=" * 60)
    
    # 1. Load and concatenate all input files
    train_path = Path('dataset/train')
    input_files = sorted(train_path.glob('input*.csv'))
    
    print(f"\nLoading {len(input_files)} input files...")
    dfs = []
    for file in input_files:
        df = pd.read_csv(file)
        dfs.append(df)
    
    all_weeks = pd.concat(dfs, ignore_index=True)
    print(f"Total rows: {len(all_weeks):,}")
    
    # 2. Standardize coordinates
    print("\nStandardizing coordinates...")
    all_weeks_std = standardize(all_weeks)
    
    # 3. Filter features (keep what we need after standardization)
    play_features = [
        'game_id', 'play_id', 'frame_id', 'nfl_id', 'player_height', 'player_weight',
        'player_position', 'player_side', 'player_role', 's', 'a', 'dir', 'o',
        'x_rel', 'y_rel', 'ball_land_x_rel', 'ball_land_y_rel'
    ]
    all_weeks_std = all_weeks_std.filter(play_features)
    
    # 4. Merge with supplementary data
    print("Merging supplementary data...")
    supp = pd.read_csv('dataset/supplementary_data.csv')
    supp_features = ['game_id', 'play_id', 'down', 'yards_to_go', 'dropback_type', 'team_coverage_type']
    supp = supp.filter(supp_features)
    
    merged = pd.merge(left=all_weeks_std, right=supp, how='left', on=['game_id', 'play_id'])
    
    # 5. Convert height to inches and preserve player_side before encoding
    print("One-hot encoding categorical features...")
    merged['player_height'] = merged['player_height'].apply(height_to_inches)
    
    # Preserve player_side for metadata before one-hot encoding
    player_side_original = merged['player_side'].copy()
    
    categorical_cols = merged.select_dtypes(include=['object']).columns.tolist()
    encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
    encoded_array = encoder.fit_transform(merged[categorical_cols])
    encoded_feature_names = encoder.get_feature_names_out(categorical_cols)
    encoded_df = pd.DataFrame(encoded_array, columns=encoded_feature_names, index=merged.index)
    
    merged_encoded = pd.concat([merged.drop(columns=categorical_cols), encoded_df], axis=1)
    # Add back the original player_side for metadata extraction
    merged_encoded['player_side_original'] = player_side_original
    print(f"Encoded shape: {merged_encoded.shape}")
    
    # 6. Transform to ML format with metadata tracking
    print("\nTransforming to ML format...")
    grouped = merged_encoded.groupby(['game_id', 'play_id'])
    
    agent_feature_cols = ['player_height', 'player_weight', 's', 'a', 'dir', 'o', 
                          'x_rel', 'y_rel'] + [col for col in merged_encoded.columns 
                                                if (col.startswith('player_position_') or 
                                                    col.startswith('player_side_') or 
                                                    col.startswith('player_role_')) and 
                                                   col != 'player_side_original']
    
    global_feature_cols = ['down', 'yards_to_go'] + [col for col in merged_encoded.columns 
                                                       if col.startswith('dropback_type_') or 
                                                       col.startswith('team_coverage_type_')]
    
    trajectory_cols = ['x_rel', 'y_rel']
    
    historical_agent_features = []
    global_context_features = []
    ground_truth_trajectories = []
    play_metadata = []  # Store game_id, play_id, player mapping
    
    for (game_id, play_id), play_data in grouped:
        play_data = play_data.sort_values('frame_id')
        frames = play_data['frame_id'].unique()
        
        if len(frames) < 2:
            continue
        
        frame_data = []
        ground_truth_data = []
        player_ids = None
        player_sides = None
        
        for i in range(len(frames) - 1):
            current_frame = frames[i]
            next_frame = frames[i + 1]
            
            current_frame_players = play_data[play_data['frame_id'] == current_frame].sort_values('nfl_id')
            next_frame_players = play_data[play_data['frame_id'] == next_frame].sort_values('nfl_id')
            
            # Store player IDs and sides from first frame
            if player_ids is None:
                player_ids = current_frame_players['nfl_id'].astype(int).values
                # Binary encode player_side: 0 for offense, 1 for defense
                player_sides = (current_frame_players['player_side_original'] == 'defense').astype(int).values
            
            agent_features = current_frame_players[agent_feature_cols].values
            frame_data.append(agent_features)
            
            next_positions = next_frame_players[trajectory_cols].values
            ground_truth_data.append(next_positions)
        
        historical_agent_features.append(np.array(frame_data))
        ground_truth_trajectories.append(np.array(ground_truth_data))
        
        global_features = play_data[global_feature_cols].iloc[0].values
        global_context_features.append(global_features)
        
        # Store metadata for cross-referencing
        play_metadata.append({
            'game_id': int(game_id),
            'play_id': int(play_id),
            'player_ids': player_ids.tolist(),  # Convert to list for JSON compatibility
            'player_sides': player_sides.tolist(),  # Binary encoded: 0=offense, 1=defense
            'n_frames': len(frame_data),
            'n_agents': len(player_ids)
        })
    
    print(f"Processed {len(historical_agent_features)} plays")
    
    # 7. Convert to tensors
    print("\nConverting to PyTorch tensors...")
    print(f"DEBUG: historical_agent_features type: {type(historical_agent_features)}")
    print(f"DEBUG: First element type: {type(historical_agent_features[0])}")
    print(f"DEBUG: First element dtype: {historical_agent_features[0].dtype}")
    print(f"DEBUG: First element shape: {historical_agent_features[0].shape}")
    print(f"DEBUG: Sample values from first element:\n{historical_agent_features[0][0, 0, :]}")
    
    historical_agent_features_tensors = [torch.tensor(arr, dtype=torch.float32) for arr in historical_agent_features]
    ground_truth_trajectories_tensors = [torch.tensor(arr, dtype=torch.float32) for arr in ground_truth_trajectories]
    global_context_features_tensor = torch.tensor(np.array(global_context_features), dtype=torch.float32)
    
    # 8. Save to disk
    save_path = Path('dataset/processed')
    save_path.mkdir(exist_ok=True)
    
    torch.save({
        'historical_agent_features': historical_agent_features_tensors,
        'ground_truth_trajectories': ground_truth_trajectories_tensors,
        'global_context_features': global_context_features_tensor,
        'play_metadata': play_metadata  # Add metadata for cross-referencing
    }, PROCESSED_DATA_PATH)
    
    print(f"\n✓ Saved processed data to {PROCESSED_DATA_PATH}")
    print(f"  - {len(historical_agent_features_tensors)} plays")
    print(f"  - {len(historical_agent_features_tensors)} plays with metadata")
    print("=" * 60)
    
    return {
        'historical_agent_features': historical_agent_features_tensors,
        'ground_truth_trajectories': ground_truth_trajectories_tensors,
        'global_context_features': global_context_features_tensor,
        'play_metadata': play_metadata
    }


def load_or_process_data():
    """
    Load processed data if it exists, otherwise process raw data.
    """
    if PROCESSED_DATA_PATH.exists():
        print(f"✓ Loading cached data from {PROCESSED_DATA_PATH}")
        loaded_data = torch.load(PROCESSED_DATA_PATH)
        print(f"  Loaded {len(loaded_data['historical_agent_features'])} plays")
        return loaded_data
    else:
        print(f"✗ Cached data not found at {PROCESSED_DATA_PATH}")
        print("  Processing raw data...")

In [3]:
class HeliocentricityTransformer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        
        # Unpack kwargs for clarity
        self.T_HIST, self.T_PRED, self.N_AGENTS = kwargs['T_HIST'], kwargs['T_PRED'], kwargs['N_AGENTS']
        self.D_AGENT, self.D_GLOBAL, self.D_MODEL = kwargs['D_AGENT'], kwargs['D_GLOBAL'], kwargs['D_MODEL']
        self.D_LATENT, self.N_HEADS, self.N_LAYERS = kwargs['D_LATENT'], kwargs['N_HEADS'], kwargs['N_LAYERS']
        self.KL_BETA = kwargs['KL_BETA']
        
        # --- 1. Initial Embedding Layers ---
        self.agent_embed = nn.Linear(self.D_AGENT, self.D_MODEL)
        self.global_embed = nn.Linear(self.D_GLOBAL, self.D_MODEL)
        
        # --- 2. Transformer Encoder (Core STT) ---
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.D_MODEL, 
            nhead=self.N_HEADS, 
            dim_feedforward=self.D_MODEL * 4, 
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=self.N_LAYERS)
        
        # --- 3. CVAE Heads (Prediction Heads from Context C) ---
        # CAVE requires a context vector (C) for prior/recognition networks
        
        # CVAE: Prior Network (p(z|C)) -> outputs mu_prior, log_var_prior
        self.mlp_prior = nn.Sequential(
            nn.Linear(self.D_MODEL, self.D_MODEL),
            nn.ReLU(),
            nn.Linear(self.D_MODEL, 2 * self.D_LATENT)
        )

        # CVAE: Recognition Network (q(z|C, Y_truth)) -> outputs mu_rec, log_var_rec
        # Input is C + flattened Y_truth (context + ground truth trajectory)
        self.mlp_recognition = nn.Sequential(
            nn.Linear(self.D_MODEL + self.T_PRED * self.N_AGENTS * 2, self.D_MODEL),
            nn.ReLU(),
            nn.Linear(self.D_MODEL, 2 * self.D_LATENT)
        )

        # --- 4. Decoder Head (Trajectory Generator) ---
        # Input is C + Z. Output is the flattened trajectory (x, y coordinates)
        self.mlp_decoder = nn.Sequential(
            nn.Linear(self.D_MODEL + self.D_LATENT, self.D_MODEL * 2),
            nn.ReLU(),
            # Output shape: (Batch, T_PRED * N_AGENTS * 2)
            nn.Linear(self.D_MODEL * 2, self.T_PRED * self.N_AGENTS * 2)
        )
        
    def reparameterize(self, mu, log_var):
        # Sampling Z = mu + sigma * epsilon
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, X_hist_agents, X_global, Y_truth=None):
        B = X_hist_agents.size(0)
        
        # 1. Agent Embedding (Per-Frame)
        # (B, T_hist, N_agents, D_agent) -> (B, T_hist, N_agents, D_MODEL)
        agent_emb = self.agent_embed(X_hist_agents)
        
        # 2. Global CLS Token Embedding
        # (B, D_global) -> (B, D_MODEL)
        global_emb = self.global_embed(X_global)
        # Expand for T_hist: (B, 1, 1, D_MODEL). Expand(T_hist) not needed as we use flatten below
        
        # 3. Prepare Sequence for Transformer
        
        # Create CLS Token for each time step in the historical sequence
        # Shape: (B, T_hist, 1, D_MODEL)
        cls_tokens = global_emb.unsqueeze(1).unsqueeze(1).expand(-1, self.T_HIST, -1, -1)
        
        # Concatenate CLS token to the front of each frame's set of agents
        # Shape: (B, T_hist, N_agents + 1, D_MODEL)
        input_sequence = torch.cat([cls_tokens, agent_emb], dim=2)
        
        # Flatten time and agent dimensions for Transformer input
        # Shape: (B, T_hist * (N_agents + 1), D_MODEL)
        flat_input = input_sequence.view(B, -1, self.D_MODEL)
        
        # Add Positional/Temporal Encoding here (Omitted)
        
        # 4. Transformer Encoding
        # Encoded_Output: (B, T_hist * (N_agents + 1), D_MODEL)
        encoded_output = self.transformer_encoder(flat_input)
        
        # 5. Extract Context Vector C from the first CLS token
        # The first token is CLS at t=0. C should capture the full context.
        # Context C: (B, D_MODEL)
        C = encoded_output[:, 0, :]
        
        # --- CVAE Latent Space ---
        # Prior Network: p(z|C)
        mu_prior, log_var_prior = self.mlp_prior(C).chunk(2, dim=-1)

        # Recognition Network: q(z|C, Y_truth) is only used during training
        if Y_truth is not None:
            # Flatten Y_truth: (B, T_pred * N_agents * 2)
            Y_flat = Y_truth.view(B, -1)
            rec_input = torch.cat([C, Y_flat], dim=-1)
            mu_rec, log_var_rec = self.mlp_recognition(rec_input).chunk(2, dim=-1)
            Z = self.reparameterize(mu_rec, log_var_rec)
        else:
            # Inference: Sample Z from the Prior distribution p(z|C)
            # This is key for generating diverse, expected trajectories (E)
            Z = self.reparameterize(mu_prior, log_var_prior)
            mu_rec, log_var_rec = mu_prior, log_var_prior # Use prior stats for loss calc placeholder

        # --- Decoder ---
        # Input: [C, Z]
        decoder_input = torch.cat([C, Z], dim=-1)
        
        # Output: (B, T_pred * N_agents * 2)
        Y_pred_flat = self.mlp_decoder(decoder_input)
        
        # Reshape to (B, T_pred, N_agents, 2)
        Y_pred = Y_pred_flat.view(B, self.T_PRED, self.N_AGENTS, 2)
        
        return Y_pred, mu_rec, log_var_rec, mu_prior, log_var_prior

In [4]:
def vae_loss(Y_pred, Y_truth, mu_rec, log_var_rec, mu_prior, log_var_prior, KL_BETA):
    # 1. Reconstruction Loss (L_recon): RMSE on the predicted x, y coordinates
    # We use MSE here for simplicity in PyTorch, but RMSE is the metric.
    L_recon = F.mse_loss(Y_pred, Y_truth, reduction='sum') / Y_pred.size(0) # Mean over batch

    # 2. KL Divergence Loss (L_KL): KL(q(z|C, Y) || p(z|C))
    # Closed-form KL for Gaussian: 0.5 * sum(1 + log(sigma_prior^2) - log(sigma_rec^2) - (mu_rec - mu_prior)^2 / sigma_prior^2 - exp(log(sigma_rec^2)) / sigma_prior^2)
    # Using torch.exp(log_var) = sigma^2
    kl_loss = 0.5 * torch.sum(
        log_var_prior - log_var_rec - 1 
        + (torch.exp(log_var_rec) + (mu_rec - mu_prior).pow(2)) / torch.exp(log_var_prior)
    ) / Y_pred.size(0)

    # Total Loss (Weighted sum)
    total_loss = L_recon + KL_BETA * kl_loss
    return total_loss, L_recon.item(), kl_loss.item()

# --- Heliocentricity Inference Function (E Generator) ---
@torch.no_grad()
def generate_expected_trajectories(model, X_hist_agents, X_global, K=10):
    """
    Generates K diverse, plausible trajectories for the defense (E) 
    by sampling the latent space Z from the prior distribution.
    """
    model.eval()
    B = X_hist_agents.size(0)
    
    # Repeat inputs K times to batch the K samples
    X_hist_agents_K = X_hist_agents.repeat_interleave(K, dim=0)
    X_global_K = X_global.repeat_interleave(K, dim=0)

    # Since Y_truth=None, Z is sampled from the prior p(z|C)
    Y_pred_K, _, _, _, _ = model(X_hist_agents_K, X_global_K, Y_truth=None)
    
    # Reshape: (B * K, T_pred, N_agents, 2) -> (B, K, T_pred, N_agents, 2)
    return Y_pred_K.view(B, K, model.T_PRED, model.N_AGENTS, 2)

# Note: The final step of calculating Heliocentricity (H) based on 
# min separation distance (A vs E) is a NumPy/Pandas operation after this PyTorch step.

In [5]:
# === Custom Dataset with Padding and Metadata Tracking ===

class FootballDataset(Dataset):
    """Dataset with padding for variable-length sequences and metadata tracking."""
    
    def __init__(self, hist_features, gt_trajectories, global_features, metadata, 
                 max_hist_len=None, max_pred_len=None, max_n_agents=None):
        self.hist_features = hist_features
        self.gt_trajectories = gt_trajectories
        self.global_features = global_features
        self.metadata = metadata  # Play/player metadata for cross-referencing
        
        # Determine max lengths if not provided
        self.max_hist_len = max_hist_len or max(x.shape[0] for x in hist_features)
        self.max_pred_len = max_pred_len or max(y.shape[0] for y in gt_trajectories)
        self.max_n_agents = max_n_agents or max(x.shape[1] for x in hist_features)
        
    def __len__(self):
        return len(self.hist_features)
    
    def __getitem__(self, idx):
        hist = self.hist_features[idx]  # (T_hist_actual, N_agents_actual, D_agent)
        gt = self.gt_trajectories[idx]  # (T_pred_actual, N_agents_actual, 2)
        global_feat = self.global_features[idx]  # (D_global,)
        meta = self.metadata[idx]  # Play metadata
        
        # Get actual lengths
        hist_len = hist.shape[0]
        pred_len = gt.shape[0]
        n_agents = hist.shape[1]
        
        # Pad historical features to max_hist_len and max_n_agents
        if hist_len < self.max_hist_len:
            pad_hist_time = torch.zeros(self.max_hist_len - hist_len, hist.shape[1], hist.shape[2], dtype=hist.dtype)
            hist = torch.cat([hist, pad_hist_time], dim=0)
        else:
            hist = hist[:self.max_hist_len]
            hist_len = self.max_hist_len
        
        if n_agents < self.max_n_agents:
            pad_hist_agents = torch.zeros(hist.shape[0], self.max_n_agents - n_agents, hist.shape[2], dtype=hist.dtype)
            hist_padded = torch.cat([hist, pad_hist_agents], dim=1)
        else:
            hist_padded = hist[:, :self.max_n_agents, :]
            n_agents = self.max_n_agents
        
        # Pad ground truth to max_pred_len and max_n_agents
        if pred_len < self.max_pred_len:
            pad_gt_time = torch.zeros(self.max_pred_len - pred_len, gt.shape[1], 2, dtype=gt.dtype)
            gt = torch.cat([gt, pad_gt_time], dim=0)
        else:
            gt = gt[:self.max_pred_len]
            pred_len = self.max_pred_len
        
        if gt.shape[1] < self.max_n_agents:
            pad_gt_agents = torch.zeros(gt.shape[0], self.max_n_agents - gt.shape[1], 2, dtype=gt.dtype)
            gt_padded = torch.cat([gt, pad_gt_agents], dim=1)
        else:
            gt_padded = gt[:, :self.max_n_agents, :]
        
        # Return data + metadata index for cross-referencing
        return hist_padded, global_feat, gt_padded, hist_len, pred_len, idx
    
    def get_metadata(self, idx):
        """Get play/player metadata for a specific index."""
        return self.metadata[idx]

In [6]:
# === Load or Process Data ===

# Load cached data or process from scratch
loaded_data = load_or_process_data()

# Extract data and metadata
historical_agent_features = loaded_data['historical_agent_features']
ground_truth_trajectories = loaded_data['ground_truth_trajectories']
global_context_features = loaded_data['global_context_features']
play_metadata = loaded_data['play_metadata']

print(f"\nGlobal context shape: {global_context_features.shape}")
print(f"Sample play metadata: {play_metadata[0]}")

# Model configuration
model_config = {
    'T_HIST': T_HIST, 'T_PRED': T_PRED, 'N_AGENTS': N_AGENTS, 'D_AGENT': D_AGENT, 
    'D_GLOBAL': D_GLOBAL, 'D_MODEL': D_MODEL, 'D_LATENT': D_LATENT, 'N_HEADS': N_HEADS, 
    'N_LAYERS': N_LAYERS, 'KL_BETA': KL_BETA
}

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

# Create dataset with metadata
dataset = FootballDataset(
    historical_agent_features, 
    ground_truth_trajectories, 
    global_context_features,
    play_metadata,
    max_hist_len=model_config['T_HIST'],
    max_pred_len=model_config['T_PRED'],
    max_n_agents=model_config['N_AGENTS']
)

# Split into train and test sets (80/20)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

print(f"\nDataset split: Train={train_size}, Test={test_size}")
print(f"Max hist length: {dataset.max_hist_len}, Max pred length: {dataset.max_pred_len}, Max agents: {dataset.max_n_agents}")

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Initialize Model and Optimizer
model = HeliocentricityTransformer(**model_config).to(device)
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")

✓ Loading cached data from dataset/processed/processed_data.pt


  loaded_data = torch.load(PROCESSED_DATA_PATH)


  Loaded 14108 plays

Global context shape: torch.Size([14108, 18])
Sample play metadata: {'game_id': 2023090700, 'play_id': 101, 'player_ids': [43290, 44930, 46137, 52546, 53487, 53541, 53959, 54486, 54527], 'player_sides': [0, 0, 0, 0, 0, 0, 0, 0, 0], 'n_frames': 25, 'n_agents': 9}

Using device: cpu

Dataset split: Train=11286, Test=2822
Max hist length: 25, Max pred length: 25, Max agents: 9
Model initialized with 865,602 parameters
Model initialized with 865,602 parameters


In [7]:
# --- 3. The Training Loop with Masking ---

def create_mask(lengths, max_len, device):
    """Create attention mask: True for valid positions, False for padding"""
    batch_size = len(lengths)
    mask = torch.arange(max_len, device=device).expand(batch_size, max_len) < lengths.unsqueeze(1)
    return mask

def train_model(model, train_loader, optimizer, model_config, device, num_epochs=20):
    """
    Train the Heliocentricity Transformer model.
    
    Args:
        model: The HeliocentricityTransformer model
        train_loader: DataLoader for training data
        optimizer: Optimizer for training
        model_config: Dictionary with model configuration
        device: Device to train on (cuda/cpu)
        num_epochs: Number of epochs to train
        
    Returns:
        Dictionary with training history
    """
    print(f"Starting training on {device}...")
    
    history = {
        'total_loss': [],
        'recon_loss': [],
        'kl_loss': []
    }
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_recon_loss = 0
        total_kl_loss = 0
        
        for batch_idx, (X_agents, X_global, Y_truth, hist_lens, pred_lens, _) in enumerate(train_loader):
            
            # Move to device
            X_agents = X_agents.to(device)
            X_global = X_global.to(device)
            Y_truth = Y_truth.to(device)
            hist_lens = hist_lens.to(device)
            pred_lens = pred_lens.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # 1. Forward Pass
            Y_pred, mu_rec, log_var_rec, mu_prior, log_var_prior = model(X_agents, X_global, Y_truth=Y_truth)
            
            # 2. Create mask for ground truth prediction loss
            pred_mask = create_mask(pred_lens, model_config['T_PRED'], device)
            pred_mask_expanded = pred_mask.unsqueeze(-1).unsqueeze(-1).expand_as(Y_pred)
            
            # Apply mask to predictions and ground truth
            Y_pred_masked = Y_pred * pred_mask_expanded
            Y_truth_masked = Y_truth * pred_mask_expanded
            
            # 3. Compute VAE Loss with masked outputs
            L_recon = F.mse_loss(Y_pred_masked, Y_truth_masked, reduction='sum') / pred_lens.sum()
            
            kl_loss = 0.5 * torch.sum(
                log_var_prior - log_var_rec - 1 
                + (torch.exp(log_var_rec) + (mu_rec - mu_prior).pow(2)) / torch.exp(log_var_prior)
            ) / X_agents.size(0)
            
            loss = L_recon + model_config['KL_BETA'] * kl_loss
            
            # 4. Backward Pass and Optimization
            loss.backward()
            optimizer.step()
            
            # Accumulate metrics
            total_loss += loss.item()
            total_recon_loss += L_recon.item()
            total_kl_loss += kl_loss.item()
            
            # Print update every 100 batches
            if (batch_idx + 1) % 100 == 0:
                print(f"  Batch {batch_idx+1}/{len(train_loader)} | Total Loss: {total_loss / (batch_idx+1):.4f} | Recon: {total_recon_loss / (batch_idx+1):.4f} | KL: {total_kl_loss / (batch_idx+1):.4f}")
    
        # --- End of Epoch ---
        avg_epoch_loss = total_loss / len(train_loader)
        avg_recon = total_recon_loss / len(train_loader)
        avg_kl = total_kl_loss / len(train_loader)
        
        history['total_loss'].append(avg_epoch_loss)
        history['recon_loss'].append(avg_recon)
        history['kl_loss'].append(avg_kl)
        
        print(f"\n--- Epoch {epoch+1}/{num_epochs} Complete ---")
        print(f"Average Total Loss: {avg_epoch_loss:.4f} | Recon: {avg_recon:.4f} | KL: {avg_kl:.4f}")
    
    return history

In [8]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.metrics import mean_squared_error

# Load the saved model weights
# model.load_state_dict(torch.load("best_heliocentricity_model.pt"))

@torch.no_grad()
def evaluate_model(model, data_loader, dataset, device):
    """
    Evaluate the model and return results with metadata for Heliocentricity calculation.
    
    Args:
        model: The trained HeliocentricityTransformer model
        data_loader: DataLoader for evaluation data
        dataset: The FootballDataset instance to retrieve metadata
        device: Device to evaluate on (cuda/cpu)
        
    Returns:
        List of dictionaries with predictions and metadata for each play
    """
    model.eval()
    total_rmse = []
    total_recon_loss = 0
    total_kl_loss = 0
    
    # Store data for final Heliocentricity calculation
    results_for_H_calc = []

    for batch_idx, (X_agents, X_global, Y_truth, hist_lens, pred_lens, indices) in enumerate(data_loader):
            
        # Move to device
        X_agents = X_agents.to(device)
        X_global = X_global.to(device)
        Y_truth = Y_truth.to(device)
        hist_lens = hist_lens.to(device)
        pred_lens = pred_lens.to(device)

        # 1. Deterministic Prediction (for Reconstruction Loss)
        # Uses the recognition network q(z|C, Y_truth) which yields the best reconstruction
        Y_pred, mu_rec, log_var_rec, mu_prior, log_var_prior = model(X_agents, X_global, Y_truth=Y_truth)
        
        # Calculate Loss components
        loss, L_recon, L_KL = vae_loss(
            Y_pred, Y_truth, 
            mu_rec, log_var_rec, mu_prior, log_var_prior, 
            model.KL_BETA
        )
        total_recon_loss += L_recon
        total_kl_loss += L_KL

        # 2. Calculate Root Mean Squared Error (RMSE) on the deterministic prediction
        # Detach and convert to numpy for standard metric calculation
        Y_pred_np = Y_pred.cpu().numpy()
        Y_truth_np = Y_truth.cpu().numpy()
        
        # Calculate RMSE for each sample and average (Flattening all T_pred * N_agents * 2 dimensions)
        sample_rmse = np.sqrt(mean_squared_error(Y_truth_np.reshape(-1, 1), Y_pred_np.reshape(-1, 1)))
        total_rmse.append(sample_rmse)

        # 3. Generate K Stochastic Predictions for Heliocentricity (E)
        # This uses the prior network p(z|C) for diverse sampling
        K = 10 # Number of samples per play
        Y_pred_K = generate_expected_trajectories(model, X_agents, X_global, K=K).cpu().numpy()
        
        # 4. Store results with metadata for each play in batch
        for i in range(Y_truth_np.shape[0]):
            # Get the original dataset index for this sample
            dataset_idx = indices[i].item()
            metadata = dataset.get_metadata(dataset_idx)
            
            # Store all data needed for Heliocentricity calculation
            results_for_H_calc.append({
                'Y_truth': Y_truth_np[i],
                'Y_pred': Y_pred_np[i],
                'Y_pred_K': Y_pred_K[i],
                'game_id': metadata['game_id'],
                'play_id': metadata['play_id'],
                'player_ids': metadata['player_ids'],
                'player_sides': metadata['player_sides'],
                'n_agents': metadata['n_agents'],
                'star_idx': 4  # TODO: Identify actual star receiver from metadata
            })

        # Print update every 100 batches
        if (batch_idx + 1) % 100 == 0:
            print(f"Batch {batch_idx+1}/{len(data_loader)}")

    avg_rmse = np.mean(total_rmse)
    avg_recon = total_recon_loss / len(data_loader)
    avg_kl = total_kl_loss / len(data_loader)
    
    print(f"\n--- Validation Results ---")
    print(f"Avg Trajectory RMSE: {avg_rmse:.4f} meters")
    print(f"Avg Reconstruction Loss: {avg_recon:.4f}")
    print(f"Avg KL Divergence: {avg_kl:.4f}")
    
    return results_for_H_calc

In [9]:
PRETRAINED_WGTS = None #Path('dataset/pretrained/best_heliocentricity_model.pt')

if PRETRAINED_WGTS is None:
    print('Training model from scratch:')
    train_model(model,
                train_loader,
                optimizer,
                model_config,
                device,
                num_epochs=NUM_EPOCHS)
    torch.save(model.state_dict(), 'dataset/pretrained/best_heliocentricity_model.pt')
else:
    print('Loading model from pretrained weights:')
    state_dict = torch.load(PRETRAINED_WGTS, map_location=device)
    model.load_state_dict(state_dict)

Training model from scratch:
Starting training on cpu...


KeyboardInterrupt: 

In [None]:
# === Evaluate Model ===
results = evaluate_model(model, test_loader, dataset, device)

Batch 100/177

--- Validation Results ---
Avg Trajectory RMSE: 2.8821 meters
Avg Reconstruction Loss: 14763.0471
Avg KL Divergence: 608.6204


In [None]:
# === Heliocentricity Calculation Utilities ===

def min_separation_distance(receiver_coords, defense_coords):
    """
    Calculate minimum separation distance between receiver and defense.
    
    Args:
        receiver_coords: (T_pred, 2) - receiver trajectory
        defense_coords: (T_pred, N_defenders, 2) - defense trajectories
        
    Returns:
        (T_pred,) array of minimum distances at each timestep
    """
    # Calculate distance from receiver to every defender at every frame
    dist_to_defenders = np.linalg.norm(
        receiver_coords[:, np.newaxis, :] - defense_coords, axis=2
    )
    
    # Find the minimum separation at each frame
    min_dist = np.min(dist_to_defenders, axis=1)
    return min_dist


def compute_heliocentricity(play_data):
    """
    Compute Heliocentricity score for a single play WITH metadata tracking.
    
    Args:
        play_data: Dictionary with:
            - 'Y_truth': (T_pred, N_agents, 2) ground truth trajectories
            - 'Y_pred_K': (K, T_pred, N_agents, 2) stochastic predictions
            - 'star_idx': Index of the star receiver
            - 'player_sides': Binary array (0=offense, 1=defense)
            - 'game_id', 'play_id', 'player_ids': Metadata for cross-referencing
        
    Returns:
        Dictionary with H_score, H_frame_diff, and metadata
    """
    Y_truth = play_data['Y_truth']
    Y_pred_K = play_data['Y_pred_K']
    star_idx = play_data['star_idx']
    player_sides = np.array(play_data['player_sides'])  # Binary encoded: 0=offense, 1=defense
    
    # Identify defensive players using binary encoding (1 = defense)
    def_indices = np.where(player_sides == 1)[0]
    
    if len(def_indices) == 0:
        # Fallback if no defense found: use all except star receiver
        all_indices = np.arange(len(player_sides))
        def_indices = all_indices[all_indices != star_idx]
    
    # Calculate Actual Attention (A)
    actual_R_coords = Y_truth[:, star_idx, :]
    actual_D_coords = Y_truth[:, def_indices, :]
    A = min_separation_distance(actual_R_coords, actual_D_coords)
    
    # Calculate Expected Coverage (E)
    E_K = []
    for k in range(Y_pred_K.shape[0]):
        predicted_D_coords = Y_pred_K[k, :, def_indices, :]
        E_k = min_separation_distance(actual_R_coords, predicted_D_coords)
        E_K.append(E_k)
    
    E_mean = np.mean(np.stack(E_K, axis=0), axis=0)
    
    # Calculate Heliocentricity (H = E - A)
    H_frame_diff = E_mean - A
    H_score = np.mean(H_frame_diff)
    
    return {
        'H_score': H_score,
        'H_frame_diff': H_frame_diff,
        'game_id': play_data['game_id'],
        'play_id': play_data['play_id'],
        'player_ids': play_data['player_ids'],
        'star_player_id': play_data['player_ids'][star_idx] if star_idx < len(play_data['player_ids']) else None
    }


def compute_heliocentricity_for_all(results):
    """
    Compute Heliocentricity for all plays in evaluation results.
    
    Args:
        results: List of result dictionaries from evaluate_model
        
    Returns:
        DataFrame with Heliocentricity scores and metadata
    """
    helio_results = []
    
    for play_data in results:
        helio = compute_heliocentricity(play_data)
        helio_results.append(helio)
    
    # Convert to DataFrame for easy analysis
    df = pd.DataFrame([{
        'game_id': r['game_id'],
        'play_id': r['play_id'],
        'H_score': r['H_score'],
        'star_player_id': r['star_player_id']
    } for r in helio_results])
    
    return df, helio_results

In [None]:
# === Compute Heliocentricity Scores ===

# Compute Heliocentricity for all evaluated plays
helio_df, helio_results = compute_heliocentricity_for_all(results)

# Display summary statistics
print("\n=== Heliocentricity Summary ===")
print(helio_df.describe())

# Show top 5 plays by Heliocentricity score
print("\n=== Top 5 Plays by Heliocentricity ===")
top_plays = helio_df.nlargest(5, 'H_score')
print(top_plays)

# Example: Access detailed results for a specific play
if helio_results:
    print("\n=== Sample Detailed Result ===")
    sample = helio_results[0]
    print(f"Game ID: {sample['game_id']}, Play ID: {sample['play_id']}")
    print(f"Star Player ID: {sample['star_player_id']}")
    print(f"Heliocentricity Score: {sample['H_score']:.4f}")
    print(f"Frame-by-frame values shape: {sample['H_frame_diff'].shape}")