In [1]:
# Core
import logging
import time
import numpy as np
import pandas as pd
import requests, io, math
import polars as pl
from pytz import timezone
import asyncio
import aiohttp
import pandas as pd
import polars as pl
import requests
import numpy as np
from pytz import timezone
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from plotly.subplots import make_subplots
import random
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.stats.outliers_influence import variance_inflation_factor
import cx_Oracle
import pickle
import sys
import os


# Progress
from tqdm.auto import tqdm

# Scikit-learn
from sklearn.model_selection import KFold, train_test_split

from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import RobustScaler
from sklearn.pipeline import make_pipeline
from sklearn.base import clone
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# LightGBM
from lightgbm import LGBMRegressor
import lightgbm as lgb  # for callbacks like lgb.early_stopping

# Optuna
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner  # (optional) SuccessiveHalvingPruner, HyperbandPruner

import logging
from numpy.random import default_rng
from sklearn.metrics import (
    classification_report, confusion_matrix, roc_auc_score,
    precision_score, recall_score, f1_score
)
import warnings
warnings.filterwarnings('ignore')

# Set up logging for debugging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# used for debugging
pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)
pd.set_option('max_colwidth', None)
pd.reset_option('display.max_colwidth')


api_key = 'd5578e9c-00ad-4bcf-967e-51a3fa70b43d' # Can be purchased with Season/Career Level API Access Subscription
season = 2025

# Oracle connection details
username = "ADMIN"
password = "MaxWillow!1216"
dsn = "ch2sockkby63dgzo_high"


Helper Functions

In [2]:
def init_oracle_client():
    """Initialize Oracle client based on the platform."""
    if sys.platform.startswith("darwin"):
        cx_Oracle.init_oracle_client(
            lib_dir=os.environ.get("HOME") + "/Downloads/instantclient_19_8",
            config_dir=""
        )
    elif sys.platform.startswith("win"):
        cx_Oracle.init_oracle_client(
            lib_dir=r"C:\\Program Files\\Oracle\\instantclient_21_9"
        )

def fetch_model_from_oracle(dsn, username, password, model_name):
    # Create a connection
    connection = cx_Oracle.connect(user=username, password=password, dsn=dsn)
    cursor = connection.cursor()

    # Base SQL query
    select_sql = """
    SELECT model_data, label_encoders
    FROM models 
    WHERE model_name = :model_name 
    """
    
    # Parameters dictionary
    params = {
        'model_name': model_name
    }

    try:
        # Execute the select statement with the parameters
        cursor.execute(select_sql, params)

        # Fetch the first row (assuming you expect one model per criteria)
        row = cursor.fetchone()

        if row:
            # Read the LOB object as bytes
            model_binary = row[0].read()
            label_encoders_binary = row[1].read()
            
            # Deserialize the model using pickle
            model = pickle.loads(model_binary)
            label_encoders = pickle.loads(label_encoders_binary)
            print('Model fetched successfully.')
            return model, label_encoders
        else:
            print('No model found with the specified criteria.')
            return None

    except cx_Oracle.DatabaseError as e:
        print(f"Database error occurred: {e}")
        return None
    finally:
        # Close the cursor and connection
        cursor.close()
        connection.close()

def get_pitchers(season):
    """
    Get season-level data by pitcher for a specific season (since 2020).

    Parameters:
    - season (int): The season for which baseball data is requested (e.g., 2022).

    Returns:
    - df (polars.DataFrame)

    Dependencies:
    - requests: This library is used to make HTTP requests to the remote API.
    - polars: This library is used for data manipulation and analysis.
    """

    # Send request to the API
    data_request = requests.get(
        f'https://g837e5a6fbcb0dd-ch2sockkby63dgzo.adb.us-chicago-1.oraclecloudapps.com/ords/admin/players/check/{season}'
    )
    
    # Parse JSON response
    data_json = data_request.json()

    # Extract the 'items' list from the response
    data_items = data_json['items']

    # Convert the list of dictionaries to a Polars DataFrame
    df = pl.DataFrame(data_items)

    # keep only rows with arm_length and return selected columns, renaming pitcher_id -> pitcher
    df = (
        df
        .filter(pl.col("arm_length").is_not_null())
        .select([
            pl.col("pitcher_id").alias("pitcher"),
            pl.col("pitcher_name"),
            pl.col("season"),
            pl.col("height"),
        ])
    )

    return df

def get_schedule(year_input=[2025], sport_id=[1], game_type=['R']):
    """
    Retrieves the schedule of baseball games based on the specified parameters.
    """
    # Type checks
    if not isinstance(year_input, list) or not all(isinstance(year, int) for year in year_input):
        raise ValueError("year_input must be a list of integers.")
    if not isinstance(sport_id, list) or not all(isinstance(sid, int) for sid in sport_id):
        raise ValueError("sport_id must be a list of integers.")
    if not isinstance(game_type, list) or not all(isinstance(gt, str) for gt in game_type):
        raise ValueError("game_type must be a list of strings.")

    eastern = timezone('US/Eastern')

    # Convert input lists to comma-separated strings
    year_input_str = ','.join(map(str, year_input))
    sport_id_str = ','.join(map(str, sport_id))
    game_type_str = ','.join(game_type)

    # Make API call to retrieve game schedule
    url = f'https://statsapi.mlb.com/api/v1/schedule/?sportId={sport_id_str}&gameTypes={game_type_str}&season={year_input_str}&hydrate=lineup,players'
    game_call = requests.get(url).json()

    # Extract relevant data from the API response
    data = [
        {
            'game_pk': game['gamePk'],
            'time': game['gameDate'],
            'date': game['officialDate'],
            'away': game['teams']['away']['team']['name'],
            'home': game['teams']['home']['team']['name'],
            'state': game['status']['codedGameState'],
            'venue_id': game['venue']['id'],
            'venue_name': game['venue']['name']
        }
        for day in game_call['dates'] for game in day['games']
    ]

    # Create a DataFrame with the extracted data
    game_df = pd.DataFrame(data)

    # Check if the DataFrame is empty
    if game_df.empty:
        return 'Schedule Length of 0, please select different parameters.'

    # Convert date and time columns to appropriate formats
    game_df['date'] = pd.to_datetime(game_df['date']).dt.date
    game_df['time'] = (
        pd.to_datetime(game_df['time'])
        .dt.tz_convert(eastern)
        .dt.strftime('%I:%M %p')
    )

    # Remove duplicate games and sort by date
    game_df = game_df.drop_duplicates(subset='game_pk').sort_values('date')

    # Check again if the DataFrame is empty after processing
    if game_df.empty:
        return 'Schedule Length of 0, please select different parameters.'

    return game_df

def get_pbp_game(game_pk, api_key, max_retries=3, backoff_factor=1.0):
    """
    Get pitch-by-pitch data by game (since 2020) with retry logic.
    
    Parameters:
    - game_pk (int): The MLBAM game_pk for which baseball data is requested
    - api_key (str): The api_key required for authentication
    - max_retries (int): Maximum number of retry attempts (default: 3)
    - backoff_factor (float): Backoff factor for exponential retry delay (default: 1.0)
    """
    url = f'https://g837e5a6fbcb0dd-ch2sockkby63dgzo.adb.us-chicago-1.oraclecloudapps.com/ords/admin/patreon/GET_PBP_GAME/{game_pk}/{api_key}'
    
    # Create a session with retry strategy
    session = requests.Session()
    retry_strategy = Retry(
        total=max_retries,
        status_forcelist=[429, 500, 502, 503, 504],  # HTTP status codes to retry
        backoff_factor=backoff_factor,
        raise_on_status=False
    )
    adapter = HTTPAdapter(max_retries=retry_strategy)
    session.mount("http://", adapter)
    session.mount("https://", adapter)
    
    for attempt in range(max_retries + 1):
        try:
            # Add small random delay to avoid thundering herd
            if attempt > 0:
                delay = backoff_factor * (2 ** (attempt - 1)) + random.uniform(0, 1)
                time.sleep(delay)
            
            response = session.get(url, timeout=30)
            response.raise_for_status()  # Raise exception for bad status codes
            
            # Check if response is empty or not JSON
            if not response.text.strip():
                raise ValueError(f"Empty response for game_pk {game_pk}")
            
            data_json = response.json()
            
            # Validate the expected structure
            if 'items' not in data_json or not data_json['items']:
                raise ValueError(f"Invalid response structure for game_pk {game_pk}")
            
            data_json_df = pd.json_normalize(data_json)['items'][0]
            df = pd.json_normalize(data_json_df)
            return df
            
        except (requests.exceptions.RequestException, ValueError, KeyError) as e:
            if attempt == max_retries:
                print(f"Failed to retrieve data for game_pk {game_pk} after {max_retries + 1} attempts: {e}")
                return pd.DataFrame()  # Return empty DataFrame on final failure
            else:
                print(f"Attempt {attempt + 1} failed for game_pk {game_pk}: {e}. Retrying...")
        
        except Exception as e:
            print(f"Unexpected error for game_pk {game_pk}: {e}")
            return pd.DataFrame()
    
    return pd.DataFrame()

# OPTIMIZATION 1: Async version for maximum speed with retries
async def get_pbp_game_async(session, game_pk, api_key, max_retries=3, backoff_factor=1.0):
    """
    Async version of get_pbp_game with retry logic for concurrent requests.
    """
    url = f'https://g837e5a6fbcb0dd-ch2sockkby63dgzo.adb.us-chicago-1.oraclecloudapps.com/ords/admin/patreon/GET_PBP_GAME/{game_pk}/{api_key}'
    
    for attempt in range(max_retries + 1):
        try:
            # Add exponential backoff with jitter
            if attempt > 0:
                delay = backoff_factor * (2 ** (attempt - 1)) + random.uniform(0, 1)
                await asyncio.sleep(delay)
            
            async with session.get(url) as response:
                # Check status code
                if response.status >= 400:
                    raise aiohttp.ClientResponseError(
                        request_info=response.request_info,
                        history=response.history,
                        status=response.status
                    )
                
                text = await response.text()
                
                # Check if response is empty
                if not text.strip():
                    raise ValueError(f"Empty response for game_pk {game_pk}")
                
                data_json = await response.json()
                
                # Validate the expected structure
                if 'items' not in data_json or not data_json['items']:
                    raise ValueError(f"Invalid response structure for game_pk {game_pk}")
                
                data_json_df = pd.json_normalize(data_json)['items'][0]
                df = pd.json_normalize(data_json_df)
                return df
                
        except (aiohttp.ClientError, ValueError, KeyError, asyncio.TimeoutError) as e:
            if attempt == max_retries:
                print(f"Failed to retrieve data for game_pk {game_pk} after {max_retries + 1} attempts: {e}")
                return pd.DataFrame()
            else:
                print(f"Attempt {attempt + 1} failed for game_pk {game_pk}: {e}. Retrying...")
        
        except Exception as e:
            print(f"Unexpected error for game_pk {game_pk}: {e}")
            return pd.DataFrame()
    
    return pd.DataFrame()

async def get_combined_pbp_data_async(schedule: pd.DataFrame, api_key: str, max_concurrent=20, max_retries=3):
    """
    Async version - FASTEST option. Retrieves play-by-play data concurrently with retries.
    
    Parameters:
    - schedule (pd.DataFrame): A DataFrame containing game schedule data with a 'game_pk' column.
    - api_key (str): API key required to access the get_pbp_game API.
    - max_concurrent (int): Maximum number of concurrent requests (default: 20)
    - max_retries (int): Maximum number of retry attempts per request (default: 3)
    
    Returns:
    - combined_df (pd.DataFrame): Combined play-by-play data.
    """
    unique_game_pks = schedule['game_pk'].unique()
    
    # Create semaphore to limit concurrent requests
    semaphore = asyncio.Semaphore(max_concurrent)
    
    async def fetch_with_semaphore(session, game_pk):
        async with semaphore:
            return await get_pbp_game_async(session, game_pk, api_key, max_retries)
    
    # Configure session with connection limits and timeouts
    timeout = aiohttp.ClientTimeout(total=60, connect=10)  # Increased timeout
    connector = aiohttp.TCPConnector(
        limit=max_concurrent, 
        limit_per_host=max_concurrent,
        ttl_dns_cache=300,  # DNS caching
        use_dns_cache=True
    )
    
    async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
        # Create tasks for all games
        tasks = [fetch_with_semaphore(session, game_pk) for game_pk in unique_game_pks]
        
        # Execute with progress bar
        dataframes = []
        successful_requests = 0
        failed_requests = 0
        
        for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Fetching PBP Data", unit="game"):
            result = await coro
            if not result.empty:
                dataframes.append(result)
                successful_requests += 1
            else:
                failed_requests += 1
    
    print(f"\nCompleted: {successful_requests} successful, {failed_requests} failed requests")
    
    # Combine all DataFrames
    if dataframes:
        combined_df = pd.concat(dataframes, ignore_index=True)
    else:
        combined_df = pd.DataFrame()
    
    return combined_df

# OPTIMIZATION 2: Threading version with retries (easier to implement, still much faster)
def get_combined_pbp_data_threaded(schedule: pd.DataFrame, api_key: str, max_workers=10, max_retries=3):
    """
    Threaded version with retry logic - Good balance of speed and simplicity.
    
    Parameters:
    - schedule (pd.DataFrame): A DataFrame containing game schedule data with a 'game_pk' column.
    - api_key (str): API key required to access the get_pbp_game API.
    - max_workers (int): Maximum number of worker threads (default: 10)
    - max_retries (int): Maximum number of retry attempts per request (default: 3)
    
    Returns:
    - combined_df (pd.DataFrame): Combined play-by-play data.
    """
    unique_game_pks = schedule['game_pk'].unique()
    dataframes = []
    
    def fetch_game_data(game_pk):
        try:
            return get_pbp_game(game_pk, api_key, max_retries)
        except Exception as e:
            print(f"Unexpected error retrieving data for game_pk {game_pk}: {e}")
            return pd.DataFrame()
    
    # Use ThreadPoolExecutor for concurrent requests
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_game_pk = {executor.submit(fetch_game_data, game_pk): game_pk 
                           for game_pk in unique_game_pks}
        
        # Collect results with progress bar
        successful_requests = 0
        failed_requests = 0
        
        for future in tqdm(as_completed(future_to_game_pk), 
                          total=len(future_to_game_pk), 
                          desc="Fetching PBP Data", unit="game"):
            result = future.result()
            if not result.empty:
                dataframes.append(result)
                successful_requests += 1
            else:
                failed_requests += 1
    
    print(f"\nCompleted: {successful_requests} successful, {failed_requests} failed requests")
    
    # Combine all DataFrames
    if dataframes:
        combined_df = pd.concat(dataframes, ignore_index=True)
    else:
        combined_df = pd.DataFrame()
    
    return combined_df

# OPTIMIZATION 3: Direct Polars conversion with column dropping and retries
def get_combined_pbp_data_polars(schedule: pd.DataFrame, api_key: str, 
                                columns_to_drop=None, max_workers=10, max_retries=3):
    """
    Optimized version that converts directly to Polars and drops columns efficiently, with retry logic.
    
    Parameters:
    - schedule (pd.DataFrame): A DataFrame containing game schedule data with a 'game_pk' column.
    - api_key (str): API key required to access the get_pbp_game API.
    - columns_to_drop (list): List of columns to drop from the final DataFrame.
    - max_workers (int): Maximum number of worker threads (default: 10)
    - max_retries (int): Maximum number of retry attempts per request (default: 3)
    
    Returns:
    - combined_df (pl.DataFrame): Combined play-by-play data as Polars DataFrame.
    """
    if columns_to_drop is None:
        columns_to_drop = ['stuff', 'stuff_plus', 'xwhiff', 'xwhiff_plus', 
                          'stuff', 'stuff_plus', 'command', 'command_plus', 
                          'pitch_group', 'fastball_release_speed_diff', 
                          'fastball_induced_vertical_acceleration_diff', 
                          'fastball_horizontal_acceleration_normalized_diff',
                          'expected_vaa', 'vaa_vs_expected', 'expected_haa', 'haa_vs_expected']
    
    # Get pandas DataFrame using threading with retries
    pandas_df = get_combined_pbp_data_threaded(schedule, api_key, max_workers, max_retries)
    
    if pandas_df.empty:
        return pl.DataFrame()
    
    # Convert to Polars and drop columns in one operation
    polars_df = pl.from_pandas(pandas_df)
    
    # Only drop columns that actually exist
    existing_columns_to_drop = [col for col in columns_to_drop if col in polars_df.columns]
    if existing_columns_to_drop:
        polars_df = polars_df.drop(existing_columns_to_drop)
    
    return polars_df


In [3]:
init_oracle_client()

In [4]:
pitchers = get_pitchers(season)
schedule = get_schedule(year_input=[season])

columns_to_drop = ['stuff', 'stuff_plus', 'xwhiff', 'xwhiff_plus', 
                    'pitching', 'pitching_plus', 'command', 'command_plus', 
                    'location', 'location_plus','pitch_group', 
                    'fastball_release_speed_diff', 'fastball_induced_vertical_acceleration_diff', 
                    'fastball_horizontal_acceleration_normalized_diff',
                    'expected_vaa', 'vaa_vs_expected', 'expected_haa', 'haa_vs_expected']

savant_data = get_combined_pbp_data_polars(schedule, api_key, columns_to_drop, max_workers=20)

df = pitchers.join(savant_data, on="pitcher", how="inner").select([
    pl.col("game_year"),
    pl.col("pitcher"),
    pl.col("pitcher_name"),
    pl.col("p_throws"),
    pl.col("height"),
    pl.col("release_extension"),
    pl.col("release_pos_x"),
    pl.col("release_pos_z"),
    pl.col("arm_angle"),
])
df.to_pandas().head()

Fetching PBP Data:   0%|          | 0/2430 [00:00<?, ?game/s]


Completed: 2430 successful, 0 failed requests


Unnamed: 0,game_year,pitcher,pitcher_name,p_throws,height,release_extension,release_pos_x,release_pos_z,arm_angle
0,2025,605280,Clay Holmes,R,77,6.0,-0.37,6.73,39.5
1,2025,605280,Clay Holmes,R,77,6.0,-0.49,6.78,41.9
2,2025,605280,Clay Holmes,R,77,6.0,-0.56,6.62,45.4
3,2025,605280,Clay Holmes,R,77,5.8,-0.42,6.72,39.5
4,2025,605280,Clay Holmes,R,77,6.0,-0.61,6.66,42.5


Arm Angle Functions

In [5]:
# Function to calculate arm angles
def calculate_arm_angles(p_throws, release_pos_x, release_pos_z, height):
    release_pos_x_inches = release_pos_x * 12
    release_pos_z_inches = release_pos_z * 12
    shoulder_pos_inches = height * 0.70  # Estimate shoulder position based on height
    Adj = release_pos_z_inches - shoulder_pos_inches  # Adjacent side length
    Opp = abs(release_pos_x_inches)  # Opposite side length
    arm_angle_rad = np.arctan2(Opp, Adj)  # Arm angle in radians
    angle_degrees = np.degrees(arm_angle_rad)  # Convert to degrees

    if p_throws == 'L':
        angle_degrees = -angle_degrees

    return angle_degrees

def _prepare_features_for_inference_polars(
    df_pl: pl.DataFrame,
    features: list[str],
    encoders: dict | None
) -> np.ndarray:
    """
    Prepare features for inference entirely in Polars.
    Returns a NumPy array ready for model.predict().
    """
    encoders = encoders or {}
    
    # Build all transformations as expressions
    exprs = []
    for f in features:
        if f in encoders:
            # Categorical: map using replace, fill unseen with -1, cast to Int8
            mapping = encoders[f]
            exprs.append(
                pl.col(f)
                .replace(mapping, default=-1)
                .cast(pl.Int8)
                .alias(f)
            )
        else:
            # Numeric: cast to Float64
            exprs.append(
                pl.col(f).cast(pl.Float64).alias(f)
            )
    
    # Single select, then to numpy
    return df_pl.select(exprs).to_numpy()

def model_arm_angles(
    df_pl: pl.DataFrame,
    model,
    label_encoders: dict | None,
    features: list[str],
    output_col: str | None = None,
) -> pl.DataFrame | np.ndarray:
    """
    Run predictions using stored pipeline on a Polars DataFrame.
    
    Returns:
        - If output_col is None: NumPy array of predictions
        - Otherwise: Polars DataFrame with predictions appended
    """
    if not isinstance(df_pl, pl.DataFrame):
        raise TypeError("df_pl must be a Polars DataFrame")

    # Prepare features entirely in Polars -> NumPy
    X = _prepare_features_for_inference_polars(df_pl, features, label_encoders)

    # Predict
    y_pred = model.predict(X)

    if output_col is None:
        return y_pred

    return df_pl.with_columns(pl.Series(output_col, y_pred))

In [10]:
with open('models/arm_angle/arm_angle_estimator.pkl', 'rb') as f:
    bundle = pickle.load(f)

savant = model_arm_angles(
    df_pl=df,
    model=bundle['model'],
    label_encoders=bundle['label_encoders'],
    features=bundle['features'],
    output_col="arm_angle_pred"
)

Add Angles

In [11]:
# 1) Extract columns to NumPy arrays via to_numpy()
pth = savant["p_throws"].to_numpy()      # string dtype
rx  = savant["release_pos_x"].to_numpy()  # float
rz  = savant["release_pos_z"].to_numpy()  # float
ht  = savant["height"].to_numpy()         # float

# 2) Use numpy.frompyfunc to vectorize your function call
ufunc = np.frompyfunc(calculate_arm_angles, 4, 1)

# 3) Compute predictions
angles = ufunc(pth, rx, rz, ht).astype(np.float64)

# 4) Add results back into your Polars DataFrame
savant = savant.with_columns(
    pl.Series("arm_angle_calc", angles)
)


In [12]:
snell_df = savant.to_pandas()
snell_df = snell_df[snell_df["pitcher_name"] == "Taylor Rogers"]
snell_df.describe()

Unnamed: 0,game_year,pitcher,height,release_extension,release_pos_x,release_pos_z,arm_angle,arm_angl_pred,arm_angle_calc,arm_angle_pred
count,894.0,894.0,894.0,893.0,894.0,894.0,894.0,894.0,894.0,894.0
mean,2025.0,573124.0,75.0,5.733147,1.912998,5.635716,-60.604698,-61.525932,-56.552197,-61.525932
std,0.0,0.0,0.0,0.139742,0.126792,0.080811,2.140303,1.568248,2.800976,1.568248
min,2025.0,573124.0,75.0,5.4,1.57,5.36,-67.7,-69.268775,-63.335391,-69.268775
25%,2025.0,573124.0,75.0,5.6,1.82,5.58,-61.8,-62.298306,-58.512531,-62.298306
50%,2025.0,573124.0,75.0,5.7,1.91,5.63,-60.7,-61.06274,-56.628456,-61.06274
75%,2025.0,573124.0,75.0,5.8,2.0,5.68,-59.6,-60.638248,-54.621196,-60.638248
max,2025.0,573124.0,75.0,6.7,2.3,5.9,-48.6,-55.229621,-48.625951,-55.229621


Evaluation

In [29]:
# ================================
# Arm Angle Model Evaluation
# Story: ML Model vs Geometric Calculation
# ================================
import polars as pl
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ---------- CONFIG ----------
COLORS = {
    'model': '#2563eb',      # blue
    'calc': '#dc2626',       # red
    'model_light': '#93c5fd',
    'calc_light': '#fca5a5',
    'bg': '#fafafa',
    'grid': '#e5e7eb',
    'text': '#1f2937',
}

SAMPLE_N = 50_000
EPS = 1e-12

# ---------- DATA PREP ----------
df = savant
if isinstance(df, pl.LazyFrame):
    df = df.collect()

df = df.with_columns(
    (pl.col("arm_angle_pred") - pl.col("arm_angle")).alias("err_pred"),
    (pl.col("arm_angle_calc") - pl.col("arm_angle")).alias("err_calc"),
).with_columns(
    pl.col("err_pred").abs().alias("ae_pred"),
    pl.col("err_calc").abs().alias("ae_calc"),
)

# ---------- COMPUTE KEY METRICS ----------
def compute_metrics(df: pl.DataFrame, pred_col: str, actual_col: str = "arm_angle") -> dict:
    err = df.select(
        ((pl.col(pred_col) - pl.col(actual_col)) ** 2).mean().sqrt().alias("rmse"),
        (pl.col(pred_col) - pl.col(actual_col)).abs().mean().alias("mae"),
        (pl.col(pred_col) - pl.col(actual_col)).mean().alias("bias"),
    ).row(0, named=True)
    
    # R²
    ss_res = ((df[pred_col] - df[actual_col]) ** 2).sum()
    ss_tot = ((df[actual_col] - df[actual_col].mean()) ** 2).sum()
    r2 = 1 - ss_res / (ss_tot + EPS)
    
    return {**err, 'r2': r2}

model_metrics = compute_metrics(df, "arm_angle_pred")
calc_metrics = compute_metrics(df, "arm_angle_calc")

# ---------- FIGURE 1: HERO METRICS COMPARISON ----------
fig_hero = make_subplots(
    rows=1, cols=4,
    subplot_titles=["R²", "MAE (°)", "RMSE (°)", "Bias (°)"],
    horizontal_spacing=0.08
)

metrics_order = ['r2', 'mae', 'rmse', 'bias']
metric_labels = ['R²', 'MAE', 'RMSE', 'Bias']

for i, (key, label) in enumerate(zip(metrics_order, metric_labels), 1):
    model_val = model_metrics[key]
    calc_val = calc_metrics[key]
    
    # For R², higher is better; for others, closer to 0 is better
    if key == 'r2':
        model_better = model_val > calc_val
    else:
        model_better = abs(model_val) < abs(calc_val)
    
    fig_hero.add_trace(go.Bar(
        x=['Model'], y=[model_val],
        marker_color=COLORS['model'],
        name='Model' if i == 1 else None,
        showlegend=(i == 1),
        text=[f"{model_val:.3f}"],
        textposition='outside',
        textfont=dict(size=14, color=COLORS['model'])
    ), row=1, col=i)
    
    fig_hero.add_trace(go.Bar(
        x=['Calculated'], y=[calc_val],
        marker_color=COLORS['calc'],
        name='Calculated' if i == 1 else None,
        showlegend=(i == 1),
        text=[f"{calc_val:.3f}"],
        textposition='outside',
        textfont=dict(size=14, color=COLORS['calc'])
    ), row=1, col=i)

fig_hero.update_layout(
    title=dict(
        text="<b>ML Model Outperforms Geometric Calculation</b>",
        font=dict(size=24, color=COLORS['text']),
        x=0.5,
    ),
    barmode='group',
    height=600,
    width=1100,
    paper_bgcolor='white',
    plot_bgcolor='white',
    font=dict(family="Inter, sans-serif", color=COLORS['text']),
    legend=dict(orientation='h', yanchor='top', y = 1.12, xanchor='center', x=0.5),
    margin=dict(t=120, b=60)
)
fig_hero.update_yaxes(showgrid=True, gridcolor=COLORS['grid'], zeroline=True, zerolinecolor=COLORS['grid'])
fig_hero.show()

# ---------- FIGURE 2: SCATTER COMPARISON (CLEANER) ----------
sample_df = df.sample(n=min(SAMPLE_N, len(df)), seed=42).to_pandas()

fig_scatter = make_subplots(
    rows=1, cols=2,
    subplot_titles=[
        f"<b>ML Model</b> (R² = {model_metrics['r2']:.3f})",
        f"<b>Geometric Calc</b> (R² = {calc_metrics['r2']:.3f})"
    ],
    horizontal_spacing=0.1
)

# Model scatter
fig_scatter.add_trace(go.Scattergl(
    x=sample_df['arm_angle_pred'],
    y=sample_df['arm_angle'],
    mode='markers',
    marker=dict(color=COLORS['model'], size=3, opacity=0.4),
    name='Model',
    showlegend=False
), row=1, col=1)

# Calc scatter
fig_scatter.add_trace(go.Scattergl(
    x=sample_df['arm_angle_calc'],
    y=sample_df['arm_angle'],
    mode='markers',
    marker=dict(color=COLORS['calc'], size=3, opacity=0.4),
    name='Calculated',
    showlegend=False
), row=1, col=2)

# Identity lines
for col in [1, 2]:
    fig_scatter.add_trace(go.Scatter(
        x=[-180, 180], y=[-180, 180],
        mode='lines',
        line=dict(color='#6b7280', dash='dash', width=1.5),
        showlegend=False
    ), row=1, col=col)

fig_scatter.update_xaxes(title_text="Predicted", row=1, col=1, range=[-180, 180])
fig_scatter.update_xaxes(title_text="Calculated", row=1, col=2, range=[-180, 180])
fig_scatter.update_yaxes(title_text="Actual Arm Angle", row=1, col=1, range=[-180, 180])
fig_scatter.update_yaxes(title_text="", row=1, col=2, range=[-180, 180])

fig_scatter.update_layout(
    title=dict(text="<b>Prediction Accuracy</b>", font=dict(size=20), x=0.5),
    height=500,
    width=1000,
    paper_bgcolor='white',
    plot_bgcolor='white',
    font=dict(family="Inter, sans-serif"),
)
fig_scatter.update_yaxes(showgrid=True, gridcolor=COLORS['grid'])
fig_scatter.update_xaxes(showgrid=True, gridcolor=COLORS['grid'])
fig_scatter.show()

# ---------- FIGURE 3: ERROR DISTRIBUTION ----------
fig_error = go.Figure()

fig_error.add_trace(go.Histogram(
    x=sample_df['err_pred'],
    name='Model Error',
    marker_color=COLORS['model'],
    opacity=0.7,
    nbinsx=80
))

fig_error.add_trace(go.Histogram(
    x=sample_df['err_calc'],
    name='Calculated Error',
    marker_color=COLORS['calc'],
    opacity=0.7,
    nbinsx=80
))

fig_error.add_vline(x=0, line_dash="dash", line_color="#6b7280", line_width=2)

fig_error.update_layout(
    title=dict(
        text="<b>Error Distribution</b><br><sup>Model errors cluster tighter around zero</sup>",
        font=dict(size=20),
        x=0.5
    ),
    xaxis_title="Error (Predicted - Actual, degrees)",
    yaxis_title="Count",
    barmode='overlay',
    height=450,
    width=900,
    paper_bgcolor='white',
    plot_bgcolor='white',
    font=dict(family="Inter, sans-serif"),
    legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='center', x=0.5),
)
fig_error.update_xaxes(showgrid=True, gridcolor=COLORS['grid'], range=[-30, 30])
fig_error.update_yaxes(showgrid=True, gridcolor=COLORS['grid'])
fig_error.show()

# ---------- FIGURE 4: IMPROVEMENT BY HANDEDNESS ----------
hand_metrics = (
    df.group_by("p_throws")
    .agg(
        pl.col("ae_pred").mean().alias("MAE_model"),
        pl.col("ae_calc").mean().alias("MAE_calc"),
        pl.len().alias("n")
    )
    .with_columns(
        ((pl.col("MAE_calc") - pl.col("MAE_model")) / pl.col("MAE_calc") * 100).alias("improvement_pct")
    )
    .sort("p_throws")
    .to_pandas()
)

fig_hand = go.Figure()

fig_hand.add_trace(go.Bar(
    x=hand_metrics['p_throws'],
    y=hand_metrics['MAE_model'],
    name='Model',
    marker_color=COLORS['model'],
    text=[f"{v:.2f}°" for v in hand_metrics['MAE_model']],
    textposition='outside'
))

fig_hand.add_trace(go.Bar(
    x=hand_metrics['p_throws'],
    y=hand_metrics['MAE_calc'],
    name='Calculated',
    marker_color=COLORS['calc'],
    text=[f"{v:.2f}°" for v in hand_metrics['MAE_calc']],
    textposition='outside'
))

# Add improvement annotations
for i, row in hand_metrics.iterrows():
    fig_hand.add_annotation(
        x=row['p_throws'],
        y=max(row['MAE_model'], row['MAE_calc']) + 0.8,
        text=f"<b>{row['improvement_pct']:.1f}% better</b>",
        showarrow=False,
        font=dict(size=12, color=COLORS['model'])
    )

fig_hand.update_layout(
    title=dict(text="<b>MAE by Handedness</b>", font=dict(size=20), x=0.5),
    xaxis_title="Pitcher Handedness",
    yaxis_title="Mean Absolute Error (°)",
    barmode='group',
    height=450,
    width=600,
    paper_bgcolor='white',
    plot_bgcolor='white',
    font=dict(family="Inter, sans-serif"),
    legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='center', x=0.5),
)
fig_hand.update_yaxes(showgrid=True, gridcolor=COLORS['grid'])
fig_hand.show()

# ---------- SUMMARY PRINT ----------
improvement = (calc_metrics['mae'] - model_metrics['mae']) / calc_metrics['mae'] * 100

print("\n" + "="*60)
print("SUMMARY: ML Model vs Geometric Calculation")
print("="*60)
print(f"\n{'Metric':<15} {'Model':>12} {'Calculated':>12} {'Winner':>10}")
print("-"*50)
print(f"{'R²':<15} {model_metrics['r2']:>12.4f} {calc_metrics['r2']:>12.4f} {'Model ✓':>10}")
print(f"{'MAE (°)':<15} {model_metrics['mae']:>12.4f} {calc_metrics['mae']:>12.4f} {'Model ✓':>10}")
print(f"{'RMSE (°)':<15} {model_metrics['rmse']:>12.4f} {calc_metrics['rmse']:>12.4f} {'Model ✓':>10}")
print(f"{'Bias (°)':<15} {model_metrics['bias']:>12.4f} {calc_metrics['bias']:>12.4f} {'Model ✓':>10}")
print(f"\n→ Model reduces MAE by {improvement:.1f}%")
print("="*60)


SUMMARY: ML Model vs Geometric Calculation

Metric                 Model   Calculated     Winner
--------------------------------------------------
R²                    0.9862       0.9504    Model ✓
MAE (°)               4.3790       8.3470    Model ✓
RMSE (°)              5.6636      10.7419    Model ✓
Bias (°)              0.0014      -0.7227    Model ✓

→ Model reduces MAE by 47.5%
