In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from scipy.ndimage import gaussian_filter
import tqdm
from joblib import Parallel, delayed
from markov_functions import *
from itertools import product
import warnings

In [None]:
hits = ['1B', '2B', '3B', 'HR', 'IHR', 'H']
bbs = {'uBB', 'IBB'}
hbps = {'HBP'}
non_ab_results = {'SH', 'SF', 'uBB', 'IBB', 'HBP', 'IH', 'IR', 'ID'}
tb_map = {'1B': 1, 'H': 1, '2B': 2, '3B': 3, 'HR': 4, 'IHR': 4}
_swing_tokens = {'SW', 'F', 'FT', 'FOUL_BUNT', 'TRY_BUNT', 'BUNT', 'H'}
_fastball_tokens = {'FF', 'SI', 'FC'} # four-seam, sinker, cutter
_offspeed_tokens = {'CH', 'FO', 'FS', 'KN', 'EP'} # changeup, forkball, split-finger, knuckleball
_breaking_tokens = {'CU', 'SL'} # curveball, slider
pitch_types = ['fastball', 'offspeed', 'breaking']

In [None]:
pa_pitches_filename = './data/paired_filtered.csv'  # or provide a full path

if not os.path.exists(pa_pitches_filename):
    csv_files = [f for f in os.listdir('.') if f.lower().endswith('.csv')]
    if not csv_files:
        raise FileNotFoundError(f"{pa_pitches_filename!r} not found and no CSV files in the current directory.")
    filename = csv_files[0]
    print(f"No file named 'data.csv' found. Using first CSV in cwd: {filename}")



pas = pd.read_csv(pa_pitches_filename)
pas.drop(['pa_seq', 'bases', 'velocities_events', 'pitchCodes_events'], axis=1, inplace=True)
print(f"Loaded {len(pas)} rows and {len(pas.columns)} columns from {pa_pitches_filename}")
# pas.head()


In [None]:
pitcher = pas[pas['pitcherName']=='古林睿煬']

batter = pas[pas['batterName'].isin(pas['batterName'].unique())]
# batter = pas[pas['batterName']=='魔鷹']
opposite_batter = batter[batter['batterHand'] != batter['pitcherHand']]
samehand_batter = batter[batter['batterHand'] == batter['pitcherHand']]

# pitcher.head(5)

In [None]:
opposite = False

counts = {
    'ball': [0, 1, 2, 3],
    'strike': [0, 1, 2],
}

pitcher_events, pitcher_event_list = get_pitches_with_counts(pitcher, opposite_hand=opposite, **counts)
if opposite == False:
    batter_events, batter_event_list = get_pitches_with_counts(samehand_batter, opposite_hand=False, **counts)
elif opposite == True:
    batter_events, batter_event_list = get_pitches_with_counts(opposite_batter, opposite_hand=True, **counts)
else:
    batter_events, batter_event_list = get_pitches_with_counts(batter, opposite_hand=opposite, **counts)


In [None]:
pitchtype_map, swing_map, whiff_map, inplay_map, soft_map, called_strike_zone = counts_prob('0-0', pitcher_event_list, batter_event_list, situation_params=None)

In [None]:
plot_map =  pitchtype_map
fig, axes = plt.subplots(1, 3, figsize=(16,6), tight_layout=True)
for i, pitch_type in enumerate(pitch_types):
    c = axes[i].pcolormesh(x_bound, y_bound, plot_map[:, :, i].T, shading='auto', cmap='Blues', vmin=0, vmax=np.max(plot_map))
    axes[i].set_title(f'{pitch_type.capitalize()} Pitch Density')
    plotting_background(axes[i])
    for idx_x in range(len(x_centers)):
        for idx_y in range(len(y_centers)):
            density_value = plot_map[idx_x, idx_y, i]
            if density_value > 0:
                axes[i].text(x_centers[idx_x], y_centers[idx_y], f'{density_value*100:.1f}',
                             color='red', fontsize=8, ha='center', va='center')
plt.show()


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16,6), tight_layout=True)
for i, pitch_type in enumerate(pitch_types):
    c = axes[i].pcolormesh(x_bound, y_bound, pitchtype_map[:, :, i].T, shading='auto', cmap='Blues', vmin=0, vmax=np.max(pitchtype_map))
    axes[i].set_title(f'{pitch_type.capitalize()} Pitch Density')
    plotting_background(axes[i])

n_samples = 1000
whiff_count = 0
contact_count = 0
noswing_count = 0
for i in range(n_samples):
    sampled_pitch = sample_pitch(pitchtype_map)
    (x_idx, y_idx, pitchtype), (x_sampled, y_sampled, pitchtype) = sampled_pitch
    if prob_determine(swing_map, x_idx, y_idx, pitchtype):
        if prob_determine(whiff_map, x_idx, y_idx, pitchtype):
            axes[pitchtype].scatter(x_sampled, y_sampled, color='green', s=10, alpha=0.5)
            whiff_count += 1
        else:
            axes[pitchtype].scatter(x_sampled, y_sampled, color='blue', s=10, alpha=0.5)
            contact_count += 1
    else:
        axes[pitchtype].scatter(x_sampled, y_sampled, color='red', s=10, alpha=0.5)
        noswing_count += 1
print(f'Total samples: {n_samples}, Whiffs: {whiff_count}, Contacts: {contact_count}, No Swings: {noswing_count}')
print(f'Whiff Rate: {whiff_count/(whiff_count+contact_count):.3f}, Contact Rate: {contact_count/n_samples:.3f}, No Swing Rate: {noswing_count/n_samples:.3f}')
plt.show()

In [None]:
n_pitch = 0
strike = 0
ball = 0
counts = f'{ball}-{strike}'

pa_end = False
good_ending = False


situation_params_init = {
    'pitch_type_last': None,
    'coords_quadrant_last': None,
    'swing_last': None,
    'whiff_last': None,
    'pitch_type_last2': None,
    'coords_quadrant_last2': None,
    'swing_last2': None,
    'whiff_last2': None
}
pitchtype_map, swing_map, whiff_map, inplay_map, soft_map, called_strike_zone = counts_prob(counts, pitcher_event_list, batter_event_list, situation_params=situation_params_init)

In [None]:
situation_params = situation_params_init.copy()


In [None]:
store_result = 'simulation_results.csv'
if os.path.exists(store_result): os.remove(store_result)

In [None]:
n_PA = 50
for i in tqdm.tqdm(range(n_PA)):
    pitch_coord_sequence = []
    pitch_types_sequence = []
    pitch_results_sequence = []
    ending_type = None

    while not pa_end:
        sampled_pitch = sample_pitch(pitchtype_map)
        (x_idx, y_idx, pitchtype), (x_sampled, y_sampled, pitchtype) = sampled_pitch
        pitch_coord_sequence.append((x_sampled, y_sampled))
        pitch_types_sequence.append(pitch_types[pitchtype])
        n_pitch += 1
        if prob_determine(swing_map, x_idx, y_idx, pitchtype): # swing or not
            if prob_determine(whiff_map, x_idx, y_idx, pitchtype): # whiff or contact
                pitch_results_sequence.append('WHIFF')
                # print(f'Sampled pitch at ({x_sampled:.2f}, {y_sampled:.2f}) of type {pitch_types[pitchtype]} resulted in WHIFF.')
                situation_params = write_situation(situation_params=situation_params, pitchtype=pitch_types[pitchtype], x=x_sampled, y=y_sampled, swing=True, whiff=True)
                if strike < 2: 
                    strike += 1
                elif strike == 2:
                    pa_end = True
                    good_ending = True
                    ending_type = 'strikeout'
                    # print('Batter struck out!')
            else:
                if prob_determine(inplay_map, x_idx, y_idx, pitchtype): # inplay or foul
                    pa_end = True
                    if prob_determine(soft_map, x_idx, y_idx, pitchtype): # soft or hard
                        good_ending = True
                        pitch_results_sequence.append('SOFT-INPLAY')
                        ending_type = 'soft-inplay'
                        # print(f'Sampled pitch at ({x_sampled:.2f}, {y_sampled:.2f}) of type {pitch_types[pitchtype]} resulted in SOFT INPLAY.')
                        # print('Batter put the ball in play softly.')
                    else:
                        good_ending = False
                        pitch_results_sequence.append('HARD-INPLAY')
                        ending_type = 'hard-inplay'
                        # print(f'Sampled pitch at ({x_sampled:.2f}, {y_sampled:.2f}) of type {pitch_types[pitchtype]} resulted in HARD INPLAY.')
                        # print('Batter put the ball in play hard.')
                else:
                    pitch_results_sequence.append('FOUL')
                    # print(f'Sampled pitch at ({x_sampled:.2f}, {y_sampled:.2f}) of type {pitch_types[pitchtype]} resulted in FOUL.')
                    situation_params = write_situation(situation_params=situation_params, pitchtype=pitch_types[pitchtype], x=x_sampled, y=y_sampled, swing=True, whiff=False)
                    if strike < 2: strike += 1
        else:
            if prob_determine(called_strike_zone, x_idx, y_idx, pitchtype):
                pitch_results_sequence.append('CALLED-STRIKE')
                # print(f'Sampled pitch at ({x_sampled:.2f}, {y_sampled:.2f}) of type {pitch_types[pitchtype]} resulted in CALLED STRIKE.')
                situation_params = write_situation(situation_params=situation_params, pitchtype=pitch_types[pitchtype], x=x_sampled, y=y_sampled, swing=False, whiff=False)
                if strike < 2: 
                    strike += 1
                elif strike == 2:
                    pa_end = True
                    good_ending = True
                    ending_type = 'strikeout'
                    # print('Batter struck out!')
            else:
                pitch_results_sequence.append('BALL')
                # print(f'Sampled pitch at ({x_sampled:.2f}, {y_sampled:.2f}) of type {pitch_types[pitchtype]} resulted in BALL.')
                situation_params = write_situation(situation_params=situation_params, pitchtype=pitch_types[pitchtype], x=x_sampled, y=y_sampled, swing=False, whiff=False)
                if ball < 3: 
                    ball += 1
                elif ball == 3:
                    pa_end = True
                    good_ending = False
                    ending_type = 'walk'
                    # print('Batter walked!')

        counts = f'{ball}-{strike}'
        # print(f'Count after pitch: {counts}')

        # if pa_end:
        #     if good_ending:
        #         print('PA ended with a GOOD outcome.')
        #     else:
        #         print('PA ended with a BAD outcome.')
        #     print(f'Pitches thrown in PA: {n_pitch}')
            
            

        if not pa_end:
            with warnings.catch_warnings():
                warnings.simplefilter("error", RuntimeWarning)
                try:
                    pitchtype_map, swing_map, whiff_map, inplay_map, soft_map, called_strike_zone = counts_prob(counts, pitcher_event_list, batter_event_list, situation_params=situation_params)
                except RuntimeWarning:
                    try:
                        # Reset the second-to-last pitch info and retry
                        situation_params_copy = situation_params.copy()
                        situation_params_copy['pitch_type_last2'] = None
                        situation_params_copy['coords_quadrant_last2'] = None
                        situation_params_copy['swing_last2'] = None
                        situation_params_copy['whiff_last2'] = None
                        pitchtype_map, swing_map, whiff_map, inplay_map, soft_map, called_strike_zone = counts_prob(counts, pitcher_event_list, batter_event_list, situation_params=situation_params_copy)
                    except RuntimeWarning:
                        # If it still fails, use initial parameters
                        pitchtype_map, swing_map, whiff_map, inplay_map, soft_map, called_strike_zone = counts_prob(counts, pitcher_event_list, batter_event_list, situation_params=situation_params_init)

    with open(store_result, 'a') as f:
        f.write(f"{pitch_coord_sequence},{pitch_types_sequence},{pitch_results_sequence},{ending_type},{good_ending}\n")
    
    n_pitch = 0
    strike = 0
    ball = 0
    pa_end = False
    good_ending = False
    situation_params = situation_params_init.copy()

    

# print(pitch_types_sequence)
# print(pitch_results_sequence)
# print(ending_type)