In [1]:
import numpy as np
import pandas as pd
import seaborn as sns

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import Window
from pyspark.sql.types import *
import pyspark.sql.functions as F

In [3]:
sc = SparkSession.builder.getOrCreate()

# Data preperation for PitchNet

In [4]:
baseball = sc.read.parquet('./baseball_savant.parquet')

In [5]:
baseball = baseball.select('pitch_type','game_date','pitcher','p_throws','player_name','pfx_x','pfx_z','release_speed')

In [6]:
baseball.count()

3646300

In [7]:
baseball = baseball.na.drop()

In [8]:
baseball.count()

3629838

In [9]:
baseball = baseball.withColumn('game_date',baseball.game_date.cast(DateType()))
baseball = baseball.withColumn('game_year',F.year(baseball.game_date))
baseball = baseball.withColumn('game_month',F.month(baseball.game_date))
baseball = baseball.withColumnRenamed('pfx_x','horizontal_break')
baseball = baseball.withColumnRenamed('pfx_z','vertical_break')
baseball = baseball.withColumn('idx', F.monotonically_increasing_id())

In [10]:
baseball.show()

+----------+----------+-------+--------+-----------+-----------------+------------------+-------------+---------+----------+-----+--------------------+
|pitch_type| game_date|pitcher|p_throws|player_name| horizontal_break|    vertical_break|release_speed|game_year|game_month|  idx|season_total_pitches|
+----------+----------+-------+--------+-----------+-----------------+------------------+-------------+---------+----------+-----+--------------------+
|        FT|2017-07-15| 425492|       R|Ryan Madson| -1.2152526818938| 0.993839644932961|         96.1|     2017|         7|32953|                 929|
|        FT|2017-07-15| 425492|       R|Ryan Madson|-1.29308094077011| 0.910182884625976|         95.9|     2017|         7|32954|                 929|
|        FT|2017-07-15| 425492|       R|Ryan Madson|-1.19077934677773| 0.684373302559358|         94.2|     2017|         7|32955|                 929|
|        FT|2017-07-15| 425492|       R|Ryan Madson|-1.34445394610972| 0.505212550403798

In [12]:
baseball.groupby('pitch_type').count().show()

+----------+-------+
|pitch_type|  count|
+----------+-------+
|        FT| 401239|
|        SC|    113|
|        SL| 583237|
|        FC| 202561|
|        EP|    867|
|        FF|1296645|
|        FS|  54809|
|        PO|    630|
|        KC|  89087|
|        IN|   6390|
|        CH| 378308|
|        CU| 299183|
|        FO|    845|
|        UN|     20|
|        KN|  11453|
|        FA|     10|
|        SI| 304441|
+----------+-------+



In [13]:
# keeping pitch types that are more common
valid_pitch_type = ['CH','CU','FS','KC','SL','SI','FF','FC','FT']
baseball = baseball.filter(baseball.pitch_type.isin(valid_pitch_type))

In [14]:
baseball.count()

3609510

In [15]:
# create a column that gives a pitcher's total pitcher in a season. 
window = Window.partitionBy('pitcher','game_year')
baseball = baseball.withColumn('season_total_pitches',F.count('*').over(window))

In [16]:
# output file
baseball.printSchema()

root
 |-- pitch_type: string (nullable = true)
 |-- game_date: date (nullable = true)
 |-- pitcher: integer (nullable = true)
 |-- p_throws: string (nullable = true)
 |-- player_name: string (nullable = true)
 |-- horizontal_break: double (nullable = true)
 |-- vertical_break: double (nullable = true)
 |-- release_speed: double (nullable = true)
 |-- game_year: integer (nullable = true)
 |-- game_month: integer (nullable = true)
 |-- idx: long (nullable = false)
 |-- season_total_pitches: long (nullable = false)



In [20]:
#baseball.write.mode('overwrite').parquet('pitch_prediction_data.parquet')