In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm import tqdm
import pickle

In [2]:
N_PREV = 1

GAME_ID = 2022100600
PLAY_ID = 90
FRAME_ID = 2

# Loading data

In [3]:
data = pd.read_feather('data/pivoted/tracking_week_5.feather')
data = data[(data.gameId == GAME_ID) & (data.playId == PLAY_ID)]
fixed_cols = ['gameId', 'playId', 'frameId']
player_cols = [col for col in data.columns if col.endswith(('_x', '_y'))]

data = pd.concat(
    [data[fixed_cols]] + 
    [data[player_cols].shift(i) for i in range(1, N_PREV + 1)]
, axis=1)

data.columns = fixed_cols + [f'{col}-{i}' for i in range(1, N_PREV + 1) for col in player_cols]
data.dropna(axis=0, how='any', inplace=True)
data.shape

(59, 49)

In [4]:
data.head(3)

Unnamed: 0,gameId,playId,frameId,p0_x-1,p1_x-1,p2_x-1,p3_x-1,p4_x-1,p5_x-1,p6_x-1,...,p13_y-1,p14_y-1,p15_y-1,p16_y-1,p17_y-1,p18_y-1,p19_y-1,p20_y-1,p21_y-1,p22_y-1
1,2022100600,90,2,85.410004,90.42,85.76,86.42,85.99,90.49,86.47,...,33.01,36.27,25.1,17.73,26.41,21.47,22.91,9.28,29.0,32.61
2,2022100600,90,3,85.410004,90.39,85.76,86.42,86.0,90.48,86.48,...,32.57,36.27,25.1,17.72,26.39,21.46,22.89,9.28,29.0,32.58
3,2022100600,90,4,85.410004,90.36,85.76,86.42,86.0,90.49,86.47,...,32.14,36.26,25.1,17.72,26.38,21.45,22.87,9.28,29.0,32.54


In [5]:
ball_pair = [item for sublist in [[f'p0_x-{j}', f'p0_y-{j}'] for j in range(1, N_PREV + 1)] for item in sublist]
player_pairs = []
for i in range(1, 23):
    pair = []
    for j in range(1, N_PREV + 1):
        pair += [f'p{i}_x-{j}', f'p{i}_y-{j}']
    player_pairs.append(pair)

In [6]:
def remove_pair(pairs, remove_id):
    return pairs[:remove_id] + pairs[remove_id + 1:]

In [7]:
X, Y, X_ball_flag = [], [], []
start_frame = data[data.frameId == FRAME_ID]

# Add ball
X.append(start_frame[sum([ball_pair] + player_pairs, [])].values)
X_ball_flag.append(np.ones((X[-1].shape[0], 1)))
# Y.append(start_frame[['p0_dx', 'p0_dy']].values)

# Add Players
for p_id in range(1, 23):
    X.append(start_frame[sum([ball_pair] + [player_pairs[p_id-1]] + remove_pair(player_pairs, p_id-1), [])].values)
    X_ball_flag.append(np.zeros((X[-1].shape[0], 1)))
    # Y.append(start_frame[[f'p{p_id}_dx', f'p{p_id}_dy']].values)

X = np.concatenate(X, axis=0)
X_ball_flag = np.concatenate(X_ball_flag, axis=0)
X = np.concatenate([X, X_ball_flag], axis=1)
# Y = np.concatenate(Y, axis=0).clip(-CLIP_POINT, CLIP_POINT)

In [8]:
pd.DataFrame(X)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,37,38,39,40,41,42,43,44,45,46
0,85.410004,23.83,90.42,23.74,85.76,23.79,86.42,22.04,85.99,25.49,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,1.0
1,85.410004,23.83,90.42,23.74,85.76,23.79,86.42,22.04,85.99,25.49,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,0.0
2,85.410004,23.83,85.76,23.79,90.42,23.74,86.42,22.04,85.99,25.49,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,0.0
3,85.410004,23.83,86.42,22.04,90.42,23.74,85.76,23.79,85.99,25.49,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,0.0
4,85.410004,23.83,85.99,25.49,90.42,23.74,85.76,23.79,86.42,22.04,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,0.0
5,85.410004,23.83,90.49,26.04,90.42,23.74,85.76,23.79,86.42,22.04,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,0.0
6,85.410004,23.83,86.47,26.96,90.42,23.74,85.76,23.79,86.42,22.04,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,0.0
7,85.410004,23.83,85.92,35.18,90.42,23.74,85.76,23.79,86.42,22.04,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,0.0
8,85.410004,23.83,87.97,31.92,90.42,23.74,85.76,23.79,86.42,22.04,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,0.0
9,85.410004,23.83,86.88,9.01,90.42,23.74,85.76,23.79,86.42,22.04,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,0.0


# Infering

In [9]:
def sample_column(row):
    return np.random.choice(row.index, p=row.values)

In [10]:
model = pickle.load(open('model/rf_basics_small/model.pkl', 'rb'))
discretizer_a = pickle.load(open('model/rf_basics_small/discretizer_a.pkl', 'rb'))
discretizer_b = pickle.load(open('model/rf_basics_small/discretizer_b.pkl', 'rb'))

In [11]:
frame_columns = ['ball_x', 'ball_y'] + sum([[f'p{i}_x', f'p{i}_y'] for i in range(1, 23)], [])
frames = pd.DataFrame([X[:, 2:4].flatten()], columns=frame_columns)
for _ in tqdm(range(data.frameId.nunique() - 1)):
    pred = model.predict_proba(X)
    pred_dx = pd.DataFrame(pred[0])
    pred_dy = pd.DataFrame(pred[1])

    pred_dx['sampled_column'] = pred_dx.apply(sample_column, axis=1)
    pred_dy['sampled_column'] = pred_dy.apply(sample_column, axis=1)
    pred_dx['est_dx'] = discretizer_a.inverse_transform(pred_dx['sampled_column'].values.reshape(-1, 1))
    pred_dy['est_dy'] = discretizer_b.inverse_transform(pred_dy['sampled_column'].values.reshape(-1, 1))
    est_movement = pd.concat([pred_dx['est_dx'], pred_dy['est_dy']], axis=1)

    update_matrix = np.zeros(X.shape)
    update_matrix[:, 0] = np.ones(X.shape[0]) * est_movement.loc[0]['est_dx']
    update_matrix[:, 1] = np.ones(X.shape[0]) * est_movement.loc[0]['est_dy']

    update_matrix[:, 2] = est_movement['est_dx']
    update_matrix[:, 3] = est_movement['est_dy']

    for i in range(1, X.shape[0]-1):
        update_matrix[:, 2+2*i] = est_movement.loc[i]['est_dx']
        update_matrix[:, 3+2*i] = est_movement.loc[i]['est_dy']
        update_matrix[:i, 2+2*i] = est_movement.loc[i+1]['est_dx']
        update_matrix[:i, 3+2*i] = est_movement.loc[i+1]['est_dy']

    X = X + update_matrix
    frames = pd.concat([frames, pd.DataFrame([X[:, 2:4].flatten()], columns=frame_columns)])
frames.reset_index(drop=True, inplace=True)
frames['frameId'] = frames.index + 2

100%|██████████| 58/58 [00:01<00:00, 36.47it/s]


In [12]:
frames.head(3)

Unnamed: 0,ball_x,ball_y,p1_x,p1_y,p2_x,p2_y,p3_x,p3_y,p4_x,p4_y,...,p18_y,p19_x,p19_y,p20_x,p20_y,p21_x,p21_y,p22_x,p22_y,frameId
0,90.42,23.74,90.42,23.74,85.76,23.79,86.42,22.04,85.99,25.49,...,21.47,81.03,22.91,84.32,9.28,84.45,29.0,71.53,32.61,2
1,90.697228,23.660792,90.42,24.254851,85.601584,23.829604,86.657624,22.04,85.831584,25.173168,...,21.66802,80.990396,23.226832,84.082376,9.002772,84.45,28.920792,71.569604,32.253564,3
2,90.261584,23.779604,90.42,24.056832,85.641188,23.829604,86.61802,22.277624,85.831584,25.29198,...,21.311584,80.911188,23.147624,84.32,9.08198,84.410396,28.762376,71.648812,32.055545,4


In [13]:
data.head(3)

Unnamed: 0,gameId,playId,frameId,p0_x-1,p1_x-1,p2_x-1,p3_x-1,p4_x-1,p5_x-1,p6_x-1,...,p13_y-1,p14_y-1,p15_y-1,p16_y-1,p17_y-1,p18_y-1,p19_y-1,p20_y-1,p21_y-1,p22_y-1
1,2022100600,90,2,85.410004,90.42,85.76,86.42,85.99,90.49,86.47,...,33.01,36.27,25.1,17.73,26.41,21.47,22.91,9.28,29.0,32.61
2,2022100600,90,3,85.410004,90.39,85.76,86.42,86.0,90.48,86.48,...,32.57,36.27,25.1,17.72,26.39,21.46,22.89,9.28,29.0,32.58
3,2022100600,90,4,85.410004,90.36,85.76,86.42,86.0,90.49,86.47,...,32.14,36.26,25.1,17.72,26.38,21.45,22.87,9.28,29.0,32.54


# Animating

In [14]:
def create_football_field(ax):
    ax.set_facecolor('green')
   
    # Add yard lines
    for yard in range(10, 110, 10):
        ax.axvline(yard, color='white', linestyle='-', linewidth=2)
        if yard == 50:
            ax.text(yard, 5, str(yard), color='white', ha='center')
        elif yard < 50:
            ax.text(yard, 5, str(yard), color='white', ha='center')
            ax.text(120-yard, 5, str(yard), color='white', ha='center')
   
    # Add end zones
    ax.axvspan(0, 10, facecolor='blue', alpha=0.3)
    ax.axvspan(110, 120, facecolor='red', alpha=0.3)

def update_play(frame, play_data, home_scatter, away_scatter, ball_scatter, frame_text, ax, full_field_view_mode):
    frame_data = play_data.iloc[frame]
   
    # Update home team positions
    home_x = [frame_data[f'p{i}_x'] for i in range(1, 12)]
    home_y = [frame_data[f'p{i}_y'] for i in range(1, 12)]
    home_scatter.set_offsets(np.column_stack((home_x, home_y)))

    # Update away team positions
    away_x = [frame_data[f'p{i}_x'] for i in range(12, 23)]
    away_y = [frame_data[f'p{i}_y'] for i in range(12, 23)]
    away_scatter.set_offsets(np.column_stack((away_x, away_y)))

    # Update ball position
    ball_scatter.set_offsets([[frame_data['ball_x'], frame_data['ball_y']]])

    frame_text.set_text(f"Frame: {frame}")

    if full_field_view_mode:
        # Setting field limits
        padding = 5
        ax.set_xlim(-padding, 120+padding)
        ax.set_ylim(-padding, 53.3+padding)
    else:
        # Adjust the field of view
        all_x = home_x + away_x + [frame_data['ball_x']]
        all_y = home_y + away_y + [frame_data['ball_y']]
        x_min, x_max = min(all_x), max(all_x)
        y_min, y_max = min(all_y), max(all_y)
       
        # Add some padding
        padding = 10
        ax.set_xlim(max(0, x_min - padding), min(120, x_max + padding))
        ax.set_ylim(max(0, y_min - padding), min(53.3, y_max + padding))
   
    # Clear previous annotations
    for artist in ax.artists + ax.texts:
        if artist != frame_text:
            artist.remove()
   
    # Add jersey numbers
    for i in range(1, 23):
        x = frame_data[f'p{i}_x']
        y = frame_data[f'p{i}_y']
        ax.annotate(str(i), (x, y), xytext=(0, 5), textcoords='offset points', ha='center', fontsize=8)

    return [home_scatter, away_scatter, ball_scatter, frame_text]

def create_play_animation(play_data, output_file, full_field_view_mode=False):
    fig, ax = plt.subplots(figsize=(12, 6))
    create_football_field(ax)
   
    # Create initial scatter plots for home team, away team, and the ball
    home_scatter = ax.scatter([], [], s=100, color='blue', label='Home Team')
    away_scatter = ax.scatter([], [], s=100, color='red', label='Away Team')
    ball_scatter = ax.scatter([], [], color='brown', s=50, label='Ball')
   
    # Add frame text
    frame_text = ax.text(0.02, 0.95, '', fontsize=10, transform=ax.transAxes)
   
    # Create the animation
    anim = animation.FuncAnimation(
        fig, update_play, frames=len(play_data),
        fargs=(play_data, home_scatter, away_scatter, ball_scatter, frame_text, ax, full_field_view_mode),
        interval=100, blit=False
    )
   
    # Save the animation as a GIF
    anim.save(output_file, writer='pillow', fps=10)
    plt.close(fig)

In [15]:
create_play_animation(frames, 'football_animation.gif', full_field_view_mode=False)