In [2]:
import shots_data_retriever
from shots_data_retriever import ShotsDataRetriever
import importlib
import warnings
import pandas as pd

# Filter dtype warnings. Column 10 can contain str, int, and float types 
# which spams a warning in the output
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)
importlib.reload(shots_data_retriever)

shotsDataRetriever = ShotsDataRetriever()

In [3]:
# Average shot rate
def get_league_avg_shot_rate_by_coordinate(year: str):
    df = shotsDataRetriever.get_season_shots(year)

    single_game_time = 1 # in hours, assumption from google doc
    total_games = df['game_id'].nunique()
    total_game_time = single_game_time * total_games

    shot_location = df.groupby(['x_coord', 'y_coord']).size().reset_index(name='shot_count')
    shot_location['shot_rate'] = shot_location['shot_count'] / total_game_time

    shot_location = shot_location.sort_values(by=['x_coord', 'y_coord']).reset_index(drop=True)

    return shot_location

In [4]:
# Shot rate by team
def get_team_avg_shot_rate_by_coordinate(year: str, team_id: int):
    df = shotsDataRetriever.get_season_shots_for_team(year, team_id)

    single_game_time = 1 # in hours, assumption from google doc
    total_games = df['game_id'].nunique()
    total_game_time = single_game_time * total_games

    shot_location = df.groupby(['x_coord', 'y_coord']).size().reset_index(name='shot_count')
    shot_location['shot_rate'] = shot_location['shot_count'] / total_game_time

    shot_location = shot_location.sort_values(by=['x_coord', 'y_coord']).reset_index(drop=True)

    return shot_location

In [5]:
def get_team_excess_shot_rate(year: str, team_id: int):
    team_df = get_team_avg_shot_rate_by_coordinate(year, team_id)
    league_df = get_league_avg_shot_rate_by_coordinate(year)

    result_df = pd.merge(league_df, team_df, on=['x_coord', 'y_coord'], how='left', suffixes=('_league', '_team'))

    result_df.loc[:, 'shot_count_team'] = result_df['shot_count_team'].fillna(0)
    result_df.loc[:, 'shot_rate_team'] = result_df['shot_rate_team'].fillna(0)

    result_df['team_shot_rate'] = result_df['shot_rate_team']
    result_df['league_shot_rate_per_side'] = result_df['shot_rate_league'] / 2 # divide by two to account for two teams playing per game

    result_df['excess_shot_rate'] = result_df['team_shot_rate'].sub(result_df['league_shot_rate_per_side'], fill_value=0)

    return result_df

In [6]:
import os
import numpy as np
from tqdm import tqdm
import plotly.graph_objects as go
from PIL import Image
from scipy.ndimage import gaussian_filter

In [7]:
def plot_shot_heatmap_plotly(shot_rate_map: pd.DataFrame):
    x = shot_rate_map['x_coord']
    y = shot_rate_map['y_coord']
    shot_rate = shot_rate_map['excess_shot_rate']

    # Create a 2D histogram (heatmap) of the shot rates - obtained using ChatGPT
    heatmap, xedges, yedges = np.histogram2d(
        x, 
        y, 
        bins=[np.linspace(0, 100, 150), 
              np.linspace(-42.5, 42.5, 150)], 
        weights=shot_rate)
    
    heatmap = gaussian_filter(heatmap, sigma=5)

    fig = go.Figure()

    fig.add_layout_image(
        dict(
            source=Image.open("../../figures/nhl_rink.png"),
            x=-100,
            y=42.5,
            xref="x",
            yref="y",
            sizex=200,
            sizey=85,
            opacity=1,
            sizing="stretch",
            xanchor="left",
            yanchor="top",
            layer="below"
        )
    )

    custom_colorscale = [[0, 'rgba(0,0,255,0.5)'], [0.5, 'rgba(255,255,255,0.5)'], [1, 'rgba(255,0,0,0.5)']]

    z_val = max(np.max(heatmap), abs(np.min(heatmap)))

    fig.add_trace(go.Heatmap(
        z=heatmap.T,
        x=xedges[:-1],
        y=yedges[:-1],
        colorscale=custom_colorscale,
        colorbar=dict(
            title='Excess Shot Rate<br>Per Hour',
            tickvals=[-z_val, 0, z_val],
            ticktext=[f'-{z_val:.2e}', '0', f'{z_val:.2e}'],
        ),        
        zmin=-z_val,
        zmax=z_val,
        hovertemplate='X: %{x:.1f}<br>Y: %{y:.1f}<br>Excess Shot Rate: %{z:.2e}<extra></extra>',
    ))

    fig.update_layout(
        xaxis=dict(title='X Coordinate (feet)', range=[0, 100]),
        yaxis=dict(title='Y Coordinate (feet)', range=[-42.5, 42.5]),
        showlegend=False,
        height=480,
        width=600,
        plot_bgcolor='white',
        paper_bgcolor='white',
        modebar=dict(remove=['zoom', 'pan', 'resetView', 'zoomIn', 'zoomOut', 'autoScale', 'resetScale'])
    )
    
    return fig

In [8]:
def generate_shot_rate_plots(year: str):
    teams = sorted(shotsDataRetriever.get_season_shots(year)['team_id'].unique())
    print(teams)

    for i, team_id in tqdm(enumerate(teams), desc="Generating Shot Rate Plots", total=len(teams)):
        if i == 0:
            fig = plot_shot_heatmap_plotly(get_team_excess_shot_rate(year, team_id))
            continue

        df = get_team_excess_shot_rate(year, team_id)
        heatmap_fig = plot_shot_heatmap_plotly(df)
        heatmap = heatmap_fig.data[0]
        heatmap.visible = False
        fig.add_trace(heatmap)
    
    fig.update_layout(
        title = f'Excess Shot Rate for Team {teams[0]} - {year}',
        updatemenus = [
            {
                "buttons": [
                    {
                        "label": f"Team: {team_id}",
                        "method": "update",
                        "args": [
                            {"visible": [i == j for j in range(len(teams))]},  # Show selected trace
                            {"title": f'Excess Shot Rate for Team {team_id} - {year}'},  # Update title
                        ],
                    } for i, team_id in enumerate(teams)
                ],
                "direction": "down",
                "showactive": True,
                "x": 1.1,
                "y": 1.15,
            }
        ],
    )

    dir = "plots"
    plotname = f"excess_shot_rates_{year}.html"
    path = os.path.join(dir, plotname)
    if not os.path.exists(dir):
        os.makedirs(dir)

    fig.write_html(path)
    fig.show()

generate_shot_rate_plots('2017')

for year in range(2016, 2020):
    generate_shot_rate_plots(str(year))

[np.int64(52)]


Generating Shot Rate Plots: 100%|██████████| 1/1 [00:37<00:00, 37.15s/it]


[np.int64(10)]


Generating Shot Rate Plots: 100%|██████████| 1/1 [00:27<00:00, 27.16s/it]


[np.int64(52)]


Generating Shot Rate Plots: 100%|██████████| 1/1 [00:27<00:00, 27.13s/it]


[np.int64(8)]


Generating Shot Rate Plots: 100%|██████████| 1/1 [00:24<00:00, 24.83s/it]


[np.int64(9)]


Generating Shot Rate Plots: 100%|██████████| 1/1 [00:26<00:00, 26.25s/it]
