# Completion Probability Model - Route-Aware with Counterfactual Prediction

## Key Enhancements:
1. **Route type awareness** - Learns route-specific completion patterns via embeddings
2. **Receiver momentum features** - Captures receiver trajectory and adjustment capability
3. **Training only on release frames** - No post-release contamination
4. **Counterfactual prediction wrapper** - Enables "what if" scenario analysis

## Part 1: Imports and Configuration

In [None]:
# Standard library
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import pickle
import time

# Data processing
import numpy as np
import pandas as pd
import polars as pl

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Metrics
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report
)
from sklearn.calibration import calibration_curve

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Progress bars
from tqdm import tqdm

# Suppress warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Device configuration
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

In [None]:

# Configure Kaggle credentials
import os
import json

os.makedirs('/root/.kaggle', exist_ok=True)
kaggle_json = {
    "username": "",
    "key": ""
}

with open('/root/.kaggle/kaggle.json', 'w') as f:
    json.dump(kaggle_json, f)

os.chmod('/root/.kaggle/kaggle.json', 0o600)
print("✓ Kaggle credentials configured!")

# Download Big Data Bowl 2026 Prediction data
!kaggle competitions download -c nfl-big-data-bowl-2026-analytics -p ./data
!unzip -o -q ./data/nfl-big-data-bowl-2026-analytics.zip -d ./data
print("✓ Data downloaded and extracted!")

# Path configuration
# Adjust these paths based on your environment (local vs Kaggle)

# Local paths
INPUT_DATA_DIR = Path("data/raw/analytics/train")
SUPPLEMENTARY_DATA_PATH = INPUT_DATA_DIR.parent / "supplementary_data.csv"

# Kaggle paths 
# INPUT_DATA_DIR = Path("data/114239_nfl_competition_files_published_analytics_final/train")
# SUPPLEMENTARY_DATA_PATH = INPUT_DATA_DIR.parent / "supplementary_data.csv"

# Output paths
OUT_DIR = Path("visualizations/model")
OUT_DIR.mkdir(exist_ok=True, parents=True)

# Output paths
# OUT_DIR = Path("/kaggle/working/split_prepped_data/")
# OUT_DIR.mkdir(exist_ok=True, parents=True)

print(f"Input data directory: {INPUT_DATA_DIR}")
print(f"Supplementary data path: {SUPPLEMENTARY_DATA_PATH}")
print(f"Input data exists: {INPUT_DATA_DIR.exists()}")
print(f"Supplementary data exists: {SUPPLEMENTARY_DATA_PATH.exists()}")

## Part 2: Data Preparation Functions

In [None]:
def load_supplementary_data() -> pl.DataFrame:
    """
    Load supplementary data with completion labels AND route information.

    Returns:
        pl.DataFrame: Supplementary data with completion and route_type columns.
    """
    df = pl.read_csv(
        SUPPLEMENTARY_DATA_PATH,
        null_values=["NA", "nan", "N/A", "NaN", ""],
    )

    # Filter to plays with pass_result
    df = df.filter(pl.col("pass_result").is_not_null())

    print(f"Loaded supplementary data with {len(df)} plays")
    print(f"Pass result distribution:")
    print(df['pass_result'].value_counts().sort('pass_result'))

    # Convert pass_result to binary completion
    df = df.with_columns([
        pl.when(pl.col("pass_result") == "C")
        .then(1)
        .otherwise(0)
        .cast(pl.Int32)
        .alias("completion")
    ])

    # Handle missing route types
    df = df.with_columns([
        pl.col("route_of_targeted_receiver")
        .fill_null("UNKNOWN")
        .alias("route_type")
    ])

    # Select columns we need
    df = df.select(["game_id", "play_id", "completion", "route_type"])

    print(f"\n=== Completion Distribution ===")
    print(f"Total plays: {len(df)}")
    print(f"Completions: {df['completion'].sum()} ({100*df['completion'].mean():.1f}%)")
    print(f"Incompletions: {(df['completion'] == 0).sum()} ({100*(1-df['completion'].mean()):.1f}%)")

    print(f"\n=== Route Distribution ===")
    print(df['route_type'].value_counts().sort('route_type'))

    return df

In [None]:
def create_route_encoding(df: pl.DataFrame) -> dict:
    """
    Create mapping from route names to integer IDs.

    Args:
        df: DataFrame with route_type column

    Returns:
        dict: {"route_name": id, ...}
    """
    unique_routes = sorted(df['route_type'].unique().to_list())
    route_to_id = {route: idx for idx, route in enumerate(unique_routes)}

    print(f"\n=== Route Encoding ===")
    print(f"Total unique routes: {len(route_to_id)}")
    for route, idx in sorted(route_to_id.items(), key=lambda x: x[1]):
        count = (df['route_type'] == route).sum()
        print(f"  {idx:2d}: {route:20s} ({count:4d} plays)")

    return route_to_id

In [None]:
def load_input_data() -> pl.DataFrame:
    """
    Load input tracking data from CSV files.

    Returns:
        pl.DataFrame: Raw tracking data with all fields from input files.
    """
    csv_pattern = str(INPUT_DATA_DIR / "input_*.csv")
    df = pl.read_csv(csv_pattern, null_values=["NA", "nan", "N/A", "NaN", ""])
    print(f"Loaded {len(df)} rows from input files")
    print(f"Unique plays: {df.n_unique(['game_id', 'play_id'])}")
    print(f"Columns: {df.columns}")
    return df

In [None]:
def convert_tracking_to_cartesian(tracking_df: pl.DataFrame) -> pl.DataFrame:
    """
    Convert polar coordinates to Unit-circle Cartesian format.
    Computes velocity components (vx, vy) and orientation (ox, oy).

    Args:
        tracking_df: Tracking data with dir, o, s columns

    Returns:
        pl.DataFrame: Tracking data with Cartesian coordinates.
    """
    return (
        tracking_df.with_columns(
            # Adjust dir and o to match unit circle convention
            dir_adjusted=((pl.col("dir") - 90) * -1) % 360,
            o_adjusted=((pl.col("o") - 90) * -1) % 360,
        )
        .with_columns(
            vx=pl.col("s") * pl.col("dir_adjusted").radians().cos(),
            vy=pl.col("s") * pl.col("dir_adjusted").radians().sin(),
            ox=pl.col("o_adjusted").radians().cos(),
            oy=pl.col("o_adjusted").radians().sin(),
        )
        .drop(["dir_adjusted", "o_adjusted"])
    )

In [None]:
def standardize_tracking_directions(tracking_df: pl.DataFrame) -> pl.DataFrame:
    """
    Standardize play directions to always moving left to right.
    Also standardize ball_land_x and ball_land_y.

    Args:
        tracking_df: Tracking data with play_direction column

    Returns:
        pl.DataFrame: Tracking data with standardized directions.
    """
    return tracking_df.with_columns(
        x=pl.when(pl.col("play_direction") == "right").then(pl.col("x")).otherwise(120 - pl.col("x")),
        y=pl.when(pl.col("play_direction") == "right").then(pl.col("y")).otherwise(53.3 - pl.col("y")),
        vx=pl.when(pl.col("play_direction") == "right").then(pl.col("vx")).otherwise(-1 * pl.col("vx")),
        vy=pl.when(pl.col("play_direction") == "right").then(pl.col("vy")).otherwise(-1 * pl.col("vy")),
        ox=pl.when(pl.col("play_direction") == "right").then(pl.col("ox")).otherwise(-1 * pl.col("ox")),
        oy=pl.when(pl.col("play_direction") == "right").then(pl.col("oy")).otherwise(-1 * pl.col("oy")),
        ball_land_x=pl.when(pl.col("play_direction") == "right").then(pl.col("ball_land_x")).otherwise(120 - pl.col("ball_land_x")),
        ball_land_y=pl.when(pl.col("play_direction") == "right").then(pl.col("ball_land_y")).otherwise(53.3 - pl.col("ball_land_y")),
    ).drop("play_direction")

In [None]:
def engineer_ngs_features(tracking_df: pl.DataFrame) -> pl.DataFrame:
    """
    Engineer NGS Completion Probability features.
    Should be called on data that's already filtered appropriately.

    Features:
    - air_distance: Distance from QB to ball landing location
    - target_separation: Minimum distance from target to defenders
    - sideline_separation: Distance from target to nearest sideline
    - qb_speed: QB speed
    - time_to_throw: Time since snap (using frame_id as proxy)
    """
    # Get QB position and speed for each frame
    qb_data = (
        tracking_df
        .filter(pl.col("player_role") == "Passer")
        .select([
            "game_id", "play_id", "frame_id",
            pl.col("x").alias("qb_x"),
            pl.col("y").alias("qb_y"),
            pl.col("s").alias("qb_speed"),
        ])
    )

    # Join QB data to main dataframe
    tracking_df = tracking_df.join(qb_data, on=["game_id", "play_id", "frame_id"], how="left")

    # FEATURE 1: Air Distance (QB to ball landing location)
    tracking_df = tracking_df.with_columns([
        (
            ((pl.col("ball_land_x") - pl.col("qb_x"))**2 +
             (pl.col("ball_land_y") - pl.col("qb_y"))**2)**0.5
        ).alias("air_distance")
    ])

    # FEATURE 2: Target Separation
    # Get all defenders with their positions
    defenders = (
        tracking_df
        .filter(pl.col("player_role") == "Defensive Coverage")
        .select([
            "game_id", "play_id", "frame_id",
            pl.col("x").alias("def_x"),
            pl.col("y").alias("def_y"),
            pl.col("nfl_id").alias("def_nfl_id"),
        ])
    )

    # Get targets with their positions
    targets = (
        tracking_df
        .filter(pl.col("player_role") == "Targeted Receiver")
        .select([
            "game_id", "play_id", "frame_id",
            pl.col("x").alias("tgt_x"),
            pl.col("y").alias("tgt_y"),
        ])
    )

    # Cross join defenders with targets to compute all pairwise distances
    target_separation_df = (
        defenders.join(targets, on=["game_id", "play_id", "frame_id"], how="inner")
        .with_columns([
            (
                ((pl.col("def_x") - pl.col("tgt_x"))**2 +
                 (pl.col("def_y") - pl.col("tgt_y"))**2)**0.5
            ).alias("separation_dist")
        ])
        .group_by(["game_id", "play_id", "frame_id"])
        .agg([
            pl.min("separation_dist").alias("target_separation")
        ])
    )

    # Join back to main dataframe
    tracking_df = tracking_df.join(
        target_separation_df,
        on=["game_id", "play_id", "frame_id"],
        how="left"
    )

    # FEATURE 3: Sideline Separation
    tracking_df = tracking_df.with_columns([
        pl.min_horizontal(
            pl.col("ball_land_y"),
            53.3 - pl.col("ball_land_y")
        ).alias("sideline_separation")
    ])

    # FEATURE 4 & 5: QB speed (already joined) and Time to throw
    # Get snap frame for each play
    snap_frames = (
        tracking_df
        .group_by(["game_id", "play_id"])
        .agg(pl.col("frame_id").min().alias("snap_frame"))
    )
    tracking_df = tracking_df.join(snap_frames, on=["game_id", "play_id"], how="left")

    tracking_df = tracking_df.with_columns([
        ((pl.col("frame_id") - pl.col("snap_frame")) / 10.0).alias("time_to_throw")
    ])

    # Fill NULLs with sensible defaults
    tracking_df = tracking_df.with_columns([
        pl.col("air_distance").fill_null(0.0),
        pl.col("target_separation").fill_null(999.0),
        pl.col("sideline_separation").fill_null(26.65),
        pl.col("qb_speed").fill_null(0.0),
        pl.col("time_to_throw").fill_null(0.0),
    ])

    print(f"\n=== NGS Features Engineered ===")
    print(f"  - Air distance: QB to ball landing location")
    print(f"  - Target separation: Closest defender to target")
    print(f"  - Sideline separation: Target to nearest sideline")
    print(f"  - QB speed: QB speed at frame")
    print(f"  - Time to throw: Seconds from snap")

    return tracking_df

In [None]:
def engineer_receiver_momentum_features(tracking_df: pl.DataFrame) -> pl.DataFrame:
    """
    Add receiver trajectory and momentum features.

    Features added:
    1. receiver_vx, receiver_vy: Receiver velocity components
    2. receiver_speed: Receiver velocity magnitude
    3. receiver_momentum_alignment: Velocity alignment toward ball landing location
    4. receiver_orientation_alignment: Body orientation toward target
    5. time_since_snap: Temporal context

    Args:
        tracking_df: Tracking data with positions, velocities, roles

    Returns:
        Tracking data with momentum features added
    """
    # Get target receiver data
    target_data = (
        tracking_df
        .filter(pl.col("player_role") == "Targeted Receiver")
        .select([
            "game_id", "play_id", "frame_id",
            pl.col("x").alias("receiver_x"),
            pl.col("y").alias("receiver_y"),
            pl.col("vx").alias("receiver_vx"),
            pl.col("vy").alias("receiver_vy"),
            pl.col("s").alias("receiver_speed"),
            pl.col("ox").alias("receiver_ox"),
            pl.col("oy").alias("receiver_oy"),
        ])
    )

    # Merge back to main dataframe
    tracking_df = tracking_df.join(target_data, on=["game_id", "play_id", "frame_id"], how="left")

    # FEATURE: Vector from receiver to ball landing location
    tracking_df = tracking_df.with_columns([
        (pl.col("ball_land_x") - pl.col("receiver_x")).alias("target_vector_x"),
        (pl.col("ball_land_y") - pl.col("receiver_y")).alias("target_vector_y"),
    ])

    # Compute target vector magnitude
    tracking_df = tracking_df.with_columns([
        ((pl.col("target_vector_x")**2 + pl.col("target_vector_y")**2)**0.5 + 1e-6).alias("target_vector_mag")
    ])

    # FEATURE: Momentum alignment (dot product of velocity and target vector)
    # Normalized to represent how much velocity is directed toward target
    tracking_df = tracking_df.with_columns([
        (
            (pl.col("receiver_vx") * pl.col("target_vector_x") +
             pl.col("receiver_vy") * pl.col("target_vector_y")) /
            pl.col("target_vector_mag")
        ).alias("receiver_momentum_alignment")
    ])

    # FEATURE: Orientation alignment (dot product of orientation and target vector)
    tracking_df = tracking_df.with_columns([
        (
            (pl.col("receiver_ox") * pl.col("target_vector_x") +
             pl.col("receiver_oy") * pl.col("target_vector_y")) /
            pl.col("target_vector_mag")
        ).alias("receiver_orientation_alignment")
    ])

    # FEATURE: Distance from receiver to ball landing (receiver adjustment needed)
    tracking_df = tracking_df.with_columns([
        pl.col("target_vector_mag").alias("receiver_to_ball_distance")
    ])

    # Fill NULLs
    tracking_df = tracking_df.with_columns([
        pl.col("receiver_vx").fill_null(0.0),
        pl.col("receiver_vy").fill_null(0.0),
        pl.col("receiver_speed").fill_null(0.0),
        pl.col("receiver_momentum_alignment").fill_null(0.0),
        pl.col("receiver_orientation_alignment").fill_null(0.0),
        pl.col("receiver_to_ball_distance").fill_null(10.0),
    ])

    # Clean up intermediate columns
    tracking_df = tracking_df.drop(["target_vector_x", "target_vector_y", "target_vector_mag"])

    print(f"\n=== Receiver Momentum Features Engineered ===")
    print(f"  - receiver_vx, receiver_vy: Velocity components")
    print(f"  - receiver_speed: Velocity magnitude")
    print(f"  - receiver_momentum_alignment: Velocity toward target")
    print(f"  - receiver_orientation_alignment: Body facing target")
    print(f"  - receiver_to_ball_distance: Distance to ball landing")

    return tracking_df

In [None]:
def prepare_tracking_data(
    tracking_df: pl.DataFrame,
    supplementary_df: pl.DataFrame
) -> pl.DataFrame:
    """
    Prepare tracking data with all features for completion probability prediction.

    IMPORTANT: Filters to LAST FRAME ONLY (release frame) for training.

    Args:
        tracking_df: Raw tracking data
        supplementary_df: Supplementary data with completion labels and route types

    Returns:
        pl.DataFrame: Prepared features with completion labels and route types
    """
    # Filter out rows where ball_land_x or ball_land_y are null
    tracking_df = tracking_df.filter(
        pl.col("ball_land_x").is_not_null() & pl.col("ball_land_y").is_not_null()
    )
    print(f"After filtering nulls: {len(tracking_df)} rows")
    print(f"Unique plays: {tracking_df.n_unique(['game_id', 'play_id'])}")

    # Filter to plays that have supplementary data
    tracking_df = tracking_df.join(
        supplementary_df.select(["game_id", "play_id"]),
        on=["game_id", "play_id"],
        how="inner"
    )
    print(f"After filtering to plays with labels: {len(tracking_df)} rows")
    print(f"Unique plays: {tracking_df.n_unique(['game_id', 'play_id'])}")

    # ===== FILTER TO LAST FRAME (RELEASE FRAME) ONLY =====
    last_frames = (
        tracking_df
        .group_by(["game_id", "play_id"])
        .agg(pl.col("frame_id").max().alias("max_frame"))
    )
    tracking_df = tracking_df.join(last_frames, on=["game_id", "play_id"], how="inner")
    tracking_df = tracking_df.filter(pl.col("frame_id") == pl.col("max_frame"))
    tracking_df = tracking_df.drop("max_frame")
    print(f"After filtering to release frame: {len(tracking_df)} rows")
    print(f"Unique plays (should match rows/22): {tracking_df.n_unique(['game_id', 'play_id'])}")

    # Add is_release_frame marker
    tracking_df = tracking_df.with_columns([
        pl.lit(True).alias("is_release_frame")
    ])

    # Engineer NGS features
    tracking_df = engineer_ngs_features(tracking_df)

    # Engineer receiver momentum features
    tracking_df = engineer_receiver_momentum_features(tracking_df)

    # Merge with supplementary data (completion labels and route types)
    tracking_df = tracking_df.join(
        supplementary_df.select(["game_id", "play_id", "completion", "route_type"]),
        on=["game_id", "play_id"],
        how="inner"
    )

    print(f"\n=== Final Training Data ===")
    print(f"Total rows: {len(tracking_df)}")
    print(f"Unique plays: {tracking_df.n_unique(['game_id', 'play_id'])}")
    print(f"Completion rate: {tracking_df['completion'].mean():.1%}")

    return tracking_df

In [None]:
def split_train_val_test(
    df: pl.DataFrame,
    train_frac: float = 0.7,
    val_frac: float = 0.15,
    seed: int = 42
) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
    """
    Split data into train, validation, and test sets at the PLAY level.

    Args:
        df: Full dataset
        train_frac: Fraction for training (default 0.7)
        val_frac: Fraction for validation (default 0.15)
        seed: Random seed

    Returns:
        Tuple of (train_df, val_df, test_df)
    """
    # Get unique plays
    plays = df.select(["game_id", "play_id"]).unique(maintain_order=True)

    # Sample for test+val
    test_val_plays = plays.sample(fraction=1.0 - train_frac, seed=seed)
    train_plays = plays.join(test_val_plays, on=["game_id", "play_id"], how="anti")

    # Split test_val into test and val
    test_frac_of_remaining = (1.0 - train_frac - val_frac) / (1.0 - train_frac)
    test_plays = test_val_plays.sample(fraction=test_frac_of_remaining, seed=seed)
    val_plays = test_val_plays.join(test_plays, on=["game_id", "play_id"], how="anti")

    # Split dataframes
    train_df = df.join(train_plays, on=["game_id", "play_id"], how="inner")
    val_df = df.join(val_plays, on=["game_id", "play_id"], how="inner")
    test_df = df.join(test_plays, on=["game_id", "play_id"], how="inner")

    print(f"\n=== Data Split ===")
    print(f"Train: {train_df.n_unique(['game_id', 'play_id'])} plays ({len(train_df)} rows)")
    print(f"Val:   {val_df.n_unique(['game_id', 'play_id'])} plays ({len(val_df)} rows)")
    print(f"Test:  {test_df.n_unique(['game_id', 'play_id'])} plays ({len(test_df)} rows)")

    # Check completion rates in each split
    for name, split_df in [("Train", train_df), ("Val", val_df), ("Test", test_df)]:
        comp_rate = split_df.select("completion").unique()["completion"].mean()
        print(f"{name} completion rate: {comp_rate:.1%}")

    return train_df, val_df, test_df

## Part 3: Load and Prepare Data

In [None]:
# Load data
print("Loading supplementary data...")
supp_df = load_supplementary_data()

print("\nLoading tracking data...")
tracking_df = load_input_data()

In [None]:
# Convert to cartesian and standardize
print("Converting to cartesian coordinates...")
tracking_df = convert_tracking_to_cartesian(tracking_df)

print("Standardizing play directions...")
tracking_df = standardize_tracking_directions(tracking_df)

In [None]:
# Prepare full tracking data with all features
print("Preparing tracking data with features...")
prepared_df = prepare_tracking_data(tracking_df, supp_df)

In [None]:
# Create route encoding
route_to_id = create_route_encoding(supp_df)
id_to_route = {v: k for k, v in route_to_id.items()}
NUM_ROUTES = len(route_to_id)

In [None]:
# Split data
train_df, val_df, test_df = split_train_val_test(prepared_df)

In [None]:
# Verify columns
print("\nColumns in prepared data:")
print(prepared_df.columns)

## Part 4: Dataset Class

In [None]:
class RouteAwareCompletionDataset(Dataset):
    """
    Enhanced dataset that includes route types and momentum features.

    Returns:
        - player_features: (22, 8) tensor of all player kinematics
        - route_id: Integer route type ID
        - momentum_features: (7,) tensor of receiver trajectory features (includes receiver_to_ball_distance)
        - ngs_features: (5,) tensor of NGS completion probability features
        - label: Binary completion label
    """

    def __init__(
        self,
        df: pl.DataFrame,
        route_to_id: dict,
        max_players: int = 22
    ):
        super().__init__()

        self.route_to_id = route_to_id
        self.max_players = max_players

        # Convert to pandas for indexing
        self.df = df.to_pandas()

        # Get unique plays
        self.plays = self.df[['game_id', 'play_id']].drop_duplicates().reset_index(drop=True)

        print(f"Dataset initialized with {len(self.plays)} plays")

    def __len__(self):
        return len(self.plays)

    def __getitem__(self, idx):
        """
        UPDATED: Now includes receiver_to_ball_distance (7 features total)
        """
        play = self.plays.iloc[idx]
        game_id = play['game_id']
        play_id = play['play_id']

        # Get all players for this play
        play_df = self.df[
            (self.df['game_id'] == game_id) &
            (self.df['play_id'] == play_id)
        ].copy()

        # Extract player features: x, y, vx, vy, ox, oy, s, a
        player_cols = ['x', 'y', 'vx', 'vy', 'ox', 'oy', 's', 'a']
        player_features = play_df[player_cols].values.astype(np.float32)

        # Pad or truncate to max_players
        num_players = len(player_features)
        if num_players < self.max_players:
            padding = np.zeros((self.max_players - num_players, len(player_cols)), dtype=np.float32)
            player_features = np.vstack([player_features, padding])
        elif num_players > self.max_players:
            player_features = player_features[:self.max_players]

        # Extract route type
        route_type = play_df['route_type'].iloc[0]
        route_id = self.route_to_id.get(route_type, 0)

        # Extract receiver momentum features - NOW 7 FEATURES
        momentum_cols = [
            'receiver_vx', 'receiver_vy', 'receiver_speed',
            'receiver_momentum_alignment', 'receiver_orientation_alignment',
            'receiver_to_ball_distance',  # CRITICAL: Distance receiver must travel
            'time_to_throw'
        ]
        momentum_features = play_df[momentum_cols].iloc[0].values.astype(np.float32)

        # Extract NGS features
        ngs_cols = [
            'air_distance', 'target_separation', 'sideline_separation',
            'qb_speed', 'time_to_throw'
        ]
        ngs_features = play_df[ngs_cols].iloc[0].values.astype(np.float32)

        # Extract label
        label = play_df['completion'].iloc[0]

        # Handle NaN values
        player_features = np.nan_to_num(player_features, 0.0)
        momentum_features = np.nan_to_num(momentum_features, 0.0)
        ngs_features = np.nan_to_num(ngs_features, 0.0)

        return {
            'player_features': torch.from_numpy(player_features),
            'route_id': torch.tensor(route_id, dtype=torch.long),
            'momentum_features': torch.from_numpy(momentum_features),
            'ngs_features': torch.from_numpy(ngs_features),
            'label': torch.tensor(label, dtype=torch.float32)
        }


In [None]:
# Create datasets
print("Creating datasets...")
train_dataset = RouteAwareCompletionDataset(train_df, route_to_id)
val_dataset = RouteAwareCompletionDataset(val_df, route_to_id)
test_dataset = RouteAwareCompletionDataset(test_df, route_to_id)

In [None]:
# Create dataloaders
BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Test dataset
sample = train_dataset[0]
print("Sample batch shapes:")
for key, value in sample.items():
    print(f"  {key}: {value.shape if hasattr(value, 'shape') else value}")

## Part 5: Model Architecture

In [None]:
class RouteAwareCompletionModel(nn.Module):
    """
    Enhanced completion probability model with route awareness.

    Key Components:
    1. Route type embedding layer (learns route-specific patterns)
    2. Player transformer encoder (captures spatial relationships)
    3. Receiver momentum encoder (captures trajectory information)
    4. NGS feature encoder (scalar completion probability features)
    5. Combined prediction head
    """

    def __init__(
        self,
        num_routes: int,
        route_embedding_dim: int = 32,
        player_feature_dim: int = 256,
        player_input_dim: int = 8,  # x, y, vx, vy, ox, oy, s, a
        momentum_feature_dim: int = 7,
        ngs_feature_dim: int = 5,
        hidden_dim: int = 512,
        num_transformer_layers: int = 4,
        num_attention_heads: int = 8,
        dropout: float = 0.1
    ):
        super().__init__()

        self.num_routes = num_routes
        self.route_embedding_dim = route_embedding_dim
        self.player_feature_dim = player_feature_dim

        # Component 1: Route type embedding
        self.route_embedding = nn.Embedding(
            num_embeddings=num_routes,
            embedding_dim=route_embedding_dim
        )

        # Component 2: Player feature normalization and projection
        self.player_norm = nn.BatchNorm1d(player_input_dim)
        self.player_input_projection = nn.Sequential(
            nn.Linear(player_input_dim, player_feature_dim),
            nn.ReLU(),
            nn.LayerNorm(player_feature_dim),
            nn.Dropout(dropout)
        )

        # Component 3: Transformer encoder for player interactions
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=player_feature_dim,
            nhead=num_attention_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True
        )
        self.player_encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_transformer_layers
        )

        # Component 4: Receiver momentum encoder
        self.momentum_encoder = nn.Sequential(
            nn.Linear(momentum_feature_dim, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        # Component 5: NGS feature encoder
        self.ngs_encoder = nn.Sequential(
            nn.Linear(ngs_feature_dim, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        # Component 6: Completion prediction head
        combined_dim = (
            player_feature_dim +  # From player encoder
            route_embedding_dim +  # From route embedding
            32 +  # From momentum encoder
            32    # From NGS encoder
        )

        self.completion_head = nn.Sequential(
            nn.Linear(combined_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize weights using Xavier initialization."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0, std=0.1)

    def forward(
        self,
        player_features: torch.Tensor,  # (batch, 22, 8)
        route_ids: torch.Tensor,  # (batch,)
        momentum_features: torch.Tensor,  # (batch, 7) - includes receiver_to_ball_distance
        ngs_features: torch.Tensor  # (batch, 5)
    ) -> torch.Tensor:
        """
        Forward pass.

        Returns:
            Logits (not probabilities) of shape (batch,)
        """
        batch_size = player_features.shape[0]

        # 1. Encode route type
        route_embedded = self.route_embedding(route_ids)  # (batch, route_embedding_dim)

        # 2. Encode player features
        # Normalize: need to reshape for BatchNorm1d
        B, P, F = player_features.shape
        player_normed = self.player_norm(
            player_features.permute(0, 2, 1)  # (B, F, P)
        ).permute(0, 2, 1)  # (B, P, F)

        # Project to model dimension
        player_projected = self.player_input_projection(player_normed)  # (B, P, player_feature_dim)

        # Transform with attention
        player_encoded = self.player_encoder(player_projected)  # (B, P, player_feature_dim)

        # Pool over players (mean pooling)
        player_pooled = player_encoded.mean(dim=1)  # (B, player_feature_dim)

        # 3. Encode receiver momentum
        momentum_encoded = self.momentum_encoder(momentum_features)  # (B, 32)

        # 4. Encode NGS features
        ngs_encoded = self.ngs_encoder(ngs_features)  # (B, 32)

        # 5. Combine all features
        combined = torch.cat([
            player_pooled,
            route_embedded,
            momentum_encoded,
            ngs_encoded
        ], dim=1)  # (B, combined_dim)

        # 6. Predict completion (logits)
        logits = self.completion_head(combined).squeeze(-1)  # (B,)

        return logits

    def predict_proba(self, *args, **kwargs) -> torch.Tensor:
        """Get probabilities instead of logits."""
        logits = self.forward(*args, **kwargs)
        return torch.sigmoid(logits)

In [None]:
# Initialize model
model = RouteAwareCompletionModel(
    num_routes=NUM_ROUTES,
    route_embedding_dim=32,
    player_feature_dim=256,
    num_transformer_layers=4,
    num_attention_heads=8,
    dropout=0.1
)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {num_params:,}")

# Move to device
model = model.to(DEVICE)
print(f"Model moved to {DEVICE}")

In [None]:
# Test forward pass
sample_batch = next(iter(train_loader))
with torch.no_grad():
    logits = model(
        sample_batch['player_features'].to(DEVICE),
        sample_batch['route_id'].to(DEVICE),
        sample_batch['momentum_features'].to(DEVICE),
        sample_batch['ngs_features'].to(DEVICE)
    )
print(f"Output shape: {logits.shape}")
print(f"Output range: [{logits.min().item():.3f}, {logits.max().item():.3f}]")

## Part 6: Training Loop

In [None]:
def train_route_aware_model(
    train_loader: DataLoader,
    val_loader: DataLoader,
    model: RouteAwareCompletionModel,
    num_epochs: int = 50,
    learning_rate: float = 1e-4,
    weight_decay: float = 0.01,
    device: str = DEVICE,
    patience: int = 10
):
    """
    Train the route-aware completion probability model.

    Features:
    - Class imbalance handling with pos_weight
    - Learning rate scheduling
    - Early stopping
    - Gradient clipping
    """
    model = model.to(device)

    # Calculate class weights for imbalanced dataset
    num_pos = sum([batch['label'].sum().item() for batch in train_loader])
    num_neg = len(train_loader.dataset) - num_pos
    pos_weight = torch.tensor([0.65]).to(device) #num_neg / num_pos
    print(f"Class imbalance - Pos: {num_pos}, Neg: {num_neg}, Weight: {pos_weight.item():.3f}")

    # Loss function with class weighting
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    # Training tracking
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0

    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_auc': [], 'val_auc': []
    }

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_losses = []
        train_preds = []
        train_labels = []

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Train")
        for batch in pbar:
            # Move to device
            player_features = batch['player_features'].to(device)
            route_ids = batch['route_id'].to(device)
            momentum_features = batch['momentum_features'].to(device)
            ngs_features = batch['ngs_features'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            optimizer.zero_grad()
            logits = model(player_features, route_ids, momentum_features, ngs_features)

            # Compute loss
            loss = criterion(logits, labels)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Track metrics
            train_losses.append(loss.item())
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            train_preds.extend(probs)
            train_labels.extend(labels.cpu().numpy())

            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        # Validation phase
        model.eval()
        val_losses = []
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Val"):
                player_features = batch['player_features'].to(device)
                route_ids = batch['route_id'].to(device)
                momentum_features = batch['momentum_features'].to(device)
                ngs_features = batch['ngs_features'].to(device)
                labels = batch['label'].to(device)

                logits = model(player_features, route_ids, momentum_features, ngs_features)
                loss = criterion(logits, labels)

                val_losses.append(loss.item())
                probs = torch.sigmoid(logits).cpu().numpy()
                val_preds.extend(probs)
                val_labels.extend(labels.cpu().numpy())

        # Calculate metrics
        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)

        train_preds_arr = np.array(train_preds)
        val_preds_arr = np.array(val_preds)
        train_labels_arr = np.array(train_labels)
        val_labels_arr = np.array(val_labels)

        train_acc = accuracy_score(train_labels_arr, train_preds_arr > 0.5)
        val_acc = accuracy_score(val_labels_arr, val_preds_arr > 0.5)

        train_auc = roc_auc_score(train_labels_arr, train_preds_arr)
        val_auc = roc_auc_score(val_labels_arr, val_preds_arr)

        # Store history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['train_auc'].append(train_auc)
        history['val_auc'].append(val_auc)

        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train AUC: {train_auc:.4f}")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f} | Val AUC:   {val_auc:.4f}")

        # Learning rate scheduling
        scheduler.step(val_loss)

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
            print(f"  -> New best model (val_loss: {best_val_loss:.4f})")
        else:
            patience_counter += 1
            print(f"  -> No improvement ({patience_counter}/{patience})")

        if patience_counter >= patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break

    # Load best model
    model.load_state_dict(best_model_state)

    return model, history

In [None]:
# Train model
trained_model, history = train_route_aware_model(
    train_loader,
    val_loader,
    model,
    num_epochs=10,
    learning_rate=1e-4,
    patience=3
)

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss over Epochs')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy over Epochs')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# AUC
axes[2].plot(history['train_auc'], label='Train')
axes[2].plot(history['val_auc'], label='Val')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('AUC-ROC')
axes[2].set_title('AUC-ROC over Epochs')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUT_DIR / 'training_history.png', dpi=150)
plt.show()

## Part 7: Model Evaluation

In [None]:
def evaluate_model(model, test_loader, device=DEVICE):
    """
    Comprehensive model evaluation.
    """
    model.eval()
    all_preds = []
    all_labels = []
    all_route_ids = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            player_features = batch['player_features'].to(device)
            route_ids = batch['route_id'].to(device)
            momentum_features = batch['momentum_features'].to(device)
            ngs_features = batch['ngs_features'].to(device)
            labels = batch['label'].to(device)

            logits = model(player_features, route_ids, momentum_features, ngs_features)
            probs = torch.sigmoid(logits)

            all_preds.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_route_ids.extend(route_ids.cpu().numpy())

    preds = np.array(all_preds)
    labels = np.array(all_labels)
    route_ids = np.array(all_route_ids)

    # Overall metrics
    print("=" * 50)
    print("OVERALL METRICS")
    print("=" * 50)
    print(f"Accuracy:  {accuracy_score(labels, preds > 0.5):.4f}")
    print(f"AUC-ROC:   {roc_auc_score(labels, preds):.4f}")
    print(f"Precision: {precision_score(labels, preds > 0.5):.4f}")
    print(f"Recall:    {recall_score(labels, preds > 0.5):.4f}")
    print(f"F1 Score:  {f1_score(labels, preds > 0.5):.4f}")

    print("\nClassification Report:")
    print(classification_report(labels, preds > 0.5, target_names=['Incomplete', 'Complete']))

    return preds, labels, route_ids

In [None]:
# Evaluate on test set
test_preds, test_labels, test_route_ids = evaluate_model(trained_model, test_loader)

In [None]:
# Route-specific performance
print("\n" + "=" * 50)
print("ROUTE-SPECIFIC PERFORMANCE")
print("=" * 50)

route_metrics = []
for route_id in np.unique(test_route_ids):
    mask = test_route_ids == route_id
    route_name = id_to_route[route_id]
    route_preds = test_preds[mask]
    route_labels = test_labels[mask]

    if len(np.unique(route_labels)) > 1:  # Need both classes for AUC
        auc = roc_auc_score(route_labels, route_preds)
    else:
        auc = np.nan

    acc = accuracy_score(route_labels, route_preds > 0.5)
    comp_rate = route_labels.mean()

    route_metrics.append({
        'route': route_name,
        'n_plays': len(route_labels),
        'completion_rate': comp_rate,
        'accuracy': acc,
        'auc': auc
    })

route_df = pd.DataFrame(route_metrics).sort_values('n_plays', ascending=False)
print(route_df.to_string(index=False))

In [None]:
# Calibration plot
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Calibration curve
prob_true, prob_pred = calibration_curve(test_labels, test_preds, n_bins=10)
axes[0].plot(prob_pred, prob_true, 's-', label='Model')
axes[0].plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
axes[0].set_xlabel('Mean Predicted Probability')
axes[0].set_ylabel('Fraction of Positives')
axes[0].set_title('Calibration Curve')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Prediction distribution
axes[1].hist(test_preds[test_labels == 0], bins=30, alpha=0.5, label='Incomplete', density=True)
axes[1].hist(test_preds[test_labels == 1], bins=30, alpha=0.5, label='Complete', density=True)
axes[1].set_xlabel('Predicted Probability')
axes[1].set_ylabel('Density')
axes[1].set_title('Prediction Distribution by Class')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUT_DIR / 'calibration.png', dpi=150)
plt.show()

In [None]:
# Confusion matrix
cm = confusion_matrix(test_labels, test_preds > 0.5)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Incomplete', 'Complete'],
            yticklabels=['Incomplete', 'Complete'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig(OUT_DIR / 'confusion_matrix.png', dpi=150)
plt.show()

### Evaluate on pretrained model

In [None]:
# Load the saved model
import torch

# Initialize the model architecture (same as training)
model = RouteAwareCompletionModel(
    num_routes=NUM_ROUTES,
    route_embedding_dim=32,
    player_feature_dim=256,
    num_transformer_layers=4,
    num_attention_heads=8,
    dropout=0.1
)

# Initialize model

# Load the checkpoint
MODEL_PATH = 'recap_model.pt'
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)

# Extract just the model state dict
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(DEVICE)

print(f"Model loaded from {MODEL_PATH}")

# Optionally, you can also extract other useful info from the checkpoint:
if 'route_to_id' in checkpoint:
    route_to_id = checkpoint['route_to_id']
    print(f"Loaded route mapping with {len(route_to_id)} routes")

if 'num_routes' in checkpoint:
    print(f"Number of routes: {checkpoint['num_routes']}")

if 'history' in checkpoint:
    print("Training history available in checkpoint")

In [None]:
test_preds, test_labels, test_route_ids = evaluate_model(model, test_loader)

In [None]:
# Enhanced Calibration Plot - Patriots Colors!
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt
import numpy as np

# Set style
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.size'] = 11

# Create figure
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
fig.patch.set_facecolor('#fafafa')

# Patriots Color palette
colors = {
    'model': '#002244',      # Patriots Navy
    'perfect': '#C60C30',    # Patriots Red
    'incomplete': '#C60C30', # Red for incomplete
    'complete': '#002244'    # Blue for complete
}

# ===== LEFT: Calibration Curve =====
ax1 = axes[0]
ax1.set_facecolor('white')

prob_true, prob_pred = calibration_curve(test_labels, test_preds, n_bins=10)

# Plot with enhanced styling
ax1.plot(prob_pred, prob_true, 'o-', color=colors['model'],
         linewidth=2.5, markersize=8, label='Model',
         markeredgecolor='white', markeredgewidth=1.5)
ax1.plot([0, 1], [0, 1], '--', color=colors['perfect'],
         linewidth=2, alpha=0.7, label='Perfect Calibration')

# Fill area between curves
ax1.fill_between(prob_pred, prob_true, prob_pred, alpha=0.1, color=colors['model'])

ax1.set_xlabel('Mean Predicted Probability', fontsize=12, fontweight='600')
ax1.set_ylabel('Fraction of Positives', fontsize=12, fontweight='600')
ax1.set_title('Calibration Curve', fontsize=14, fontweight='700', pad=15)
ax1.legend(loc='upper left', frameon=True, shadow=True, fontsize=11)
ax1.grid(True, alpha=0.25, linestyle='-', linewidth=0.5)
ax1.set_xlim(-0.02, 1.02)
ax1.set_ylim(-0.02, 1.02)

# Add diagonal reference lines
for spine in ax1.spines.values():
    spine.set_edgecolor('#cccccc')
    spine.set_linewidth(1.2)

# ===== RIGHT: Prediction Distribution =====
ax2 = axes[1]
ax2.set_facecolor('white')

# Create histograms with Patriots colors
n_incomplete, bins_incomplete, _ = ax2.hist(
    test_preds[test_labels == 0], bins=35, alpha=0.65,
    label='Incomplete', density=True, color=colors['incomplete'],
    edgecolor='white', linewidth=0.5
)

n_complete, bins_complete, _ = ax2.hist(
    test_preds[test_labels == 1], bins=35, alpha=0.65,
    label='Complete', density=True, color=colors['complete'],
    edgecolor='white', linewidth=0.5
)

# Add vertical line at 0.5 threshold
ax2.axvline(x=0.5, color='#B0B7BC', linestyle='--', linewidth=2,
            alpha=0.8, label='Threshold (0.5)')

ax2.set_xlabel('Predicted Probability', fontsize=12, fontweight='600')
ax2.set_ylabel('Density', fontsize=12, fontweight='600')
ax2.set_title('Prediction Distribution by Class', fontsize=14, fontweight='700', pad=15)
ax2.legend(loc='upper center', frameon=True, shadow=True, fontsize=11)
ax2.grid(True, alpha=0.25, linestyle='-', linewidth=0.5, axis='y')
ax2.set_xlim(-0.02, 1.02)

for spine in ax2.spines.values():
    spine.set_edgecolor('#cccccc')
    spine.set_linewidth(1.2)

plt.tight_layout(pad=2.5)
plt.savefig(OUT_DIR / 'enhanced_calibration.png', dpi=300, bbox_inches='tight', facecolor='#fafafa')
plt.show()

# Print calibration statistics
print("\n" + "="*50)
print("CALIBRATION STATISTICS")
print("="*50)
print(f"Mean Predicted Probability: {test_preds.mean():.4f}")
print(f"Actual Positive Rate: {test_labels.mean():.4f}")
print(f"Calibration Error: {abs(test_preds.mean() - test_labels.mean()):.4f}")

In [None]:
# Enhanced Prediction Distribution - Patriots Colors!
import matplotlib.pyplot as plt
import numpy as np

# Set style
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.size'] = 11

# Create figure
fig, ax = plt.subplots(figsize=(8, 6))
fig.patch.set_facecolor('#fafafa')

# Patriots Color palette
colors = {
    'incomplete': '#C60C30', # Red for incomplete
    'complete': '#002244',   # Blue for complete
    'threshold': '#B0B7BC'   # Silver for threshold
}

# Set background
ax.set_facecolor('white')

# Create histograms with Patriots colors
n_incomplete, bins_incomplete, _ = ax.hist(
    test_preds[test_labels == 0], bins=35, alpha=0.65,
    label='Incomplete', density=True, color=colors['incomplete'],
    edgecolor='white', linewidth=0.5
)

n_complete, bins_complete, _ = ax.hist(
    test_preds[test_labels == 1], bins=35, alpha=0.65,
    label='Complete', density=True, color=colors['complete'],
    edgecolor='white', linewidth=0.5
)

# Add vertical line at 0.5 threshold
ax.axvline(x=0.5, color=colors['threshold'], linestyle='--', linewidth=2,
           alpha=0.8, label='Threshold (0.5)')

ax.set_xlabel('Predicted Probability', fontsize=12, fontweight='600')
ax.set_ylabel('Density', fontsize=12, fontweight='600')
ax.set_title('Prediction Distribution by Class', fontsize=14, fontweight='700', pad=15)
ax.legend(loc='upper center', frameon=True, shadow=True, fontsize=11)
ax.grid(True, alpha=0.25, linestyle='-', linewidth=0.5, axis='y')
ax.set_xlim(-0.02, 1.02)

for spine in ax.spines.values():
    spine.set_edgecolor('#cccccc')
    spine.set_linewidth(1.2)

plt.tight_layout(pad=1.5)
plt.savefig(OUT_DIR / 'prediction_distribution.png', dpi=300, bbox_inches='tight', facecolor='#fafafa')
plt.show()

# Print distribution statistics
print("\n" + "="*50)
print("DISTRIBUTION STATISTICS")
print("="*50)
print(f"Mean Predicted Probability: {test_preds.mean():.4f}")
print(f"Actual Positive Rate: {test_labels.mean():.4f}")
print(f"Predictions > 0.5: {(test_preds > 0.5).mean():.4f}")

In [None]:
# Enhanced Prediction Distribution
import matplotlib.pyplot as plt
import numpy as np

# Set style
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.size'] = 11

# Create figure
fig, ax = plt.subplots(figsize=(8, 6))
fig.patch.set_facecolor('#fafafa')

# Patriots Color palette
colors = {
    'incomplete': '#C60C30', # Red for incomplete
    'complete': '#002244',   # Blue for complete
    'threshold': '#B0B7BC'   # Silver for threshold
}

# Set background
ax.set_facecolor('white')

# Create histograms with
n_incomplete, bins_incomplete, _ = ax.hist(
    test_preds[test_labels == 0], bins=35, alpha=0.65,
    label='Incomplete', density=True, color=colors['incomplete'],
    edgecolor='white', linewidth=0.5
)

n_complete, bins_complete, _ = ax.hist(
    test_preds[test_labels == 1], bins=35, alpha=0.65,
    label='Complete', density=True, color=colors['complete'],
    edgecolor='white', linewidth=0.5
)

# Add vertical line at 0.5 threshold
ax.axvline(x=0.5, color=colors['threshold'], linestyle='--', linewidth=2,
           alpha=0.8, label='Threshold (0.5)')

ax.set_xlabel('Predicted Probability', fontsize=12, fontweight='600')
ax.set_ylabel('Density', fontsize=12, fontweight='600')
ax.set_title('Prediction Distribution by Class', fontsize=14, fontweight='700', pad=15)
ax.legend(loc='upper center', frameon=True, shadow=True, fontsize=11)
ax.grid(True, alpha=0.25, linestyle='-', linewidth=0.5, axis='y')
ax.set_xlim(-0.02, 1.02)

for spine in ax.spines.values():
    spine.set_edgecolor('#cccccc')
    spine.set_linewidth(1.2)

plt.tight_layout(pad=1.5)
plt.savefig(OUT_DIR / 'prediction_distribution.png', dpi=300, bbox_inches='tight', facecolor='#fafafa')
plt.show()

# Print distribution statistics
print("\n" + "="*50)
print("DISTRIBUTION STATISTICS")
print("="*50)
print(f"Mean Predicted Probability: {test_preds.mean():.4f}")
print(f"Actual Positive Rate: {test_labels.mean():.4f}")
print(f"Predictions > 0.5: {(test_preds > 0.5).mean():.4f}")

In [None]:
# Confusion matrix
cm = confusion_matrix(test_labels, test_preds > 0.5)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Incomplete', 'Complete'],
            yticklabels=['Incomplete', 'Complete'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig(OUT_DIR / 'confusion_matrix.png', dpi=150)
plt.show()

## Part 8: Counterfactual Prediction Wrapper

In [None]:
class CounterfactualCompletionPredictor:
    """
    Wrapper for making counterfactual predictions using the trained model.

    Enables "what if" scenario analysis:
    - What if the QB threw earlier/later?
    - What if the QB threw to a different location?
    """

    def __init__(
        self,
        model: RouteAwareCompletionModel,
        route_to_id: dict,
        device: str = DEVICE
    ):
        self.model = model.to(device)
        self.model.eval()
        self.route_to_id = route_to_id
        self.device = device

    def predict_at_frame(
        self,
        play_data: pd.DataFrame,
        hypothetical_frame: int,
        target_x: float,
        target_y: float
    ) -> Dict[str, float]:
        """
        Predict completion probability if ball thrown at hypothetical_frame to (target_x, target_y).

        Uses constant velocity projection for player positions.

        Args:
            play_data: Full play tracking data (all frames)
            hypothetical_frame: Frame ID where we imagine the throw happens
            target_x: Ball landing x-coordinate
            target_y: Ball landing y-coordinate

        Returns:
            dict with completion_prob and auxiliary info
        """
        # Get player positions at hypothetical frame
        frame_data = play_data[play_data['frame_id'] == hypothetical_frame].copy()

        if len(frame_data) == 0:
            raise ValueError(f"Frame {hypothetical_frame} not found in play data")

        # Get QB data
        qb_data = frame_data[frame_data['player_role'] == 'Passer']
        if len(qb_data) == 0:
            raise ValueError("No passer found in frame data")
        qb_data = qb_data.iloc[0]

        qb_x, qb_y = qb_data['x'], qb_data['y']

        # Calculate air distance
        air_distance = np.sqrt((target_x - qb_x)**2 + (target_y - qb_y)**2)

        # Estimate hang time (~0.08 sec/yard)
        hang_time = 0.08 * air_distance

        # Get receiver data
        receiver_data = frame_data[frame_data['player_role'] == 'Targeted Receiver']
        if len(receiver_data) == 0:
            # Use average values if no receiver
            receiver_vx, receiver_vy = 0.0, 0.0
            receiver_speed = 0.0
            receiver_x, receiver_y = target_x, target_y
            receiver_ox, receiver_oy = 1.0, 0.0
        else:
            receiver_data = receiver_data.iloc[0]
            receiver_x, receiver_y = receiver_data['x'], receiver_data['y']
            receiver_vx, receiver_vy = receiver_data['vx'], receiver_data['vy']
            receiver_speed = receiver_data['s']
            receiver_ox = receiver_data.get('ox', 1.0)
            receiver_oy = receiver_data.get('oy', 0.0)

        # Calculate receiver momentum alignment
        target_vec_x = target_x - receiver_x
        target_vec_y = target_y - receiver_y
        target_vec_mag = np.sqrt(target_vec_x**2 + target_vec_y**2) + 1e-6

        momentum_alignment = (receiver_vx * target_vec_x + receiver_vy * target_vec_y) / target_vec_mag
        orientation_alignment = (receiver_ox * target_vec_x + receiver_oy * target_vec_y) / target_vec_mag
        receiver_to_ball_distance = target_vec_mag  # How far receiver must travel to catch

        # Get defender positions and calculate target separation
        defenders = frame_data[frame_data['player_role'] == 'Defensive Coverage']
        if len(defenders) > 0:
            def_dists = np.sqrt(
                (defenders['x'] - target_x)**2 +
                (defenders['y'] - target_y)**2
            )
            target_separation = def_dists.min()
        else:
            target_separation = 10.0

        # Calculate sideline separation
        sideline_separation = min(target_y, 53.3 - target_y)

        # Time calculations
        snap_frame = play_data['frame_id'].min()
        time_to_throw = (hypothetical_frame - snap_frame) / 10.0

        # Get route type
        route_type = play_data['route_type'].iloc[0] if 'route_type' in play_data.columns else 'UNKNOWN'
        route_id = self.route_to_id.get(route_type, 0)

        # Prepare features
        # Player features
        player_cols = ['x', 'y', 'vx', 'vy', 'ox', 'oy', 's', 'a']
        player_features = frame_data[player_cols].values.astype(np.float32)

        # Pad to 22 players
        if len(player_features) < 22:
            padding = np.zeros((22 - len(player_features), 8), dtype=np.float32)
            player_features = np.vstack([player_features, padding])
        elif len(player_features) > 22:
            player_features = player_features[:22]

        player_features = np.nan_to_num(player_features, 0.0)

        # Momentum features
        momentum_features = np.array([
            receiver_vx, receiver_vy, receiver_speed,
            momentum_alignment, orientation_alignment, receiver_to_ball_distance,  # Match training!
            time_to_throw
        ], dtype=np.float32)
        momentum_features = np.nan_to_num(momentum_features, 0.0)

        # NGS features
        ngs_features = np.array([
            air_distance,
            target_separation,
            sideline_separation,
            qb_data['s'],
            time_to_throw
        ], dtype=np.float32)
        ngs_features = np.nan_to_num(ngs_features, 0.0)

        # Convert to tensors
        player_features = torch.tensor(player_features, dtype=torch.float32).unsqueeze(0).to(self.device)
        route_id_tensor = torch.tensor([route_id], dtype=torch.long).to(self.device)
        momentum_features = torch.tensor(momentum_features, dtype=torch.float32).unsqueeze(0).to(self.device)
        ngs_features = torch.tensor(ngs_features, dtype=torch.float32).unsqueeze(0).to(self.device)

        # Predict
        with torch.no_grad():
            logits = self.model(player_features, route_id_tensor, momentum_features, ngs_features)
            completion_prob = torch.sigmoid(logits).item()

        return {
            'completion_prob': completion_prob,
            'air_distance': air_distance,
            'target_separation': target_separation,
            'time_to_throw': time_to_throw,
            'momentum_alignment': momentum_alignment
        }

    def analyze_play_timing(
        self,
        play_data: pd.DataFrame,
        target_x: float = None,
        target_y: float = None,
        window: int = 10
    ) -> pd.DataFrame:
        """
        Analyze completion probability across different throw timings.

        Args:
            play_data: Full play tracking data
            target_x, target_y: Target location (defaults to actual ball landing)
            window: Number of frames before/after actual release to analyze

        Returns:
            DataFrame with completion probabilities for each frame
        """
        # Get actual target location if not provided
        if target_x is None:
            target_x = play_data['ball_land_x'].iloc[0]
        if target_y is None:
            target_y = play_data['ball_land_y'].iloc[0]

        # Get frame range
        min_frame = play_data['frame_id'].min()
        max_frame = play_data['frame_id'].max()

        # Analyze frames
        results = []
        for frame_id in range(max(min_frame + 5, max_frame - window), max_frame + 1):
            try:
                pred = self.predict_at_frame(play_data, frame_id, target_x, target_y)
                results.append({
                    'frame_id': frame_id,
                    'time_since_snap': (frame_id - min_frame) / 10.0,
                    'completion_prob': pred['completion_prob'],
                    'is_actual_release': frame_id == max_frame
                })
            except Exception as e:
                print(f"Warning: Could not analyze frame {frame_id}: {e}")

        return pd.DataFrame(results)

    def analyze_play_spatial(
        self,
        play_data: pd.DataFrame,
        release_frame: int = None,
        x_range: tuple = None,
        y_range: tuple = (5, 48),
        grid_size: float = 3.0
    ) -> Dict:
        """
        Create spatial heatmap of completion probabilities.

        Args:
            play_data: Full play tracking data
            release_frame: Frame to analyze (default: last frame)
            x_range: (min_x, max_x) for grid (default: around ball landing)
            y_range: (min_y, max_y) for grid
            grid_size: Grid spacing in yards

        Returns:
            dict with heatmap data and optimal location
        """
        if release_frame is None:
            release_frame = play_data['frame_id'].max()

        # Default x_range around ball landing
        actual_x = play_data['ball_land_x'].iloc[0]
        actual_y = play_data['ball_land_y'].iloc[0]

        if x_range is None:
            x_range = (max(10, actual_x - 15), min(110, actual_x + 15))

        # Create grid
        x_grid = np.arange(x_range[0], x_range[1], grid_size)
        y_grid = np.arange(y_range[0], y_range[1], grid_size)

        heatmap = np.zeros((len(y_grid), len(x_grid)))

        for i, y in enumerate(tqdm(y_grid, desc="Generating spatial heatmap")):
            for j, x in enumerate(x_grid):
                try:
                    pred = self.predict_at_frame(play_data, release_frame, x, y)
                    heatmap[i, j] = pred['completion_prob']
                except:
                    heatmap[i, j] = np.nan

        # Find optimal location
        valid_mask = ~np.isnan(heatmap)
        if valid_mask.any():
            max_idx = np.unravel_index(np.nanargmax(heatmap), heatmap.shape)
            optimal_x = x_grid[max_idx[1]]
            optimal_y = y_grid[max_idx[0]]
            optimal_prob = heatmap[max_idx]
        else:
            optimal_x, optimal_y, optimal_prob = actual_x, actual_y, np.nan

        return {
            'heatmap': heatmap,
            'x_grid': x_grid,
            'y_grid': y_grid,
            'optimal_x': optimal_x,
            'optimal_y': optimal_y,
            'optimal_prob': optimal_prob,
            'actual_x': actual_x,
            'actual_y': actual_y
        }

In [None]:
# Create predictor
predictor = CounterfactualCompletionPredictor(trained_model, route_to_id)

## Part 9: Visualization Functions

In [None]:
def plot_temporal_analysis(results: pd.DataFrame, save_path: Path = None):
    """
    Plot completion probability over time.
    """
    fig, ax = plt.subplots(figsize=(12, 6))

    ax.plot(results['time_since_snap'], results['completion_prob'],
            linewidth=2, marker='o', label='Completion Probability')

    # Mark actual release
    actual_release = results[results['is_actual_release']]
    if len(actual_release) > 0:
        ax.scatter(actual_release['time_since_snap'], actual_release['completion_prob'],
                  s=200, c='red', marker='*', label='Actual Release', zorder=5)

    # Mark optimal release
    optimal_idx = results['completion_prob'].idxmax()
    optimal = results.loc[optimal_idx]
    ax.scatter(optimal['time_since_snap'], optimal['completion_prob'],
              s=200, c='green', marker='*', label='Optimal Release', zorder=5)

    ax.set_xlabel('Time Since Snap (seconds)', fontsize=12)
    ax.set_ylabel('Completion Probability', fontsize=12)
    ax.set_title('Temporal Analysis: When Should QB Throw?', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 1)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150)
    plt.show()

    if len(actual_release) > 0:
        print(f"Actual Release: {actual_release['time_since_snap'].iloc[0]:.2f}s, "
              f"Prob: {actual_release['completion_prob'].iloc[0]:.1%}")
    print(f"Optimal Release: {optimal['time_since_snap']:.2f}s, "
          f"Prob: {optimal['completion_prob']:.1%}")

In [None]:
"""
Interactive Completion Probability Heatmap Viewer - Colab Compatible

Creates an interactive widget to scroll through completion probability heatmaps across frames.
Shows where the QB should throw at each moment in the play.
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from ipywidgets import IntSlider, Layout, VBox, HBox, Button, Output, Checkbox, HTML, FloatSlider
from IPython.display import display, clear_output
import time


def compute_completion_heatmap(
    predictor,
    play_data: pd.DataFrame,
    frame_id: int,
    grid_size: float = 2.0,
    x_range: tuple = None,
    y_range: tuple = None
) -> dict:
    """
    Compute completion probability heatmap for a given frame.

    Args:
        predictor: CounterfactualCompletionPredictor instance
        play_data: Full play tracking data
        frame_id: Frame to analyze
        grid_size: Grid spacing in yards
        x_range: (min_x, max_x) - defaults to full field
        y_range: (min_y, max_y) - defaults to full field

    Returns:
        dict with heatmap data
    """
    # Default to full field
    if x_range is None:
        x_range = (0, 120)
    if y_range is None:
        y_range = (0, 53.3)

    # Get actual ball landing location
    actual_x = play_data['ball_land_x'].iloc[0]
    actual_y = play_data['ball_land_y'].iloc[0]

    # Create grid
    x_grid = np.arange(x_range[0], x_range[1] + grid_size, grid_size)
    y_grid = np.arange(y_range[0], y_range[1] + grid_size, grid_size)

    heatmap = np.zeros((len(y_grid), len(x_grid)))

    # Compute completion probability for each grid location
    for i, y in enumerate(y_grid):
        for j, x in enumerate(x_grid):
            try:
                pred = predictor.predict_at_frame(play_data, frame_id, x, y)
                heatmap[i, j] = pred['completion_prob']
            except Exception:
                heatmap[i, j] = 0.0

    # Find optimal location
    max_idx = np.unravel_index(np.argmax(heatmap), heatmap.shape)
    optimal_y = y_grid[max_idx[0]]
    optimal_x = x_grid[max_idx[1]]
    optimal_prob = heatmap[max_idx]

    # Get frame data for player positions
    frame_data = play_data[play_data['frame_id'] == frame_id]

    return {
        'heatmap': heatmap,
        'x_grid': x_grid,
        'y_grid': y_grid,
        'optimal_x': optimal_x,
        'optimal_y': optimal_y,
        'optimal_prob': optimal_prob,
        'actual_x': actual_x,
        'actual_y': actual_y,
        'frame_data': frame_data
    }


def plot_completion_heatmap_static(
    heatmap_results: dict,
    frame_id: int,
    show_players: bool = True,
    show_optimal: bool = True,
    show_actual: bool = True
):
    """
    Plot completion probability heatmap for a single frame.

    Args:
        heatmap_results: Dict from compute_completion_heatmap
        frame_id: Frame ID for title
        show_players: Whether to overlay player positions
        show_optimal: Whether to mark optimal throw location
        show_actual: Whether to mark actual throw location

    Returns:
        matplotlib figure
    """
    fig, ax = plt.subplots(figsize=(20, 10))
    fig.patch.set_facecolor('#2F5233')

    # Plot heatmap
    im = ax.contourf(
        heatmap_results['x_grid'],
        heatmap_results['y_grid'],
        heatmap_results['heatmap'],
        levels=20,
        cmap='RdYlGn',
        alpha=0.7
    )

    # Add player positions
    if show_players and 'frame_data' in heatmap_results:
        frame_data = heatmap_results['frame_data']

        # Offensive players (blue)
        offense = frame_data[frame_data['player_side'] == 'Offense']
        if len(offense) > 0:
            target = offense[offense['player_role'] == 'Targeted Receiver']
            if len(target) > 0:
                ax.scatter(target['x'], target['y'],
                          s=500, c='darkblue', marker='*',
                          label='Targeted Receiver', zorder=10,
                          edgecolors='white', linewidths=2)

            passer = offense[offense['player_role'] == 'Passer']
            if len(passer) > 0:
                ax.scatter(passer['x'], passer['y'],
                          s=400, c='blue', marker='^',
                          label='Passer', zorder=10,
                          edgecolors='white', linewidths=2)

            other_offense = offense[
                (offense['player_role'] != 'Targeted Receiver') &
                (offense['player_role'] != 'Passer')
            ]
            if len(other_offense) > 0:
                ax.scatter(other_offense['x'], other_offense['y'],
                          s=200, c='lightblue', marker='o',
                          label='Other Offense', zorder=9,
                          edgecolors='white', linewidths=1.5)

        # Defensive players (red)
        defense = frame_data[frame_data['player_side'] == 'Defense']
        if len(defense) > 0:
            coverage = defense[defense['player_role'] == 'Defensive Coverage']
            if len(coverage) > 0:
                ax.scatter(coverage['x'], coverage['y'],
                          s=200, c='red', marker='D',
                          label='Coverage', zorder=9,
                          edgecolors='white', linewidths=1.5)

            rush = defense[defense['player_role'] == 'Pass Rush']
            if len(rush) > 0:
                ax.scatter(rush['x'], rush['y'],
                          s=150, c='darkred', marker='s',
                          label='Pass Rush', zorder=8,
                          edgecolors='white', linewidths=1.5)

    # Mark optimal location
    if show_optimal:
        ax.scatter(heatmap_results['optimal_x'], heatmap_results['optimal_y'],
                  s=600, c='lime', marker='*',
                  label=f"Optimal ({heatmap_results['optimal_prob']:.1%})",
                  zorder=15, edgecolors='black', linewidths=3)

    # Mark actual throw
    if show_actual:
        ax.scatter(heatmap_results['actual_x'], heatmap_results['actual_y'],
                  s=600, c='yellow', marker='X',
                  label='Actual Target', zorder=15,
                  edgecolors='black', linewidths=3)

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Completion Probability', fontsize=14, weight='bold', color='white')
    cbar.ax.yaxis.set_tick_params(color='white')
    plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white')

    # Field markings
    ax.axhline(y=0, color='white', linewidth=4, linestyle='--', alpha=0.5)
    ax.axhline(y=53.3, color='white', linewidth=4, linestyle='--', alpha=0.5)
    ax.axvline(x=0, color='white', linewidth=4, linestyle='-', alpha=0.8)
    ax.axvline(x=120, color='white', linewidth=4, linestyle='-', alpha=0.8)

    # Yard lines
    for x in range(10, 120, 10):
        ax.axvline(x=x, color='white', linewidth=2, alpha=0.6)

    for x in range(5, 120, 5):
        if x % 10 != 0:
            ax.axvline(x=x, color='white', linewidth=1, alpha=0.3)

    # Yard markers
    for x in range(10, 120, 10):
        yard_num = min(x, 120 - x) if x <= 60 else 120 - x
        label = "50" if yard_num == 50 else str(yard_num) if yard_num > 0 else "G"
        ax.text(x, -2, label, ha='center', va='top', fontsize=12,
                color='white', weight='bold')
        ax.text(x, 55.3, label, ha='center', va='bottom', fontsize=12,
                color='white', weight='bold')

    # Styling
    ax.set_xlim(-2, 122)
    ax.set_ylim(-3, 56.3)
    ax.set_aspect('equal')
    ax.set_facecolor('#2F5233')

    ax.set_xlabel('Field Position (yards)', fontsize=16, weight='bold', color='white')
    ax.set_ylabel('Field Width (yards)', fontsize=16, weight='bold', color='white')

    # Title with stats
    title = (f'Completion Probability Heatmap: Frame {frame_id}\n'
             f'Optimal: {heatmap_results["optimal_prob"]:.1%} at '
             f'({heatmap_results["optimal_x"]:.1f}, {heatmap_results["optimal_y"]:.1f})')
    ax.set_title(title, fontsize=18, weight='bold', color='white', pad=20)

    # Legend
    if show_players:
        legend = ax.legend(fontsize=11, loc='upper center', ncol=4,
                          bbox_to_anchor=(0.5, -0.05),
                          framealpha=0.9, edgecolor='white')
        legend.get_frame().set_facecolor('#2F5233')
        for text in legend.get_texts():
            text.set_color('white')

    plt.tight_layout()
    return fig


class CompletionHeatmapViewer:
    """Interactive viewer for completion probability heatmaps across frames"""

    def __init__(self, predictor, play_data: pd.DataFrame, grid_size: float = 2.0):
        """
        Initialize the viewer.

        Args:
            predictor: CounterfactualCompletionPredictor instance
            play_data: Full play tracking data
            grid_size: Grid spacing in yards
        """
        self.predictor = predictor
        self.play_data = play_data
        self.grid_size = grid_size

        # Get frame range
        self.frames = sorted(play_data['frame_id'].unique())
        self.min_frame = self.frames[0]
        self.max_frame = self.frames[-1]

        # Cache computed heatmaps
        self.heatmap_cache = {}

        # Animation state
        self.playing = False

        # Create widgets
        self._create_widgets()

    def _create_widgets(self):
        """Create interactive widgets"""
        # Frame slider
        self.frame_slider = IntSlider(
            value=self.max_frame,  # Start at release frame
            min=self.min_frame,
            max=self.max_frame,
            step=1,
            description='Frame:',
            style={'description_width': '60px'},
            layout=Layout(width='80%')
        )

        # Play/Pause button
        self.play_button = Button(
            description='▶ Play',
            button_style='success',
            layout=Layout(width='100px')
        )
        self.play_button.on_click(self._toggle_play)

        # Speed slider
        self.speed_slider = IntSlider(
            value=5,
            min=1,
            max=10,
            step=1,
            description='Speed:',
            style={'description_width': '60px'},
            layout=Layout(width='300px')
        )

        # Show optimal checkbox
        self.optimal_checkbox = Checkbox(
            value=True,
            description='Show Optimal',
            layout=Layout(width='150px')
        )

        # Show actual checkbox
        self.actual_checkbox = Checkbox(
            value=True,
            description='Show Actual',
            layout=Layout(width='150px')
        )

        # Frame info label
        self.info_label = HTML(
            value=f"<b>Frame {self.max_frame} / {self.max_frame}</b>",
            layout=Layout(width='200px')
        )

        # Output for plot
        self.output = Output()

        # Connect observers
        self.frame_slider.observe(self._on_frame_change, names='value')
        self.optimal_checkbox.observe(self._on_display_change, names='value')
        self.actual_checkbox.observe(self._on_display_change, names='value')

    def _compute_heatmap(self, frame_id: int):
        """Compute heatmap for a frame (with caching)"""
        if frame_id not in self.heatmap_cache:
            print(f"Computing heatmap for frame {frame_id}...")
            self.heatmap_cache[frame_id] = compute_completion_heatmap(
                self.predictor,
                self.play_data,
                frame_id=frame_id,
                grid_size=self.grid_size
            )
        return self.heatmap_cache[frame_id]

    def _on_frame_change(self, change):
        """Handle frame slider change"""
        frame_id = change['new']
        self._update_plot(frame_id)

    def _on_display_change(self, change):
        """Handle display checkbox changes"""
        frame_id = self.frame_slider.value
        self._update_plot(frame_id)

    def _update_plot(self, frame_id: int):
        """Update the plot for a given frame"""
        with self.output:
            clear_output(wait=True)

            # Compute heatmap
            heatmap_results = self._compute_heatmap(frame_id)

            # Create and display plot
            fig = plot_completion_heatmap_static(
                heatmap_results,
                frame_id=frame_id,
                show_players=True,
                show_optimal=self.optimal_checkbox.value,
                show_actual=self.actual_checkbox.value
            )

            # Update info label
            self.info_label.value = f"<b>Frame {frame_id} / {self.max_frame}</b>"

            plt.show()
            plt.close(fig)

    def _toggle_play(self, button):
        """Toggle play/pause animation"""
        if not self.playing:
            # Start playing
            self.playing = True
            self.play_button.description = '⏸ Pause'
            self.play_button.button_style = 'warning'
            self._play_animation()
        else:
            # Stop playing
            self.playing = False
            self.play_button.description = '▶ Play'
            self.play_button.button_style = 'success'

    def _play_animation(self):
        """Play animation"""
        current_frame_idx = self.frames.index(self.frame_slider.value)

        while self.playing and current_frame_idx < len(self.frames) - 1:
            current_frame_idx += 1
            frame_id = self.frames[current_frame_idx]

            # Update slider (triggers plot update)
            self.frame_slider.value = frame_id

            # Delay based on speed
            delay = 1.0 / self.speed_slider.value
            time.sleep(delay)

        # Animation finished
        if self.playing:
            self.playing = False
            self.play_button.description = '▶ Play'
            self.play_button.button_style = 'success'

    def display(self):
        """Display the interactive viewer"""
        # Layout controls
        controls_top = HBox([
            self.play_button,
            self.speed_slider,
            self.optimal_checkbox,
            self.actual_checkbox,
            self.info_label
        ])

        controls_bottom = HBox([self.frame_slider])

        # Combine everything
        viewer = VBox([
            controls_top,
            controls_bottom,
            self.output
        ])

        # Initial plot
        self._update_plot(self.max_frame)  # Start at release frame

        # Display
        display(viewer)


# =============================================================================
# USAGE FUNCTION
# =============================================================================

def create_completion_heatmap_viewer(
    predictor,
    play_data: pd.DataFrame,
    grid_size: float = 2.0
):
    """
    Create and display an interactive completion probability heatmap viewer.

    Args:
        predictor: CounterfactualCompletionPredictor instance
        play_data: DataFrame with tracking data for the play
        grid_size: Grid spacing in yards (default: 2.0)

    Example:
        >>> # Get a sample play
        >>> sample_play_info = test_plays.iloc[0]
        >>> sample_play_data = raw_tracking.filter(
        ...     (pl.col('game_id') == sample_play_info['game_id']) &
        ...     (pl.col('play_id') == sample_play_info['play_id'])
        ... ).to_pandas()
        >>>
        >>> # Create interactive viewer
        >>> viewer = create_completion_heatmap_viewer(
        ...     predictor=predictor,
        ...     play_data=sample_play_data,
        ...     grid_size=2.0
        ... )

    Features:
        ✅ Drag the slider to scroll through frames
        ✅ Click "Play" to animate automatically
        ✅ Adjust speed with speed slider (1-10 fps)
        ✅ Toggle optimal/actual throw markers
        ✅ See how completion probability evolves as play develops
        ✅ Works in Colab/Jupyter without special backends!
    """
    viewer = CompletionHeatmapViewer(predictor, play_data, grid_size=grid_size)
    viewer.display()
    return viewer

## Part 10: Sample Play Analysis

In [None]:
# Get a sample play for analysis
# We need the full tracking data (not just release frame) for counterfactual analysis

# Load raw tracking data for a sample play
raw_tracking = load_input_data()
raw_tracking = convert_tracking_to_cartesian(raw_tracking)
raw_tracking = standardize_tracking_directions(raw_tracking)

# Get a sample play that exists in both raw tracking and test set
test_plays = test_df.select(['game_id', 'play_id']).unique().to_pandas()
sample_play_info = test_plays.iloc[2]

print(f"Analyzing play: game_id={sample_play_info['game_id']}, play_id={sample_play_info['play_id']}")

In [None]:
# Get full tracking data for this play
sample_play_data = raw_tracking.filter(
    (pl.col('game_id') == sample_play_info['game_id']) &
    (pl.col('play_id') == sample_play_info['play_id'])
).to_pandas()

# Add route type from supplementary data
play_supp = supp_df.filter(
    (pl.col('game_id') == sample_play_info['game_id']) &
    (pl.col('play_id') == sample_play_info['play_id'])
).to_pandas()

if len(play_supp) > 0:
    sample_play_data['route_type'] = play_supp['route_type'].iloc[0]
    sample_play_data['completion'] = play_supp['completion'].iloc[0]
else:
    sample_play_data['route_type'] = 'UNKNOWN'
    sample_play_data['completion'] = np.nan

print(f"Play frames: {sample_play_data['frame_id'].min()} to {sample_play_data['frame_id'].max()}")
print(f"Route type: {sample_play_data['route_type'].iloc[0]}")
print(f"Actual completion: {sample_play_data['completion'].iloc[0]}")
print(f"Ball landing: ({sample_play_data['ball_land_x'].iloc[0]:.1f}, {sample_play_data['ball_land_y'].iloc[0]:.1f})")

In [None]:
# Temporal analysis
timing_results = predictor.analyze_play_timing(sample_play_data, window=15)
print(timing_results)

In [None]:
# Plot temporal analysis
plot_temporal_analysis(timing_results, save_path=OUT_DIR / 'temporal_analysis.png')

In [None]:
# Get a sample play
sample_play_info = test_plays.iloc[4]

# Get full tracking data
sample_play_data = raw_tracking.filter(
    (pl.col('game_id') == sample_play_info['game_id']) &
    (pl.col('play_id') == sample_play_info['play_id'])
).to_pandas()

# Add route type
play_supp = supp_df.filter(
    (pl.col('game_id') == sample_play_info['game_id']) &
    (pl.col('play_id') == sample_play_info['play_id'])
).to_pandas()

if len(play_supp) > 0:
    sample_play_data['route_type'] = play_supp['route_type'].iloc[0]
    sample_play_data['completion'] = play_supp['completion'].iloc[0]

# Create interactive viewer
print("Creating completion probability heatmap viewer...")
heatmap_viewer = create_completion_heatmap_viewer(
    predictor=predictor,  # Your trained predictor
    play_data=sample_play_data,
    grid_size=3.0  # 3-yard grid for faster computation
)

## Part 11: Save Model

In [None]:
# Save model
model_save_path = OUT_DIR / 'recap_model_new.pt'
torch.save({
    'model_state_dict': trained_model.state_dict(),
    'route_to_id': route_to_id,
    'num_routes': NUM_ROUTES,
    'history': history
}, model_save_path)
print(f"Model saved to {model_save_path}")

In [None]:
# Summary
print("\n" + "=" * 60)
print("COMPLETION PROBABILITY MODEL V3 - SUMMARY")
print("=" * 60)
print(f"\nModel: RouteAwareCompletionModel")
print(f"Parameters: {num_params:,}")
print(f"Device: {DEVICE}")
print(f"\nTraining:")
print(f"  - Train plays: {len(train_dataset)}")
print(f"  - Val plays: {len(val_dataset)}")
print(f"  - Test plays: {len(test_dataset)}")
print(f"  - Epochs trained: {len(history['train_loss'])}")
print(f"\nBest Validation Performance:")
print(f"  - Loss: {min(history['val_loss']):.4f}")
print(f"  - Accuracy: {max(history['val_acc']):.4f}")
print(f"  - AUC-ROC: {max(history['val_auc']):.4f}")
print(f"\nTest Set Performance:")
print(f"  - Accuracy: {accuracy_score(test_labels, test_preds > 0.5):.4f}")
print(f"  - AUC-ROC: {roc_auc_score(test_labels, test_preds):.4f}")
print(f"\nKey Enhancements:")
print(f"  - Route type embeddings: {NUM_ROUTES} routes")
print(f"  - Receiver momentum features")
print(f"  - Counterfactual prediction capability")