In [28]:
import pandas as pd
import numpy as np
import pybaseball
from datetime import datetime, timedelta
import warnings
from typing import Dict, List, Tuple, Optional
import pickle
import os

warnings.filterwarnings('ignore')

class StatcastProcessor:
    """
    Comprehensive Statcast data processor for pitch-level run value analysis.
    Implements both MVP (PA-level) and Pro (pitch-level) run expectancy methods
    with proper state transitions.
    """
    
    def __init__(self, use_cache: bool = True):
        self.use_cache = use_cache
        self.re24_table = self._initialize_standard_re24_table()  # Standard RE24 table
        self.re288_table = self._initialize_standard_re288_table()  # Standard RE288 table
        self.data = None
        
        # State transition mappings
        self.outcome_transitions = self._initialize_outcome_transitions()
        
    def fetch_statcast_data(self, 
                           start_date: str, 
                           end_date: str, 
                           pitcher_ids: Optional[List[int]] = None,
                           sample_pitchers: bool = True) -> pd.DataFrame:
        """
        Fetch Statcast data for specified date range and pitchers.
        
        Args:
            start_date: 'YYYY-MM-DD' format
            end_date: 'YYYY-MM-DD' format  
            pitcher_ids: List of pitcher IDs, if None will sample some pitchers
            sample_pitchers: Whether to sample a few pitchers for fast iteration
        """
        print(f"Fetching Statcast data from {start_date} to {end_date}...")
        
        if pitcher_ids is None and sample_pitchers:
            # Sample some well-known pitchers for testing (1 RHP, 1 LHP)
            pitcher_ids = [676979, 694973]
            
        if pitcher_ids:
            # Fetch specific pitchers
            all_data = []
            for pitcher_id in pitcher_ids:
                print(f"Fetching data for pitcher {pitcher_id}...")
                try:
                    pitcher_data = pybaseball.statcast_pitcher(start_date, end_date, pitcher_id)
                    if not pitcher_data.empty:
                        all_data.append(pitcher_data)
                except Exception as e:
                    print(f"Error fetching pitcher {pitcher_id}: {e}")
                    continue
            
            if all_data:
                data = pd.concat(all_data, ignore_index=True)
            else:
                print("No data retrieved for specified pitchers")
                return pd.DataFrame()
        else:
            # Fetch league-wide data (use with caution - very large)
            data = pybaseball.statcast(start_date, end_date)
            
        print(f"Retrieved {len(data)} pitches")
        return data
    
    def clean_and_subset_data(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Clean and subset raw Statcast data, keeping only relevant columns.
        """
        print("Cleaning and subsetting data...")
        
        # Define columns to keep - expanded to include state transition fields
        required_cols = [
            'game_date', 'pitcher', 'batter', 'p_throws', 'stand', 'pitch_type',
            'release_speed', 'spin_rate', 'plate_x', 'plate_z', 'balls', 'strikes',
            'outs_when_up', 'on_1b', 'on_2b', 'on_3b', 'inning', 'events', 
            'description', 'home_score', 'away_score', 'at_bat_number', 'pitch_number',
            'game_pk', 'game_year', 'inning_topbot', 'bat_score', 'fld_score',
            'post_away_score', 'post_home_score', 'post_bat_score', 'post_fld_score'
        ]
        
        # Keep only columns that exist in the data
        available_cols = [col for col in required_cols if col in data.columns]
        data = data[available_cols].copy()
        
        # Filter out unwanted pitch types
        exclude_pitches = ['PO', 'IN']  # Pitchouts, Intentional balls
        if 'description' in data.columns:
            data = data[~data['description'].isin(exclude_pitches)].copy()
            
        # Remove obvious tracking errors
        if 'plate_x' in data.columns and 'plate_z' in data.columns:
            data = data[
                (data['plate_x'].between(-3, 3)) & 
                (data['plate_z'].between(-1, 5))
            ].copy()
            
        # Convert game_date to datetime
        if 'game_date' in data.columns:
            data['game_date'] = pd.to_datetime(data['game_date'])
            
        # Fill missing base runners with 0 (no runner)
        base_cols = ['on_1b', 'on_2b', 'on_3b']
        for col in base_cols:
            if col in data.columns:
                data[col] = data[col].fillna(0)
        
        # Fill missing score fields with 0
        score_cols = ['home_score', 'away_score', 'bat_score', 'fld_score',
                     'post_home_score', 'post_away_score', 'post_bat_score', 'post_fld_score']
        for col in score_cols:
            if col in data.columns:
                data[col] = data[col].fillna(0)
        
        print(f"Cleaned data: {len(data)} pitches remaining")
        return data.sort_values(['game_pk', 'at_bat_number', 'pitch_number']).reset_index(drop=True)
    
    def _initialize_standard_re24_table(self) -> Dict[int, float]:
        """
        Initialize the standard RE24 table provided by the user.
        Run environment set at 4.15 runs per game.
        
        States are encoded as: bases * 3 + outs
        where bases = on_1b*4 + on_2b*2 + on_3b*1
        """
        # Standard RE24 matrix from user
        re_matrix = {
            # (runners, outs): expected_runs
            # Empty bases
            (0, 0): 0.461, (0, 1): 0.243, (0, 2): 0.095,
            # Runner on 1st
            (1, 0): 0.831, (1, 1): 0.489, (1, 2): 0.214,
            # Runner on 2nd  
            (2, 0): 1.068, (2, 1): 0.644, (2, 2): 0.305,
            # Runners on 1st and 2nd
            (3, 0): 1.373, (3, 1): 0.908, (3, 2): 0.343,
            # Runner on 3rd
            (4, 0): 1.426, (4, 1): 0.865, (4, 2): 0.413,
            # Runners on 1st and 3rd
            (5, 0): 1.798, (5, 1): 1.140, (5, 2): 0.471,
            # Runners on 2nd and 3rd
            (6, 0): 1.920, (6, 1): 1.352, (6, 2): 0.570,
            # Bases loaded
            (7, 0): 2.282, (7, 1): 1.520, (7, 2): 0.736
        }
        
        # Convert to state IDs (0-23)
        re24_table = {}
        for (runners, outs), re_value in re_matrix.items():
            state_id = runners * 3 + outs
            re24_table[state_id] = re_value
            
        return re24_table
    
    def _initialize_standard_re288_table(self) -> Dict[int, float]:
        """
        Initialize the standard RE288 table provided by the user.
        This table includes all 288 base-out-count combinations.
        
        States are encoded as: (base_state * 3 + outs) * 12 + (balls * 3 + strikes)
        where base_state = on_1b*1 + on_2b*2 + on_3b*4
        """
        # Standard RE288 matrix from user - organized by base-out state and count
        # Format: [base_state][outs][count] where count is ordered as shown in table
        re288_data = {
            # Empty bases (---)
            (0, 0): [0.51, 0.46, 0.41, 0.56, 0.5, 0.44, 0.62, 0.55, 0.47, 0.75, 0.68, 0.57],  # 0 outs
            (0, 1): [0.27, 0.24, 0.2, 0.3, 0.27, 0.22, 0.34, 0.3, 0.25, 0.42, 0.37, 0.32],   # 1 out
            (0, 2): [0.1, 0.09, 0.06, 0.12, 0.1, 0.07, 0.14, 0.12, 0.09, 0.17, 0.15, 0.13],  # 2 outs
            
            # Runner on 1st (1--)
            (1, 0): [0.89, 0.82, 0.75, 0.98, 0.87, 0.81, 1.09, 0.97, 0.87, 1.25, 1.15, 1.07],
            (1, 1): [0.54, 0.49, 0.42, 0.57, 0.52, 0.44, 0.62, 0.58, 0.5, 0.77, 0.66, 0.62],
            (1, 2): [0.22, 0.18, 0.13, 0.26, 0.22, 0.17, 0.29, 0.24, 0.19, 0.36, 0.32, 0.26],
            
            # Runner on 2nd (-2-)
            (2, 0): [1.14, 1.05, 0.95, 1.19, 1.08, 0.96, 1.27, 1.21, 1.08, 1.41, 1.34, 1.17],
            (2, 1): [0.69, 0.65, 0.59, 0.72, 0.66, 0.59, 0.78, 0.71, 0.68, 0.92, 0.87, 0.71],
            (2, 2): [0.33, 0.29, 0.2, 0.35, 0.3, 0.22, 0.38, 0.31, 0.25, 0.51, 0.38, 0.3],
            
            # Runners on 1st and 2nd (12-)
            (3, 0): [1.47, 1.4, 1.31, 1.53, 1.45, 1.33, 1.67, 1.5, 1.46, 2.0, 1.77, 1.62],
            (3, 1): [0.95, 0.88, 0.79, 1.04, 0.96, 0.81, 1.06, 1.01, 0.87, 1.16, 1.19, 1.04],
            (3, 2): [0.44, 0.38, 0.28, 0.46, 0.41, 0.33, 0.53, 0.46, 0.39, 0.67, 0.58, 0.5],
            
            # Runner on 3rd (--3)
            (4, 0): [1.4, 1.28, 1.13, 1.53, 1.49, 1.47, 1.48, 1.46, 1.34, 1.45, 1.71, 1.49],
            (4, 1): [0.89, 0.91, 0.79, 1.04, 0.98, 0.82, 1.17, 1.03, 0.87, 1.36, 1.19, 0.92],
            (4, 2): [0.37, 0.31, 0.24, 0.4, 0.34, 0.25, 0.44, 0.41, 0.3, 0.46, 0.41, 0.39],
            
            # Runners on 1st and 3rd (1-3)
            (5, 0): [1.76, 1.71, 1.6, 1.78, 1.66, 1.63, 1.94, 1.81, 1.64, 2.09, 2.1, 1.85],
            (5, 1): [1.2, 1.11, 1.01, 1.25, 1.15, 1.12, 1.26, 1.22, 1.16, 1.24, 1.36, 1.25],
            (5, 2): [0.5, 0.43, 0.35, 0.55, 0.47, 0.37, 0.57, 0.54, 0.4, 0.65, 0.66, 0.46],
            
            # Runners on 2nd and 3rd (-23)
            (6, 0): [1.93, 1.89, 1.62, 2.03, 1.88, 1.73, 2.3, 2.02, 1.78, 2.14, 2.06, 1.65],
            (6, 1): [1.37, 1.26, 1.06, 1.42, 1.31, 1.16, 1.45, 1.4, 1.26, 1.53, 1.41, 1.34],
            (6, 2): [0.58, 0.5, 0.34, 0.59, 0.56, 0.41, 0.61, 0.65, 0.42, 0.85, 0.68, 0.56],
            
            # Bases loaded (123)
            (7, 0): [2.28, 2.24, 2.28, 2.34, 2.26, 2.21, 2.35, 2.26, 2.15, 2.81, 2.55, 2.49],
            (7, 1): [1.51, 1.36, 1.18, 1.6, 1.49, 1.3, 1.79, 1.59, 1.5, 2.15, 1.88, 1.53],
            (7, 2): [0.74, 0.59, 0.41, 0.88, 0.68, 0.45, 1.11, 0.86, 0.65, 1.35, 1.09, 1.04]
        }
        
        # Convert to state IDs (0-287)
        re288_table = {}
        count_combinations = [
            (0,0), (0,1), (0,2),  # 0-0, 0-1, 0-2
            (1,0), (1,1), (1,2),  # 1-0, 1-1, 1-2
            (2,0), (2,1), (2,2),  # 2-0, 2-1, 2-2
            (3,0), (3,1), (3,2)   # 3-0, 3-1, 3-2
        ]
        
        for base_state in range(8):  # 0-7 for all base combinations
            for outs in range(3):   # 0-2 outs
                if (base_state, outs) in re288_data:
                    re_values = re288_data[(base_state, outs)]
                    for count_idx, (balls, strikes) in enumerate(count_combinations):
                        # Calculate state ID: (base_state * 3 + outs) * 12 + (balls * 3 + strikes)
                        base_out_state = base_state * 3 + outs
                        count_state = balls * 3 + strikes
                        state_id = base_out_state * 12 + count_state
                        
                        if count_idx < len(re_values):
                            re288_table[state_id] = re_values[count_idx]
                        else:
                            # Fallback for missing values
                            re288_table[state_id] = 0.5
                else:
                    # Fallback for missing base-out combinations
                    for count_idx in range(12):
                        base_out_state = base_state * 3 + outs
                        state_id = base_out_state * 12 + count_idx
                        re288_table[state_id] = 0.5
        
        return re288_table
    
    def _initialize_outcome_transitions(self) -> Dict[str, Dict]:
        """
        Initialize outcome-based state transitions.
        Maps play outcomes to base advancement and run scoring.
        """
        return {
            # Hits
            'single': {'base_advance': [1, 1, 1], 'runs_scored': [0, 0, 1]},  # 1B advances 1, 2B advances 1, 3B scores
            'double': {'base_advance': [2, 2, 2], 'runs_scored': [0, 1, 1]},  # 1B advances 2, 2B scores, 3B scores
            'triple': {'base_advance': [3, 3, 3], 'runs_scored': [1, 1, 1]},  # All score
            'home_run': {'base_advance': [4, 4, 4], 'runs_scored': [1, 1, 1]},  # All score + batter
            
            # Outs
            'strikeout': {'base_advance': [0, 0, 0], 'runs_scored': [0, 0, 0]},
            'field_out': {'base_advance': [0, 0, 0], 'runs_scored': [0, 0, 0]},
            'pop_out': {'base_advance': [0, 0, 0], 'runs_scored': [0, 0, 0]},
            'fly_out': {'base_advance': [0, 0, 0], 'runs_scored': [0, 0, 0]},
            'line_out': {'base_advance': [0, 0, 0], 'runs_scored': [0, 0, 0]},
            'ground_out': {'base_advance': [0, 0, 0], 'runs_scored': [0, 0, 0]},
            'force_out': {'base_advance': [0, 0, 0], 'runs_scored': [0, 0, 0]},
            
            # Walks and HBP
            'walk': {'base_advance': [1, 0, 0], 'runs_scored': [0, 0, 0]},  # Force advance only
            'hit_by_pitch': {'base_advance': [1, 0, 0], 'runs_scored': [0, 0, 0]},
            
            # Sacrifice plays
            'sac_fly': {'base_advance': [0, 1, 1], 'runs_scored': [0, 0, 1]},  # 3B scores
            'sac_bunt': {'base_advance': [1, 1, 0], 'runs_scored': [0, 0, 0]},  # Runners advance
            
            # Errors and other plays
            'field_error': {'base_advance': [1, 1, 1], 'runs_scored': [0, 0, 1]},  # Like single
            'catcher_interf': {'base_advance': [1, 0, 0], 'runs_scored': [0, 0, 0]},  # Like walk
            
            # Double plays (simplified)
            'double_play': {'base_advance': [0, 0, 0], 'runs_scored': [0, 0, 0]},
            'grounded_into_double_play': {'base_advance': [0, 0, 0], 'runs_scored': [0, 0, 0]},
        }
    
    def _get_runners_state(self, row) -> int:
        """Convert base runners to runners state (0-7)"""
        return int(bool(row.get('on_1b', 0))) * 1 + \
               int(bool(row.get('on_2b', 0))) * 2 + \
               int(bool(row.get('on_3b', 0))) * 4

    def build_base_out_state(self, row) -> int:
        """Convert base-out situation to state ID (0-23 for RE24)"""
        runners = self._get_runners_state(row)
        outs = int(row.get('outs_when_up', 0))
        return runners * 3 + outs
    
    def build_base_out_count_state(self, row) -> int:
        """Convert base-out-count situation to state ID (0-287 for RE288)"""
        base_out_state = self.build_base_out_state(row)
        balls = int(row.get('balls', 0))
        strikes = int(row.get('strikes', 0))
        count_state = balls * 3 + strikes  # 0-11 for possible counts
        return base_out_state * 12 + count_state
    
    def _calculate_post_pa_state(self, pa_data: pd.DataFrame) -> Tuple[int, int]:
        """
        Calculate the post-PA base-out state and runs scored.
        
        Args:
            pa_data: DataFrame containing all pitches in the plate appearance
            
        Returns:
            Tuple of (post_pa_base_out_state, runs_scored_during_pa)
        """
        if pa_data.empty:
            return 0, 0
            
        last_pitch = pa_data.iloc[-1]
        event = last_pitch.get('events')
        
        if pd.isna(event) or event == '':
            # PA didn't end, return original state
            return self.build_base_out_state(pa_data.iloc[0]), 0
        
        # Get starting state
        first_pitch = pa_data.iloc[0]
        start_runners = [
            bool(first_pitch.get('on_1b', 0)),
            bool(first_pitch.get('on_2b', 0)), 
            bool(first_pitch.get('on_3b', 0))
        ]
        start_outs = int(first_pitch.get('outs_when_up', 0))
        
        # Try to use actual post-PA data from Statcast if available
        # Look for the next PA to see the resulting state
        next_pa_data = None
        current_game = first_pitch['game_pk']
        current_at_bat = first_pitch['at_bat_number']
        
        # For now, implement rule-based transitions
        # Initialize post-PA state
        post_runners = [False, False, False]  # 1B, 2B, 3B
        runs_scored = 0
        post_outs = start_outs
        
        # Apply outcome transition
        event_lower = event.lower() if isinstance(event, str) else ''
        
        # Handle hits first
        if event_lower == 'single':
            # Runners advance 1 base, 3B scores
            if start_runners[2]:  # Runner on 3B scores
                runs_scored += 1
            if start_runners[1]:  # Runner on 2B goes to 3B
                post_runners[2] = True
            if start_runners[0]:  # Runner on 1B goes to 2B
                post_runners[1] = True
            post_runners[0] = True  # Batter to 1B
            
        elif event_lower == 'double':
            # Runners advance 2+ bases, 2B and 3B score
            if start_runners[2]:  # Runner on 3B scores
                runs_scored += 1
            if start_runners[1]:  # Runner on 2B scores
                runs_scored += 1
            if start_runners[0]:  # Runner on 1B goes to 3B
                post_runners[2] = True
            post_runners[1] = True  # Batter to 2B
            
        elif event_lower == 'triple':
            # All runners score
            runs_scored += sum(start_runners)
            post_runners[2] = True  # Batter to 3B
            
        elif event_lower == 'home_run':
            # Everyone scores including batter
            runs_scored = sum(start_runners) + 1
            # No runners left on base
            
        elif event_lower in ['walk', 'hit_by_pitch']:
            # Force advancement only
            if start_runners[0]:  # Force advance from 1B
                if start_runners[1]:  # Force advance from 2B
                    if start_runners[2]:  # Force score from 3B
                        runs_scored += 1
                    else:
                        post_runners[2] = True
                else:
                    post_runners[1] = True
            # Copy non-forced runners
            if not start_runners[0]:
                if start_runners[1]:
                    post_runners[1] = True
                if start_runners[2]:
                    post_runners[2] = True
            post_runners[0] = True  # Batter to 1B
            
        elif event_lower == 'sac_fly':
            post_outs += 1
            # Usually scores runner from 3B
            if start_runners[2]:
                runs_scored += 1
            # Other runners stay (simplified)
            if start_runners[0]:
                post_runners[0] = True
            if start_runners[1]:
                post_runners[1] = True
                
        elif event_lower == 'sac_bunt':
            post_outs += 1
            # Advance runners (simplified - usually 1B to 2B)
            if start_runners[0]:
                post_runners[1] = True
            if start_runners[1]:
                post_runners[2] = True
            if start_runners[2]:
                runs_scored += 1
                
        elif event_lower == 'field_error':
            # Treat like a single
            if start_runners[2]:  
                runs_scored += 1
            if start_runners[1]:  
                post_runners[2] = True
            if start_runners[0]:  
                post_runners[1] = True
            post_runners[0] = True  
            
        elif event_lower in ['double_play', 'grounded_into_double_play']:
            post_outs += 2
            # Usually eliminates batter and lead runner
            # Other runners may advance (simplified)
            if start_runners[2]:
                post_runners[2] = True
            if start_runners[1] and not start_runners[0]:
                post_runners[1] = True
                
        else:
            # All other outs (strikeout, field_out, fly_out, etc.)
            post_outs += 1
            # Runners stay put (no advancement on regular outs)
            post_runners = start_runners.copy()
        
        # Handle 3rd out - inning ends
        if post_outs >= 3:
            # Inning over - no runners, but runs that scored during the play count
            post_runners = [False, False, False]
            post_outs = 0  # Next inning starts with 0 outs
            # Note: The RE24 table has states 21-23 for 3 outs, with RE=0
            # We'll use state 21 (empty bases, 3 outs) to represent inning end
            post_state = 21  # Empty bases, 3 outs (RE = 0.095)
        else:
            # Convert to state ID
            runners_state = (int(post_runners[0]) * 1 + 
                            int(post_runners[1]) * 2 + 
                            int(post_runners[2]) * 4)
            post_state = runners_state * 3 + post_outs
        
        # Try to get actual runs scored from score differential if available
        if ('post_bat_score' in last_pitch.index and 'bat_score' in last_pitch.index and 
            not pd.isna(last_pitch['post_bat_score']) and not pd.isna(last_pitch['bat_score'])):
            actual_runs = int(last_pitch['post_bat_score']) - int(last_pitch['bat_score'])
            if actual_runs >= 0:  # Use actual if it makes sense
                runs_scored = actual_runs
        
        return post_state, runs_scored
    
    def _calculate_post_pitch_state(self, current_row: pd.Series, next_row: pd.Series = None) -> Tuple[int, int]:
        """
        Calculate the post-pitch base-out-count state and runs scored.
        
        Args:
            current_row: Current pitch data
            next_row: Next pitch data (if available)
            
        Returns:
            Tuple of (post_pitch_base_out_count_state, runs_on_pitch)
        """
        description = current_row.get('description', '')
        
        # Start with current state
        balls = int(current_row.get('balls', 0))
        strikes = int(current_row.get('strikes', 0))
        outs = int(current_row.get('outs_when_up', 0))
        runners_state = self._get_runners_state(current_row)
        runs_on_pitch = 0
        
        # Apply pitch outcome
        if description in ['ball', 'blocked_ball', 'pitchout']:
            balls = min(balls + 1, 3)
            if balls == 4:  # Walk
                balls, strikes = 0, 0
                # Handle walk logic (simplified - would need force advancement logic)
                if runners_state & 1:  # Runner on 1B, force advancement
                    runners_state = min(runners_state + 1, 7)
                runners_state |= 1  # Batter to 1B
                
        elif description in ['called_strike', 'swinging_strike', 'foul_tip']:
            strikes = min(strikes + 1, 2)
            if strikes == 3:  # Strikeout
                outs += 1
                balls, strikes = 0, 0
                
        elif description in ['foul']:
            if strikes < 2:
                strikes += 1
                
        elif description in ['hit_into_play']:
            # Would need event outcome to determine exact result
            # For now, assume it results in an out (most common)
            outs += 1
            balls, strikes = 0, 0
            
        elif description == 'hit_by_pitch':
            balls, strikes = 0, 0
            # Similar to walk
            if runners_state & 1:
                runners_state = min(runners_state + 1, 7)
            runners_state |= 1
        
        # If we have next row data, use it for more accurate state
        if next_row is not None and not pd.isna(next_row.get('balls')):
            balls = int(next_row.get('balls', 0))
            strikes = int(next_row.get('strikes', 0))
            outs = int(next_row.get('outs_when_up', 0))
            runners_state = self._get_runners_state(next_row)
            
            # Check for runs scored between pitches
            if 'bat_score' in current_row.index and 'bat_score' in next_row.index:
                if not pd.isna(current_row['bat_score']) and not pd.isna(next_row['bat_score']):
                    runs_on_pitch = max(0, int(next_row['bat_score']) - int(current_row['bat_score']))
        
        # Ensure values are within bounds
        outs = min(outs, 3)
        
        # Convert to RE288 state ID
        base_out_state = runners_state * 3 + outs
        count_state = balls * 3 + strikes
        post_state = base_out_state * 12 + count_state
        
        return post_state, runs_on_pitch
    
    def calculate_re24_table(self, data: pd.DataFrame) -> Dict[int, float]:
        """
        DEPRECATED: Use standard RE24 table instead.
        This method is kept for backward compatibility but is no longer used.
        The standard RE24 table is initialized in __init__.
        """
        print("Using standard RE24 table instead of calculating from data...")
        return self.re24_table
    
    def calculate_re288_table(self, data: pd.DataFrame) -> Dict[int, float]:
        """
        DEPRECATED: Use standard RE288 table instead.
        This method is kept for backward compatibility but is no longer used.
        The standard RE288 table is initialized in __init__.
        """
        print("Using standard RE288 table instead of calculating from data...")
        return self.re288_table
    
    def calculate_mvp_run_values(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Calculate MVP (PA-terminal) run values using standard RE24 table with proper state transitions.
        Formula: RE24 = (End_RE + Runs_Scored) - Start_RE
        PITCHER PERSPECTIVE: Inverted so positive values are good for the pitcher.
        """
        print("Calculating MVP run values with state transitions (pitcher perspective)...")
            
        data_copy = data.copy()
        data_copy['run_value'] = 0.0
        data_copy['before_pa_re'] = np.nan
        data_copy['after_pa_re'] = np.nan
        data_copy['runs_during_pa'] = 0
        
        # Group by game and process each plate appearance
        processed_count = 0
        total_pas = 0
        
        for game_id in data_copy['game_pk'].unique():
            game_data = data_copy[data_copy['game_pk'] == game_id].copy()
            
            for at_bat in game_data['at_bat_number'].unique():
                pa_data = game_data[game_data['at_bat_number'] == at_bat].copy()
                if pa_data.empty:
                    continue
                    
                total_pas += 1
                pa_indices = pa_data.index
                first_pitch_idx = pa_indices[0]
                last_pitch_idx = pa_indices[-1]
                
                # Get before-PA state from first pitch
                first_pitch = pa_data.iloc[0]
                before_state = self.build_base_out_state(first_pitch)
                before_re = self.re24_table.get(before_state, 0.5)
                
                # Calculate after-PA state and runs scored using state transitions
                after_state, runs_during_pa = self._calculate_post_pa_state(pa_data)
                after_re = self.re24_table.get(after_state, 0.5)
                
                # Special handling for inning-ending scenarios
                after_outs = after_state % 3
                if after_outs == 0 and after_state >= 21:  # 3+ outs reached, inning ends
                    after_re = 0.0  # No more runs expected this inning
                elif after_state == 21:  # Specifically the "3 outs, empty bases" state
                    after_re = 0.0  # Inning over
                
                # Calculate PA run value using RE24 formula
                batter_run_value = (after_re + runs_during_pa) - before_re
                
                # INVERT for pitcher perspective (positive = good for pitcher)
                pa_run_value = -batter_run_value
                
                # Assign values to all pitches in PA, but run_value only to terminal pitch
                data_copy.loc[pa_indices, 'before_pa_re'] = before_re
                data_copy.loc[pa_indices, 'after_pa_re'] = after_re
                data_copy.loc[pa_indices, 'runs_during_pa'] = runs_during_pa
                data_copy.loc[last_pitch_idx, 'run_value'] = pa_run_value
                
                processed_count += 1
                
                # Debug: Print some examples
                if processed_count <= 10:
                    event = pa_data.iloc[-1].get('events', 'Unknown')
                    print(f"PA {processed_count}: {event}, Before: {before_re:.3f}, After: {after_re:.3f}, Runs: {runs_during_pa}, Pitcher RV: {pa_run_value:.3f}")
                
        print(f"Processed {processed_count} plate appearances out of {total_pas} total")
        return data_copy
    
    def calculate_pro_run_values(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Calculate Pro (pitch-level) run values using standard RE288 table with proper state transitions.
        Formula: Pitch Run Value = (PostPitch_RE + Runs_On_Pitch) - PrePitch_RE
        PITCHER PERSPECTIVE: Inverted so positive values are good for the pitcher.
        """
        print("Calculating Pro run values with state transitions (pitcher perspective)...")
            
        data_copy = data.copy()
        data_copy['run_value'] = 0.0
        data_copy['pre_pitch_re'] = np.nan
        data_copy['post_pitch_re'] = np.nan
        data_copy['runs_on_pitch'] = 0
        
        # Sort to ensure proper sequencing
        data_copy = data_copy.sort_values(['game_pk', 'at_bat_number', 'pitch_number']).reset_index(drop=True)
        
        processed_count = 0
        
        # Process each pitch
        for idx in data_copy.index:
            current_row = data_copy.loc[idx]
            
            # Get next row for state transition (if available and same PA)
            next_row = None
            if idx + 1 < len(data_copy):
                next_candidate = data_copy.loc[idx + 1]
                if (current_row['game_pk'] == next_candidate['game_pk'] and 
                    current_row['at_bat_number'] == next_candidate['at_bat_number']):
                    next_row = next_candidate
            
            # Pre-pitch state
            pre_state = self.build_base_out_count_state(current_row)
            pre_re = self.re288_table.get(pre_state, 0.5)
            
            # Post-pitch state with proper transitions
            post_state, runs_on_pitch = self._calculate_post_pitch_state(current_row, next_row)
            post_re = self.re288_table.get(post_state, 0.5)
            
            # Handle inning ending scenarios
            post_outs = (post_state // 12) % 3
            if post_outs >= 3:
                post_re = 0.0  # Inning over
            
            # Calculate pitch run value using RE288 formula
            batter_run_value = (post_re + runs_on_pitch) - pre_re
            
            # INVERT for pitcher perspective (positive = good for pitcher)
            pitch_run_value = -batter_run_value
            
            data_copy.loc[idx, 'run_value'] = pitch_run_value
            data_copy.loc[idx, 'pre_pitch_re'] = pre_re
            data_copy.loc[idx, 'post_pitch_re'] = post_re
            data_copy.loc[idx, 'runs_on_pitch'] = runs_on_pitch
            
            processed_count += 1
        
        print(f"Processed {processed_count} pitches with run values")
        return data_copy
    
    def engineer_features(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Engineer features for modeling.
        """
        print("Engineering features...")
        
        data_copy = data.copy()
        
        # Count feature
        data_copy['count'] = data_copy['balls'].astype(str) + '-' + data_copy['strikes'].astype(str)
        
        # Location bins
        if 'plate_x' in data_copy.columns and 'plate_z' in data_copy.columns:
            # Use quantile-based binning
            data_copy['plate_x_bin'] = pd.qcut(data_copy['plate_x'].dropna(), 
                                              q=7, labels=False, duplicates='drop')
            data_copy['plate_z_bin'] = pd.qcut(data_copy['plate_z'].dropna(), 
                                              q=5, labels=False, duplicates='drop')
            data_copy['loc_bin'] = (data_copy['plate_x_bin'].astype(str) + '|' + 
                                   data_copy['plate_z_bin'].astype(str))
        
        # Previous pitch type (lag-1 within pitcher & PA)
        data_copy = data_copy.sort_values(['game_pk', 'pitcher', 'at_bat_number', 'pitch_number'])
        data_copy['prev_pitch_type'] = data_copy.groupby(['game_pk', 'pitcher', 'at_bat_number'])['pitch_type'].shift(1)
        
        # Standardized velocity
        if 'release_speed' in data_copy.columns:
            data_copy['release_speed_z'] = (data_copy['release_speed'] - data_copy['release_speed'].mean()) / data_copy['release_speed'].std()
        
        # Game context features
        if 'home_score' in data_copy.columns and 'away_score' in data_copy.columns:
            data_copy['score_diff'] = data_copy['home_score'] - data_copy['away_score']
        
        # Base-out state ID and descriptive features
        data_copy['base_out_state'] = data_copy.apply(self.build_base_out_state, axis=1)
        data_copy['base_out_count_state'] = data_copy.apply(self.build_base_out_count_state, axis=1)
        
        # Leverage index (simplified version based on base-out state)
        data_copy['leverage_index'] = data_copy['base_out_state'].map(self.re24_table) / 0.5  # Normalized by average
        
        # Runner presence indicators
        data_copy['runners_on'] = (data_copy[['on_1b', 'on_2b', 'on_3b']].fillna(0) > 0).sum(axis=1)
        data_copy['scoring_position'] = ((data_copy['on_2b'].fillna(0) > 0) | (data_copy['on_3b'].fillna(0) > 0)).astype(int)
        
        print(f"Feature engineering complete. Shape: {data_copy.shape}")
        return data_copy
    
    def create_train_test_split(self, data: pd.DataFrame, 
                               split_date: str = None,
                               train_pct: float = 0.7,
                               method: str = 'temporal') -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Create train/test split with flexible methods.
        
        Args:
            data: DataFrame with game_date column
            split_date: Specific date to split on ('YYYY-MM-DD'). If None, will calculate based on train_pct
            train_pct: Percentage of data for training (0.0-1.0)
            method: 'temporal' (chronological), 'random' (random split), or 'stratified' (by pitcher)
        """
        print(f"Creating {method} train/test split...")
        
        if 'game_date' not in data.columns:
            raise ValueError("Data must have 'game_date' column for splitting")
        
        # Ensure game_date is datetime
        data = data.copy()
        data['game_date'] = pd.to_datetime(data['game_date'])
        
        if method == 'temporal':
            if split_date is None:
                # Calculate split date based on train_pct
                data_sorted = data.sort_values('game_date')
                split_idx = int(len(data_sorted) * train_pct)
                split_date = data_sorted.iloc[split_idx]['game_date']
                print(f"Auto-calculated split date: {split_date.strftime('%Y-%m-%d')}")
            else:
                split_date = pd.to_datetime(split_date)
                print(f"Using provided split date: {split_date.strftime('%Y-%m-%d')}")
            
            train_data = data[data['game_date'] < split_date].copy()
            test_data = data[data['game_date'] >= split_date].copy()
            
        elif method == 'random':
            from sklearn.model_selection import train_test_split
            train_indices, test_indices = train_test_split(
                data.index, 
                train_size=train_pct, 
                random_state=42,
                shuffle=True
            )
            train_data = data.loc[train_indices].copy()
            test_data = data.loc[test_indices].copy()
            
        elif method == 'stratified':
            # Stratify by pitcher to ensure each pitcher appears in both sets
            from sklearn.model_selection import GroupShuffleSplit
            
            if 'pitcher' not in data.columns:
                print("Warning: No 'pitcher' column found, falling back to temporal split")
                return self.create_train_test_split(data, split_date, train_pct, 'temporal')
                
            splitter = GroupShuffleSplit(n_splits=1, train_size=train_pct, random_state=42)
            train_indices, test_indices = next(splitter.split(data, groups=data['pitcher']))
            train_data = data.iloc[train_indices].copy()
            test_data = data.iloc[test_indices].copy()
            
        else:
            raise ValueError("Method must be 'temporal', 'random', or 'stratified'")
        
        # Report split statistics
        print(f"Train set: {len(train_data)} pitches ({len(train_data)/len(data)*100:.1f}%)")
        print(f"Test set: {len(test_data)} pitches ({len(test_data)/len(data)*100:.1f}%)")
        
        if len(train_data) > 0 and len(test_data) > 0:
            train_date_range = f"{train_data['game_date'].min().strftime('%Y-%m-%d')} to {train_data['game_date'].max().strftime('%Y-%m-%d')}"
            test_date_range = f"{test_data['game_date'].min().strftime('%Y-%m-%d')} to {test_data['game_date'].max().strftime('%Y-%m-%d')}"
            print(f"Train date range: {train_date_range}")
            print(f"Test date range: {test_date_range}")
            
            if 'pitcher' in data.columns:
                print(f"Train pitchers: {train_data['pitcher'].nunique()}")
                print(f"Test pitchers: {test_data['pitcher'].nunique()}")
        else:
            print("Warning: One of the splits is empty!")
        
        return train_data, test_data
    
    def create_season_split_2025(self, data: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Create a sensible train/test split for 2025 season data.
        For the range 2025-04-01 to 2025-08-15, splits at the middle (2025-06-08).
        This gives approximately 50/50 split by date.
        """
        print("Creating 2025 season-specific train/test split...")
        
        # Calculate the middle date of the 2025-04-01 to 2025-08-15 range
        # Total days = 136 days, middle = ~68 days from start = June 8th
        split_date = '2025-06-08'
        
        return self.create_train_test_split(
            data, 
            split_date=split_date, 
            method='temporal'
        )
    
    def final_cleaning(self, data: pd.DataFrame, target_col: str = 'run_value') -> pd.DataFrame:
        """
        Final data cleaning and preparation.
        """
        print("Performing final cleaning...")
        
        data_copy = data.copy()
        
        # Drop rows with missing critical fields
        critical_fields = ['pitch_type', 'balls', 'strikes']
        if 'plate_x' in data_copy.columns:
            critical_fields.extend(['plate_x', 'plate_z'])
            
        initial_count = len(data_copy)
        data_copy = data_copy.dropna(subset=critical_fields)
        print(f"Dropped {initial_count - len(data_copy)} rows with missing critical fields")
        
        # Cap extreme target values
        if target_col in data_copy.columns:
            data_copy[target_col] = data_copy[target_col].clip(-3, 3)
            
        print(f"Final dataset: {len(data_copy)} pitches")
        return data_copy
    
    def validate_state_transitions(self, data: pd.DataFrame, sample_size: int = 100) -> pd.DataFrame:
        """
        Validate state transitions by sampling and showing before/after states.
        """
        print("Validating state transitions...")
        
        # Sample some completed plate appearances for validation
        completed_pas = data[data['run_value'] != 0].copy()
        if len(completed_pas) == 0:
            print("No completed plate appearances found for validation")
            return pd.DataFrame()
            
        sample_pas = completed_pas.sample(min(sample_size, len(completed_pas)))
        
        validation_results = []
        for idx, row in sample_pas.iterrows():
            validation_results.append({
                'game_pk': row['game_pk'],
                'at_bat_number': row['at_bat_number'],
                'pitch_number': row['pitch_number'],
                'event': row.get('events', 'Unknown'),
                'before_pa_re': row['before_pa_re'],
                'after_pa_re': row['after_pa_re'],
                'runs_during_pa': row['runs_during_pa'],
                'run_value': row['run_value'],
                'calculated_rv': -(row['after_pa_re'] + row['runs_during_pa'] - row['before_pa_re'])  # Inverted for pitcher perspective
            })
        
        validation_df = pd.DataFrame(validation_results)
        print(f"Validation sample: {len(validation_df)} plate appearances")
        print(f"Average run value: {validation_df['run_value'].mean():.4f}")
        print(f"Run value std: {validation_df['run_value'].std():.4f}")
        
        return validation_df
    
    def run_full_pipeline(self, 
                         start_date: str,
                         end_date: str,
                         pitcher_ids: Optional[List[int]] = None,
                         method: str = 'mvp',  # 'mvp' or 'pro'
                         save_path: str = 'statcast_processed.parquet',
                         validate: bool = True) -> pd.DataFrame:
        """
        Run the full pipeline from raw data to analysis-ready dataset.
        """
        print(f"Starting full pipeline with {method} method and state transitions...")
        
        # Step 1: Fetch data
        raw_data = self.fetch_statcast_data(start_date, end_date, pitcher_ids)
        if raw_data.empty:
            print("No data fetched. Exiting.")
            return pd.DataFrame()
        
        # Step 2: Clean and subset
        clean_data = self.clean_and_subset_data(raw_data)
        
        # Step 3: Calculate run values with state transitions
        if method.lower() == 'mvp':
            data_with_rv = self.calculate_mvp_run_values(clean_data)
        else:  # pro
            data_with_rv = self.calculate_pro_run_values(clean_data)
        
        # Step 4: Engineer features
        featured_data = self.engineer_features(data_with_rv)
        
        # Step 5: Final cleaning
        final_data = self.final_cleaning(featured_data)
        
        # Step 6: Validation (optional)
        if validate and method.lower() == 'mvp':
            validation_results = self.validate_state_transitions(final_data)
            if not validation_results.empty:
                print("\nValidation sample:")
                print(validation_results[['event', 'before_pa_re', 'after_pa_re', 'runs_during_pa', 'run_value']].head(10))
        
        # Step 7: Save
        if save_path:
            if save_path.endswith('.parquet'):
                final_data.to_parquet(save_path, index=False)
            else:
                final_data.to_csv(save_path, index=False)
            print(f"Saved processed data to {save_path}")
        
        self.data = final_data
        print("Pipeline complete with state transitions!")
        return final_data

# Example usage and testing
if __name__ == "__main__":
    # Initialize processor
    processor = StatcastProcessor()
    
    # Run pipeline for 2025 season data
    try:
        # Full 2025 season range: April 1 to August 15
        processed_data = processor.run_full_pipeline(
            start_date='2025-04-01',
            end_date='2025-08-15',
            pitcher_ids=[676979, 694973],  # Sample pitchers for testing
            method='mvp',
            save_path='statcast_2025_season_data.parquet',
            validate=True
        )
        
        print("\nSample of processed data:")
        print(processed_data.head())
        print(f"\nFeatures available: {list(processed_data.columns)}")
        print(f"\nRun value stats:")
        if 'run_value' in processed_data.columns:
            print(processed_data['run_value'].describe())
            
            # Show some examples of non-zero run values
            non_zero_rv = processed_data[processed_data['run_value'] != 0]
            if len(non_zero_rv) > 0:
                print(f"\nNon-zero run values: {len(non_zero_rv)} out of {len(processed_data)}")
                print("Sample outcomes with run values:")
                print(non_zero_rv[['events', 'run_value', 'before_pa_re', 'after_pa_re', 'runs_during_pa']].head(10))
            
        # Create train/test split using the updated 2025-specific logic
        train_data, test_data = processor.create_season_split_2025(processed_data)
        
        print(f"\nTrain/test split complete:")
        print(f"Training set: {len(train_data)} pitches")
        print(f"Test set: {len(test_data)} pitches")
        
        # Show the actual date ranges
        if len(train_data) > 0:
            print(f"Training data date range: {train_data['game_date'].min()} to {train_data['game_date'].max()}")
        if len(test_data) > 0:
            print(f"Test data date range: {test_data['game_date'].min()} to {test_data['game_date'].max()}")
        
        # Test Pro method on the full date range as well
        print("\nTesting Pro (pitch-level) method on full date range:")
        pro_data = processor.run_full_pipeline(
            start_date='2025-04-01',
            end_date='2025-08-15',  # Full range for Pro method
            pitcher_ids=[676979, 694973],  # Same pitchers as MVP
            method='pro',
            save_path='statcast_2025_pro_full.parquet',
            validate=False
        )
        
        if 'run_value' in pro_data.columns:
            print("Pro method run value stats (full range):")
            print(pro_data['run_value'].describe())
            
            # Show non-zero run values for Pro method
            non_zero_rv_pro = pro_data[pro_data['run_value'] != 0]
            print(f"\nNon-zero run values in Pro method: {len(non_zero_rv_pro)} out of {len(pro_data)}")
            
        # Create train/test split for Pro method data too
        if len(pro_data) > 0:
            train_pro, test_pro = processor.create_season_split_2025(pro_data)
            print(f"\nPro method train/test split:")
            print(f"Training set: {len(train_pro)} pitches")
            print(f"Test set: {len(test_pro)} pitches")
        
    except Exception as e:
        print(f"Error running pipeline: {e}")
        import traceback
        traceback.print_exc()
        print("Make sure you have pybaseball installed: pip install pybaseball")

Starting full pipeline with mvp method and state transitions...
Fetching Statcast data from 2025-04-01 to 2025-08-15...
Fetching data for pitcher 676979...
Gathering Player Data
Fetching data for pitcher 694973...
Gathering Player Data
Retrieved 4565 pitches
Cleaning and subsetting data...
Cleaned data: 4557 pitches remaining
Calculating MVP run values with state transitions (pitcher perspective)...
PA 1: home_run, Before: 0.461, After: 0.461, Runs: 1, Pitcher RV: -1.000
PA 2: field_out, Before: 0.461, After: 0.243, Runs: 0, Pitcher RV: 0.218
PA 3: field_out, Before: 0.243, After: 0.095, Runs: 0, Pitcher RV: 0.148
PA 4: single, Before: 0.095, After: 0.214, Runs: 0, Pitcher RV: -0.119
PA 5: strikeout, Before: 0.214, After: 0.000, Runs: 0, Pitcher RV: 0.214
PA 6: field_out, Before: 0.461, After: 0.243, Runs: 0, Pitcher RV: 0.218
PA 7: field_out, Before: 0.243, After: 0.095, Runs: 0, Pitcher RV: 0.148
PA 8: strikeout, Before: 0.095, After: 0.000, Runs: 0, Pitcher RV: 0.095
PA 9: field_out

In [29]:
df = pd.read_parquet(r"C:\Users\dmari\Downloads\bbayes\notebooks\statcast_2025_pro_full.parquet")
pd.set_option('display.max_columns', None)
df

Unnamed: 0,game_date,pitcher,batter,p_throws,stand,pitch_type,release_speed,plate_x,plate_z,balls,strikes,outs_when_up,on_1b,on_2b,on_3b,inning,events,description,home_score,away_score,at_bat_number,pitch_number,game_pk,game_year,inning_topbot,bat_score,fld_score,post_away_score,post_home_score,post_bat_score,post_fld_score,run_value,pre_pitch_re,post_pitch_re,runs_on_pitch,count,plate_x_bin,plate_z_bin,loc_bin,prev_pitch_type,release_speed_z,score_diff,base_out_state,base_out_count_state,leverage_index,runners_on,scoring_position
0,2025-08-12,694973,686217,R,L,FF,98.2,-0.25,2.13,0,0,0,0.0,0.0,0.0,1,,called_strike,0,0,4,1,776771,2025,Bot,0,0,0,0,0,0,0.05,0.51,0.46,0,0-0,2,1,2|1,,1.023947,0,0,0,0.922,0,0
1,2025-08-12,694973,686217,R,L,FF,98.4,0.39,4.07,0,1,0,0.0,0.0,0.0,1,,ball,0,0,4,2,776771,2025,Bot,0,0,0,0,0,0,-0.04,0.46,0.50,0,0-1,4,4,4|4,FF,1.060448,0,0,1,0.922,0,0
2,2025-08-12,694973,686217,R,L,ST,82.6,-0.01,2.58,1,1,0,0.0,0.0,0.0,1,,foul,0,0,4,3,776771,2025,Bot,0,0,0,0,0,0,0.06,0.50,0.44,0,1-1,3,2,3|2,FF,-1.823078,0,0,4,0.922,0,0
3,2025-08-12,694973,686217,R,L,FF,98.2,-0.07,3.35,1,2,0,0.0,0.0,0.0,1,,foul,0,0,4,4,776771,2025,Bot,0,0,0,0,0,0,-0.00,0.44,0.44,0,1-2,3,4,3|4,ST,1.023947,0,0,5,0.922,0,0
4,2025-08-12,694973,686217,R,L,FF,97.7,-1.03,4.22,1,2,0,0.0,0.0,0.0,1,,ball,0,0,4,5,776771,2025,Bot,0,0,0,0,0,0,-0.03,0.44,0.47,0,1-2,0,4,0|4,FF,0.932697,0,0,5,0.922,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4552,2025-04-02,676979,622761,L,R,CH,88.6,1.25,2.19,0,1,0,602104.0,0.0,0.0,8,,ball,0,3,58,2,778483,2025,Bot,0,3,3,0,0,3,-0.05,0.82,0.87,0,0-1,6,2,6|2,FF,-0.728068,-3,3,37,1.662,1,0
4553,2025-04-02,676979,622761,L,R,FC,89.3,0.69,2.76,1,1,0,602104.0,0.0,0.0,8,,called_strike,0,3,58,3,778483,2025,Bot,0,3,3,0,0,3,0.06,0.87,0.81,0,1-1,5,3,5|3,CH,-0.600317,-3,3,40,1.662,1,0
4554,2025-04-02,676979,622761,L,R,FF,94.7,0.09,3.81,1,2,0,602104.0,0.0,0.0,8,field_out,hit_into_play,0,3,58,4,778483,2025,Bot,0,3,3,0,0,3,0.27,0.81,0.54,0,1-2,3,4,3|4,FC,0.385192,-3,3,41,1.662,1,0
4555,2025-04-02,676979,656775,L,L,SI,94.9,1.08,2.98,0,0,1,602104.0,0.0,0.0,8,,foul,0,3,59,1,778483,2025,Bot,0,3,3,0,0,3,0.05,0.54,0.49,0,0-0,6,3,6|3,,0.421692,-3,4,48,0.978,1,0
