In [None]:
import sys
import logging
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import nflreadpy as nfl
import pymc as pm
import pymc_bart as pmb
from sklearn.model_selection import StratifiedShuffleSplit

sys.path.append('../py')
from preprocess import preprocess
from nflplotlib import nflplot as nfp

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

LOG = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

RANDOM_SEED = 2
np.random.seed(RANDOM_SEED)
N_WEEKS = 18

In [3]:
sup_data = pd.read_csv('../data/supplementary_data.csv')
tracking_input, tracking_output = pd.DataFrame(), pd.DataFrame()
for week in tqdm(range(1, N_WEEKS+1), desc="Loading weekly data"):
    tracking_input = pd.concat([tracking_input, pd.read_csv(f'../data/train/input_2023_w{week:02d}.csv')], axis=0)
    tracking_output = pd.concat([tracking_output, pd.read_csv(f'../data/train/output_2023_w{week:02d}.csv')], axis=0)
LOG.info(f'Tracking input shape: {tracking_input.shape}, output shape: {tracking_output.shape}')

Loading weekly data: 100%|██████████| 18/18 [00:08<00:00,  2.16it/s]
2025-10-30 12:54:03,012 - INFO - Tracking input shape: (4880579, 23), output shape: (562936, 6)


In [4]:
games, plays, players, tracking = preprocess.process_data(tracking_input, tracking_output, sup_data)
team_desc = preprocess.fetch_team_desc()

2025-10-30 12:54:09,979 - INFO - Joined input and output tracking data: 14108 unique plays, 1384 unique nfl_ids
2025-10-30 12:54:09,980 - INFO - Standardizing direction of play and players to be left to right
2025-10-30 12:54:11,164 - INFO - Approximating missing speed, acceleration and direction values
2025-10-30 12:54:13,564 - INFO - Correlation results for imputations: s_approx: speed R²=0.9966 | a_approx: accel R²=0.0831 | dir_approx: dir R²=0.0587
2025-10-30 12:54:14,114 - INFO - Joining supplemental data to plays DataFrame
2025-10-30 12:54:14,150 - INFO - Loading NFL PBP data for season 2023
2025-10-30 12:54:14,151 - INFO - Loading pbp from local parquet file
2025-10-30 12:54:17,351 - INFO - Mapping player IDs to nfl_id using seasonal rosters
2025-10-30 12:54:17,351 - INFO - Rosters for season 2023 already cached, loading from parquet
2025-10-30 12:54:30,373 - INFO - Defaulting passer to QB for play without a passer: 2023092406_3048
2025-10-30 12:54:31,747 - INFO - Defaulting pas

In [5]:
tracking.query('position.isin(["FS","SS","S"]) and pass_thrown').gpid.unique()[50:100]

array(['2023091000_842', '2023091000_927', '2023091001_1042',
       '2023091001_1232', '2023091001_1368', '2023091001_1438',
       '2023091001_1470', '2023091001_1574', '2023091001_1980',
       '2023091001_2033', '2023091001_2151', '2023091001_2351',
       '2023091001_2432', '2023091001_2523', '2023091001_2759',
       '2023091001_3086', '2023091001_3514', '2023091001_3879',
       '2023091001_3902', '2023091001_4018', '2023091001_407',
       '2023091001_4191', '2023091001_4216', '2023091001_4239',
       '2023091001_4322', '2023091001_4564', '2023091001_4589',
       '2023091001_4639', '2023091001_559', '2023091001_893',
       '2023091002_1209', '2023091002_1412', '2023091002_1464',
       '2023091002_1540', '2023091002_1974', '2023091002_2872',
       '2023091002_2942', '2023091002_3253', '2023091002_3314',
       '2023091002_3337', '2023091002_3409', '2023091002_3571',
       '2023091002_3812', '2023091002_3835', '2023091002_4036',
       '2023091002_444', '2023091002_767', '2

In [18]:
# gpid="2023091100_993"
# gpid="2023091003_410"
# gpid="2023091003_1706"
gpid="2023121711_2228"
nfp.animate_play(
    tracking.query('gpid==@gpid'),
    plays.query('gpid==@gpid'),
    games.query(f'game_id=={gpid.split("_")[0]}'),
    team_desc,
    # save_path='animation.gif',
    plot_positions=True,
    highlight_postpass_players=True,
    show_postpass_paths=True
)

2025-10-31 11:09:27,249 - INFO - Animation.save using <class 'matplotlib.animation.HTMLWriter'>


# i. Feature Engineering

In [7]:
def get_all_nearest_defenders(tracking, n_values=(1, 2, 3)):
    """
    Vectorized: compute the n-th nearest defenders to the receiver per frame for all plays.
    Returns one DataFrame with ['gpid','frame_id', defender_dist_n, defender_x_n, defender_y_n, ...].
    """
    defenders = tracking.query('player_side == "Defense"')[['gpid','frame_id','x','y']]
    receivers = tracking.query('is_receiver')[['gpid','frame_id','x','y']].rename(
        columns={'x':'receiver_x','y':'receiver_y'}
    )

    # Merge defenders with receiver positions per frame
    merged = defenders.merge(receivers, on=['gpid','frame_id'], how='inner')
    merged['dist'] = np.sqrt(
        (merged['x'] - merged['receiver_x'])**2 + 
        (merged['y'] - merged['receiver_y'])**2
    )

    # Rank defenders by distance *within each play and frame*
    merged['rank'] = merged.groupby(['gpid','frame_id'])['dist'].rank(method='first')

    # Build result incrementally for all n
    results = []
    for n in n_values:
        nth = (
            merged[merged['rank'] == n][['gpid','frame_id','x','y','dist']]
            .rename(columns={
                'x': f'defender_x_{n}',
                'y': f'defender_y_{n}',
                'dist': f'defender_dist_{n}'
            })
        )
        results.append(nth)

    # Combine all nth results into one wide frame
    out = results[0]
    for r in results[1:]:
        out = out.merge(r, on=['gpid','frame_id'], how='outer')

    return out


def get_ball_flight_pct(df):
    """
    Compute percent of ball flight for all plays in one pass (vectorized).
    Returns a copy with new 'ball_flight_pct' column.
    """
    df = df.sort_values(['gpid','frame_id']).copy()

    # Find throw frames and end frames per play
    throw_frame = (
        df.loc[df['pass_thrown'], ['gpid','frame_id']]
        .groupby('gpid')['frame_id']
        .min()
        .rename('throw_frame')
    )
    end_frame = df.groupby('gpid')['frame_id'].max().rename('end_frame')

    df = df.merge(throw_frame, on='gpid', how='left').merge(end_frame, on='gpid', how='left')

    # Compute flight pct
    df['ball_flight_pct'] = 0.0
    in_flight = df['frame_id'] >= df['throw_frame']
    df.loc[in_flight, 'ball_flight_pct'] = (
        (df.loc[in_flight, 'frame_id'] - df.loc[in_flight, 'throw_frame'])
        / (df.loc[in_flight, 'end_frame'] - df.loc[in_flight, 'throw_frame']).clip(lower=1)
    ) * 100

    return df.drop(columns=['throw_frame','end_frame'])

LOG.info("Preparing base data")

# Base ball + receiver merge
data = (
    tracking
    .query('position == "Ball"')[['gpid','frame_id','pass_thrown','x','y']]
    .drop_duplicates(subset=['gpid','frame_id'])
    .rename(columns={'x':'ball_x','y':'ball_y'})
    .merge(
        tracking.query('is_receiver')[['gpid','frame_id','x','y']]
        .rename(columns={'x':'receiver_x','y':'receiver_y'}),
        on=['gpid','frame_id'],
        how='left'
    )
    .assign(
        dist_ball_to_receiver=lambda df: np.sqrt(
            (df.ball_x - df.receiver_x)**2 + (df.ball_y - df.receiver_y)**2
        )
    )
)

LOG.info("Finding nearest defenders (1–3)")
nearest_defenders = get_all_nearest_defenders(tracking, n_values=(1,2,3))
data = data.merge(nearest_defenders, on=['gpid','frame_id'], how='left')

LOG.info("Calculating ball flight percentage")
data = get_ball_flight_pct(data)

LOG.info("Joining in play data")
play_cols = ['gpid','expected_points_added','yards_after_catch','pass_result','play_nullified_by_penalty',
             'ball_land_x','absolute_yardline_number','play_description']
data = (
    data
    .merge(plays[play_cols], on='gpid', how='left')
    .assign(
        interception=lambda df: (df.pass_result == 'IN').astype(int),
        completion=lambda df: (df.pass_result == 'C').astype(int)
    )
    .drop(columns=['pass_result'])
)

LOG.info("Joining start yardline of next play")
pbp = (
    nfl.load_pbp(seasons=[2023]).to_pandas()
    .query('play_type != "no_play"')
    [['old_game_id','play_id','yardline_100','touchdown','touchback']]
    .assign(
        gpid=lambda df: df['old_game_id'].str.replace('_','') + '_' + df['play_id'].astype(int).astype(str),
        yardline_100=lambda df: 110 - df['yardline_100']
    )
    .sort_values(['old_game_id','play_id'])
)
pbp['pid'] = pbp.groupby('old_game_id').cumcount() + 1
pbp_next = (
    pbp[['old_game_id', 'pid', 'play_id', 'yardline_100']]
    .rename(
        columns={'yardline_100':'next_play_start_yardline_100'}
    )
    .assign(
        pid=lambda df: df['pid'] - 1
    )
)
pbp = pbp.merge(pbp_next, on=['old_game_id','pid'], how='left').drop(columns=['pid'])
data = data.merge(pbp[['gpid','next_play_start_yardline_100','touchdown','touchback']], on='gpid', how='left')
data['interception_return_yards'] = np.where(
    data['interception'] == 1,
    np.where(
        data.touchdown,
        data.ball_land_x,
        np.where(
            data.touchback, 
            0,
            data['ball_land_x'] - (data['next_play_start_yardline_100'] + 10)
        )
    ),
    np.nan
)

LOG.info("Data preparation complete")

final_cols = [
    'gpid', 'frame_id', 'pass_thrown', 'ball_flight_pct',
    'dist_ball_to_receiver', 'defender_dist_1', 'defender_dist_2', 'defender_dist_3',
    'interception', 'completion', 'expected_points_added', 'yards_after_catch',
    'play_nullified_by_penalty', 'interception_return_yards'
]
data = data[final_cols]

2025-10-30 12:55:18,376 - INFO - Preparing base data
2025-10-30 12:55:18,760 - INFO - Finding nearest defenders (1–3)
2025-10-30 12:55:21,089 - INFO - Calculating ball flight percentage
2025-10-30 12:55:21,442 - INFO - Joining in play data
2025-10-30 12:55:21,707 - INFO - Joining start yardline of next play
2025-10-30 12:55:23,202 - INFO - Data preparation complete


In [8]:
data.head()

Unnamed: 0,gpid,frame_id,pass_thrown,ball_flight_pct,dist_ball_to_receiver,defender_dist_1,defender_dist_2,defender_dist_3,interception,completion,expected_points_added,yards_after_catch,play_nullified_by_penalty,interception_return_yards
0,2023090700_1001,1,False,0.0,6.931876,2.913589,4.664118,5.781228,0,1,1.195112,0.0,N,
1,2023090700_1001,2,False,0.0,6.983552,2.894305,4.621796,5.76184,0,1,1.195112,0.0,N,
2,2023090700_1001,3,False,0.0,7.057833,2.853086,4.550275,5.742752,0,1,1.195112,0.0,N,
3,2023090700_1001,4,False,0.0,7.122612,2.842006,4.506995,5.727347,0,1,1.195112,0.0,N,
4,2023090700_1001,5,False,0.0,7.201389,2.853384,4.461255,5.743953,0,1,1.195112,0.0,N,


In [9]:
data.columns

Index(['gpid', 'frame_id', 'pass_thrown', 'ball_flight_pct',
       'dist_ball_to_receiver', 'defender_dist_1', 'defender_dist_2',
       'defender_dist_3', 'interception', 'completion',
       'expected_points_added', 'yards_after_catch',
       'play_nullified_by_penalty', 'interception_return_yards'],
      dtype='object')

In [10]:
data.play_nullified_by_penalty.value_counts()

play_nullified_by_penalty
N    556363
Name: count, dtype: int64

# ii. P(interception | state)

In [11]:
data.drop_duplicates('gpid').interception.value_counts()

interception
0    13754
1      337
Name: count, dtype: int64

In [15]:
tracking.head()

Unnamed: 0,gpid,game_id,play_id,frame_id,nfl_id,pass_thrown,player_to_predict,player_side,player_role,position,x,y,s,a,dir,o,is_passer,is_receiver,is_interceptor
0,2023090700_1001,2023090700,1001,1,0,False,False,Ball,Ball,Ball,104.71,23.6,,,,,False,False,False
1,2023090700_1001,2023090700,1001,2,0,False,False,Ball,Ball,Ball,104.68,23.6,,,,,False,False,False
2,2023090700_1001,2023090700,1001,3,0,False,False,Ball,Ball,Ball,104.65,23.61,,,,,False,False,False
3,2023090700_1001,2023090700,1001,4,0,False,False,Ball,Ball,Ball,104.59,23.63,,,,,False,False,False
4,2023090700_1001,2023090700,1001,5,0,False,False,Ball,Ball,Ball,104.5,23.65,,,,,False,False,False


In [16]:
tracking.query('pass_thrown').drop_duplicates(['gpid','nfl_id']).gpid.value_counts()

gpid
2023091705_2480    10
2023110502_3580     9
2023121001_3613     9
2023101506_3156     9
2023101501_644      9
2024010706_783      9
2023120304_1700     9
2023121600_3220     9
2023121705_3372     9
2023121006_2207     9
2023121100_4506     9
2023123104_683      9
2023111600_1794     9
2023120302_877      9
2023122100_1450     9
2023122402_1054     9
2023121001_3568     9
2023121711_2228     9
2023121000_3022     9
2023100811_2626     9
2023100811_2698     9
2023121000_4426     9
2023091100_3167     9
2024010704_3564     9
2023101504_2658     9
2023123103_3929     9
2023111908_1187     8
2023121602_2774     8
2024010601_2866     8
2023091400_3333     8
2023121711_1314     8
2023102202_3858     8
2023100801_2104     8
2023120301_2032     8
2023092401_1938     8
2023100801_1591     8
2023111910_3061     8
2023092501_1235     8
2023100801_1148     8
2023111910_1031     8
2023120303_2994     8
2024010601_1418     8
2023123104_1769     8
2023092500_288      8
2023102202_4175     8
20231

In [12]:
play_col = 'gpid'

features = [
    'pass_thrown', 'ball_flight_pct',
    'dist_ball_to_receiver', 'defender_dist_1', 'defender_dist_2', 'defender_dist_3'
]
target = 'interception'

# ensure types numeric
df = data.copy()
df[features] = df[features].apply(pd.to_numeric, errors='coerce')
df[target] = df[target].astype(int)

# ---------- construct play-level labels for stratified split ----------
# For each play, we want to know if that play had any interception (so we stratify on that)
play_label = df.groupby(play_col)[target].max().rename('play_has_int')  # 1 if any frame in that play had int
play_index = play_label.index.to_numpy()
play_y = play_label.values  # 0/1 per play

# ---------- Stratified split at play level ----------
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=RANDOM_SEED)
train_play_ix, test_play_ix = next(sss.split(play_index, play_y))
train_plays = play_index[train_play_ix]
test_plays = play_index[test_play_ix]

train_df = df[df[play_col].isin(train_plays)].reset_index(drop=True)
test_df  = df[df[play_col].isin(test_plays)].reset_index(drop=True)

print(f"Plays: total {len(play_index)}, train plays {len(train_plays)}, test plays {len(test_plays)}")
print("Frame counts: train", len(train_df), "test", len(test_df))
print("Train interception frames:", train_df[target].sum(), "of", len(train_df))
print("Test interception frames:", test_df[target].sum(), "of", len(test_df))

# ---------- Standardize numeric features (helpful for BART stability) ----------
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(train_df[features])
X_test  = scaler.transform(test_df[features])
y_train = train_df[target].values
y_test  = test_df[target].values

# factorize play ids for hierarchical intercept (optional but recommended)
play_idx_train, play_levels = pd.factorize(train_df[play_col])
n_plays = len(play_levels)

# baseline logit init using global rate (helps with class imbalance)
base_rate = y_train.mean()
init_logit = np.log(base_rate / (1 - base_rate + 1e-12))

# ---------- PyMC BART model ----------
with pm.Model() as model:
    # hierarchical (varying) play intercept
    sigma_play = pm.Exponential("sigma_play", 1.0)
    play_offset = pm.Normal("play_offset", mu=0.0, sigma=1.0, shape=n_plays)
    play_effect = pm.Deterministic("play_effect", play_offset * sigma_play)

    # BART latent function: pass X and Y (Y required)
    # NOTE: pm.BART expects a continuous Y; we pass y_train.astype(float).
    # BART will learn a latent f that we treat as the logit contribution (plus play_effect).
    bart_f = pmb.BART("bart_f", X=X_train, Y=y_train.astype(float), m=50)  # m = number of trees

    # linear predictor = intercept + play_effect[play_idx] + bart_f
    intercept = pm.Normal("intercept", mu=init_logit, sigma=2.0)
    logit_p = intercept + play_effect[play_idx_train] + bart_f

    # probability via logistic
    p = pm.Deterministic("p", pm.math.sigmoid(logit_p))

    # observed Bernoulli
    y_obs = pm.Bernoulli("y_obs", p=p, observed=y_train)

    # sample (use fewer draws for testing; increase for production)
    idata = pm.sample(1000, tune=1000, chains=2, target_accept=0.9, cores=2)

2025-10-30 12:55:42,780 - INFO - Found 'auto' as default backend, checking available backends
2025-10-30 12:55:42,780 - INFO - Matplotlib is available, defining as default backend
2025-10-30 12:55:42,784 - INFO - arviz_base available, exposing its functions as part of arviz.preview
2025-10-30 12:55:49,930 - INFO - arviz_stats available, exposing its functions as part of arviz.preview
2025-10-30 12:55:49,931 - INFO - arviz_plots not installed


Plays: total 14091, train plays 11272, test plays 2819
Frame counts: train 445304 test 111059
Train interception frames: 12106 of 445304
Test interception frames: 3007 of 111059


2025-10-30 12:55:52,011 - INFO - Compiling new CVM
2025-10-30 12:55:54,224 - INFO - New version 0.31
2025-10-30 12:56:11,170 - INFO - Multiprocess sampling (2 chains in 2 jobs)
2025-10-30 12:56:11,171 - INFO - CompoundStep
2025-10-30 12:56:11,171 - INFO - >NUTS: [sigma_play, play_offset, intercept]
2025-10-30 12:56:11,171 - INFO - >PGBART: [bart_f]


Output()

ValueError: Not enough samples to build a trace.

# iii. P(completion | state) fit using only non-interception data

# iv. Return_yards | interception, state

# v. YAC | completion, state

# vi. Expected Points Model