In [None]:
import importlib
import os
import sys
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# ensure project root on path so `src` is importable
sys.path.append(os.path.abspath(".."))

from src.data_loader import (
    list_available_years,
    load_raw_shots,
    clean_shots,
    save_clean,
    load_clean,
    get_player_shots,
)
import src.viz as viz
importlib.reload(viz)
from src.viz import (
    set_modern_style,
    plot_shot_chart_modern,
    plot_hexbin_frequency,
    plot_fg_prob_kde_modern,
    plot_kde_heatmap_modern,
)

set_modern_style()
pd.set_option("display.max_rows", 50)
pd.set_option("display.max_columns", 50)

available_years = list_available_years()
print("Available seasons:", available_years)
print("NOTE: coords auto-scale; FG% heatmap is smoothed KDE (not grid).")


In [None]:
# load all available seasons
years = available_years
print("Loading seasons:", years)

df_raw = load_raw_shots(years=years, n_rows=None)
print("Raw shape:", df_raw.shape)
df_raw.head()


In [None]:
df_clean = clean_shots(df_raw)
print("Cleaned shape:", df_clean.shape)
print(f"Column count (raw -> clean): {len(df_raw.columns)} -> {len(df_clean.columns)}")

df_clean.head()

save_clean(df_clean)

df = load_clean()
print("Reloaded shape:", df.shape)
df.head(), df["YEAR"].value_counts().sort_index()


In [None]:
def summarize_dtypes(frame):
    return (
        frame.dtypes.astype(str)
        .value_counts()
        .rename_axis("dtype")
        .to_frame("column_count")
        .reset_index()
    )

raw_dtype_summary = summarize_dtypes(df_raw)
clean_dtype_summary = summarize_dtypes(df_clean)

print("Raw dtype coverage:")
display(raw_dtype_summary)

print("\nClean dtype coverage (should match raw coverage):")
display(clean_dtype_summary)

missing_dtypes = set(raw_dtype_summary["dtype"]) - set(clean_dtype_summary["dtype"])
if missing_dtypes:
    print("⚠️ Missing dtype(s) after cleaning:", missing_dtypes)
else:
    print("All raw dtypes preserved in cleaned data.")

print(f"Total columns preserved in cleaned data: {len(df_clean.columns)}")



In [None]:
# Align rim to y=0 for small-range coords (no scaling)
max_abs = float(max(df["LOC_X"].abs().max(), df["LOC_Y"].abs().max()))
if max_abs < 30:
    print("Detected small coordinate range; centering rim at y=0")
    df["LOC_Y"] = df["LOC_Y"] - 5.25
    save_clean(df)
    print("Re-saved clean file with rim-centered coords")

print("Post-fix coord summary:")
print(df[["LOC_X", "LOC_Y"]].describe())


In [None]:
def search_players(df, substring, n=20):
    names = (
        df["PLAYER_NAME"]
        .dropna()
        .unique()
    )
    matches = [name for name in names if substring.lower() in name.lower()]
    return matches[:n]

players_by_attempts = (
    df.groupby("PLAYER_NAME")["SHOT_MADE_FLAG"].count()
    .sort_values(ascending=False)
)
players_by_attempts.head(20)


In [None]:
# --- selection ---
player = "Stephen Curry"  # change to any name from players_by_attempts.index
years_filter = None       # e.g., [2023, 2024] or None for all seasons

names = df["PLAYER_NAME"].dropna().unique().tolist()
matches = [n for n in names if n.lower() == player.lower()]
if matches:
    player = matches[0]
else:
    print(f"No exact match for '{player}'. Similar names: {search_players(df, player)}")

available_years_for_player = None
if years_filter is not None:
    available_years_for_player = sorted(
        {int(y) for y in years_filter} & set(df["YEAR"].unique())
    )
    if not available_years_for_player:
        print(f"No data for seasons {years_filter}; falling back to all seasons.")
        available_years_for_player = None

# --- player subset ---
df_p = get_player_shots(df, player_name=player, years=available_years_for_player)
if df_p.empty:
    raise ValueError(f"No shots found for {player} with current filters.")

player_total = len(df_p)
print(f"{player}: {player_total} shots across seasons {available_years_for_player or 'all'}")

season_counts = df_p["YEAR"].value_counts().sort_index()
season_fg = df_p.groupby("YEAR")["SHOT_MADE_FLAG"].mean().sort_index()
print("\nAttempts by season:\n", season_counts)
print("\nFG% by season:\n", season_fg.round(3))

# aggregated counts per season and action/shot type
action_counts = (
    df_p.groupby(["YEAR", "ACTION_TYPE"])
    .size()
    .reset_index(name="ATTEMPTS")
    .sort_values(["YEAR", "ATTEMPTS"], ascending=[True, False])
)
shot_counts = (
    df_p.groupby(["YEAR", "SHOT_TYPE"])
    .size()
    .reset_index(name="ATTEMPTS")
    .sort_values(["YEAR", "ATTEMPTS"], ascending=[True, False])
)

print("\nSample rows:")
df_p.head()


In [None]:
print("Action type attempts (top 20 rows):")
display(action_counts.head(20))
print("Shot type attempts (top 20 rows):")
display(shot_counts.head(20))


def plot_top_counts(df_counts, label_col, title, top_n=8):
    df_top = (
        df_counts.sort_values(["YEAR", "ATTEMPTS"], ascending=[True, False])
        .groupby("YEAR")
        .head(top_n)
    )
    g = sns.catplot(
        data=df_top,
        x="ATTEMPTS",
        y=label_col,
        col="YEAR",
        kind="bar",
        col_wrap=2,
        height=4,
        aspect=1.2,
        sharex=False,
        sharey=False,
        palette="crest",
    )
    g.set_titles("Season {col_name}")
    g.set_axis_labels("Attempts", label_col.replace("_", " ").title())
    g.fig.suptitle(title, y=1.03)
    for ax in g.axes.flatten():
        for c in ax.containers:
            ax.bar_label(c, fmt="%.0f", padding=2, fontsize=8)
    plt.tight_layout()
    plt.show()

plot_top_counts(action_counts, "ACTION_TYPE", f"Top action types per season - {player}")
plot_top_counts(shot_counts, "SHOT_TYPE", f"Top shot types per season - {player}")


In [None]:
# modern aggregate visuals across selected seasons
plot_shot_chart_modern(df_p, title=f"{player} Shot Chart (Modern)", units="feet")
plot_hexbin_frequency(df_p, title=f"{player} Shot Frequency (Hexbin)", units="feet")
plot_fg_prob_kde_modern(df_p, title=f"{player} FG% Heatmap (Smoothed)", bins=80, sigma=1.2, min_attempts=0.5, units="feet")
plot_kde_heatmap_modern(df_p, title=f"{player} Density (KDE)", units="feet")

# per-season shot visuals
seasons_to_plot = available_years_for_player or sorted(df_p["YEAR"].unique())
for y in seasons_to_plot:
    df_y = df_p[df_p["YEAR"] == y]
    if df_y.empty:
        continue
    print(f"\nSeason {y}: {len(df_y)} shots")
    plot_shot_chart_modern(df_y, title=f"{player} Shot Chart {y} (Modern)", units="feet")
    plot_hexbin_frequency(df_y, title=f"{player} Shot Frequency {y} (Hexbin)", units="feet")
    plot_fg_prob_kde_modern(df_y, title=f"{player} FG% {y} (Smoothed)", bins=80, sigma=1.2, min_attempts=0.5, units="feet")
    plot_kde_heatmap_modern(df_y, title=f"{player} Density {y} (KDE)", units="feet")
