In [1]:
!pip install pybaseball

# General Packages
import numpy as np
import pandas as pd
import pybaseball as pyb
import seaborn as sns

# Machine Learning
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report

import warnings
warnings.filterwarnings('ignore')

pitches = pyb.statcast('2018-06-01', '2018-07-01')

print('Our Dataset has {0} Pitches'.format(len(pitches)))

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pybaseball
  Downloading pybaseball-2.2.1-py3-none-any.whl (415 kB)
[K     |████████████████████████████████| 415 kB 11.9 MB/s 
[?25hCollecting pygithub>=1.51
  Downloading PyGithub-1.55-py3-none-any.whl (291 kB)
[K     |████████████████████████████████| 291 kB 40.4 MB/s 
Collecting deprecated
  Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)
Collecting pyjwt>=2.0
  Downloading PyJWT-2.4.0-py3-none-any.whl (18 kB)
Collecting pynacl>=1.4.0
  Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB)
[K     |████████████████████████████████| 856 kB 28.1 MB/s 
Installing collected packages: pynacl, pyjwt, deprecated, pygithub, pybaseball
Successfully installed deprecated-1.2.13 pybaseball-2.2.1 pygithub-1.55 pyjwt-2.4.0 pynacl-1.5.0
This is a large query, it may take a moment to complete


100%|██████████| 31/31 [00:48<00:00,  1.56s/it]


Our Dataset has 122689 Pitches


In [2]:
def decision_tree(data,fastball_group):
    data = data.loc[:,['pitch_name',
                       'release_speed',
                       'release_spin_rate',
                       'vx0','vy0','vz0',
                       'ax','ay','az']]
    
    data = data.dropna()
    
    def pitch_filter(x):
        if x=='2-Seam Fastball' or x=='4-Seam Fastball' or x=='Sinker':
            return 'Fastball_group'
        return x
    
    if fastball_group == True:
        data['pitch_name'] = data['pitch_name'].apply(pitch_filter)
    
    X = data.loc[:,['release_speed','vx0','vy0','vz0','ax','ay','az','release_spin_rate']]
    y = data.loc[:,['pitch_name']]
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=8)
    
    dt_model = tree.DecisionTreeClassifier(max_depth=10, min_samples_split=50)
    
    dt_model.fit(X_train, y_train)
    
    predictions = dt_model.predict(X_train)
    print('Training Score Accuracy {0}'.format(accuracy_score(predictions, y_train)))
    
    predictions = dt_model.predict(X_test)
    print('Test Score Accuracy {0}'.format(accuracy_score(predictions, y_test)))
    
    print(classification_report(predictions, y_test))

In [3]:
# Data = All of MLB over a one month span

decision_tree(pitches, fastball_group = False)

Training Score Accuracy 0.8384943137641119
Test Score Accuracy 0.8156825344169846
                 precision    recall  f1-score   support

4-Seam Fastball       0.91      0.90      0.90      8737
       Changeup       0.86      0.83      0.84      2624
      Curveball       0.72      0.75      0.73      1786
         Cutter       0.54      0.62      0.58      1363
         Eephus       0.67      0.44      0.53         9
       Fastball       0.00      0.00      0.00         0
  Knuckle Curve       0.36      0.55      0.44       460
    Knuckleball       0.00      0.00      0.00         0
      Pitch Out       0.00      0.00      0.00         0
      Screwball       0.00      0.00      0.00         1
         Sinker       0.84      0.83      0.83      4745
         Slider       0.85      0.75      0.79      4236
   Split-Finger       0.27      0.57      0.37       155

       accuracy                           0.82     24116
      macro avg       0.46      0.48      0.46     24116
   w

In [4]:
# Data = All of MLB over a one month span
# Treat Four-Seam, Two-Seam, and Sinker as a general fastball

decision_tree(pitches, fastball_group = True)

Training Score Accuracy 0.8876437109298059
Test Score Accuracy 0.8716619671587328
                precision    recall  f1-score   support

      Changeup       0.87      0.83      0.85      2655
     Curveball       0.75      0.75      0.75      1841
        Cutter       0.52      0.68      0.59      1208
        Eephus       0.67      0.44      0.53         9
      Fastball       0.00      0.00      0.00         0
Fastball_group       0.97      0.97      0.97     13332
 Knuckle Curve       0.39      0.57      0.46       472
   Knuckleball       0.00      0.00      0.00         0
     Pitch Out       0.00      0.00      0.00         0
     Screwball       0.00      0.00      0.00         2
        Slider       0.88      0.75      0.81      4393
  Split-Finger       0.33      0.53      0.41       204

      accuracy                           0.87     24116
     macro avg       0.45      0.46      0.45     24116
  weighted avg       0.88      0.87      0.88     24116

