In [1]:
import pandas as pd
import numpy as np
pd.set_option("display.max_columns", None)

from src.prepData import load_train_data, normalize_input_fields, normalize_output_fields

## Data Prep

#### Load Data

In [2]:
print("Loading data...")
input_df, output_df = load_train_data()
# input_df.to_pickle("data/personal/input_df.pkl")
# output_df.to_pickle("data/personal/output_df.pkl")

print(f"Loaded {len(input_df)} input rows, {len(output_df)} output rows")
print(
  f"Unique plays: {input_df[['game_id', 'play_id']].drop_duplicates().shape[0]}"
)

Loading data...
Loaded 4880579 input rows, 562936 output rows
Unique plays: 14108


#### Normalize fields

In [3]:
input_df = normalize_input_fields(input_df)
norm_helper = input_df[['game_id','play_id','play_direction','absolute_yardline_number']].drop_duplicates()
output_df = normalize_output_fields(output_df, norm_helper)

#### Create play-level features

In [4]:
distinct_plays = input_df[['game_id', 'play_id']].drop_duplicates()
distinct_plays.sort_values(['game_id','play_id']).head(3)

Unnamed: 0,game_id,play_id
0,2023090700,101
234,2023090700,194
650,2023090700,219


In [5]:
# Get max frame_id from input_df for each play (throw_frame_id baseline)
input_max_frames = (
    input_df
    .groupby(['game_id', 'play_id'])[['frame_id', "ball_land_x_std","ball_land_y_std"]]
    .max()
    .reset_index()
    .rename(columns={'frame_id': 'throw_frame_id'})
)

# Get max frame_id from output_df for each play (throw_land_frame_id baseline)
output_max_frames = (
    output_df
    .groupby(['game_id', 'play_id'])[['frame_id']]
    .max()
    .reset_index()
    .rename(columns={'frame_id': 'throw_land_frame_id'})
)

# Combine both into baseline frame info
baseline_frame_info = input_max_frames.merge(
    output_max_frames,
    on=['game_id', 'play_id'],
    how='outer'
)

print(f"Baseline frame info shape: {baseline_frame_info.shape}")
print(f"Unique plays: {baseline_frame_info.shape[0]}")
baseline_frame_info.head(2)

Baseline frame info shape: (14108, 6)
Unique plays: 14108


Unnamed: 0,game_id,play_id,throw_frame_id,ball_land_x_std,ball_land_y_std,throw_land_frame_id
0,2023090700,101,26,21.259998,-0.22,21
1,2023090700,194,32,4.059998,31.55,9


In [6]:
# Create all play-level features
qb_frame = input_df[input_df['player_role'] == 'Passer']
if qb_frame[['game_id', 'play_id']].drop_duplicates().shape[0] < len(distinct_plays):
    print(f"Warning: fewer plays with QB ({qb_frame[['game_id', 'play_id']].drop_duplicates().shape[0]}) than original plays ({len(distinct_plays)})")

# Get QB max frame for plays with a passer
qb_max_frame = (
    qb_frame
    .groupby(['game_id', 'play_id', 'nfl_id', 'player_role'])['frame_id']
    .max()
    .reset_index()
)

# Find plays without a passer
plays_with_qb = qb_max_frame[['game_id', 'play_id']].drop_duplicates()
plays_without_qb = (
    distinct_plays
    .merge(plays_with_qb, on=['game_id', 'play_id'], how='left', indicator=True)
    .query('_merge == "left_only"')
    .drop(columns=['_merge'])
)

# For plays without a passer, use the overall max frame_id
if len(plays_without_qb) > 0:
    print(f"Found {len(plays_without_qb)} plays without a Passer. Using overall max frame_id.")
    
    missing_max_frames = (
        input_df
        .merge(plays_without_qb, on=['game_id', 'play_id'])
        .groupby(['game_id', 'play_id'])['frame_id']
        .max()
        .reset_index()
    )
    
    # Add placeholder columns for nfl_id and player_role
    missing_max_frames['nfl_id'] = None
    missing_max_frames['player_role'] = None
    
    # Combine with QB frames
    qb_max_frame = pd.concat([qb_max_frame, missing_max_frames], ignore_index=True)

# Join back to input_df to get the full row data
qb_rows = pd.merge(
    input_df, 
    qb_max_frame, 
    on=['game_id', 'play_id', 'nfl_id', 'frame_id', 'player_role'], 
    how='inner'
)

# Start with qb_rows
qb_sub = qb_rows.copy()

# Calculate derived features
qb_sub['qb_throw_distance'] = np.sqrt((qb_sub['ball_land_x_std'] - qb_sub['x_std'])**2 + (qb_sub['ball_land_y_std'] - qb_sub['y_std'])**2)
qb_sub['qb_ball_dir'] = (90 - np.degrees(np.arctan2(
    qb_sub['ball_land_y_std'] - qb_sub['y_std'],
    qb_sub['ball_land_x_std'] - qb_sub['x_std']
))) % 360
qb_sub['qb_direction_diff'] = (qb_sub['o_std'] - qb_sub['qb_ball_dir'] + 180) % 360 - 180  # difference between -180 and 180

# Rename frame_id to be QB-specific
qb_sub.rename(columns={'frame_id':'throw_frame_id'}, inplace=True)

# Drop player_to_predict column (not needed for QB)
qb_sub = qb_sub.drop(columns=['player_to_predict'])

# Rename QB kinematic fields to have qb_ prefix
qb_kinematic_fields_rename = {
    "x_std": "qb_x_std",
    "y_std": "qb_y_std",
    "o_std": "qb_o_std",
    "dir_std": "qb_dir_std",
    "s": "qb_s",
    "a": "qb_a"
}
qb_sub = qb_sub.rename(columns=qb_kinematic_fields_rename)

qb_sub = qb_sub.drop(columns=["ball_land_x_std","ball_land_y_std"])

qb_sub.head(3)

Found 3 plays without a Passer. Using overall max frame_id.


Unnamed: 0,game_id,play_id,nfl_id,throw_frame_id,play_direction,absolute_yardline_number,player_name,player_height,player_weight,player_birth_date,player_position,player_side,player_role,x,y,qb_s,qb_a,dir,o,num_frames_output,ball_land_x,ball_land_y,week,absolute_yardline_number_std,qb_x_std,qb_y_std,qb_o_std,qb_dir_std,qb_throw_distance,qb_ball_dir,qb_direction_diff
0,2023090700,101,43290,26,right,42,Jared Goff,6-4,223,1994-10-14,QB,Offense,Passer,35.41,29.99,0.64,0.47,108.83,212.25,21,63.259998,-0.22,1,42,-6.59,29.99,212.25,108.83,41.08852,137.327657,74.922343
1,2023090700,194,44822,32,left,89,Patrick Mahomes,6-3,230,1995-09-17,QB,Offense,Passer,97.62,29.67,0.96,1.64,185.14,285.7,9,84.940002,21.75,1,31,-8.62,23.63,105.7,5.14,14.950209,58.010861,47.689139
2,2023090700,219,44822,17,left,79,Patrick Mahomes,6-3,230,1995-09-17,QB,Offense,Passer,85.87,22.97,1.49,2.76,133.64,245.38,8,75.849998,11.49,1,41,-6.87,30.33,65.38,313.64,15.237809,41.115185,24.264815


In [7]:
# Just prove only one player per output
input_unique_players = input_df[['game_id', 'play_id', 'nfl_id', 'player_role', 'player_side']].drop_duplicates()
output_unique_players = output_df[['game_id', 'play_id', 'nfl_id']].drop_duplicates()

a = output_unique_players.merge(input_unique_players[['game_id','play_id','nfl_id','player_role','player_side']], on=['game_id', 'play_id', 'nfl_id'], how='inner', indicator=True)
b = a.loc[a['player_side'] == 'Offense', ['game_id','play_id','nfl_id']].groupby(['game_id','play_id']).nunique().reset_index()
b['nfl_id'].value_counts()

nfl_id
1    14108
Name: count, dtype: int64

In [8]:
# Create all play-level features
qb_frame = input_df[input_df['player_role'] == 'Passer']
if qb_frame[['game_id', 'play_id']].drop_duplicates().shape[0] < len(distinct_plays):
    print(f"Warning: fewer plays with QB ({qb_frame[['game_id', 'play_id']].drop_duplicates().shape[0]}) than original plays ({len(distinct_plays)})")

# Get QB max frame for plays with a passer
qb_max_frame = (
    qb_frame
    .groupby(['game_id', 'play_id', 'nfl_id', 'player_role'])['frame_id']
    .max()
    .reset_index()
)

# Find plays without a passer
plays_with_qb = qb_max_frame[['game_id', 'play_id']].drop_duplicates()
plays_without_qb = (
    distinct_plays
    .merge(plays_with_qb, on=['game_id', 'play_id'], how='left', indicator=True)
    .query('_merge == "left_only"')
    .drop(columns=['_merge'])
)

# For plays without a passer, use the overall max frame_id
if len(plays_without_qb) > 0:
    print(f"Found {len(plays_without_qb)} plays without a Passer. Using overall max frame_id.")
    
    missing_max_frames = (
        input_df
        .merge(plays_without_qb, on=['game_id', 'play_id'])
        .groupby(['game_id', 'play_id'])['frame_id']
        .max()
        .reset_index()
    )
    
    # Add placeholder columns for nfl_id and player_role
    missing_max_frames['nfl_id'] = None
    missing_max_frames['player_role'] = None
    
    # Combine with QB frames
    qb_max_frame = pd.concat([qb_max_frame, missing_max_frames], ignore_index=True)

# Join back to input_df to get the full row data
qb_rows = pd.merge(
    input_df, 
    qb_max_frame, 
    on=['game_id', 'play_id', 'nfl_id', 'frame_id', 'player_role'], 
    how='inner'
)

# Start with qb_rows
qb_sub = qb_rows.copy()

# Calculate derived features
qb_sub['qb_throw_distance'] = np.sqrt((qb_sub['ball_land_x_std'] - qb_sub['x_std'])**2 + (qb_sub['ball_land_y_std'] - qb_sub['y_std'])**2)
qb_sub['qb_ball_dir'] = (90 - np.degrees(np.arctan2(
    qb_sub['ball_land_y_std'] - qb_sub['y_std'],
    qb_sub['ball_land_x_std'] - qb_sub['x_std']
))) % 360
qb_sub['qb_direction_diff'] = (qb_sub['o_std'] - qb_sub['qb_ball_dir'] + 180) % 360 - 180  # difference between -180 and 180

# Rename frame_id to be QB-specific
qb_sub.rename(columns={'frame_id':'throw_frame_id'}, inplace=True)

# Drop player_to_predict column (not needed for QB)
qb_sub = qb_sub.drop(columns=['player_to_predict'])

# Rename QB kinematic fields to have qb_ prefix
qb_kinematic_fields_rename = {
    "x_std": "qb_x_std",
    "y_std": "qb_y_std",
    "o_std": "qb_o_std",
    "dir_std": "qb_dir_std",
    "s": "qb_s",
    "a": "qb_a"
}
qb_sub = qb_sub.rename(columns=qb_kinematic_fields_rename)

qb_sub = qb_sub.drop(columns=["ball_land_x_std","ball_land_y_std"])

qb_sub.head(3)

Found 3 plays without a Passer. Using overall max frame_id.


Unnamed: 0,game_id,play_id,nfl_id,throw_frame_id,play_direction,absolute_yardline_number,player_name,player_height,player_weight,player_birth_date,player_position,player_side,player_role,x,y,qb_s,qb_a,dir,o,num_frames_output,ball_land_x,ball_land_y,week,absolute_yardline_number_std,qb_x_std,qb_y_std,qb_o_std,qb_dir_std,qb_throw_distance,qb_ball_dir,qb_direction_diff
0,2023090700,101,43290,26,right,42,Jared Goff,6-4,223,1994-10-14,QB,Offense,Passer,35.41,29.99,0.64,0.47,108.83,212.25,21,63.259998,-0.22,1,42,-6.59,29.99,212.25,108.83,41.08852,137.327657,74.922343
1,2023090700,194,44822,32,left,89,Patrick Mahomes,6-3,230,1995-09-17,QB,Offense,Passer,97.62,29.67,0.96,1.64,185.14,285.7,9,84.940002,21.75,1,31,-8.62,23.63,105.7,5.14,14.950209,58.010861,47.689139
2,2023090700,219,44822,17,left,79,Patrick Mahomes,6-3,230,1995-09-17,QB,Offense,Passer,85.87,22.97,1.49,2.76,133.64,245.38,8,75.849998,11.49,1,41,-6.87,30.33,65.38,313.64,15.237809,41.115185,24.264815


In [9]:
qb_features = ["qb_x_std", 
               "qb_y_std", 
               "qb_s", 
               "qb_a", 
               "qb_dir_std", 
               "qb_o_std", 
               "qb_throw_distance", 
               "qb_ball_dir"]

play_level_features = baseline_frame_info.merge(
  qb_sub[['game_id','play_id'] + qb_features], 
  how = 'left', 
  on = ['game_id','play_id'])

def impute_qb_features_safe(df: pd.DataFrame) -> pd.DataFrame:
    """
    Fill missing QB features using ball trajectory (always available)
    This is 'safe' because ball_land_x/y are inputs, not targets
    """
    mask = df['qb_x_std'].isnull()
    
    if mask.sum() > 0:
        # Proxy: assume QB was ~10 yards behind ball landing
        df.loc[mask, 'qb_x_std'] = df.loc[mask, 'ball_land_x_std'] - 10
        df.loc[mask, 'qb_y_std'] = 26.7  # assume center of field
        
        # Proxy: assume QB was stationary (conservative)
        df.loc[mask, 'qb_s'] = 0.0
        df.loc[mask, 'qb_a'] = 0.0
      
        # Throw distance from imputed position
        df.loc[mask, 'qb_throw_distance'] = np.sqrt(
            (df.loc[mask, 'ball_land_x_std'] - df.loc[mask, 'qb_x_std'])**2 +
            (df.loc[mask, 'ball_land_y_std'] - df.loc[mask, 'qb_y_std'])**2
        )

        # Proxy: QB facing ball direction
        df.loc[mask, 'qb_o_std'] = (90 - np.degrees(np.arctan2(
            df.loc[mask, 'ball_land_y_std'] - df.loc[mask, 'qb_y_std'],
            df.loc[mask, 'ball_land_x_std'] - df.loc[mask, 'qb_x_std']
        ))) % 360
        df.loc[mask, 'qb_dir_std'] = df.loc[mask, 'qb_o_std']

        df.loc[mask, 'qb_ball_dir'] = (90 - np.degrees(np.arctan2(
            df.loc[mask, 'ball_land_y_std'] - df.loc[mask, 'qb_y_std'],
            df.loc[mask, 'ball_land_x_std'] - df.loc[mask, 'qb_x_std']
        ))) % 360
    
    return df

# Apply BEFORE split
play_level_features = impute_qb_features_safe(play_level_features)


In [10]:
x_data = baseline_frame_info[['game_id','play_id','throw_frame_id']].merge(
                                 input_df[input_df['player_to_predict'] == True],
                                 left_on = ['game_id','play_id','throw_frame_id'],
                                 right_on = ['game_id','play_id','frame_id'],
                                 how = 'inner')
player_level_features = ['player_height',
                         'player_weight',
                         'player_birth_date',
                         'player_position',
                         'player_side',
                         'player_role',
                         'x_std',
                         'y_std',
                         'o_std',
                         'dir_std',
                         's',
                         'a']
x_data = x_data[['game_id','play_id','nfl_id'] + player_level_features].copy()
x_data = x_data.merge(play_level_features, on = ['game_id','play_id'])


def height_to_inches(col):
    # col: pandas Series of "6-1" strings
    split_vals = col.str.split("-", expand=True)
    feet = split_vals[0].astype(float)
    inches = split_vals[1].astype(float)
    return feet * 12 + inches

x_data["height_in"] = height_to_inches(x_data["player_height"])
# Age in years (super rough)
x_data["birth_year"] = pd.to_datetime(x_data["player_birth_date"]).dt.year


# Encode angles as sin/cos
for col in ["dir_std", "o_std", "qb_o_std", "qb_dir_std", "qb_ball_dir"]:
    rad = np.deg2rad(x_data[col])
    x_data[col + "_sin"] = np.sin(rad)
    x_data[col + "_cos"] = np.cos(rad)


# Calculate speed in directions
o_rad = np.deg2rad(x_data['o_std'])
x_data["s_x_std"] = x_data['s'] * np.sin(o_rad)
x_data["s_y_std"] = x_data['s'] * np.cos(o_rad)


x_data.sort_values(['game_id','play_id','nfl_id'], inplace=True)
x_data.head(3)


Unnamed: 0,game_id,play_id,nfl_id,player_height,player_weight,player_birth_date,player_position,player_side,player_role,x_std,y_std,o_std,dir_std,s,a,throw_frame_id,ball_land_x_std,ball_land_y_std,throw_land_frame_id,qb_x_std,qb_y_std,qb_s,qb_a,qb_dir_std,qb_o_std,qb_throw_distance,qb_ball_dir,height_in,birth_year,dir_std_sin,dir_std_cos,o_std_sin,o_std_cos,qb_o_std_sin,qb_o_std_cos,qb_dir_std_sin,qb_dir_std_cos,qb_ball_dir_sin,qb_ball_dir_cos,s_x_std,s_y_std
2,2023090700,101,44930,6-3,196,1995-02-16,WR,Offense,Targeted Receiver,10.43,14.14,106.8,99.25,7.9,2.68,26,21.259998,-0.22,21,-6.59,29.99,0.64,0.47,108.83,212.25,41.08852,137.327657,75.0,1995,0.986996,-0.160743,0.957319,-0.289032,-0.533615,-0.845728,0.94648,-0.322761,0.677805,-0.735242,7.562824,-2.283351
0,2023090700,101,46137,6-1,204,1997-02-15,SS,Defense,Defensive Coverage,13.82,17.67,184.99,134.17,5.34,1.8,26,21.259998,-0.22,21,-6.59,29.99,0.64,0.47,108.83,212.25,41.08852,137.327657,73.0,1997,0.717276,-0.69679,-0.086982,-0.99621,-0.533615,-0.845728,0.94648,-0.322761,0.677805,-0.735242,-0.464483,-5.319761
1,2023090700,101,52546,6-1,193,1997-01-21,CB,Defense,Defensive Coverage,6.01,12.44,309.47,192.18,2.93,4.75,26,21.259998,-0.22,21,-6.59,29.99,0.64,0.47,108.83,212.25,41.08852,137.327657,73.0,1997,-0.210984,-0.97749,-0.771958,0.635674,-0.533615,-0.845728,0.94648,-0.322761,0.677805,-0.735242,-2.261836,1.862525


In [12]:
y_data = output_df.merge(
    baseline_frame_info[['game_id','play_id']], 
    on=['game_id','play_id']
)

y_data.sort_values(['game_id','play_id','nfl_id', 'frame_id'], inplace=True)

In [13]:
def hybrid_trajectory_interpolation(x_data, y_data, frame_rate=10, blend_factor=0.5):
    """
    Hybrid: blend velocity projection (early) with ball-directed (late)
    blend_factor: 0 = pure velocity, 1 = pure ball-directed
    """
    results = []
    
    for idx, row in x_data.iterrows():
        if idx % 10000 == 0:
            print(f"Processing row {idx}/{len(x_data)}")
        gid = row['game_id']
        pid = row['play_id']
        nid = row['nfl_id']
        
        x_throw = row['x_std']
        y_throw = row['y_std']
        vx = row['s_x_std']
        vy = row['s_y_std']
        x_land = row['ball_land_x_std']
        y_land = row['ball_land_y_std']
        throw_frame = row['throw_frame_id']
        
        traj_frames = y_data[
            (y_data['game_id'] == gid) &
            (y_data['play_id'] == pid) &
            (y_data['nfl_id'] == nid)
        ].sort_values('frame_id')
        
        if traj_frames.empty:
            continue
        
        frame_ids = traj_frames['frame_id'].values
        n_frames = len(frame_ids)
        
        for i, fid in enumerate(frame_ids):
            dt = (fid) / frame_rate
            t_norm = i / max(n_frames - 1, 1)  # 0 to 1
            
            # Velocity projection
            x_vel = x_throw + vx * dt
            y_vel = y_throw + vy * dt
            
            # Ball-directed interpolation
            x_ball = x_throw + t_norm * (x_land - x_throw)
            y_ball = y_throw + t_norm * (y_land - y_throw)
            
            # Blend: early frames favor velocity, late frames favor ball
            alpha = t_norm * blend_factor
            x_hybrid = (1 - alpha) * x_vel + alpha * x_ball
            y_hybrid = (1 - alpha) * y_vel + alpha * y_ball
            
            results.append({
                'game_id': gid,
                'play_id': pid,
                'nfl_id': nid,
                'frame_id': fid,
                'x_std_hybrid': x_hybrid,
                'y_std_hybrid': y_hybrid,
            })
    
    return pd.DataFrame(results)

# Generate hybrid trajectories
hybrid_traj = hybrid_trajectory_interpolation(x_data, y_data, blend_factor=0.7)
y_with_hybrid = y_data.merge(hybrid_traj, on=['game_id', 'play_id', 'nfl_id', 'frame_id'])

y_with_hybrid.shape

Processing row 0/46045
Processing row 10000/46045
Processing row 20000/46045
Processing row 30000/46045
Processing row 40000/46045


(562936, 14)

In [14]:
import numpy as np

def calculate_kaggle_rmse(df):
    """
    Calculate RMSE per Kaggle's formula
    df should have: x_std, y_std (actual), x_std_hybrid, y_std_hybrid (predicted)
    """
    # Calculate squared errors per frame
    squared_errors = (
        (df['x_std'] - df['x_std_hybrid'])**2 + 
        (df['y_std'] - df['y_std_hybrid'])**2
    )
    
    # RMSE = sqrt(mean of squared distances)
    rmse = np.sqrt(squared_errors.mean())
    
    return rmse

# Calculate overall RMSE
overall_rmse = calculate_kaggle_rmse(y_with_hybrid)
print(f"\n{'='*50}")
print(f"üèà Hybrid Baseline RMSE: {overall_rmse:.4f} yards")
print(f"{'='*50}\n")

# Calculate per-frame RMSE (to see if error grows over time)
# frame_rmse = y_with_hybrid.groupby('frame_id').apply(
#     lambda g: np.sqrt(((g['x_std'] - g['x_std_hybrid'])**2 + 
#                        (g['y_std'] - g['y_std_hybrid'])**2).mean())
# ).reset_index(name='rmse')

# print("RMSE by frame:")
# print(frame_rmse.head(15))

# # Calculate per-play RMSE (to identify hardest plays)
# play_rmse = y_with_hybrid.groupby(['game_id', 'play_id']).apply(
#     lambda g: np.sqrt(((g['x_std'] - g['x_std_hybrid'])**2 + 
#                        (g['y_std'] - g['y_std_hybrid'])**2).mean())
# ).reset_index(name='rmse')

# print(f"\nPlay-level RMSE statistics:")
# print(play_rmse['rmse'].describe())
# print(f"\nWorst 5 plays:")
# print(play_rmse.nlargest(5, 'rmse'))


üèà Hybrid Baseline RMSE: 3.8054 yards



In [15]:
y_with_hybrid['target_dx'] = y_with_hybrid['x_std_hybrid'] - y_with_hybrid['x_std']
y_with_hybrid['target_dy'] = y_with_hybrid['y_std_hybrid'] - y_with_hybrid['y_std']

y_with_hybrid.head(10)
y_data = y_with_hybrid[['game_id','play_id','nfl_id','frame_id','target_dx','target_dy']].copy()

In [16]:
interaction_features = ['x_std','y_std','s_x_std','s_y_std','height_in']

inv_numeric_features = [
    # Predicted player features
    "height_in", "player_weight", "birth_year",
    # Predicted player kinematics
    "x_std", "y_std",
    "s_x_std", "s_y_std",
    "a",  # if present
    "dir_std_sin", "dir_std_cos",
    "o_std_sin", "o_std_cos",
    
    # QB kinematics
    "qb_x_std", "qb_y_std", "qb_s", "qb_a",
    "qb_o_std_sin", "qb_o_std_cos",
    "qb_dir_std_sin", "qb_dir_std_cos",
    
    # Throw features - global
    "throw_frame_id", "throw_land_frame_id",
    "ball_land_x_std", "ball_land_y_std",
    # Time of throw - needs QB kinematics
    "qb_throw_distance", 
    "qb_ball_dir_sin", "qb_ball_dir_cos",
]

inv_categorical_features = [
    "player_position",
    "player_side",
    "player_role",
]

from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder


preproc_invariant = ColumnTransformer(
    transformers=[
        ("num", "passthrough", inv_numeric_features),
        ("cat", OneHotEncoder(handle_unknown="ignore", sparse_output=False), inv_categorical_features),
    ]
)

preproc_invariant.fit(x_data[inv_numeric_features + inv_categorical_features])

0,1,2
,transformers,"[('num', ...), ('cat', ...)]"
,remainder,'drop'
,sparse_threshold,0.3
,n_jobs,
,transformer_weights,
,verbose,False
,verbose_feature_names_out,True
,force_int_remainder_cols,'deprecated'

0,1,2
,categories,'auto'
,drop,
,sparse_output,False
,dtype,<class 'numpy.float64'>
,handle_unknown,'ignore'
,min_frequency,
,max_categories,
,feature_name_combiner,'concat'


In [28]:
import torch
from torch.utils.data import Dataset
from tqdm import tqdm

class PlayDataset(Dataset):
    def __init__(self, x_data, y_data, interaction_features,
                 inv_numeric_features, inv_categorical_features,
                 preproc_invariant, device="cpu"):
        """
        x_data: throw-frame dataframe (one row per (game, play, nfl_id) at throw)
        y_data: output dataframe with (game, play, nfl_id, frame_id, target_dx, target_dy)
        """
        self.device = device
        self.interaction_features = interaction_features
        self.inv_numeric_features = inv_numeric_features
        self.inv_categorical_features = inv_categorical_features
        self.preproc_invariant = preproc_invariant

        # Build list of plays
        self.plays = []
        self.samples = []
        for (gid, pid), play_df_all in tqdm(x_data.groupby(["game_id", "play_id"])):
            play_df = play_df_all.sort_values("nfl_id").reset_index(drop=True)
            nfl_ids = play_df["nfl_id"].tolist()

            # Gather output rows for each player
            frames_per_player = []
            targets_per_player = []
            T_max = 0

            for nid in nfl_ids:
                out_rows = (
                    y_data
                    .query("game_id == @gid and play_id == @pid and nfl_id == @nid")
                    .sort_values("frame_id")
                )
                if out_rows.empty:
                    continue
                frames = out_rows["frame_id"].to_numpy()
                targets = out_rows[["target_dx", "target_dy"]].to_numpy(dtype="float32")
                frames_per_player.append(frames)
                targets_per_player.append(targets)
                T_max = max(T_max, len(frames))

            if len(frames_per_player) == 0:
                continue

            # Normalize time 0..1 using max length in this play
            # Here we just use frame index within each player's sequence
            # (you can also use true time in seconds if you prefer)
            t_norm = torch.linspace(0.0, 1.0, steps=T_max, dtype=torch.float32)

            # We'll pad targets to (N, T_max, 2), with mask
            N = len(targets_per_player)
            targets_tensor = torch.zeros(N, T_max, 2, dtype=torch.float32)
            mask = torch.zeros(N, T_max, dtype=torch.bool)

            for i, targ in enumerate(targets_per_player):
                Ti = targ.shape[0]
                targets_tensor[i, :Ti, :] = torch.from_numpy(targ)
                mask[i, :Ti] = True

            # Store info for this play
            self.plays.append({
                "gid": gid,
                "pid": pid,
                "play_df": play_df,
                "targets": targets_tensor,
                "mask": mask,
                "t_norm": t_norm,
            })

            X_pair, X_inv = self._build_pairwise_and_invariant(play_df)
            self.samples.append((X_pair, X_inv, t_norm, targets_tensor, mask))
            # self.samples.append({"X_pair": X_pair, "X_inv": X_inv, "t_norm": t_norm, "targets": targets_tensor, "mask": mask})
            

    def __len__(self):
        return len(self.samples)

    def _build_pairwise_and_invariant(self, play_df: pd.DataFrame) -> tuple[torch.Tensor, torch.Tensor]:
        # ---- pairwise grid (F_int, N, N) ----
        import numpy as np

        X_int = play_df[self.interaction_features].to_numpy(dtype=np.float32)  # (N, F_int)
        N, F_int = X_int.shape
        feat_i = X_int[:, None, :]                # (N, 1, F_int)
        feat_j = X_int[None, :, :]                # (1, N, F_int)
        pair_diff = feat_j - feat_i               # (N, N, F_int)
        X_pair = np.transpose(pair_diff, (2, 0, 1))  # (F_int, N, N)

        # ---- invariant features (N, F_inv) ----
        X_inv = self.preproc_invariant.transform(
            play_df[self.inv_numeric_features + self.inv_categorical_features]
        )
        X_inv = X_inv.astype("float32")

        return torch.from_numpy(X_pair), torch.from_numpy(X_inv)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.samples[idx]

# Build dataset
full_dataset = PlayDataset(
    x_data=x_data,
    y_data=y_data,  # with proper residual targets!
    interaction_features=interaction_features,
    inv_numeric_features=inv_numeric_features,
    inv_categorical_features=inv_categorical_features,
    preproc_invariant=preproc_invariant,
)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14108/14108 [02:18<00:00, 102.11it/s]


In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PairwiseInteractionEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels=64, out_channels=64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
        self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1)
        self.conv3 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # x: (B, F_int, N, N)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))   # (B, C, N, N)
        x = x.mean(dim=3)           # pool over "other player" j ‚Üí (B, C, N)
        x = x.permute(0, 2, 1)      # ‚Üí (B, N, C)
        return x

class TimeConditionedMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim=128, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        # x: (..., in_dim)
        return self.net(x)

class FullModel(nn.Module):
    def __init__(self, in_channels, inv_dim, hidden_dim=128, enc_hidden=64, enc_out=64):
        super().__init__()
        self.encoder = PairwiseInteractionEncoder(
            in_channels=in_channels,
            hidden_channels=enc_hidden,
            out_channels=enc_out,
        )
        self.mlp = TimeConditionedMLP(
            in_dim=enc_out + inv_dim + 1,  # +1 for time feature
            hidden_dim=hidden_dim,
            out_dim=2,
        )

    def forward(self, X_pair, X_inv, t_norm, mask):
        """
        X_pair: (B, F_int, N, N)
        X_inv:  (B, N, F_inv)
        t_norm: (B, T_max)
        mask:   (B, N, T_max)  (bool) ‚Äì True where target is valid
        """
        B, F_int, N, _ = X_pair.shape
        _, N_inv, F_inv = X_inv.shape
        _, T_max = t_norm.shape

        assert N == N_inv, "Mismatch in N between pairwise and inv features"

        # --- Encode interactions ---
        z_int = self.encoder(X_pair)    # (B, N, C)

        # --- Prepare features over time ---
        # z_int:     (B, N, C)     ‚Üí (B, N, T, C)
        # X_inv:     (B, N, F_inv) ‚Üí (B, N, T, F_inv)
        # t_norm:    (B, T)        ‚Üí (B, 1, T, 1) broadcast to (B, N, T, 1)
        C = z_int.shape[-1]
        z_int_exp = z_int.unsqueeze(2).expand(B, N, T_max, C)          # (B, N, T, C)
        X_inv_exp = X_inv.unsqueeze(2).expand(B, N, T_max, F_inv)      # (B, N, T, F_inv)
        t_exp     = t_norm.unsqueeze(1).unsqueeze(-1).expand(
            B, N, T_max, 1
        )  # (B, N, T, 1)

        feat = torch.cat([z_int_exp, X_inv_exp, t_exp], dim=-1)        # (B, N, T, C+F_inv+1)

        # Flatten players and time to feed MLP
        feat_flat = feat.view(B * N * T_max, -1)       # (B*N*T, in_dim)
        out_flat  = self.mlp(feat_flat)                # (B*N*T, 2)
        out       = out_flat.view(B, N, T_max, 2)      # (B, N, T, 2)

        # Apply mask in loss outside (we return full out)
        return out

In [57]:
from torch.utils.data import DataLoader
import numpy as np

# For now, simple random split by index (you can do group splits by game_id if you like)
# Dataset is already at the play level, so this way of splitting is fine
n = len(full_dataset)
idxs = np.arange(n)
np.random.seed(42)
np.random.shuffle(idxs)

n_train = int(0.7 * n)
n_val   = int(0.15 * n)
train_idx = idxs[:n_train]
val_idx   = idxs[n_train:n_train+n_val]
test_idx  = idxs[n_train+n_val:]

from torch.utils.data import Subset

train_ds = Subset(full_dataset, train_idx)
val_ds   = Subset(full_dataset, val_idx)
test_ds  = Subset(full_dataset, test_idx)

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Infer dims
F_int = len(interaction_features)
# Get one batch to determine inv_dim
X_pair0, X_inv0, t_norm0, targets0, mask0 = next(iter(train_loader))
inv_dim = X_inv0.shape[-1]

model = FullModel(
    in_channels=F_int,
    inv_dim=inv_dim,
    hidden_dim=256,
    enc_hidden=64,
    enc_out=64,
).to(device)

criterion = nn.MSELoss(reduction="sum")  # we'll divide by #valid later
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

def run_epoch(loader, train=True):
    if train:
        model.train()
    else:
        model.eval()

    total_loss = 0.0
    total_count = 0

    for X_pair, X_inv, t_norm, targets, mask in tqdm(loader):
        X_pair  = X_pair.to(device).float()          # (B=1, F_int, N, N)
        X_inv   = X_inv.to(device).float()           # (B=1, N, F_inv)
        t_norm  = t_norm.to(device).float()          # (B=1, T)
        targets = targets.to(device).float()         # (B=1, N, T, 2)
        mask    = mask.to(device)                    # (B=1, N, T)

        if train:
            optimizer.zero_grad()

        preds = model(X_pair, X_inv, t_norm, mask)   # (B, N, T, 2)

        # Only count valid frames
        mask_expanded = mask.unsqueeze(-1).expand_as(preds)  # (B, N, T, 2)
        diff = (preds - targets) * mask_expanded
        loss = criterion(diff, torch.zeros_like(diff))

        valid_count = mask.sum().item() * 2  # *2 because x and y
        if valid_count == 0:
            continue

        loss = loss / valid_count  # mean over valid coordinates

        if train:
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        total_count += 1

    return total_loss / max(total_count, 1)

num_epochs = 90
best_val = float("inf")
best_state = None
current_ts_abbreviated = __import__('datetime').datetime.now().strftime("%Y%m%d_%H%M%S")

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss = run_epoch(train_loader, train=True)
    val_loss   = run_epoch(val_loader,   train=False)
    print(f"Epoch {epoch+1}: train={train_loss:.4f}, val={val_loss:.4f}")
    if val_loss < best_val:
        best_val = val_loss
        best_state = model.state_dict().copy()
        torch.save(best_state, f"best_model_{current_ts_abbreviated}.pth")
        print(f"  New best model saved with val loss {best_val:.4f}")



Epoch 1/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:12<00:00, 820.98it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 4024.03it/s]


Epoch 1: train=4.0931, val=2.5732
  New best model saved with val loss 2.5732
Epoch 2/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:13<00:00, 756.32it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 3972.16it/s]


Epoch 2: train=2.5796, val=2.1023
  New best model saved with val loss 2.1023
Epoch 3/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:11<00:00, 855.35it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 3927.74it/s]


Epoch 3: train=2.1832, val=2.1143
Epoch 4/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:10<00:00, 905.51it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 4200.47it/s]


Epoch 4: train=1.9676, val=1.6854
  New best model saved with val loss 1.6854
Epoch 5/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:11<00:00, 835.57it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 4053.58it/s]


Epoch 5: train=1.8043, val=1.4874
  New best model saved with val loss 1.4874
Epoch 6/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:10<00:00, 912.42it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 4034.58it/s]


Epoch 6: train=1.6338, val=1.2593
  New best model saved with val loss 1.2593
Epoch 7/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:11<00:00, 852.25it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 4163.07it/s]


Epoch 7: train=1.4725, val=1.1858
  New best model saved with val loss 1.1858
Epoch 8/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:11<00:00, 880.59it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 3843.55it/s]


Epoch 8: train=1.3633, val=1.1624
  New best model saved with val loss 1.1624
Epoch 9/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:11<00:00, 867.99it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 3806.83it/s]


Epoch 9: train=1.3163, val=1.1346
  New best model saved with val loss 1.1346
Epoch 10/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:12<00:00, 822.45it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 4330.85it/s]


Epoch 10: train=1.2476, val=1.5804
Epoch 11/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:10<00:00, 906.89it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 4196.00it/s]


Epoch 11: train=1.1778, val=1.0929
  New best model saved with val loss 1.0929
Epoch 12/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:11<00:00, 855.32it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 3423.65it/s]


Epoch 12: train=1.1163, val=1.0533
  New best model saved with val loss 1.0533
Epoch 13/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:12<00:00, 790.27it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 3782.65it/s]


Epoch 13: train=1.1203, val=1.3516
Epoch 14/90


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9875/9875 [00:14<00:00, 696.80it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2116/2116 [00:00<00:00, 3155.30it/s]


Epoch 14: train=1.0615, val=0.8956
  New best model saved with val loss 0.8956
Epoch 15/90


 55%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 5476/9875 [00:06<00:05, 771.46it/s]

In [52]:
import time

def train_and_eval(model, train_loader, val_loader, num_epochs=10, patience=20, lr=5e-4):
    device = next(model.parameters()).device
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss(reduction="sum")

    best_val = float("inf")
    best_state = None
    bad_epochs = 0

    for epoch in range(num_epochs):
        # ----- train -----
        model.train()
        for X_pair, X_inv, t_norm, targets, mask in tqdm(train_loader):
            X_pair  = X_pair.to(device).float()
            X_inv   = X_inv.to(device).float()
            t_norm  = t_norm.to(device).float()
            targets = targets.to(device).float()
            mask    = mask.to(device)

            optimizer.zero_grad()
            preds = model(X_pair, X_inv, t_norm, mask)

            mask_expanded = mask.unsqueeze(-1).expand_as(preds)
            diff = (preds - targets) * mask_expanded
            loss = criterion(diff, torch.zeros_like(diff))
            valid_count = mask.sum().item() * 2
            if valid_count == 0:
                continue
            loss = loss / valid_count

            loss.backward()
            optimizer.step()

        # ----- validate -----
        model.eval()
        val_loss = 0.0
        n_batches = 0
        with torch.no_grad():
            for X_pair, X_inv, t_norm, targets, mask in tqdm(val_loader):
                X_pair  = X_pair.to(device).float()   
                X_inv   = X_inv.to(device).float()
                t_norm  = t_norm.to(device).float()
                targets = targets.to(device).float()
                mask    = mask.to(device)

                preds = model(X_pair, X_inv, t_norm, mask)
                mask_expanded = mask.unsqueeze(-1).expand_as(preds)
                diff = (preds - targets) * mask_expanded
                loss = criterion(diff, torch.zeros_like(diff))
                valid_count = mask.sum().item() * 2
                if valid_count == 0:
                    continue
                loss = loss / valid_count

                val_loss += loss.item()
                n_batches += 1

        val_loss /= max(n_batches, 1)
        print(f"Epoch {epoch+1}: val={val_loss:.4f}")

        # early stopping
        if val_loss < best_val:
            best_val = val_loss
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            bad_epochs = 0
        else:
            bad_epochs += 1
            if bad_epochs >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    current_ts_abbreviated = time.time().__str__().replace('.', '')[-6:]
    torch.save(model.state_dict(), f"best_model_{current_ts_abbreviated}.pth")
    return best_val

In [53]:
all_idxs = np.arange(len(full_dataset))
np.random.shuffle(all_idxs)

subset_size = int(0.3 * len(all_idxs))
subset_idxs = all_idxs[:subset_size]
subset_train_idxs = subset_idxs[:int(0.7 * subset_size)]
subset_val_idxs   = subset_idxs[int(0.7 * subset_size):]

subset_train_ds = Subset(full_dataset, subset_train_idxs)
subset_val_ds   = Subset(full_dataset, subset_val_idxs)

subset_train_loader = DataLoader(subset_train_ds, batch_size=1, shuffle=True)
subset_val_loader   = DataLoader(subset_val_ds, batch_size=1, shuffle=False)


lrs = [1e-3, 5e-4, 2e-4]
hidden_dims = [64, 128]
enc_hidden = [32, 64]

results = []
for lr in lrs:
    for hd in hidden_dims:
        for eh in enc_hidden:
            model = FullModel(
                in_channels=len(interaction_features),
                inv_dim=inv_dim,
                hidden_dim=hd,
                enc_hidden=eh,
                enc_out=64,
        ).to(device)

        print(f"Testing lr={lr}, hidden_dim={hd}")
        val_loss = train_and_eval(
            model,
            subset_train_loader,
            subset_val_loader,
            num_epochs=8,
            patience=3,
            lr=lr,
        )
        results.append((lr, hd, val_loss))

print(sorted(results, key=lambda x: x[2]))

Testing lr=0.001, hidden_dim=64


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1067.78it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5605.11it/s]


Epoch 1: val=3.4395


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:03<00:00, 929.77it/s] 
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 4679.12it/s]


Epoch 2: val=2.7751


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1147.22it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5682.68it/s]


Epoch 3: val=2.5799


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1166.03it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5778.19it/s]


Epoch 4: val=2.3838


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1168.14it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5694.89it/s]


Epoch 5: val=2.3556


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1145.10it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5783.53it/s]


Epoch 6: val=3.7473


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1170.32it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5707.71it/s]


Epoch 7: val=2.1383


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1164.34it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 3925.84it/s]


Epoch 8: val=2.1492
Testing lr=0.001, hidden_dim=128


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1087.90it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5572.55it/s]


Epoch 1: val=3.1475


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1091.09it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5583.13it/s]


Epoch 2: val=3.1386


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1080.97it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5558.94it/s]


Epoch 3: val=2.6380


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1049.37it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 4632.26it/s]


Epoch 4: val=2.2430


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:03<00:00, 975.08it/s] 
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5573.28it/s]


Epoch 5: val=3.0534


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1090.60it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5573.60it/s]


Epoch 6: val=2.5408


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1097.28it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5606.82it/s]


Epoch 7: val=2.5444
Testing lr=0.0005, hidden_dim=64


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1172.00it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5742.91it/s]


Epoch 1: val=3.0357


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1166.94it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5781.52it/s]


Epoch 2: val=2.8009


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1171.59it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5845.75it/s]


Epoch 3: val=2.5530


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1178.30it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5829.62it/s]


Epoch 4: val=2.4146


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1121.00it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5841.34it/s]


Epoch 5: val=2.3966


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1161.28it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5512.46it/s]


Epoch 6: val=2.3337


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1178.79it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5841.58it/s]


Epoch 7: val=2.2719


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1042.90it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5679.81it/s]


Epoch 8: val=2.1696
Testing lr=0.0005, hidden_dim=128


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1089.46it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5598.29it/s]


Epoch 1: val=3.0112


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1083.43it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5553.49it/s]


Epoch 2: val=3.0530


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1094.04it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5600.61it/s]


Epoch 3: val=2.5582


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1058.35it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5628.41it/s]


Epoch 4: val=2.7617


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1089.35it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5596.19it/s]


Epoch 5: val=2.3149


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1088.26it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5637.78it/s]


Epoch 6: val=2.1667


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1073.96it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5542.10it/s]


Epoch 7: val=2.0895


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1098.87it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5601.55it/s]


Epoch 8: val=2.5958
Testing lr=0.0002, hidden_dim=64


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1142.22it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5801.21it/s]


Epoch 1: val=3.3533


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:03<00:00, 946.89it/s] 
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5633.95it/s]


Epoch 2: val=4.0406


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1177.71it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5789.78it/s]


Epoch 3: val=2.8816


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1096.25it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5827.18it/s]


Epoch 4: val=2.6020


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1173.94it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5847.04it/s]


Epoch 5: val=2.9643


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1182.17it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5715.13it/s]


Epoch 6: val=2.5461


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1181.52it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5817.81it/s]


Epoch 7: val=2.4730


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1166.98it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5833.45it/s]


Epoch 8: val=2.2459
Testing lr=0.0002, hidden_dim=128


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1033.11it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5530.25it/s]


Epoch 1: val=3.0283


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1097.53it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5591.37it/s]


Epoch 2: val=3.0438


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1094.81it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5379.41it/s]


Epoch 3: val=2.8579


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1028.91it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5046.63it/s]


Epoch 4: val=2.6156


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1034.98it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5453.46it/s]


Epoch 5: val=2.5986


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1093.52it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5565.52it/s]


Epoch 6: val=2.5691


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:02<00:00, 1091.65it/s]
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5144.89it/s]


Epoch 7: val=2.3651


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2962/2962 [00:03<00:00, 945.09it/s] 
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1270/1270 [00:00<00:00, 5862.21it/s]

Epoch 8: val=2.4478
[(0.0005, 128, 2.0894919444990205), (0.001, 64, 2.138300970046759), (0.0005, 64, 2.1696047066659556), (0.001, 128, 2.2430069523474834), (0.0002, 64, 2.2459094112606968), (0.0002, 128, 2.3650691633648058)]





In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PairwiseInteractionEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels=64, out_channels=64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
        self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1)
        self.conv3 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # x: (B, F_int, N, N)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))   # (B, C, N, N)
        x = x.mean(dim=3)           # pool over "other player" j ‚Üí (B, C, N)
        x = x.permute(0, 2, 1)      # ‚Üí (B, N, C)
        return x

class TimeConditionedMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim=128, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        # x: (..., in_dim)
        return self.net(x)

class FullModel(nn.Module):
    def __init__(self, in_channels, inv_dim, hidden_dim=128, enc_hidden=64, enc_out=64):
        super().__init__()
        self.encoder = PairwiseInteractionEncoder(
            in_channels=in_channels,
            hidden_channels=enc_hidden,
            out_channels=enc_out,
        )
        self.mlp = TimeConditionedMLP(
            in_dim=enc_out + inv_dim + 1,  # +1 for time feature
            hidden_dim=hidden_dim,
            out_dim=2,
        )

    def forward(self, X_pair, X_inv, t_norm, mask):
        """
        X_pair: (B, F_int, N, N)
        X_inv:  (B, N, F_inv)
        t_norm: (B, T_max)
        mask:   (B, N, T_max)  (bool) ‚Äì True where target is valid
        """
        B, F_int, N, _ = X_pair.shape
        _, N_inv, F_inv = X_inv.shape
        _, T_max = t_norm.shape

        assert N == N_inv, "Mismatch in N between pairwise and inv features"

        # --- Encode interactions ---
        z_int = self.encoder(X_pair)    # (B, N, C)

        # --- Prepare features over time ---
        # z_int:     (B, N, C)     ‚Üí (B, N, T, C)
        # X_inv:     (B, N, F_inv) ‚Üí (B, N, T, F_inv)
        # t_norm:    (B, T)        ‚Üí (B, 1, T, 1) broadcast to (B, N, T, 1)
        C = z_int.shape[-1]
        z_int_exp = z_int.unsqueeze(2).expand(B, N, T_max, C)          # (B, N, T, C)
        X_inv_exp = X_inv.unsqueeze(2).expand(B, N, T_max, F_inv)      # (B, N, T, F_inv)
        t_exp     = t_norm.unsqueeze(1).unsqueeze(-1).expand(
            B, N, T_max, 1
        )  # (B, N, T, 1)

        feat = torch.cat([z_int_exp, X_inv_exp, t_exp], dim=-1)        # (B, N, T, C+F_inv+1)

        # Flatten players and time to feed MLP
        feat_flat = feat.view(B * N * T_max, -1)       # (B*N*T, in_dim)
        out_flat  = self.mlp(feat_flat)                # (B*N*T, 2)
        out       = out_flat.view(B, N, T_max, 2)      # (B, N, T, 2)

        # Apply mask in loss outside (we return full out)
        return out

In [None]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder

preproc_invariant = ColumnTransformer(
    transformers=[
        ("num", "passthrough", inv_numeric_features),
        ("cat", OneHotEncoder(handle_unknown="ignore", sparse_output=False), inv_categorical_features),
    ]
)

preproc_invariant.fit(x_data[inv_numeric_features + inv_categorical_features])

# merge in play_targets to get y per play
x_with_y = x_data.merge(
    y_data[['game_id','play_id','nfl_id','target_dx','target_dy']],
    on=['game_id','play_id','nfl_id'],
    how='inner',
    indicator=True
).query('_merge == "both"').drop(columns=['_merge'])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PairwiseInteractionEncoder(nn.Module):
    """
    Input:  (B, F_int, N, N)  pairwise features
    Output: (B, N, C)         per-player interaction embedding
    """
    def __init__(self, in_channels, hidden_channels=64, out_channels=64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
        self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1)
        self.conv3 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # x: (B, F_int, N, N)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))  # (B, C, N, N)

        # pool over "other player" axis (j), keep i:
        # assume dim 2 = i (row player), dim 3 = j (other)
        x = x.mean(dim=3)          # (B, C, N)

        # reshape to per-player embeddings (B, N, C)
        x = x.permute(0, 2, 1)     # (B, N, C)
        return x

In [24]:
def build_play_embeddings(play_df, encoder: PairwiseInteractionEncoder, device="cpu"):
    play_df = play_df.sort_values("nfl_id").reset_index(drop=True)
    N = len(play_df)

    # Interaction features -> pairwise grid
    X_int = play_df[interaction_features].to_numpy(dtype=np.float32)  # (N, F_int)
    N, F_int = X_int.shape

    feat_i = X_int[:, None, :]                # (N, 1, F_int)
    feat_j = X_int[None, :, :]                # (1, N, F_int)
    pair_diff = feat_j - feat_i               # (N, N, F_int)
    X_pair = np.transpose(pair_diff, (2, 0, 1)).astype(np.float32)  # (F_int, N, N)
    X_pair_t = torch.from_numpy(X_pair).unsqueeze(0).to(device)     # (1, F_int, N, N)

    # Invariant features
    X_inv = preproc_invariant.transform(
        play_df[inv_numeric_features + inv_categorical_features]
    )
    X_inv = np.asarray(X_inv, dtype=np.float32)                      # (N, F_inv)
    X_inv_t = torch.from_numpy(X_inv).to(device).unsqueeze(0)       # (1, N, F_inv)

    # Encode interactions
    # with torch.no_grad():  # (for now: treat encoder as fixed)
    z_int = encoder(X_pair_t)            # (1, N, C)

    # Concatenate interaction + invariant per player
    Z_play = torch.cat([z_int, X_inv_t], dim=-1)  # (1, N, D)

    return Z_play, play_df  # return df so we know which row is which

In [33]:
from tqdm import tqdm

device = "cpu"  # or "cuda" if available
encoder = PairwiseInteractionEncoder(
    in_channels=len(interaction_features),
    hidden_channels=128,
    out_channels=128,
).to(device)

X_list = []  # will hold [z_player || time_features]
y_list = []
play_ids = []  # ‚úÖ NEW: Track which play each sample belongs to


for (gid, pid), play_df_all in tqdm(x_data.groupby(["game_id", "play_id"])):
    play_df = play_df_all.copy()

    # Build embeddings for all players in this play
    Z_play, play_df_sorted = build_play_embeddings(play_df, encoder, device=device)
    Z_play = Z_play.squeeze(0)   # (N, D)

    play_df_sorted = play_df_sorted.reset_index(drop=True)
    N, D = Z_play.shape

    for i in range(N):
        row = play_df_sorted.iloc[i]
        nid = row["nfl_id"]

        # Get this player's future frames
        out_rows = (
            y_data
            .query(
                "game_id == @gid and play_id == @pid and nfl_id == @nid"
            )
            .sort_values("frame_id")
        )

        if out_rows.empty:
            continue
        
        T_i = len(out_rows)

        # Example time feature: normalized time 0..1
        t_norm = (np.arange(T_i, dtype=np.float32) / max(T_i - 1, 1)).reshape(-1, 1)  # (T_i, 1)
        # print(Z_play)
        # Player embedding (D,) -> repeat over T_i frames
        z_i = Z_play[i].detach().cpu().numpy()         # (D,)
        z_rep = np.repeat(z_i[None, :], T_i, axis=0)   # (T_i, D)

        # Concatenate [z_i || t_features]
        X_i_t = np.concatenate([z_rep, t_norm], axis=1)  # (T_i, D+1)

        # Targets: x_t, y_t for each frame
        y_i_t = out_rows[["target_dx", "target_dy"]].to_numpy(dtype=np.float32)  # (T_i, 2)

        X_list.append(X_i_t)
        y_list.append(y_i_t)

        play_ids.extend([f"{gid}_{pid}"] * T_i)

# Stack all (T_i, ‚Ä¶) chunks into one big (num_samples, ‚Ä¶)
X_all = np.concatenate(X_list, axis=0)  # (num_samples, D+1)
Y_all = np.concatenate(y_list, axis=0)  # (num_samples, 2)
play_ids = np.array(play_ids)  # ‚úÖ (num_samples,)

print(X_all.shape, Y_all.shape, play_ids.shape)

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 14108/14108 [02:20<00:00, 100.66it/s]


(562936, 177) (562936, 2) (562936,)


In [28]:
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import GroupShuffleSplit

# Create train/test split grouped by play
splitter = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, test_idx = next(splitter.split(X_all, Y_all, groups=play_ids))

X_train_full = X_all[train_idx]
y_train_full = Y_all[train_idx]
play_ids_train = play_ids[train_idx]

X_test = X_all[test_idx]
y_test = Y_all[test_idx]
play_ids_test = play_ids[test_idx]

print(f"Train: {len(X_train_full)} samples from {len(np.unique(play_ids_train))} plays")
print(f"Test: {len(X_test)} samples from {len(np.unique(play_ids_test))} plays")

# ‚úÖ Verify no overlap
assert len(set(play_ids_train) & set(play_ids_test)) == 0, "Data leakage detected!"

# Second split: train -> train + val (also grouped)
splitter_val = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx2, val_idx = next(splitter_val.split(
    X_train_full, y_train_full, groups=play_ids_train
))

X_train = X_train_full[train_idx2]
y_train = y_train_full[train_idx2]

X_val = X_train_full[val_idx]
y_val = y_train_full[val_idx]

print(f"Final split:")
print(f"  Train: {len(X_train)} samples")
print(f"  Val:   {len(X_val)} samples")
print(f"  Test:  {len(X_test)} samples")

Train: 451172 samples from 11286 plays
Test: 111764 samples from 2822 plays
Final split:
  Train: 360785 samples
  Val:   90387 samples
  Test:  111764 samples


In [29]:
from torch.utils.data import TensorDataset, DataLoader

train_ds = TensorDataset(
    torch.from_numpy(X_train),
    torch.from_numpy(y_train),
)
val_ds = TensorDataset(
    torch.from_numpy(X_val),
    torch.from_numpy(y_val),
)
test_ds = TensorDataset(
    torch.from_numpy(X_test),
    torch.from_numpy(y_test),
)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False)

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

class TimeConditionedMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim=128, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )
    def forward(self, x):
        return self.net(x)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
in_dim = X_all.shape[1]

def train_one_config(hidden_dim, lr, num_epochs=30, patience=5):
    model = TimeConditionedMLP(in_dim=in_dim, hidden_dim=hidden_dim, out_dim=2).to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_val = float("inf")
    best_state = None
    bad_epochs = 0

    for epoch in range(num_epochs):
        # ---- Train ----
        model.train()
        train_loss = 0.0
        n_train = 0

        for xb, yb in train_loader:
            xb = xb.to(device).float()
            yb = yb.to(device).float()

            optimizer.zero_grad()
            preds = model(xb)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * xb.size(0)
            n_train += xb.size(0)

        train_loss /= n_train

        # ---- Validate ----
        model.eval()
        val_loss = 0.0
        n_val = 0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device).float()
                yb = yb.to(device).float()
                preds = model(xb)
                loss = criterion(preds, yb)
                val_loss += loss.item() * xb.size(0)
                n_val += xb.size(0)
        val_loss /= n_val

        print(f"[hd={hidden_dim}, lr={lr}] Epoch {epoch+1}: train={train_loss:.4f}, val={val_loss:.4f}")

        # ---- Early stopping tracking ----
        if val_loss < best_val:
            best_val = val_loss
            best_state = copy.deepcopy(model.state_dict())
            bad_epochs = 0
        else:
            bad_epochs += 1
            if bad_epochs >= patience:
                print(f"Early stopping (no val improvement for {patience} epochs).")
                break

    # Load best weights before returning
    if best_state is not None:
        model.load_state_dict(best_state)

    return model, best_val

In [32]:
hidden_dims = [128, 256]
lrs = [1e-3, 5e-4]

best_cfg = None
best_val = float("inf")
best_model = None

# hidden_dims = [256]
# lrs = [5e-4]

for hd in hidden_dims:
    for lr in lrs:
        print(f"\n=== Training config: hidden_dim={hd}, lr={lr} ===")
        model, val_loss = train_one_config(hidden_dim=hd, lr=lr, num_epochs=75, patience=10)

        print(f"Config (hd={hd}, lr={lr}) finished with best val MSE={val_loss:.4f}")
        if val_loss < best_val:
            best_val = val_loss
            best_cfg = (hd, lr)
            best_model = model

print("\nBest config:", best_cfg, "with val MSE=", best_val)


=== Training config: hidden_dim=128, lr=0.001 ===
[hd=128, lr=0.001] Epoch 1: train=5.9389, val=3.7616
[hd=128, lr=0.001] Epoch 2: train=4.0389, val=3.5694
[hd=128, lr=0.001] Epoch 3: train=3.7739, val=3.8000
[hd=128, lr=0.001] Epoch 4: train=3.5165, val=3.1176
[hd=128, lr=0.001] Epoch 5: train=3.2068, val=2.8053
[hd=128, lr=0.001] Epoch 6: train=2.7851, val=2.3730
[hd=128, lr=0.001] Epoch 7: train=2.1014, val=1.8786
[hd=128, lr=0.001] Epoch 8: train=1.7370, val=1.5704
[hd=128, lr=0.001] Epoch 9: train=1.5280, val=1.4568
[hd=128, lr=0.001] Epoch 10: train=1.4522, val=1.5396
[hd=128, lr=0.001] Epoch 11: train=1.3617, val=1.3525
[hd=128, lr=0.001] Epoch 12: train=1.3262, val=1.4631
[hd=128, lr=0.001] Epoch 13: train=1.2690, val=1.4267
[hd=128, lr=0.001] Epoch 14: train=1.2444, val=1.2320
[hd=128, lr=0.001] Epoch 15: train=1.2066, val=1.2118
[hd=128, lr=0.001] Epoch 16: train=1.1723, val=1.2130
[hd=128, lr=0.001] Epoch 17: train=1.1372, val=1.1835
[hd=128, lr=0.001] Epoch 18: train=1.130

In [108]:
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

X_train, X_test, y_train, y_test = train_test_split(
    X_all, Y_all, test_size=0.2, random_state=42
)

train_ds = TensorDataset(
    torch.from_numpy(X_train),  # (N_samples, D+1)
    torch.from_numpy(y_train),  # (N_samples, 2)
)
test_ds = TensorDataset(
    torch.from_numpy(X_test),
    torch.from_numpy(y_test),
)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=256, shuffle=False)


class TimeConditionedMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim=128, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )
    def forward(self, x):
        return self.net(x)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
in_dim = X_all.shape[1]

model = TimeConditionedMLP(in_dim=in_dim, hidden_dim=128, out_dim=2).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.0025)

for epoch in range(75):
    model.train()
    total_loss = 0.0
    n = 0
    for xb, yb in train_loader:
        xb = xb.to(device).float()
        yb = yb.to(device).float()

        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)
        n += xb.size(0)
    print(f"Epoch {epoch+1}: train MSE={total_loss/n:.4f}")

Epoch 1: train MSE=13.8809
Epoch 2: train MSE=5.3395
Epoch 3: train MSE=3.3607
Epoch 4: train MSE=2.4103
Epoch 5: train MSE=2.0325
Epoch 6: train MSE=1.8106
Epoch 7: train MSE=1.6705
Epoch 8: train MSE=1.5075
Epoch 9: train MSE=1.4319
Epoch 10: train MSE=1.3599
Epoch 11: train MSE=1.3028
Epoch 12: train MSE=1.2703
Epoch 13: train MSE=1.2543
Epoch 14: train MSE=1.2273
Epoch 15: train MSE=1.2061
Epoch 16: train MSE=1.1722
Epoch 17: train MSE=1.1794
Epoch 18: train MSE=1.1544
Epoch 19: train MSE=1.1563
Epoch 20: train MSE=1.1255
Epoch 21: train MSE=1.1455
Epoch 22: train MSE=1.1133
Epoch 23: train MSE=1.1113
Epoch 24: train MSE=1.0993
Epoch 25: train MSE=1.1084
Epoch 26: train MSE=1.0819
Epoch 27: train MSE=1.0758
Epoch 28: train MSE=1.0801
Epoch 29: train MSE=1.0694
Epoch 30: train MSE=1.0679
Epoch 31: train MSE=1.0719
Epoch 32: train MSE=1.0625
Epoch 33: train MSE=1.0572
Epoch 34: train MSE=1.0481
Epoch 35: train MSE=1.0404
Epoch 36: train MSE=1.0338
Epoch 37: train MSE=1.0303
Epoch 38: