# Step: Downloading the dataset

To begin, you must connect to runtime and have an api token from kaggle saved in your immediate directory of Google Drive in order to access the dataset directly through CLI

In [1]:
#@title Run this cell to mount Google Drive and get `kaggle.json` from personal directory

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#@title Run this cell to download the competition dataset to notebook directory

! pip install kaggle
! mkdir ~/.kaggle
! cp /content/drive/MyDrive/kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! kaggle competitions download nfl-big-data-bowl-2025

Downloading nfl-big-data-bowl-2025.zip to /content
 99% 1.13G/1.14G [00:11<00:00, 93.7MB/s]
100% 1.14G/1.14G [00:11<00:00, 102MB/s] 


### Install SportsTransformers_utils provided by SumerSports

In [3]:
file_location = '/content/nfl-big-data-bowl-2025.zip'

In [4]:
!unzip {file_location}

Archive:  /content/nfl-big-data-bowl-2025.zip
  inflating: games.csv               
  inflating: player_play.csv         
  inflating: players.csv             
  inflating: plays.csv               
  inflating: tracking_week_1.csv     
  inflating: tracking_week_2.csv     
  inflating: tracking_week_3.csv     
  inflating: tracking_week_4.csv     
  inflating: tracking_week_5.csv     
  inflating: tracking_week_6.csv     
  inflating: tracking_week_7.csv     
  inflating: tracking_week_8.csv     
  inflating: tracking_week_9.csv     


# Stage: Prep Data

## Load in the data files from the data set

In [13]:
from argparse import ArgumentParser
from pathlib import Path

import polars as pl
import os

import random

# TEMPORARY CHANGE
INPUT_DATA_DIR = Path("./")
OUT_DIR = Path("/content/drive/My Drive/bdb-2025/split_prepped_data")

OUT_DIR.mkdir(parents=True, exist_ok=True) # Create directory if it doesn't exist

In [14]:
def get_players_df() -> pl.DataFrame:
    """
    Load player-level data and preprocesses features.

    Returns:
        pl.DataFrame: Preprocessed player data with additional features.
    """
    return (
        pl.read_csv(INPUT_DATA_DIR / "players.csv", null_values=["NA", "nan", "N/A", "NaN", ""])
        .with_columns(
            height_inches=(
                pl.col("height").str.split("-").map_elements(lambda s: int(s[0]) * 12 + int(s[1]), return_dtype=int)
            )
        )
        .with_columns(
            weight_Z=(pl.col("weight") - pl.col("weight").mean()) / pl.col("weight").std(),
            height_Z=(pl.col("height_inches") - pl.col("height_inches").mean()) / pl.col("height_inches").std(),
        )
    )

def get_plays_df() -> pl.DataFrame:
    """
    Load play-level data and preprocesses features.

    Returns:
        pl.DataFrame: Preprocessed play data with additional features.
    """
    return pl.read_csv(INPUT_DATA_DIR / "plays.csv", null_values=["NA", "nan", "N/A", "NaN", ""]).with_columns(
        distanceToGoal=(
            pl.when(pl.col("possessionTeam") == pl.col("yardlineSide"))
            .then(100 - pl.col("yardlineNumber"))
            .otherwise(pl.col("yardlineNumber"))
        )
    )

def get_tracking_df() -> pl.DataFrame:
    """
    Load tracking data and preprocesses features. Notably, exclude rows representing the football's movement.

    Returns:
        pl.DataFrame: Preprocessed tracking data with additional features.
    """
    # don't include football rows for this project.
    # NOTE: Only processing week 1 for the sake of time.  Change "1" to "*" to process all weeks
    return pl.read_csv(INPUT_DATA_DIR / "tracking_week_1.csv", null_values=["NA", "nan", "N/A", "NaN", ""]).filter(
        pl.col("displayName") != "football"
    )

In [15]:
# Load in raw data
print("Load players")
players_df = get_players_df()
print("Load plays")
plays_df = get_plays_df()
print("Load tracking")
tracking_df = get_tracking_df()
print("tracking_df rows:", len(tracking_df))

Load players
Load plays
Load tracking
tracking_df rows: 6795800


In [16]:
players_df.head()

nflId,height,weight,birthDate,collegeName,position,displayName,height_inches,weight_Z,height_Z
i64,str,i64,str,str,str,str,i64,f64,f64
25511,"""6-4""",225,"""1977-08-03""","""Michigan""","""QB""","""Tom Brady""",76,-0.439612,0.658795
29550,"""6-4""",328,"""1982-01-22""","""Arkansas""","""T""","""Jason Peters""",76,1.740005,0.658795
29851,"""6-2""",225,"""1983-12-02""","""California""","""QB""","""Aaron Rodgers""",74,-0.439612,-0.096594
30842,"""6-6""",267,"""1984-05-19""","""UCLA""","""TE""","""Marcedes Lewis""",78,0.449164,1.414183
33084,"""6-4""",217,"""1985-05-17""","""Boston College""","""QB""","""Matt Ryan""",76,-0.608903,0.658795


In [17]:
plays_df.head()

gameId,playId,playDescription,quarter,down,yardsToGo,possessionTeam,defensiveTeam,yardlineSide,yardlineNumber,gameClock,preSnapHomeScore,preSnapVisitorScore,playNullifiedByPenalty,absoluteYardlineNumber,preSnapHomeTeamWinProbability,preSnapVisitorTeamWinProbability,expectedPoints,offenseFormation,receiverAlignment,playClockAtSnap,passResult,passLength,targetX,targetY,playAction,dropbackType,dropbackDistance,passLocationType,timeToThrow,timeInTackleBox,timeToSack,passTippedAtLine,unblockedPressure,qbSpike,qbKneel,qbSneak,rushLocationType,penaltyYards,prePenaltyYardsGained,yardsGained,homeTeamWinProbabilityAdded,visitorTeamWinProbilityAdded,expectedPointsAdded,isDropback,pff_runConceptPrimary,pff_runConceptSecondary,pff_runPassOption,pff_passCoverage,pff_manZone,distanceToGoal
i64,i64,str,i64,i64,i64,str,str,str,i64,str,i64,i64,str,i64,f64,f64,f64,str,str,i64,str,i64,f64,f64,bool,str,f64,str,f64,f64,f64,bool,bool,bool,i64,bool,str,i64,i64,i64,f64,f64,f64,bool,str,str,i64,str,str,i64
2022102302,2655,"""(1:54) (Shotgun) J.Burrow pass…",3,1,10,"""CIN""","""ATL""","""CIN""",21,"""01:54""",35,17,"""N""",31,0.982017,0.017983,0.719313,"""EMPTY""","""3x2""",10,"""C""",6.0,36.69,16.51,False,"""TRADITIONAL""",2.4,"""INSIDE_BOX""",2.99,2.99,,False,False,False,0,,,,9,9,0.004634,-0.004634,0.702717,True,,,0,"""Cover-3""","""Zone""",79
2022091809,3698,"""(2:13) (Shotgun) J.Burrow pass…",4,1,10,"""CIN""","""DAL""","""CIN""",8,"""02:13""",17,17,"""N""",18,0.424356,0.575644,0.607746,"""EMPTY""","""3x2""",9,"""C""",4.0,20.83,20.49,False,"""TRADITIONAL""",1.14,"""INSIDE_BOX""",1.836,1.836,,False,False,False,0,,,,4,4,0.002847,-0.002847,-0.240509,True,,,0,"""Quarters""","""Zone""",92
2022103004,3146,"""(2:00) (Shotgun) D.Mills pass …",4,3,12,"""HOU""","""TEN""","""HOU""",20,"""02:00""",3,17,"""N""",30,0.006291,0.993709,-0.291485,"""SHOTGUN""","""2x2""",12,"""C""",-4.0,26.02,17.56,False,"""TRADITIONAL""",3.2,"""INSIDE_BOX""",2.236,2.236,,False,False,False,0,,,,6,6,0.000205,-0.000205,-0.21848,True,,,0,"""Quarters""","""Zone""",80
2022110610,348,"""(9:28) (Shotgun) P.Mahomes pas…",1,2,10,"""KC""","""TEN""","""TEN""",23,"""09:28""",0,0,"""N""",33,0.884223,0.115777,4.249382,"""SHOTGUN""","""2x2""",11,"""C""",-6.0,38.95,14.19,False,"""TRADITIONAL""",3.02,"""INSIDE_BOX""",2.202,2.202,,False,False,False,0,,,,4,4,-0.001308,0.001308,-0.427749,True,,,0,"""Quarters""","""Zone""",23
2022102700,2799,"""(2:16) (Shotgun) L.Jackson up …",3,2,8,"""BAL""","""TB""","""TB""",27,"""02:16""",10,10,"""N""",37,0.410371,0.589629,3.928413,"""PISTOL""","""3x1""",8,,,,,True,"""DESIGNED_RUN""",2.03,,,,,,,,0,False,"""INSIDE_LEFT""",,-1,-1,0.027141,-0.027141,-0.638912,False,"""MAN""","""READ OPTION""",0,"""Cover-1""","""Man""",27


In [18]:
tracking_df.head()

gameId,playId,nflId,displayName,frameId,frameType,time,jerseyNumber,club,playDirection,x,y,s,a,dis,o,dir,event
i64,i64,i64,str,i64,str,str,i64,str,str,f64,f64,f64,f64,f64,f64,f64,str
2022091200,64,35459,"""Kareem Jackson""",1,"""BEFORE_SNAP""","""2022-09-13 00:16:03.5""",22,"""DEN""","""right""",51.06,28.55,0.72,0.37,0.07,246.17,68.34,"""huddle_break_offense"""
2022091200,64,35459,"""Kareem Jackson""",2,"""BEFORE_SNAP""","""2022-09-13 00:16:03.6""",22,"""DEN""","""right""",51.13,28.57,0.71,0.36,0.07,245.41,71.21,
2022091200,64,35459,"""Kareem Jackson""",3,"""BEFORE_SNAP""","""2022-09-13 00:16:03.7""",22,"""DEN""","""right""",51.2,28.59,0.69,0.23,0.07,244.45,69.9,
2022091200,64,35459,"""Kareem Jackson""",4,"""BEFORE_SNAP""","""2022-09-13 00:16:03.8""",22,"""DEN""","""right""",51.26,28.62,0.67,0.22,0.07,244.45,67.98,
2022091200,64,35459,"""Kareem Jackson""",5,"""BEFORE_SNAP""","""2022-09-13 00:16:03.9""",22,"""DEN""","""right""",51.32,28.65,0.65,0.34,0.07,245.74,62.83,


In [19]:
def add_features_to_tracking_df(
    tracking_df: pl.DataFrame,
    players_df: pl.DataFrame,
    plays_df: pl.DataFrame,
) -> pl.DataFrame:
    """
    Consolidates play and player level data into the tracking data.

    Args:
        tracking_df (pl.DataFrame): Tracking data
        players_df (pl.DataFrame): Player data
        plays_df (pl.DataFrame): Play data

    Returns:
        pl.DataFrame: Tracking data with additional features.
    """
    # add `is_ball_carrier`, `team_indicator`, and other features to tracking data
    og_len = len(tracking_df)
    tracking_df = (
        tracking_df.join(
            plays_df.select(
                "gameId",
                "playId",
                "defensiveTeam"
            ),
            on=["gameId", "playId"],
            how="inner",
        )
        .join(
            players_df.select(["nflId", "displayName", "position"]).unique(), # select position column
            on=["nflId", "displayName"],
            how="left",
        )
        #.join(
        #    players_df.select(["nflId", "weight_Z", "height_Z"]).unique(),
        #    on="nflId",
        #    how="inner",
        #)
        .with_columns(
            isDefense=pl.when(pl.col("club") == pl.col("defensiveTeam"))
            .then(pl.lit(1))
            .otherwise(pl.lit(-1))
            .alias("isDefense"),
        )
        .drop(["defensiveTeam"])
        .drop(["event"])
    )

    assert len(tracking_df) == og_len, "Lost rows when joining tracking data with play/player data"

    return tracking_df

print("Add features to tracking")
tracking_df = add_features_to_tracking_df(tracking_df, players_df, plays_df)
del players_df

Add features to tracking


In [20]:
tracking_df.head()

gameId,playId,nflId,displayName,frameId,frameType,time,jerseyNumber,club,playDirection,x,y,s,a,dis,o,dir,position,isDefense
i64,i64,i64,str,i64,str,str,i64,str,str,f64,f64,f64,f64,f64,f64,f64,str,i32
2022091200,64,35459,"""Kareem Jackson""",1,"""BEFORE_SNAP""","""2022-09-13 00:16:03.5""",22,"""DEN""","""right""",51.06,28.55,0.72,0.37,0.07,246.17,68.34,"""SS""",1
2022091200,64,35459,"""Kareem Jackson""",2,"""BEFORE_SNAP""","""2022-09-13 00:16:03.6""",22,"""DEN""","""right""",51.13,28.57,0.71,0.36,0.07,245.41,71.21,"""SS""",1
2022091200,64,35459,"""Kareem Jackson""",3,"""BEFORE_SNAP""","""2022-09-13 00:16:03.7""",22,"""DEN""","""right""",51.2,28.59,0.69,0.23,0.07,244.45,69.9,"""SS""",1
2022091200,64,35459,"""Kareem Jackson""",4,"""BEFORE_SNAP""","""2022-09-13 00:16:03.8""",22,"""DEN""","""right""",51.26,28.62,0.67,0.22,0.07,244.45,67.98,"""SS""",1
2022091200,64,35459,"""Kareem Jackson""",5,"""BEFORE_SNAP""","""2022-09-13 00:16:03.9""",22,"""DEN""","""right""",51.32,28.65,0.65,0.34,0.07,245.74,62.83,"""SS""",1


In [21]:
def convert_tracking_to_cartesian(tracking_df: pl.DataFrame) -> pl.DataFrame:
    """
    Convert polar coordinates to Unit-circle Cartesian format.

    Args:
        tracking_df (pl.DataFrame): Tracking data

    Returns:
        pl.DataFrame: Tracking data with Cartesian coordinates.
    """
    return (
        tracking_df.with_columns(
            dir=((pl.col("dir") - 90) * -1) % 360,
            o=((pl.col("o") - 90) * -1) % 360,
        )
        # convert polar vectors to cartesian ((s, dir) -> (vx, vy), (o) -> (ox, oy))
        .with_columns(
            vx=pl.col("s") * pl.col("dir").radians().cos(),
            vy=pl.col("s") * pl.col("dir").radians().sin(),
            ox=pl.col("o").radians().cos(),
            oy=pl.col("o").radians().sin(),
        )
    )

print("Convert tracking to cartesian")
tracking_df = convert_tracking_to_cartesian(tracking_df)

Convert tracking to cartesian


In [22]:
tracking_df.head()

gameId,playId,nflId,displayName,frameId,frameType,time,jerseyNumber,club,playDirection,x,y,s,a,dis,o,dir,position,isDefense,vx,vy,ox,oy
i64,i64,i64,str,i64,str,str,i64,str,str,f64,f64,f64,f64,f64,f64,f64,str,i32,f64,f64,f64,f64
2022091200,64,35459,"""Kareem Jackson""",1,"""BEFORE_SNAP""","""2022-09-13 00:16:03.5""",22,"""DEN""","""right""",51.06,28.55,0.72,0.37,0.07,203.83,21.66,"""SS""",1,0.669161,0.265751,-0.914748,-0.404024
2022091200,64,35459,"""Kareem Jackson""",2,"""BEFORE_SNAP""","""2022-09-13 00:16:03.6""",22,"""DEN""","""right""",51.13,28.57,0.71,0.36,0.07,204.59,18.79,"""SS""",1,0.672161,0.228691,-0.909309,-0.416122
2022091200,64,35459,"""Kareem Jackson""",3,"""BEFORE_SNAP""","""2022-09-13 00:16:03.7""",22,"""DEN""","""right""",51.2,28.59,0.69,0.23,0.07,205.55,20.1,"""SS""",1,0.647975,0.237125,-0.902209,-0.431299
2022091200,64,35459,"""Kareem Jackson""",4,"""BEFORE_SNAP""","""2022-09-13 00:16:03.8""",22,"""DEN""","""right""",51.26,28.62,0.67,0.22,0.07,205.55,22.02,"""SS""",1,0.621126,0.251203,-0.902209,-0.431299
2022091200,64,35459,"""Kareem Jackson""",5,"""BEFORE_SNAP""","""2022-09-13 00:16:03.9""",22,"""DEN""","""right""",51.32,28.65,0.65,0.34,0.07,204.26,27.17,"""SS""",1,0.578276,0.296811,-0.91169,-0.410878


NOTE: there two current `tracking` dataframes. One is the consolidated one, the other is augmented the direction of all players to assume that they are all moving in the same direction/

## Explore model target --> Position of Masked player

* Target dim.: (1, 2)
  - first dim.: `displayName` of masked player
  - second dim.: x-y coordinate of masked player on the field at the frame

## Goal: Randomly select a player and predict their x-y coordinates by removing them from the input sequence and having the model estimate their motion based on the locations of the 21 other players.



In [23]:
# Randomly select a player from players_in_game_play to filter out from each frame in the game and play

import random

def get_masked_players(tracking_df):
    """
    Randomly selects a player from a game and play, and filters their data from each frame.

    Args:
        tracking_df: The tracking DataFrame.

    Returns:
        tuple: A tuple containing the filtered tracking DataFrame and the masked players DataFrame.
    """
    masked_players_df = pl.DataFrame()

    for game_id, play_id in tracking_df.select(["gameId", "playId"]).unique().rows():

        # The defensive players in a given game + play
        filtered_df = tracking_df.filter(
          (pl.col("gameId") == game_id)
          & (pl.col("playId") == play_id)
          & (pl.col("frameId") == 1)
          & (pl.col("isDefense") == 1)
        )
        assert len(filtered_df) == 11, "No players found for gameId: {}, playId: {}".format(game_id, play_id)

        # Retrieve masked player
        selected_player = random.choice(filtered_df["displayName"].to_list())

        masked_player_df = tracking_df.filter(
            (pl.col("gameId") == game_id)
            & (pl.col("playId") == play_id)
            & (pl.col("displayName") == selected_player)
        )
        masked_players_df = pl.concat([masked_players_df, masked_player_df])

        # Filter out the selected player from tracking_df
        filtered_df = tracking_df.filter(
            ~(
                (pl.col("gameId") == game_id)
                & (pl.col("playId") == play_id)
                & (pl.col("displayName") == selected_player)
            )
        )
        #assert len(tracking_df) - len(masked_player_df) == len(filtered_df), "Players other than the masked player were lost"
        print(f"Player masked for gameId {game_id} playId {play_id}")

    return filtered_df, masked_player_df

In [24]:
print("Generate target - maskedPlayers")
rel_tracking_df, maskedPlayers_df = get_masked_players(tracking_df)

Generate target - maskedPlayers
Player masked for gameId 2022091104 playId 1713
Player masked for gameId 2022091107 playId 1841
Player masked for gameId 2022091101 playId 1537
Player masked for gameId 2022091103 playId 2955
Player masked for gameId 2022091111 playId 1862
Player masked for gameId 2022091106 playId 2429
Player masked for gameId 2022091101 playId 2629
Player masked for gameId 2022091108 playId 2962
Player masked for gameId 2022091104 playId 4001
Player masked for gameId 2022091112 playId 2723
Player masked for gameId 2022091110 playId 1232
Player masked for gameId 2022091200 playId 3826
Player masked for gameId 2022091102 playId 3119
Player masked for gameId 2022091101 playId 317
Player masked for gameId 2022091110 playId 2845
Player masked for gameId 2022091106 playId 380
Player masked for gameId 2022091107 playId 1779
Player masked for gameId 2022091104 playId 865
Player masked for gameId 2022091111 playId 542
Player masked for gameId 2022091112 playId 3672
Player maske

In [25]:
maskedPlayers_df.head()

gameId,playId,nflId,displayName,frameId,frameType,time,jerseyNumber,club,playDirection,x,y,s,a,dis,o,dir,position,isDefense,vx,vy,ox,oy
i64,i64,i64,str,i64,str,str,i64,str,str,f64,f64,f64,f64,f64,f64,f64,str,i32,f64,f64,f64,f64
2022091105,3345,52498,"""Jonathan Greenard""",1,"""BEFORE_SNAP""","""2022-09-11 19:35:28.5""",52,"""HOU""","""right""",86.07,16.88,5.01,0.98,0.53,69.07,61.3,"""DE""",1,2.40592,4.394502,0.357227,0.934018
2022091105,3345,52498,"""Jonathan Greenard""",2,"""BEFORE_SNAP""","""2022-09-11 19:35:28.6""",52,"""HOU""","""right""",86.31,17.33,5.01,1.13,0.51,69.07,61.2,"""DE""",1,2.413586,4.390296,0.357227,0.934018
2022091105,3345,52498,"""Jonathan Greenard""",3,"""BEFORE_SNAP""","""2022-09-11 19:35:28.7""",52,"""HOU""","""right""",86.56,17.79,4.99,1.44,0.52,70.92,61.34,"""DE""",1,2.393259,4.378631,0.326888,0.945063
2022091105,3345,52498,"""Jonathan Greenard""",4,"""BEFORE_SNAP""","""2022-09-11 19:35:28.8""",52,"""HOU""","""right""",86.81,18.24,4.96,1.67,0.52,72.3,61.68,"""DE""",1,2.353002,4.366347,0.304033,0.952661
2022091105,3345,52498,"""Jonathan Greenard""",5,"""BEFORE_SNAP""","""2022-09-11 19:35:28.9""",52,"""HOU""","""right""",87.06,18.71,5.05,1.41,0.53,71.55,61.65,"""DE""",1,2.398025,4.44432,0.316477,0.9486


In [26]:
tracking_df.head()

gameId,playId,nflId,displayName,frameId,frameType,time,jerseyNumber,club,playDirection,x,y,s,a,dis,o,dir,position,isDefense,vx,vy,ox,oy
i64,i64,i64,str,i64,str,str,i64,str,str,f64,f64,f64,f64,f64,f64,f64,str,i32,f64,f64,f64,f64
2022091200,64,35459,"""Kareem Jackson""",1,"""BEFORE_SNAP""","""2022-09-13 00:16:03.5""",22,"""DEN""","""right""",51.06,28.55,0.72,0.37,0.07,203.83,21.66,"""SS""",1,0.669161,0.265751,-0.914748,-0.404024
2022091200,64,35459,"""Kareem Jackson""",2,"""BEFORE_SNAP""","""2022-09-13 00:16:03.6""",22,"""DEN""","""right""",51.13,28.57,0.71,0.36,0.07,204.59,18.79,"""SS""",1,0.672161,0.228691,-0.909309,-0.416122
2022091200,64,35459,"""Kareem Jackson""",3,"""BEFORE_SNAP""","""2022-09-13 00:16:03.7""",22,"""DEN""","""right""",51.2,28.59,0.69,0.23,0.07,205.55,20.1,"""SS""",1,0.647975,0.237125,-0.902209,-0.431299
2022091200,64,35459,"""Kareem Jackson""",4,"""BEFORE_SNAP""","""2022-09-13 00:16:03.8""",22,"""DEN""","""right""",51.26,28.62,0.67,0.22,0.07,205.55,22.02,"""SS""",1,0.621126,0.251203,-0.902209,-0.431299
2022091200,64,35459,"""Kareem Jackson""",5,"""BEFORE_SNAP""","""2022-09-13 00:16:03.9""",22,"""DEN""","""right""",51.32,28.65,0.65,0.34,0.07,204.26,27.17,"""SS""",1,0.578276,0.296811,-0.91169,-0.410878


## Splitting data into train, validation, and test sets.

In [27]:
# Provided from SportsTransformers-utils

def split_train_test_val(tracking_df: pl.DataFrame, target_df: pl.DataFrame) -> dict[str, pl.DataFrame]:
    """
    Split data into train, validation, and test sets.
    Split is 70-15-15 for train-test-val respectively. Notably, we split at the play levle and not frame level.
    This ensures no target contamination between splits.

    Args:
        tracking_df (pl.DataFrame): Tracking data
        target_df (pl.DataFrame): Target data

    Returns:
        dict: Dictionary containing train, validation, and test dataframes.
    """
    tracking_df = tracking_df.sort(["gameId", "playId", "frameId"])
    target_df = target_df.sort(["gameId", "playId"])

    print(
        f"Total set: {tracking_df.n_unique(['gameId', 'playId'])} plays,",
        f"{tracking_df.n_unique(['gameId', 'playId', 'frameId'])} frames",
    )

    test_val_ids = tracking_df.select(["gameId", "playId"]).unique(maintain_order=True).sample(fraction=0.3, seed=42)
    train_tracking_df = tracking_df.join(test_val_ids, on=["gameId", "playId"], how="anti")
    train_tgt_df = target_df.join(test_val_ids, on=["gameId", "playId"], how="anti")
    print(
        f"Train set: {train_tracking_df.n_unique(['gameId', 'playId'])} plays,",
        f"{train_tracking_df.n_unique(['gameId', 'playId', 'frameId'])} frames",
    )

    test_ids = test_val_ids.sample(fraction=0.5, seed=42)  # 70-15-15 split
    test_tracking_df = tracking_df.join(test_ids, on=["gameId", "playId"], how="inner")
    test_tgt_df = target_df.join(test_ids, on=["gameId", "playId"], how="inner")
    print(
        f"Test set: {test_tracking_df.n_unique(['gameId', 'playId'])} plays,",
        f"{test_tracking_df.n_unique(['gameId', 'playId', 'frameId'])} frames",
    )

    val_ids = test_val_ids.join(test_ids, on=["gameId", "playId"], how="anti")
    val_tracking_df = tracking_df.join(val_ids, on=["gameId", "playId"], how="inner")
    val_tgt_df = target_df.join(val_ids, on=["gameId", "playId"], how="inner")
    print(
        f"Validation set: {val_tracking_df.n_unique(['gameId', 'playId'])} plays,",
        f"{val_tracking_df.n_unique(['gameId', 'playId','frameId'])} frames",
    )

    return {
        "train_features": train_tracking_df,
        "train_targets": train_tgt_df,
        "test_features": test_tracking_df,
        "test_targets": test_tgt_df,
        "val_features": val_tracking_df,
        "val_targets": val_tgt_df,
    }

In [28]:
# Writing out splits to OUT_DIR

print("Split train/test/val")
split_dfs = split_train_test_val(rel_tracking_df, maskedPlayers_df)

out_dir = Path(OUT_DIR)
out_dir.mkdir(exist_ok=True, parents=True)

for key, df in split_dfs.items():
    sort_keys = ["gameId", "playId", "frameId"]
    df.sort(sort_keys).write_parquet(out_dir / f"{key}.parquet")

Split train/test/val
Total set: 1952 plays, 308900 frames
Train set: 1367 plays, 214793 frames
Test set: 292 plays, 46712 frames
Validation set: 293 plays, 47395 frames
