In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Load the data
project_root = Path(__file__).parent.parent
data_path = project_root / "data" / "processed" / "shot_transitions.csv"

df = pd.read_csv(data_path)

# Display the first few rows of the dataset
print("First few rows of the dataset:")
print(df.head())

# Basic statistics
print("\nBasic statistics:")
print(df.describe())

# Check for missing values
print("\nMissing values:")
print(df.isnull().sum())

# Unique values in key columns
print("\nUnique values:")
print(f"Last shot types: {df['last_shot_type'].nunique()}")
print(f"Last shot directions: {df['last_shot_direction'].nunique()}")
print(f"Shot types: {df['shot_type'].nunique()}")
print(f"Shot directions: {df['shot_direction'].nunique()}")

# Most common transitions
print("\nMost common transitions:")
most_common_transitions = (
    df.groupby(["last_shot_type", "last_shot_direction", "shot_type", "shot_direction"])
    .agg({"count": "sum"})
    .reset_index()
    .sort_values(by="count", ascending=False)
)
print(most_common_transitions.head(10))

# Plot the distribution of transition counts
plt.figure(figsize=(10, 6))
sns.histplot(df["count"], bins=50, kde=True)
plt.title("Distribution of Transition Counts")
plt.xlabel("Count")
plt.ylabel("Frequency")
plt.show()

# Heatmap of most common transitions (aggregated by shot types)
heatmap_data = (
    df.groupby(["last_shot_type", "shot_type"])
    .agg({"count": "sum"})
    .reset_index()
    .pivot("last_shot_type", "shot_type", "count")
    .fillna(0)
)

plt.figure(figsize=(12, 8))
sns.heatmap(heatmap_data, annot=False, cmap="Blues", cbar=True)
plt.title("Heatmap of Transitions (Last Shot Type -> Shot Type)")
plt.xlabel("Shot Type")
plt.ylabel("Last Shot Type")
plt.show()

# Top transitions for a specific shot type
specific_shot_type = "f"  # Example: forehand
print(f"\nTop transitions for last shot type '{specific_shot_type}':")
top_transitions = df[df["last_shot_type"] == specific_shot_type].sort_values(
    by="count", ascending=False
)
print(top_transitions.head(10))

NameError: name '__file__' is not defined