In [None]:


import os
import pandas as pd
import numpy as np


%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader, random_split
import torch.nn as nn

from modules.classifer_utils import NormalizedClassifierDataset, NormalizedClassifierDatasetMetadata, TrainingManager, GeneralNN



In [185]:
games_df = pd.read_parquet("games/statcast-2020.parquet")
print(f'starting with {len(games_df)} records on disk')

# first off, ditch rows without our target label
rows_to_drop = games_df[games_df.pitch_type.isna()].index
games_df.drop(rows_to_drop, inplace=True)

# temp make this a binary classifier
LABEL_COLUMN_NAME = "is_fastball"
games_df[LABEL_COLUMN_NAME] = games_df.pitch_type.str.startswith('F').astype(int)

print(f'target label breakdown\n{games_df[LABEL_COLUMN_NAME].value_counts()}')


# TODO figure out date as a categorical
DAY_OF_YEAR = 'day_of_year'
games_df[DAY_OF_YEAR] = games_df.game_date.dt.dayofyear

games_df = games_df.astype({
    'pitcher': 'int', 
    'batter': 'int',
    'on_1b': 'int',
    'on_2b': 'int',
    'on_3b': 'int'
}, errors='ignore')

ds_meta = NormalizedClassifierDatasetMetadata(LABEL_COLUMN_NAME)
ds_meta.set_ordinal_numeric_cols( [
    "bat_score", 
    "fld_score", 
    "balls", 
    "strikes", 
    "outs_when_up",   
    DAY_OF_YEAR,
    "at_bat_number",
    "pitch_number",
    "n_thruorder_pitcher",
    "age_pit",
    "age_bat",
    "pitcher_days_since_prev_game"
] )

ds_meta.set_categorical_map({
    col : list(games_df[col].unique()) for col in ['p_throws', 'stand']  #, 'if_fielding_alignment', 'of_fielding_alignment']
})

ds_meta.set_embedding_cols(["pitcher", "batter", "on_1b", "on_2b", "on_3b"])


normed_df = games_df[ ds_meta.get_columns() ].dropna()
overall_ds = NormalizedClassifierDataset(normed_df, ds_meta)


train_ds, test_ds = random_split(overall_ds, [.80, .20])

batch_size = int(len(train_ds) / 10)
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_ds, batch_size=batch_size, shuffle=True, drop_last=True)

print(f'{len(train_ds)} batches with batch_size: {batch_size}, {len(test_ds)} batches for test.\n')
num_features = overall_ds.get_feature_count()
print(f'datasets have {num_features} features')




starting with 279660 records on disk
target label breakdown
is_fastball
0    156297
1    121669
Name: count, dtype: int64
5354 batches with batch_size: 535, 1338 batches for test.

datasets have 46 features


In [184]:

dropoutRate = 0.2

input_features = overall_ds.get_feature_count()
model = GeneralNN( input_features, [num_features*2,num_features*2,32,32,16,1], dropoutRate )


training_mgr = TrainingManager(model)
training_mgr.train(train_dataloader, 100)



  self.features_ndarray = False #df_features_tensor


Epoch [1/100], 2636 of 5350 correct 49.27 %
Epoch [2/100], 2489 of 5350 correct 46.52 %
Epoch [3/100], 2470 of 5350 correct 46.17 %
Epoch [4/100], 2544 of 5350 correct 47.55 %
Epoch [5/100], 2580 of 5350 correct 48.22 %
Epoch [6/100], 2640 of 5350 correct 49.35 %
Epoch [7/100], 2650 of 5350 correct 49.53 %
Epoch [8/100], 2495 of 5350 correct 46.64 %
Epoch [9/100], 2448 of 5350 correct 45.76 %
Epoch [10/100], 2462 of 5350 correct 46.02 %
Epoch [11/100], 2391 of 5350 correct 44.69 %
Epoch [12/100], 2418 of 5350 correct 45.20 %
Epoch [13/100], 2463 of 5350 correct 46.04 %
Epoch [14/100], 2508 of 5350 correct 46.88 %
Epoch [15/100], 2409 of 5350 correct 45.03 %
Epoch [16/100], 2423 of 5350 correct 45.29 %
Epoch [17/100], 2460 of 5350 correct 45.98 %
Epoch [18/100], 2428 of 5350 correct 45.38 %
Epoch [19/100], 2399 of 5350 correct 44.84 %
Epoch [20/100], 2424 of 5350 correct 45.31 %
Epoch [21/100], 2390 of 5350 correct 44.67 %
Epoch [22/100], 2492 of 5350 correct 46.58 %
Epoch [23/100], 246

In [180]:
training_mgr.eval(test_dataloader)

Average Test Loss: 0.7524
Test Accuracy: 0.5813
Test Precision: 0.0000
Test Recall: 0.0000
Test F1-Score: 0.0000


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
