In [1]:
import os
import sys
from os.path import join
from tqdm import tqdm

import pandas as pd
import numpy as np
import nfl_data_py as nfl

ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, os.path.join(ROOT_DIR,'py'))

import util

pd.set_option('display.max_rows',None)
pd.set_option('display.max_columns',None)

In [2]:
# Define the path to the data folder
WEEK = 7
DATA_DIR = "../data/"
WEEKS = range(WEEK, WEEK+1)

df_game = pd.read_csv(join(DATA_DIR, "games.csv"))
df_play = pd.read_csv(join(DATA_DIR, "plays.csv"))
df_player_play = pd.read_csv(join(DATA_DIR, "player_play.csv"))
df_player = pd.read_csv(join(DATA_DIR, "players.csv"))

tracking_dfs = []
for wk in tqdm(WEEKS, desc="Loading tracking files"):
    df = pd.read_csv(join(DATA_DIR, f'tracking_week_{wk}.csv'))
    if 'week' not in df.columns:
        df.insert(3,'week',wk)
    tracking_dfs.append(df)
    
df_tracking = pd.concat(tracking_dfs, axis=0)

del tracking_dfs

util.uncamelcase_columns(df_game)
util.uncamelcase_columns(df_player)
util.uncamelcase_columns(df_play)
util.uncamelcase_columns(df_player_play)
util.uncamelcase_columns(df_tracking)

# standardize direction to be offense moving right
df_tracking, df_play = util.standardize_direction(df_tracking, df_play)

df_game = df_game.query('week==@WEEK').reset_index(drop=True)
game_ids = df_game['game_id'].unique().tolist()
df_player_play = df_player_play.query('game_id in @game_ids').reset_index(drop=True)

df_tracking = df_tracking.merge(df_player[['nfl_id','position']], on='nfl_id', how='left')

df_teams = nfl.import_team_desc()

Loading tracking files: 100%|██████████| 1/1 [00:06<00:00,  6.27s/it]


In [3]:
team_cols = ['team_abbr', 'team_color','team_color2','team_logo_wikipedia', 'team_wordmark']

if 'possession_team_color' not in df_play.columns:
    df_play = df_play.merge(
        right=df_teams[team_cols].rename(columns={
            'team_abbr':'possession_team',
            'team_color':'possession_team_color',
            'team_color2':'possession_team_color2',
            'team_logo_wikipedia':'possession_team_logo',
            'team_wordmark':'possession_team_wordmark'
        }),
        how='left',
        on='possession_team'
    )

if 'defensive_team_color' not in df_play.columns:
    df_play = df_play.merge(
        right=df_teams[team_cols].rename(columns={
            'team_abbr':'defensive_team',
            'team_color':'defensive_team_color',
            'team_color2':'defensive_team_color2',
            'team_logo_wikipedia':'defensive_team_logo',
            'team_wordmark':'defensive_team_wordmark',
        }),
        how='left',
        on='defensive_team'
    )

if 'home_team_abbr' not in df_play.columns:
    df_play = df_play.merge(
        right=df_game[['game_id','home_team_abbr','visitor_team_abbr']],
        how='left',
        on='game_id'
    ).rename(columns={
        'visitor_team_abbr':'away_team_abbr'
    })
    

if 'home_team_wordmark' not in df_play.columns:
    df_play['home_team_wordmark'] = np.where(
        df_play.home_team_abbr == df_play.possession_team, 
        df_play.possession_team_wordmark, 
        df_play.defensive_team_wordmark
    )

if 'home_team_logo' not in df_play.columns:
    df_play['home_team_logo'] = np.where(
        df_play.home_team_abbr == df_play.possession_team, 
        df_play.possession_team_logo, 
        df_play.defensive_team_logo
    )
    df_play['away_team_logo'] = np.where(
        df_play.home_team_abbr == df_play.possession_team, 
        df_play.defensive_team_logo,
        df_play.possession_team_logo
    )

if 'home_team_color' not in df_play.columns:
    df_play['home_team_color'] = np.where(
        df_play.home_team_abbr == df_play.possession_team, 
        df_play.possession_team_color, 
        df_play.defensive_team_color
    )
    df_play['away_team_color'] = np.where(
        df_play.home_team_abbr == df_play.possession_team, 
        df_play.defensive_team_color,
        df_play.possession_team_color
    )

if 'down_and_dist' not in df_play.columns:
    down_map = {
        1:'1st',
        2:'2nd',
        3:'3rd',
        4:'4th'
    }
    df_play['down_and_dist'] = df_play['down'].map(down_map) + ' & ' + df_play['yards_to_go'].astype(str)

if 'quarter_with_suffix' not in df_play.columns:
    quarter_map = {
        1:'1st',
        2:'2nd',
        3:'3rd',
        4:'4th',
    }
    df_play['quarter_with_suffix'] = df_play['quarter'].map(quarter_map)


In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Rectangle, Polygon
from matplotlib.font_manager import FontProperties
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.patches import Wedge
from matplotlib.colors import to_rgba
from IPython.display import HTML
import urllib
import PIL

#TODO: put this in different python file
def hex_to_rgb(hex_color: str) -> tuple:
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

def luminance(rgb: tuple) -> float:
    r, g, b = [x / 255.0 for x in rgb]  # Normalize to [0, 1]
    r = (r / 12.92) if r <= 0.03928 else ((r + 0.055) / 1.055) ** 2.4
    g = (g / 12.92) if g <= 0.03928 else ((g + 0.055) / 1.055) ** 2.4
    b = (b / 12.92) if b <= 0.03928 else ((b + 0.055) / 1.055) ** 2.4
    return 0.2126 * r + 0.7152 * g + 0.0722 * b

def contrast_ratio(hex_color1: str, hex_color2: str) -> float:
    rgb1 = hex_to_rgb(hex_color1)
    rgb2 = hex_to_rgb(hex_color2)
    
    lum1 = luminance(rgb1)
    lum2 = luminance(rgb2)
    
    lighter = max(lum1, lum2)
    darker = min(lum1, lum2)
    
    return (lighter + 0.05) / (darker + 0.05)

def plot_image(ax: mpl.axes.Axes, x:float, y:float, imagebox:OffsetImage, ord:int) -> mpl.axes.Axes:
    """Helper function to add team logo to the plot."""
    ab = AnnotationBbox(imagebox, (x, y), frameon=False, zorder=ord)
    ax.add_artist(ab)
    return ax

class Scoreboard:
    def __init__(self, ax, play_data, tracking_data, game_clocks, play_clocks, home_img, away_img):
        self.ax = ax
        self.play_data = play_data
        self.tracking_data = tracking_data
        self.game_clocks = game_clocks
        self.play_clocks = play_clocks
        self.home_img = home_img
        self.away_img = away_img
        self.touchdown_frameid = None
        self.scoreboard_height = 5
        self.y_limit_min = 0
        self.x_limit_max = 100

    def update_scores(self, frame_id):
        if self.touchdown_frameid and frame_id and frame_id > self.touchdown_frameid:
            if self.y_limit_min < 10:
                if self.play_data['possession_team'] == self.play_data['home_team_abbr']:
                    self.play_data['pre_snap_visitor_score'] += 6
                else:
                    self.play_data['pre_snap_home_score'] += 6
            else:
                if self.play_data['possession_team'] == self.play_data['home_team_abbr']:
                    self.play_data['pre_snap_home_score'] += 6
                else:
                    self.play_data['pre_snap_visitor_score'] += 6
            self.touchdown_frameid = None

    def draw_rectangle(self, x_start, width, color, zorder):
        rect = Rectangle(
            (x_start, self.y_limit_min),
            width,
            self.scoreboard_height,
            color=color,
            zorder=zorder
        )
        self.ax.add_patch(rect)

    def add_text(self, x_position, text, fontsize=20, zorder=7, color='white'):
        txt_height = self.y_limit_min + self.scoreboard_height / 2
        self.ax.text(
            x_position,
            txt_height,
            text,
            ha='center', va='center',
            fontsize=fontsize,
            fontweight='bold',
            color=color,
            zorder=zorder
        )

    def plot_scoreboard(self, frame_id=None):
        x_interval = self.x_limit_max / 4
        if frame_id is None:
            frame_id = self.tracking_data.frame_id.min()

        self.update_scores(frame_id)

        # Draw background rectangles for scoreboard sections
        self.draw_rectangle(0, x_interval, self.play_data['away_team_color'], zorder=6)
        self.draw_rectangle(x_interval, x_interval * 2, self.play_data['home_team_color'], zorder=6)
        self.draw_rectangle(x_interval * 2, x_interval, '#1a1817', zorder=6)

        play_clock_color = 'red' if self.play_clocks[frame_id] <= 5 else 'grey'
        self.draw_rectangle(x_interval * 3 - 4, x_interval * 3, play_clock_color, zorder=6)
        self.draw_rectangle(x_interval * 3, self.x_limit_max, self.play_data['possession_team_color'], zorder=6)

        # Add text for away team score, home team score, time, play clock, and down and distance
        self.add_text(x_interval / 2 + 2.5, f'{self.play_data["away_team_abbr"]} {self.play_data["pre_snap_visitor_score"]}')
        self.add_text(x_interval * 1.5 + 2.5, f'{self.play_data["home_team_abbr"]} {self.play_data["pre_snap_home_score"]}', zorder=7)
        self.add_text(x_interval * 2 + (x_interval / 2 - 2), f'{self.play_data["quarter_with_suffix"]} {self.game_clocks[frame_id]}')
        self.add_text(x_interval * 3 - 2, f'{self.play_clocks[frame_id]:02}')
        self.add_text(x_interval * 3.5, self.play_data['down_and_dist'])

        # Add logos next to scores
        self.add_logo(3, self.y_limit_min + self.scoreboard_height / 2 - .3, self.away_img, zorder=7)
        self.add_logo(x_interval + 3, self.y_limit_min + self.scoreboard_height / 2 - .3, self.home_img, zorder=7)

    def add_logo(self, x, y, img, zorder=7):
        # Placeholder function to add team logos at specified positions
        pass

class NFLPlayAnimation:
    def __init__(self, tracking_data, play_data, show_scoreboard=True, clock_rolling=True, player_display_type='dots-team', show_player_legend=False):
        self.tracking_data = tracking_data
        self.play_data = play_data
        self.show_scoreboard = show_scoreboard
        self.clock_rolling = clock_rolling
        self.player_display_type = player_display_type # options: 'dots-team', 'dots-positional' 'positions', 'jerseys
        self.show_player_legend = show_player_legend

        if player_display_type not in ['dots-team', 'dots-positional', 'positions', 'jersey_numbers']:
            raise ValueError("Invalid player_display_type. Must be one of 'dots-team', 'dots-potitional', 'positions', or 'jersey_numbers'.")

        if self.show_player_legend:
            self.fig, self.ax = plt.subplots(figsize=(14, 8))
        else:
            self.fig, self.ax = plt.subplots(figsize=(12, 8))
        self.aspect_ratio = 1
        self.y_delta = 35
        self.x_limit_min = 0
        self.x_limit_max = 53.3
        self.numbers_font = FontProperties(fname='../data/fonts/clarendon_bold.otf')
        self.scoreboard_height = 3
        self.legend_width = 12
        self.legend_txt_color = 'black'
        img = PIL.Image.open(urllib.request.urlopen(play_data['home_team_logo']))
        w,h = img.size
        w_new = min(int(w * (70/h)), 75)
        self.home_img = OffsetImage(img.resize((w_new,70)).crop((0,10,w_new,70)))
        img = PIL.Image.open(urllib.request.urlopen(play_data['away_team_logo']))
        w,h = img.size
        w_new = min(int(w * (70/h)), 75)
        self.away_img = OffsetImage(img.resize((w_new,70)).crop((0,10,w_new,70)))
        self.home_wordmark = OffsetImage(PIL.Image.open(urllib.request.urlopen(play_data['home_team_wordmark'])), zoom=1)
        self.home_wordmark_rotated= OffsetImage(PIL.Image.open(urllib.request.urlopen(play_data['home_team_wordmark'])).rotate(180), zoom=1)
        self.poss_tm_color = play_data['possession_team_color']
        self.poss_tm_edge_color = play_data['possession_team_color2']
        cr_1 = contrast_ratio(play_data['possession_team_color'], play_data['defensive_team_color'])
        cr_2 = contrast_ratio(play_data['possession_team_color'], play_data['defensive_team_color2'])
        self.def_tm_color = play_data['defensive_team_color'] if cr_1 > cr_2  else play_data['defensive_team_color2']
        self.def_tm_edge_color = play_data['defensive_team_color2'] if cr_1 > cr_2  else play_data['defensive_team_color']

        self.position_colors = {
            'QB': to_rgba('blue'), 
            'T': to_rgba('purple'), 
            'TE': to_rgba('green'),
            'WR': to_rgba('red'),  
            'DE': to_rgba('pink'), 
            'NT': to_rgba('darkred'), 
            'SS': to_rgba('darkseagreen'),  
            'FS': to_rgba('palegoldenrod'),  # Gray
            'G': to_rgba('orange'),   # Olive
            'OLB': to_rgba('darkcyan'), # Teal
            'DT': to_rgba('slategrey'),  # Light Teal
            'CB': to_rgba('orchid'),  # Light Orange
            'RB': to_rgba('dodgerblue'),  # Light Blue
            'C': to_rgba('indigo'),   # Peach
            'ILB': to_rgba('lime'), # Yellow
            'MLB': to_rgba('gold'),  # Gold
            'FB': to_rgba('darkviolet'),  # Dark Blue
            'DB': to_rgba('darkorange'),  # Dark Orange
            'LB': to_rgba('darkgreen'),  # Dark Green
        }

        touchdown_frameid = tracking_data.query('event=="touchdown"')['frame_id']
        if touchdown_frameid.shape[0] > 0:
            self.touchdown_frameid = touchdown_frameid.values[0]
        else:
            self.touchdown_frameid = None

        if self.show_scoreboard:
            self.play_clocks = dict()
            snap_frame_id = self.tracking_data.query('event=="ball_snap"')['frame_id'].values[0]
            for fid in self.tracking_data.frame_id.unique():
                if fid <= snap_frame_id:
                    self.play_clocks[fid] = int(self.play_data['play_clock_at_snap'] + (snap_frame_id - fid) / 10 - .1)
                else:
                    self.play_clocks[fid] = 40

            # create game clock (in seconds) using self.play_data['game_clock']
            self.game_clocks = dict()
            game_clock_min = int(self.play_data['game_clock'].split(':')[0])
            game_clock_sec = int(self.play_data['game_clock'].split(':')[1])
            if clock_rolling:
                game_clock_sec = game_clock_sec + ((snap_frame_id % 10) / 10) + (snap_frame_id // 10)
                if game_clock_sec >= 60:
                    game_clock_min += game_clock_sec // 60
                    game_clock_sec = game_clock_sec % 60
            for fid in self.tracking_data.frame_id.unique():
                if fid < snap_frame_id:
                    if self.clock_rolling:
                        if game_clock_sec < 0:
                            if game_clock_min == 0:
                                game_clock_sec = 0
                                game_clock_min = 0
                            else:
                                game_clock_min -= 1
                                game_clock_sec = 59.9
                        else:
                            game_clock_sec -= 0.1
                        self.game_clocks[fid] = f'{int(game_clock_min):02}:{int(game_clock_sec):02}'
                    else:
                        self.game_clocks[fid] = self.play_data['game_clock']
                else:
                    if game_clock_sec < 0:
                        if game_clock_min == 0:
                            game_clock_sec = 0
                            game_clock_min = 0
                        else:
                            game_clock_min -= 1
                            game_clock_sec = 59.9
                    else:
                        game_clock_sec -= 0.1
                    self.game_clocks[fid] = f'{int(game_clock_min):02}:{int(game_clock_sec):02}'

    def plot_field(self):
        """Plot the NFL field layout with square aspect ratio."""
        # Hard limits for the x-axis (do not exceed the field width)
        if self.show_player_legend:
            self.ax.set_xlim(self.x_limit_min, self.x_limit_max + self.legend_width )
        else:
            self.ax.set_xlim(self.x_limit_min, self.x_limit_max)
        # Hard limits for the y-axis, updated dynamically later
        self.ax.set_ylim(self.y_limit_min, self.y_limit_min + self.y_delta)
        # self.ax.set_aspect(self.aspect_ratio)

        # Set y-axis ticks every 5 yards, excluding end zones
        yticks = [i for i in range(0, 121, 5) if i not in [5, 115]]
        self.ax.set_yticks(yticks)
        
        # Remove x-axis ticks
        self.ax.set_xticks([self.x_limit_min, self.x_limit_max])

        self.ax.grid(True, which='major', axis='y', color='white', linewidth=2)
        self.ax.grid(True, which='major', axis='x', color='white', linewidth=2)
        
        # Set background to light gray
        self.ax.set_facecolor('lightgray')
        
        # Remove plot spines (borders)
        for spine in self.ax.spines.values():
            spine.set_visible(False)

        # Set tick parameters and hide tick labels
        self.ax.tick_params(left=False, right=False, top=False, bottom=False, labelleft=False, labelbottom=False)

        # Draw vertical white lines on the edges of the field
        # self.ax.axvline(x=self.x_limit_min, color='white', linewidth=2)
        # self.ax.axvline(x=self.x_limit_max, color='white', linewidth=2)

        # Draw horizontal lines at yard lines 0 and 120 (end zones)
        # self.ax.axhline(y=0, color='white', linewidth=2)
        # self.ax.axhline(y=120, color='white', linewidth=2)

        # Draw line of scimmage
        self.ax.axhline(y=self.play_data['absolute_yardline_number'], color='blue', linewidth=2)

        # Draw first down line
        self.ax.axhline(y=self.play_data['absolute_yardline_number'] + self.play_data['yards_to_go'], color='yellow', linewidth=2)

        # Add yard markers
        for y in range(11, 110, 1):
            if y % 5 != 0:
                centerfield = self.x_limit_max / 2
                left_outer = Rectangle((1/2, y - 0.05), 2/3, 0.04, color='white')
                left_inner = Rectangle((centerfield - (37/12 + 1/3), y - 0.05), 2/3, 0.04, color='white')
                right_inner = Rectangle((centerfield + (37/12 - 1/3), y - 0.05), 2/3, 0.04, color='white')
                right_outer = Rectangle((self.x_limit_max - 7/6, y - 0.05), 2/3, 0.04, color='white')

            for hash_mark in [left_outer, left_inner, right_inner, right_outer]:
                self.ax.add_patch(hash_mark)

        # Add yardline numbers
        yardline_labels = {20: "1 0", 30: "2 0", 40: "3 0", 50: "4 0", 60: "5 0", 70: "4 0", 80: "3 0", 90: "2 0", 100: "1 0"}
        for y, label in yardline_labels.items():
            # Add yardline numbers on the left side
            self.ax.text(
                12, y, 
                label, 
                ha='center', va='center', 
                fontsize=30, 
                color='white', 
                rotation=-90, 
                fontproperties=self.numbers_font
            )
            # Add yardline numbers on the right side
            self.ax.text(
                self.x_limit_max - 12, y,
                label, 
                ha='center', va='center', 
                fontsize=30, 
                color='white', 
                rotation=90, 
                fontproperties=self.numbers_font
            )

            if y > 60:
                # plot arrows gonig up
                left_triangle = Polygon([[12, y + 1.8], [12.2, y + 2.55], [12.4, y + 1.8]], color='white')
                right_triangle = Polygon([[self.x_limit_max - 12, y + 1.8], [self.x_limit_max - 12.2, y + 2.55], [self.x_limit_max - 12.4, y + 1.8]], color='white')
                self.ax.add_patch(left_triangle)
                self.ax.add_patch(right_triangle)
            elif y < 60:
                # plot arrows going down
                left_triangle = Polygon([[12, y - 1.8], [12.2, y - 2.55], [12.4, y - 1.8]], color='white')
                right_triangle = Polygon([[self.x_limit_max - 12, y - 1.8], [self.x_limit_max - 12.2, y - 2.55], [self.x_limit_max - 12.4, y - 1.8]], color='white')
                self.ax.add_patch(left_triangle)
                self.ax.add_patch(right_triangle)

        # if y_limit_min + y_delta > 110, plot home_team_wordmark image from url in endzone
        if self.y_limit_min + self.y_delta > 110:
            # plot darker endzone
            endzone = Rectangle((0, 110), self.x_limit_max, 10, color='#b8b8b8')
            self.ax.add_patch(endzone)
            self.add_logo(self.x_limit_max / 2, 115, self.home_wordmark, ord=7)

        if self.y_limit_min < 10:
            # plot darker endzone
            endzone = Rectangle((0, 0), self.x_limit_max, 10, color='#b8b8b8')
            self.ax.add_patch(endzone)
            self.add_logo(self.x_limit_max / 2, 5, self.home_wordmark_rotated, ord=7)
        
    def plot_scoreboard(self, frame_id=None):
        x_interval = self.x_limit_max / 4

        if self.touchdown_frameid is not None and frame_id is not None and frame_id > self.touchdown_frameid:
            if self.y_limit_min < 10:
                if self.play_data['possession_team'] == self.play_data['home_team_abbr']:
                    self.play_data['pre_snap_visitor_score'] += 6
                else:
                    self.play_data['pre_snap_home_score'] += 6
                
            else:
                if self.play_data['possession_team'] == self.play_data['home_team_abbr']:
                    self.play_data['pre_snap_home_score'] += 6
                else:
                    self.play_data['pre_snap_visitor_score'] += 6 
            self.touchdown_frameid = None
                   
        rect_team1 = Rectangle(
            (0, self.y_limit_min),
            x_interval,
            self.scoreboard_height,
            color=self.play_data['away_team_color'],
            zorder=6
        )
        rect_team2 = Rectangle(
            (x_interval, self.y_limit_min),
            x_interval * 2,
            self.scoreboard_height,    
            color=self.play_data['home_team_color'],
            zorder=6 
        )
        rect_time = Rectangle(
            (x_interval * 2, self.y_limit_min),
            x_interval * 3 - 4,
            self.scoreboard_height,
            color='#1a1817',
            zorder=6 
        )
        # if clock uner 5 seconds, change color to red
        pc_color = 'grey'
        if frame_id==None:
            frame_id = self.tracking_data.frame_id.min()
        if self.play_clocks[frame_id] <= 5:
            pc_color = 'red'
        rect_play_clock = Rectangle(
            (x_interval * 3 - 4, self.y_limit_min),
            x_interval * 3,
            self.scoreboard_height,
            color=pc_color,
            zorder=6 
        )
        rect_down_dist = Rectangle(
            (x_interval * 3, self.y_limit_min),
            self.x_limit_max,
            self.scoreboard_height,
            color=self.play_data['possession_team_color'],
            zorder=6 
        )
        
        for rect in [rect_team1, rect_team2, rect_time, rect_play_clock, rect_down_dist]:
            self.ax.add_patch(rect)

        txt_height = self.y_limit_min + self.scoreboard_height / 2
        # Plot the home and away team scores
        self.ax.text(
            x_interval / 2 + 2.5, txt_height,
            f'{self.play_data["away_team_abbr"]}    {self.play_data["pre_snap_visitor_score"]}',
            ha='center', va='center', 
            fontsize=20, 
            fontweight='bold',
            color='white',
            zorder=8,
        )

        self.ax.text(
            x_interval * 1.5 + 2.5, txt_height,
            f'{self.play_data["home_team_abbr"]}    {self.play_data["pre_snap_home_score"]}',
            ha='center', va='center', 
            fontsize=20, 
            fontweight='bold',
            color='white',
            zorder=7,
        )

        # Add time to r3
        self.ax.text(
            x_interval * 2 + (x_interval / 2 - 2), txt_height,
            f'{self.play_data["quarter_with_suffix"]}  {self.game_clocks[frame_id]}',
            ha='center', va='center', 
            fontsize=20, 
            fontweight='bold',
            color='white',
            zorder=7,
        )

        # add text to r4 (play clock)
        self.ax.text(
            x_interval * 3 - 2, txt_height,
            f'{self.play_clocks[frame_id]:02}',
            ha='center', va='center', 
            fontsize=20, 
            fontweight='bold',
            color='white',
            zorder=7,
        )
        
        # Add down and distance to r5
        self.ax.text(
            x_interval * 3.5, txt_height,
            self.play_data['down_and_dist'],
            ha='center', va='center', 
            fontsize=20, 
            fontweight='bold',
            color='white',
            zorder=7,
        )

        # Add team logos next to the scores
        self.add_logo(3, txt_height - .3, self.away_img, ord=7)
        self.add_logo(x_interval + 3, txt_height - .3, self.home_img, ord=7)

    def plot_player_legend(self):
        y = self.y_limit_min + self.y_delta
        if self.show_scoreboard: y -= self.scoreboard_height
        rect_heading = Rectangle(
            (self.x_limit_max, y),
            self.legend_width ,
            .15,
            color=self.legend_txt_color,
            zorder=7
        )
        self.ax.add_patch(rect_heading)

        h = self.y_delta if self.show_scoreboard else self.y_delta + self.scoreboard_height
        rect_body = Rectangle(
            (self.x_limit_max, self.y_limit_min),
            self.legend_width ,
            h,
            color='#f0eee9',
            zorder=6
        )
        self.ax.add_patch(rect_body)

        self.ax.text(
            self.x_limit_max + (self.legend_width  / 2), y + 1.6,
            'Player Legend',
            ha='center', va='center',
            fontsize=20,
            fontweight='bold',
            color=self.legend_txt_color,
            zorder=8,
            fontdict={'family':'Arial'}
        )
        
        # loop through tracking data, plot {jersey_number}: {display_name} for each unique player
        if self.player_display_type == 'jerseys': 
            identifier = 'jersey_number'
        elif self.player_display_type == 'positions': 
            identifier = 'position'
        else:
            raise ValueError("Invalid player_display_type. Must be one of 'jerseys' or 'positions'.")
        players = self.tracking_data.query('club!="football"').groupby(['club', 'nfl_id', 'jersey_number', 'position', 'display_name']).size().reset_index()
        players = players.sort_values(['club', 'jersey_number']).reset_index(drop=True)
        current_team = None
        for i, player in players.iterrows():
            if player['club'] != current_team:
                current_team = player['club']
                self.ax.text(
                    self.x_limit_max + .5, y - 1 - 1.3 * i,
                    f'{current_team}',
                    ha='left', va='center',
                    fontsize=14,
                    fontweight='bold',
                    color=self.legend_txt_color,
                    zorder=6,
                    fontdict={'family':'Arial'}
                )
                y -= 1.3  # Add extra space between teams
            if identifier == 'jersey_number':
                msg = f'{int(player[identifier])}: {player.display_name} ({player.position})'
            else:
                msg = f'{player[identifier]}: {player.display_name}'
            self.ax.text(
                self.x_limit_max + 1.5, y - 1 - 1.3 * i,
                msg,
                ha='left', va='center',
                fontsize=12,
                fontweight='normal',
                color=self.legend_txt_color,
                zorder=6,
                fontdict={'family':'Arial'}
            )
        
    def init_animation(self):
        """Initialize the animation (empty field)."""
        self.plot_field()
        if self.show_scoreboard: self.plot_scoreboard()
        if self.show_player_legend: self.plot_player_legend()
        return self.ax
    
    def update_frame(self, frame_id):
        """Update the plot for each frame."""
        self.ax.clear()

        # Plot the field
        self.plot_field()

        # Get data for the current frame
        frame_data = self.tracking_data[self.tracking_data['frame_id'] == frame_id]
        
        ball_y = None
        
        # Plot players and football
        if self.player_display_type in ['dots-positional', 'dots-team']:
            radius = 0.3
        else:
            radius = 0.7
        for club, group in frame_data.groupby('club'):
            if club == 'football':
                size = 140
                if self.player_display_type in ['dots-positional', 'dots-team']:
                    size = 80
                ball_y = group['y'].iloc[0]
                # Plot football as a regular circle
                self.ax.scatter(group['x'], group['y'], color='brown', marker='d', s=size, edgecolors='black', zorder=6)
                self.ax.scatter(group['x'], group['y'], color='white', marker='|', s=size / 3, zorder=6)
            else:
                # Assign teams different colors
                if club == self.play_data['possession_team']:
                    color = to_rgba(self.poss_tm_color)
                    ec = to_rgba(self.poss_tm_edge_color)
                else:
                    color = to_rgba(self.def_tm_color)
                    ec = to_rgba(self.def_tm_edge_color)

                # Plot players as circles with a flat side and a front "half-square"
                for _, player in group.iterrows():
                    if self.player_display_type == 'dots-positional':
                        color = self.position_colors[player['position']]
                        if player['club'] == self.play_data['possession_team']:
                            ec = 'black'
                        else:
                            ec = 'blue'
                    
                    orientation = player['o']  # Convert radians to degrees
                    x, y = player['x'], player['y']
                                            
                    # Create the Wedge for each player (flat side 180 degrees opposite orientation)
                    wedge = Wedge((x, y), radius, theta1=orientation-90, theta2=orientation+90, color=color, zorder=5, ec=ec)

                    # Add the wedge to the axis
                    self.ax.add_patch(wedge)

                    # Calculate the half-square vertices
                    square_length = radius  # Length of the half-square extension in front of the flat side of the circle
                    angle_rad = np.radians(orientation)

                    # Calculate the direction vector for the front of the player (where the square will extend)
                    dx = square_length * np.cos(angle_rad)
                    dy = square_length * np.sin(angle_rad)

                    # Compute points along the flat edge of the circle (aligned with the player's orientation)
                    left_edge_x = x + radius * np.cos(angle_rad - np.pi/2)  # Left side of the flat edge
                    left_edge_y = y + radius * np.sin(angle_rad - np.pi/2)
                    right_edge_x = x + radius * np.cos(angle_rad + np.pi/2)  # Right side of the flat edge
                    right_edge_y = y + radius * np.sin(angle_rad + np.pi/2)

                    # Define the four corners of the square, extending from the flat side
                    # These corners are along the flat edge and then extend forward in the direction of the player's orientation
                    corners = [
                        (right_edge_x, right_edge_y),  # Right side of flat edge
                        (left_edge_x, left_edge_y),    # Left side of flat edge
                        (left_edge_x - dx, left_edge_y - dy),  # Front-left (extend forward)
                        (right_edge_x - dx, right_edge_y - dy)  # Front-right (extend forward)
                    ]

                    # Create the half-square polygon and add it to the axis
                    half_square = Polygon(corners, closed=True, color=color, zorder=5, ec=ec)
                    self.ax.add_patch(half_square)

                    # add rectangular patch where circle and square meet
                    patch_radius = radius - 0.08
                    left_edge_x = x + patch_radius * np.cos(angle_rad - np.pi/2)  # Left side of the flat edge
                    left_edge_y = y + patch_radius * np.sin(angle_rad - np.pi/2)
                    right_edge_x = x + patch_radius * np.cos(angle_rad + np.pi/2)  # Right side of the flat edge
                    right_edge_y = y + patch_radius * np.sin(angle_rad + np.pi/2)
                    corners = [
                        (right_edge_x + .1 * dx, right_edge_y + .1 * dy),
                        (left_edge_x + .1 * dx, left_edge_y + .1 * dy),
                        (left_edge_x -.1 * dx, left_edge_y - .1 * dy),
                        (right_edge_x - .1 * dx, right_edge_y -.1 * dy)
                    ]
                    rect = Polygon(corners, closed=True, color=color, zorder=5)
                    self.ax.add_patch(rect)

                    # Plot the player's jersey number, centered at (x, y)
                    if self.player_display_type == 'jerseys':
                        jersey_number = int(player['jersey_number'])  # Convert float to int
                        self.ax.text(x, y, str(jersey_number), color='white', ha='center', va='center', fontweight='bold', fontsize=12, zorder=6)
                    elif self.player_display_type == 'positions':
                        position = player['position']  # Get the player's position
                        self.ax.text(x, y, position, color='white', ha='center', va='center', fontweight='bold', fontsize=9, zorder=6)
        
            # Dynamically adjust the y-axis limit based on the ball's y position
            if ball_y is not None:
                if ball_y < self.y_limit_min + 10:  # Ball near the bottom
                    self.y_limit_min = max(0, ball_y - 10)
                elif ball_y > self.y_limit_min + self.y_delta - 10:  # Ball near the top
                    self.y_limit_min = min(120 - self.y_delta, ball_y - self.y_delta + 10)
                self.ax.set_ylim(self.y_limit_min, self.y_limit_min + self.y_delta)

        if self.show_scoreboard: 
            self.plot_scoreboard(frame_id)

        if self.show_player_legend:
            self.plot_player_legend()

        return self.ax

    def animate_play(self):
        """Create the animation of the play."""
        # Set initial y-axis limit to follow the football at the snap
        self.y_limit_min = round(self.tracking_data[
            (self.tracking_data['club'] == 'football') & (self.tracking_data['event'] == 'ball_snap')
        ]['y'].iloc[0] - 10, 2)

        if self.show_scoreboard:
            self.y_limit_min -= self.scoreboard_height

        frame_ids = self.tracking_data['frame_id'].unique()
        
        # Generate the animation
        ani = animation.FuncAnimation(self.fig, self.update_frame, frames=frame_ids, 
                                      init_func=self.init_animation, blit=False, repeat=False)
        return ani

# Query for tracking data of the specific play
game_id, play_id = 2022102307, 641
# game_id, play_id = 2022102308, 185
# game_id, play_id = 2022102302, 970
# game_id, play_id = 2022102302, 2606
ball_snap_frameid = df_tracking.query('game_id == @game_id & play_id == @play_id & event == "ball_snap"')['frame_id'].iloc[0]
frames = range(ball_snap_frameid, ball_snap_frameid + 10)
# frames = df_tracking.query('game_id == @game_id & play_id == @play_id')['frame_id'].unique()
tracking_data = df_tracking.query('game_id == @game_id & play_id == @play_id & frame_id.isin(@frames)').copy() # & frame_id.isin(@frames)').copy()

play_cols = ['home_team_logo', 'away_team_logo', 'play_clock_at_snap', 'game_clock', 
             'absolute_yardline_number', 'yards_to_go', 'away_team_color', 
             'home_team_color', 'possession_team', 'defensive_team', 'down_and_dist', 
             'quarter_with_suffix', 'pre_snap_home_score', 'pre_snap_visitor_score',
             'possession_team_color', 'defensive_team_color', 'home_team_abbr', 'away_team_abbr',
             'home_team_wordmark', 'possession_team_color2', 'defensive_team_color2']
play_data = df_play.query('game_id == @game_id & play_id == @play_id')[play_cols].to_dict(orient='records')[0]

mpl.rcParams['animation.embed_limit'] = 50

# Instantiate and generate the animation
play_anim = NFLPlayAnimation(tracking_data, play_data, show_scoreboard=True, clock_rolling=True, player_display_type='dots-team', show_player_legend=False)
ani = play_anim.animate_play()

# Use tight layout to minimize padding, and ensure the plot extends to the figure's edges
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

# Set the figure background to light grey
play_anim.fig.patch.set_facecolor('lightgray')

plt.close(play_anim.fig)

# output_path = '../videos/jets_v_broncos.mp4'
# ani.save(output_path, writer='ffmpeg', fps=10)

HTML(ani.to_jshtml(fps=10))