In [15]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import imageio
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import plotly.graph_objects as go

from data_loaders import get_tracking_and_plays
from model import PREPROCESS_STD, PREPROCESS_MEAN, load_expected_yards_model

## Model Predictions

In [16]:
test_x_arrays = []
test_y_arrays = []
test_r_arrays = []
test_pos_r_arrays = []
for i in range(4,5):
    print(i)
    cur_x = np.load(f'week_{i}_x_new2.npy')
    cur_y = np.load(f'week_{i}_y_new2.npy')
    cur_r = np.load(f'week_{i}_r_new2.npy')
    cur_pos_r = np.load(f'week_{i}_pos_r_new2.npy')
    test_x_arrays.append(cur_x)
    test_y_arrays.append(cur_y)
    test_r_arrays.append(cur_r)
    test_pos_r_arrays.append(cur_pos_r)
test_x_tensor = np.concatenate(test_x_arrays)
test_y_tensor = np.concatenate(test_y_arrays)
test_r_tensor = np.concatenate(test_r_arrays)
test_pos_r_tensor = np.concatenate(test_pos_r_arrays)

4


In [17]:
test_x = torch.tensor(test_x_tensor, dtype=torch.float)
test_pos_r = torch.tensor(test_pos_r_arrays, dtype=torch.double)

# Normalize data according to training set mean
test_x = (test_x - PREPROCESS_MEAN) / PREPROCESS_STD

In [18]:
test_dataset = TensorDataset(test_x, test_pos_r.squeeze())

test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [19]:
model = load_expected_yards_model()
model.eval()
model.return_attention = True

# all_outputs = np.zeros((len(test_x), 4))
game_ids = np.zeros((len(test_x)))
play_ids = np.zeros((len(test_x)))
frame_ids = np.zeros((len(test_x)))
yard_preds = np.zeros((len(test_x)))
attention_maps = np.zeros((len(test_x), 11, 10))
attention_ref_def = np.zeros((len(test_x), 11))
attention_ref_off = np.zeros((len(test_x), 10))
cur = 0
with torch.no_grad():
    for cur_x, cur_r in tqdm(test_loader):
        cur_outputs = model(cur_x)
        for i in range(len(cur_outputs[0])):
            game_ids[cur] = cur_r[i][0][0][2]
            play_ids[cur] = cur_r[i][0][0][3]
            frame_ids[cur] = cur_r[i][0][0][4]
            attention_maps[cur] = cur_outputs[3][i] * cur_outputs[2][i] * cur_outputs[1][i]
            attention_ref_def[cur] = cur_r[i][:,0,0]
            attention_ref_off[cur] = cur_r[i][0,:,1]
            yard_preds[cur] = cur_outputs[0][i]
            cur+=1

100%|██████████| 2317/2317 [00:05<00:00, 399.58it/s]


In [20]:
preds = pd.DataFrame({'gameId': game_ids, 'playId': play_ids, 'frameId': frame_ids, 'ExpectedYards': yard_preds, 'AttentionMap': attention_maps.tolist(), 'AttentionRefDef': attention_ref_def.tolist(), 'AttentionRefOff': attention_ref_off.tolist()})

# preds['gameId'] = preds['gameId'].astype(int)
# preds['playId'] = preds['playId'].astype(int)
# preds['frameId'] = preds['frameId'].astype(int)

In [21]:
preds[preds['gameId'] == 2022092900][preds['playId'] == 57]['frameId'].unique()

  preds[preds['gameId'] == 2022092900][preds['playId'] == 57]['frameId'].unique()


array([16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28.,
       29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41.,
       42., 43., 44., 45.])

In [22]:
df = get_tracking_and_plays('tracking_week_4.csv')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  bc_coords['bc_x']=bc_coords['x']
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  bc_coords['bc_y']=bc_coords['y']


In [23]:
# Left join preds to df
df = df.merge(preds, how='left', on=['gameId', 'playId', 'frameId'])

In [24]:
df.head()

Unnamed: 0,gameId,playId,nflId,displayName,frameId,time,jerseyNumber,club,playDirection,x,...,is_on_defence,is_ballcarrier,bc_x,bc_y,frameId_start,frameId_end,ExpectedYards,AttentionMap,AttentionRefDef,AttentionRefOff
0,2022092900,57,42654.0,La'el Collins,16,2022-09-29 20:16:01.599999,71.0,CIN,left,34.13,...,False,False,30.36,24.303333,16.0,45.0,6.254379,"[[0.10176374763250351, 0.03634851053357124, 0....","[43321.0, 43327.0, 43503.0, 46142.0, 46312.0, ...","[42654.0, 43344.0, 43510.0, 46094.0, 46163.0, ..."
1,2022092900,57,42654.0,La'el Collins,17,2022-09-29 20:16:01.700000,71.0,CIN,left,34.15,...,False,False,30.58,24.463333,16.0,45.0,6.165869,"[[0.09672700613737106, 0.03968067094683647, 0....","[43321.0, 43327.0, 43503.0, 46142.0, 46312.0, ...","[42654.0, 43344.0, 43510.0, 46094.0, 46163.0, ..."
2,2022092900,57,42654.0,La'el Collins,18,2022-09-29 20:16:01.799999,71.0,CIN,left,34.16,...,False,False,30.84,24.583333,16.0,45.0,5.951019,"[[0.0908566415309906, 0.042631473392248154, 0....","[43321.0, 43327.0, 43503.0, 46142.0, 46312.0, ...","[42654.0, 43344.0, 43510.0, 46094.0, 46163.0, ..."
3,2022092900,57,42654.0,La'el Collins,19,2022-09-29 20:16:01.900000,71.0,CIN,left,34.17,...,False,False,31.14,24.643333,16.0,45.0,5.849377,"[[0.08307496458292007, 0.043974678963422775, 0...","[43321.0, 43327.0, 43503.0, 46142.0, 46312.0, ...","[42654.0, 43344.0, 43510.0, 46094.0, 46163.0, ..."
4,2022092900,57,42654.0,La'el Collins,20,2022-09-29 20:16:02.000000,71.0,CIN,left,34.19,...,False,False,31.48,24.673333,16.0,45.0,5.719326,"[[0.079816535115242, 0.04692747816443443, 0.05...","[43321.0, 43327.0, 43503.0, 46142.0, 46312.0, ...","[42654.0, 43344.0, 43510.0, 46094.0, 46163.0, ..."


## Animation and Attention Processing

In [25]:
team_colors = {
    'ARI':"#97233F",
    'ATL':"#A71930",
    'BAL':'#241773',
    'BUF':"#00338D",
    'CAR':"#0085CA",
    'CHI':"#C83803",    
    'CIN':"#FB4F14",
    'CLE':"#311D00",
    'DAL':'#003594',
    'DEN':"#FB4F14",
    'DET':"#0076B6",
    'GB':"#203731",
    'HOU':"#03202F",
    'IND':"#002C5F",
    'JAX':"#9F792C",
    'KC':"#E31837",
    'LA':"#0373fc",
    'LAC':"#007FC8",
    'LV':"#000000",
    'MIA':"#008E97",
    'MIN':"#4F2683",
    'NE':"#002244",
    'NO':"#D3BC8D",
    'NYG':"#0B2265",
    'NYJ':"#125740",
    'PHI':"#004C54",
    'PIT':"#FFB612",
    'SEA':"#69BE28",
    'SF':"#ff3643",
    'TB':'#D50A0A',
    'TEN':"#4B92DB",
    'WAS':"#5A1414"
    }

def animate_play(df, gameId, playId):
    games = pd.read_csv('games.csv')
    game = games[games['gameId'] == gameId].iloc[0]
    home_team = game['homeTeamAbbr']
    visitor_team = game['visitorTeamAbbr']

    def create_football_field():
      rect = patches.Rectangle((0, 0), 53.3, 120, linewidth=0.1,
                              facecolor='white', zorder=0)
      fig, ax = plt.subplots(figsize=(6, 14))
      ax.add_patch(rect)
      plt.plot([0, 0, 53.3, 53.3, 0, 0, 53.3, 53.3, 0, 0, 53.3, 53.3, 0, 0, 53.3,
                    53.3, 0, 0, 53.3, 53.3, 0, 0, 53.3, 53.3, 53.3, 0, 0, 53.3],
                [10, 10, 10, 20, 20, 30, 30, 40, 40, 50, 50, 60, 60, 70, 70, 80,
                    80, 90, 90, 100, 100, 110, 110, 120, 0, 0, 120, 120],
                  color='black')
      for y in range(20, 110, 10):
        numb = y
        if y > 50:
            numb = 120 - y
        plt.text(5, y-1.5, str(numb - 10),
                  horizontalalignment='center',
                  fontsize=20,  # fontname='Arial',
                  color='black', rotation=270)
        plt.text(53.3 - 5, y - 0.95, str(numb - 10),
                  horizontalalignment='center',
                  fontsize=20,  # fontname='Arial',
                  color='black', rotation=90)
      for y in range(11,110):
        ax.plot([0.7, 0.4],[y, y], color='black')
        ax.plot([53.0, 52.5],[y, y], color='black')
        ax.plot([22.91, 23.57],[y, y], color='black')
        ax.plot([29.73, 30.39],[y, y],  color='black')
        homeEndzone = patches.Rectangle((0, 0), 53.3, 10,
                                        linewidth=0.1,
                                        edgecolor='black',
                                        facecolor=team_colors[home_team],
                                        zorder=0)
        awayEndzone = patches.Rectangle((0, 110), 53.3, 10,
                                        linewidth=0.1,
                                        edgecolor='black',
                                        facecolor=team_colors[visitor_team],  # Use the team color
                                        zorder=0)
        ax.text(53.3 / 2, 2, home_team, horizontalalignment='center',
            fontsize=20, color='white', weight='bold', zorder=2)
        ax.text(53.3 / 2, 113, visitor_team, horizontalalignment='center',
            fontsize=20, color='white', weight='bold', zorder=2)
        ax.add_patch(homeEndzone)
        ax.add_patch(awayEndzone)
        ax.set_xticks([])
        ax.set_yticks([])
        return fig, ax

    def plot_frame(frame_num, play_df):
        fig, ax = create_football_field()

        # Create a grey background for text for better readability
        textbox_height = 10
        text_box = patches.Rectangle((0, -textbox_height), 53.3, textbox_height,
                                 linewidth=1, edgecolor='none', facecolor='black',
                                 zorder=2, alpha=0.8)
        ax.add_patch(text_box)
        current_frame_df = play_df[play_df['frameId'] == frame_num]
        
        # yards_remaining = round(current_frame_df['yards_remaining'].iloc[0], 1)
        # predicted_yards = round(current_frame_df['predicted_yards_remaining'].iloc[0], 1)

        # # Add text elecreate_football_field                fontsize=10, color='white', zorder=3)
        # text_predicted_yards = ax.text(1, -9, f'Predicted Yards: {predicted_yards}',
        #                               fontsize=10, color='white', zorder=3)

        # PROCESS ATTENTION MAP

        def_player_idx = 3
        def_player_focus_id = current_frame_df['AttentionRefDef'].iloc[0][def_player_idx]
        attention_map = current_frame_df['AttentionMap'].iloc[0][def_player_idx]
        off_attention_ref = current_frame_df['AttentionRefOff'].iloc[0]
        print(attention_map)
        # Plot players and football
        for index, row in current_frame_df.iterrows():
            if row['nflId'] in off_attention_ref:
                attention_index = off_attention_ref.index(row['nflId'])
                cur_player_attention = attention_map[attention_index]
            else:
                cur_player_attention = 1

            if row['nflId'] == def_player_focus_id:
                ax.scatter(row['y'], row['x'], color='blue', s=60, zorder=2)
            elif cur_player_attention < 0.065:
                ax.scatter(row['y'], row['x'], color='red', s=60, zorder=2)
            elif row['is_ballcarrier']:
                ax.scatter(row['y'], row['x'], color='black', s=60, zorder=2)
            elif row['club'] == home_team:
                #ax.scatter(row['y'], row['x'], color=team_colors[home_team], s=40, zorder=2)
                ax.scatter(row['y'], row['x'], color='yellow', s=60, zorder=2)
            elif row['club'] == visitor_team:
                #ax.scatter(row['y'], row['x'], color=team_colors[visitor_team], s=40, zorder=2)
                ax.scatter(row['y'], row['x'], color='grey', s=60, zorder=2)
            elif row['club'] == 'football':
                ax.scatter(row['y'], row['x'], color='limegreen', s=60, zorder=2)

        # Capture the current state of the figure as an image
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig)
        return image

    # Filter the DataFrame for the specified gameId and playId
    play_df = df[(df['gameId'] == gameId) & (df['playId'] == playId)]
    frame_nums = play_df['frameId'].unique()
    #play_description = play_df['playDescription'].iloc[0]

    # Generate and collect all frame images
    images = [plot_frame(frame_num, play_df) for frame_num in frame_nums]

    # Create the GIF
    gif_filename = f"gameId: {gameId}, playId: {playId}.gif"
    imageio.mimsave(gif_filename, images, fps=5)

In [26]:
animate_play(df, 2022100300	, 3097)

[0.11681576818227768, 0.0750303566455841, 0.05527348071336746, 0.0696493461728096, 0.10875052213668823, 0.14276368916034698, 0.06446758657693863, 0.06298879534006119, 0.05776660889387131, 0.07021621614694595]
[0.11858399212360382, 0.07189425081014633, 0.05238145962357521, 0.06768660992383957, 0.10199400782585144, 0.13782091438770294, 0.06395804136991501, 0.06113969907164574, 0.05668233707547188, 0.07500310987234116]
[0.12290668487548828, 0.07135001569986343, 0.04921581596136093, 0.06653289496898651, 0.09966599941253662, 0.13536910712718964, 0.06553894281387329, 0.05777255445718765, 0.05595209077000618, 0.07329864799976349]
[0.12741629779338837, 0.07225526869297028, 0.044196341186761856, 0.06631486117839813, 0.09946337342262268, 0.13668888807296753, 0.06715064495801926, 0.05451582744717598, 0.05633818730711937, 0.06663075089454651]
[0.1270589530467987, 0.07147213816642761, 0.042283665388822556, 0.06700728088617325, 0.09904389828443527, 0.1376551240682602, 0.06648463010787964, 0.05413307