In [1]:
import traceback
import pandas as pd
from connections import AWS
from pybaseball import statcast_pitcher, statcast

In [2]:
""" INITIALIZE AWS CONNECTION """
aws_connection = AWS()
aws_connection.connect()

[AWS]: Port 5433 is free.
[AWS]: Connected to RDS endpoint.


In [3]:
# load cohort (csv)
    # NOTE: new AWS module (v0.2.8) loads this directly as a dataframe
cohort = aws_connection.load_s3_object('epidemiology/cohorts/injured/combined_0625.csv')

$\textbf{PyBaseball: Querying (Development)}$

In [8]:
def pull_statcast_data(
        pitcher_id: str, 
        start_date: str, 
        end_date: str,
        non_missing_cols: list = [
            'release_speed', 
            'release_spin_rate',
            'spin_axis', 
            'release_pos_x', 
            'release_pos_z',
            'ax', 
            'ay', 
            'az'
        ]
) -> pd.DataFrame:
    """
    Pulls Statcast data for a given pitcher within a specified date range. This essentially amounts to calling `statcast_pitcher` with 
    extra postprocessing steps.
    
    :param pitcher_id: The MLB ID of the pitcher.
    :param start_date: The start date for the data pull (YYYY-MM-DD).
    :param end_date: The end date for the data pull (YYYY-MM-DD).
    :param non_missing_cols: A list of columns that should not contain missing values.
    :return: A DataFrame containing the Statcast data for the pitcher.
    """

    # check for str pitcher_id
    if not isinstance(pitcher_id, str):
        pitcher_id = str(pitcher_id)
    
    # load raw data from statcast API
    raw_data = statcast_pitcher(
        start_dt=start_date, 
        end_dt=end_date,
        player_id=pitcher_id
    )    
    clean_data = raw_data.dropna(axis=0, subset=non_missing_cols, how='any')        # remove NAs to clean data

    return clean_data


In [29]:
# initialize data storage
full_metadata = {}
full_data = []
error_log = {}

# iterate through all pitcher IDs
for pitcher_id in cohort.sort_values(by='injury_date', ascending=False)['mlbamid'].unique():

    try:
        # get relevant pitcher info
        pitcher_inj_date = cohort.loc[cohort['mlbamid'] == pitcher_id, 'injury_date'].values[0]
        pitcher_inj_year = pitcher_inj_date.split('-')[0]   # extract year from injury date
        pitcher_inj_month = pitcher_inj_date.split('-')[1]  # extract month from injury date

        # pull data from Statcast API
            # date range starts at beginning of year of injury, ends at beginning of month after injury
        pitcher_statcast_data = pull_statcast_data(
            pitcher_id=pitcher_id,
            start_date=f'{pitcher_inj_year}-01-01',
            end_date=f'{pitcher_inj_year}-{str((int(pitcher_inj_month) + 1)).zfill(2)}-01',
        )
        pitcher_metadata = {
            'mlbamid': pitcher_id,
            'injured': 1,
            'injury_date': pitcher_inj_date,
            'tracked_pitches_prior_to_injury': pitcher_statcast_data.shape[0],
        }

        # update data storage
        if not pitcher_statcast_data.empty:
            full_data.append(pitcher_statcast_data)

            # save to S3
            aws_connection.create_s3_folder(f'epidemiology/subjects/{pitcher_id}')          # create folder in epidemiology/subjects
            aws_connection.upload_to_s3(
                pitcher_statcast_data, 
                s3_key=f'epidemiology/subjects/{pitcher_id}/ball_tracking.csv'
            )  
            aws_connection.upload_to_s3(
                pd.DataFrame([pitcher_metadata]), 
                s3_key=f'epidemiology/subjects/{pitcher_id}/metadata.csv'
            ) 
        
            # print progress
            print(f"Processed pitcher `{pitcher_id}` with {pitcher_statcast_data.shape[0]} pitches tracked prior to injury.")
        
        else:
            print(f"No data found for pitcher {pitcher_id} in the specified date range.")
        
        # add metadata to full dictionary
        full_metadata[pitcher_id] = pitcher_metadata
        
    except Exception as e:
        print(f"Error processing pitcher {pitcher_id}: {e}")
        traceback.print_exc()
        error_log[pitcher_id] = str(e)
    

Gathering Player Data
[AWS]: Folder s3://pitch-ml/epidemiology/subjects/669854/ already exists.
[AWS]: Uploaded object to s3://pitch-ml/epidemiology/subjects/669854/ball_tracking.csv
[AWS]: Uploaded object to s3://pitch-ml/epidemiology/subjects/669854/metadata.csv
Processed pitcher `669854` with 852 pitches tracked prior to injury.
Gathering Player Data
[AWS]: Created folder s3://pitch-ml/epidemiology/subjects/594902/
[AWS]: Uploaded object to s3://pitch-ml/epidemiology/subjects/594902/ball_tracking.csv
[AWS]: Uploaded object to s3://pitch-ml/epidemiology/subjects/594902/metadata.csv
Processed pitcher `594902` with 854 pitches tracked prior to injury.
Gathering Player Data
[AWS]: Created folder s3://pitch-ml/epidemiology/subjects/669203/
[AWS]: Uploaded object to s3://pitch-ml/epidemiology/subjects/669203/ball_tracking.csv
[AWS]: Uploaded object to s3://pitch-ml/epidemiology/subjects/669203/metadata.csv
Processed pitcher `669203` with 1222 pitches tracked prior to injury.
Gathering Pla

In [None]:
# concatenate all injured pitcher data --> save to epi.../cohorts/injured/statcast_data.csv
cohort_statcast_data = pd.concat(full_data, ignore_index=True)
cohort_metadata = pd.DataFrame.from_dict(full_metadata, orient='index').reset_index(drop=True)

# save to S3
aws_connection.upload_to_s3(
    cohort_statcast_data, 
    s3_key='epidemiology/cohorts/injured/statcast_data.csv'
)
aws_connection.upload_to_s3(
    cohort_metadata, 
    s3_key='epidemiology/cohorts/injured/statcast_metadata.csv'
)


[AWS]: Uploaded object to s3://pitch-ml/epidemiology/cohorts/injured/statcast_metadata.csv


In [None]:
# TODO: release position --> arm angle model? for CB group

$\textbf{Close AWS Connection}$

In [35]:
aws_connection.close()

[AWS]: Database connection closed.
[AWS]: SSH tunnel stopped.
