# 04 - Modeling

In [2]:
from google.colab import drive
drive.mount('/content/drive')

import os
cur_path = "/content/drive/MyDrive/BDB 2025/"
os.chdir(cur_path)
!pwd

Mounted at /content/drive
/content/drive/MyDrive/BDB 2025


In [3]:
# !pip install pyspark

# The entry point to programming Spark with the DataFrame API.
from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[*]").appName("DataFrame").getOrCreate()

In [4]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

pd.set_option('display.max_columns', 500)

from pyspark.sql.functions import *
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType,IntegerType
from pyspark.ml.feature import StringIndexer,IndexToString,OneHotEncoder,VectorAssembler
from pyspark.ml.classification import RandomForestClassifier,LogisticRegression,MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

import time

In [5]:
games = spark.read.option("header",'True').csv('./data/games.csv')
players = spark.read.option("header",'True').csv('./data/players.csv')
plays = spark.read.option("header",'True').csv('./data/plays.csv')
player_play = spark.read.option("header",'True').csv('./data/player_play.csv')
tracking = spark.read.option("header",'True').parquet('./data/games/tracking_week*.parquet')

In [None]:
# tracking.filter(tracking['gameId']==2022091108).filter(tracking['playId']==729).show(5)
# sample play used in the results notebook is in here

+----------+------+-----+-------+-------------+-----------+--------------------+------------+----+-------------+-----+-----+----+----+----+------+------+--------------------+--------------+----------+-----------+------+------------------+------------------+------------------+------------------+----+------------------+----------------------+--------------------+----------------+----------------+------------------+------------------+----------------------+--------------------+----------------+----------------+------------------+------------------+----------------------+--------------------+----------------+----------------+------------------+------------------+----------------------+--------------------+----------------+----------------+------------------+-----------------+---------------------+-------------------+---------------+---------------+-----------------+------------------+---------------------+-------------------+---------------+---------------+-----------------+---------------

In [None]:
# test weeks
# tracking = tracking.filter((tracking.week == '7') | (tracking.week == '9'))

In [6]:
tracking = tracking.join(plays.select('gameId','playId','quarter','down','yardsToGo','preSnapHomeScore','preSnapVisitorScore','preSnapHomeTeamWinProbability','preSnapVisitorTeamWinProbability','secRemainHalf','offenseFormation','receiverAlignment','absoluteYardlineNumber','playDescription'), on=['gameId','playId'], how='left')
tracking = tracking.join(games.select('gameId','homeTeamAbbr'), on=['gameId'], how='left')

In [7]:
# remove outlier formations (WILDCAT & JUMBO)
# tracking = tracking.filter((tracking.offenseFormation != 'WILDCAT') & (tracking.offenseFormation != 'JUMBO'))
# remove outlier receiver alignment (1x0, 3x0, 3x3)
tracking = tracking.filter((tracking.receiverAlignment != '1x0') & (tracking.receiverAlignment != '3x0') & (tracking.receiverAlignment != '3x3'))

In [None]:
# remove plays inside of 2 minutes
# tracking = tracking.withColumn('secRemainHalf', tracking['secRemainHalf'].cast(IntegerType()))
# tracking = tracking.filter(tracking.secRemainHalf > 120)
# remove no huddle plays
# tracking = tracking.where(~F.col('playDescription').contains('No Huddle'))

In [None]:
# get minimum of secRemainHalf
# tracking.select(min('secRemainHalf')).show()

+------------------+
|min(secRemainHalf)|
+------------------+
|               121|
+------------------+



In [8]:
# use formations that are under center
tracking = tracking.filter((tracking.offenseFormation == 'I_FORM') | (tracking.offenseFormation == 'SINGLEBACK'))

In [9]:
# add presnap win probability for possession team
tracking = tracking.withColumn('preSnapPossessionTeamWinProbability', when(col("possessionTeam") == col("homeTeamAbbr"), col("preSnapHomeTeamWinProbability")).otherwise(col("preSnapVisitorTeamWinProbability")))

In [10]:
# relative distance to LOS
tracking = tracking.withColumn('relativeDistLOS', when(col("playDirection") == 'left', col("x") - col('absoluteYardlineNumber')).otherwise(col("absoluteYardlineNumber") - col("x")))
tracking = tracking.withColumn('relativeDistLOS', floor(col('relativeDistLOS')))

In [11]:
# filter to offense, don't care about defensive tendencies
tracking = tracking.filter(tracking.isOnOffense==1)

# convert club to integer for one hot encoding
indexer = StringIndexer(inputCol="club", outputCol="clubId")
tracking = indexer.fit(tracking).transform(tracking)

# one hot encode club
encoder = OneHotEncoder(inputCols=["clubId"],
                        outputCols=["offenseVec"])
model = encoder.fit(tracking)
tracking = model.transform(tracking)
# tracking.show(5)
# 2 minutes

teamCount = tracking.groupBy('clubId').count().count()

In [12]:
# convert formation to integer for one hot encoding
formationindexer = StringIndexer(inputCol="offenseFormation", outputCol="offenseFormationId")
tracking = formationindexer.fit(tracking).transform(tracking)

# one hot encode formation
encoder = OneHotEncoder(inputCols=["offenseFormationId"],
                        outputCols=["formationVec"])
model = encoder.fit(tracking)
tracking = model.transform(tracking)

formationCount = tracking.groupBy('offenseFormationId').count().count()

In [13]:
# convert receiver alignment to integer for one hot encoding
alignmentindexer = StringIndexer(inputCol="receiverAlignment", outputCol="receiverAlignmentId")
tracking = alignmentindexer.fit(tracking).transform(tracking)

# one hot encode receiver alignment
encoder = OneHotEncoder(inputCols=["receiverAlignmentId"],
                        outputCols=["receiverAlignmentVec"])
model = encoder.fit(tracking)
tracking = model.transform(tracking)

alignmentCount = tracking.groupBy('receiverAlignmentId').count().count()

In [14]:
# vectorization calls for non-strings
tracking = tracking.withColumn('X_std', tracking['X_std'].cast(DoubleType())) \
                    .withColumn('Y_std', tracking['Y_std'].cast(DoubleType())) \
                    .withColumn('dir_std', tracking['dir_std'].cast(DoubleType())) \
                    .withColumn('o_std', tracking['o_std'].cast(DoubleType())) \
                    .withColumn('s', tracking['s'].cast(DoubleType())) \
                    .withColumn('a', tracking['a'].cast(DoubleType())) \
                    .withColumn('dis', tracking['dis'].cast(DoubleType())) \
                    .withColumn('10_def_first(dist)', tracking['10_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('10_def_first(dir_std2)', tracking['10_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('10_def_first(o_std2)', tracking['10_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('10_def_first(s2)', tracking['10_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('10_def_first(a2)', tracking['10_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('10_def_first(dis2)', tracking['10_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('10_off_first(dist)', tracking['10_off_first(dist)'].cast(DoubleType())) \
                    .withColumn('10_off_first(dir_std2)', tracking['10_off_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('10_off_first(o_std2)', tracking['10_off_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('10_off_first(s2)', tracking['10_off_first(s2)'].cast(DoubleType())) \
                    .withColumn('10_off_first(a2)', tracking['10_off_first(a2)'].cast(DoubleType())) \
                    .withColumn('10_off_first(dis2)', tracking['10_off_first(dis2)'].cast(DoubleType())) \
                    .withColumn('11_def_first(dist)', tracking['11_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('11_def_first(dir_std2)', tracking['11_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('11_def_first(o_std2)', tracking['11_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('11_def_first(s2)', tracking['11_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('11_def_first(a2)', tracking['11_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('11_def_first(dis2)', tracking['11_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('1_def_first(dist)', tracking['1_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('1_def_first(dir_std2)', tracking['1_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('1_def_first(o_std2)', tracking['1_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('1_def_first(s2)', tracking['1_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('1_def_first(a2)', tracking['1_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('1_def_first(dis2)', tracking['1_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('1_off_first(dist)', tracking['1_off_first(dist)'].cast(DoubleType())) \
                    .withColumn('1_off_first(dir_std2)', tracking['1_off_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('1_off_first(o_std2)', tracking['1_off_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('1_off_first(s2)', tracking['1_off_first(s2)'].cast(DoubleType())) \
                    .withColumn('1_off_first(a2)', tracking['1_off_first(a2)'].cast(DoubleType())) \
                    .withColumn('1_off_first(dis2)', tracking['1_off_first(dis2)'].cast(DoubleType())) \
                    .withColumn('2_def_first(dist)', tracking['2_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('2_def_first(dir_std2)', tracking['2_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('2_def_first(o_std2)', tracking['2_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('2_def_first(s2)', tracking['2_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('2_def_first(a2)', tracking['2_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('2_def_first(dis2)', tracking['2_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('2_off_first(dist)', tracking['2_off_first(dist)'].cast(DoubleType())) \
                    .withColumn('2_off_first(dir_std2)', tracking['2_off_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('2_off_first(o_std2)', tracking['2_off_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('2_off_first(s2)', tracking['2_off_first(s2)'].cast(DoubleType())) \
                    .withColumn('2_off_first(a2)', tracking['2_off_first(a2)'].cast(DoubleType())) \
                    .withColumn('2_off_first(dis2)', tracking['2_off_first(dis2)'].cast(DoubleType())) \
                    .withColumn('3_def_first(dist)', tracking['3_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('3_def_first(dir_std2)', tracking['3_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('3_def_first(o_std2)', tracking['3_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('3_def_first(s2)', tracking['3_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('3_def_first(a2)', tracking['3_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('3_def_first(dis2)', tracking['3_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('3_off_first(dist)', tracking['3_off_first(dist)'].cast(DoubleType())) \
                    .withColumn('3_off_first(dir_std2)', tracking['3_off_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('3_off_first(o_std2)', tracking['3_off_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('3_off_first(s2)', tracking['3_off_first(s2)'].cast(DoubleType())) \
                    .withColumn('3_off_first(a2)', tracking['3_off_first(a2)'].cast(DoubleType())) \
                    .withColumn('3_off_first(dis2)', tracking['3_off_first(dis2)'].cast(DoubleType())) \
                    .withColumn('4_def_first(dist)', tracking['4_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('4_def_first(dir_std2)', tracking['4_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('4_def_first(o_std2)', tracking['4_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('4_def_first(s2)', tracking['4_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('4_def_first(a2)', tracking['4_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('4_def_first(dis2)', tracking['4_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('4_off_first(dist)', tracking['4_off_first(dist)'].cast(DoubleType())) \
                    .withColumn('4_off_first(dir_std2)', tracking['4_off_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('4_off_first(o_std2)', tracking['4_off_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('4_off_first(s2)', tracking['4_off_first(s2)'].cast(DoubleType())) \
                    .withColumn('4_off_first(a2)', tracking['4_off_first(a2)'].cast(DoubleType())) \
                    .withColumn('4_off_first(dis2)', tracking['4_off_first(dis2)'].cast(DoubleType())) \
                    .withColumn('5_def_first(dist)', tracking['5_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('5_def_first(dir_std2)', tracking['5_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('5_def_first(o_std2)', tracking['5_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('5_def_first(s2)', tracking['5_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('5_def_first(a2)', tracking['5_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('5_def_first(dis2)', tracking['5_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('5_off_first(dist)', tracking['5_off_first(dist)'].cast(DoubleType())) \
                    .withColumn('5_off_first(dir_std2)', tracking['5_off_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('5_off_first(o_std2)', tracking['5_off_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('5_off_first(s2)', tracking['5_off_first(s2)'].cast(DoubleType())) \
                    .withColumn('5_off_first(a2)', tracking['5_off_first(a2)'].cast(DoubleType())) \
                    .withColumn('5_off_first(dis2)', tracking['5_off_first(dis2)'].cast(DoubleType())) \
                    .withColumn('6_def_first(dist)', tracking['6_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('6_def_first(dir_std2)', tracking['6_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('6_def_first(o_std2)', tracking['6_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('6_def_first(s2)', tracking['6_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('6_def_first(a2)', tracking['6_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('6_def_first(dis2)', tracking['6_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('6_off_first(dist)', tracking['6_off_first(dist)'].cast(DoubleType())) \
                    .withColumn('6_off_first(dir_std2)', tracking['6_off_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('6_off_first(o_std2)', tracking['6_off_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('6_off_first(s2)', tracking['6_off_first(s2)'].cast(DoubleType())) \
                    .withColumn('6_off_first(a2)', tracking['6_off_first(a2)'].cast(DoubleType())) \
                    .withColumn('6_off_first(dis2)', tracking['6_off_first(dis2)'].cast(DoubleType())) \
                    .withColumn('7_def_first(dist)', tracking['7_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('7_def_first(dir_std2)', tracking['7_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('7_def_first(o_std2)', tracking['7_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('7_def_first(s2)', tracking['7_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('7_def_first(a2)', tracking['7_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('7_def_first(dis2)', tracking['7_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('7_off_first(dist)', tracking['7_off_first(dist)'].cast(DoubleType())) \
                    .withColumn('7_off_first(dir_std2)', tracking['7_off_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('7_off_first(o_std2)', tracking['7_off_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('7_off_first(s2)', tracking['7_off_first(s2)'].cast(DoubleType())) \
                    .withColumn('7_off_first(a2)', tracking['7_off_first(a2)'].cast(DoubleType())) \
                    .withColumn('7_off_first(dis2)', tracking['7_off_first(dis2)'].cast(DoubleType())) \
                    .withColumn('8_def_first(dist)', tracking['8_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('8_def_first(dir_std2)', tracking['8_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('8_def_first(o_std2)', tracking['8_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('8_def_first(s2)', tracking['8_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('8_def_first(a2)', tracking['8_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('8_def_first(dis2)', tracking['8_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('8_off_first(dist)', tracking['8_off_first(dist)'].cast(DoubleType())) \
                    .withColumn('8_off_first(dir_std2)', tracking['8_off_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('8_off_first(o_std2)', tracking['8_off_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('8_off_first(s2)', tracking['8_off_first(s2)'].cast(DoubleType())) \
                    .withColumn('8_off_first(a2)', tracking['8_off_first(a2)'].cast(DoubleType())) \
                    .withColumn('8_off_first(dis2)', tracking['8_off_first(dis2)'].cast(DoubleType())) \
                    .withColumn('9_def_first(dist)', tracking['9_def_first(dist)'].cast(DoubleType())) \
                    .withColumn('9_def_first(dir_std2)', tracking['9_def_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('9_def_first(o_std2)', tracking['9_def_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('9_def_first(s2)', tracking['9_def_first(s2)'].cast(DoubleType())) \
                    .withColumn('9_def_first(a2)', tracking['9_def_first(a2)'].cast(DoubleType())) \
                    .withColumn('9_def_first(dis2)', tracking['9_def_first(dis2)'].cast(DoubleType())) \
                    .withColumn('9_off_first(dist)', tracking['9_off_first(dist)'].cast(DoubleType())) \
                    .withColumn('9_off_first(dir_std2)', tracking['9_off_first(dir_std2)'].cast(DoubleType())) \
                    .withColumn('9_off_first(o_std2)', tracking['9_off_first(o_std2)'].cast(DoubleType())) \
                    .withColumn('9_off_first(s2)', tracking['9_off_first(s2)'].cast(DoubleType())) \
                    .withColumn('9_off_first(a2)', tracking['9_off_first(a2)'].cast(DoubleType())) \
                    .withColumn('9_off_first(dis2)', tracking['9_off_first(dis2)'].cast(DoubleType())) \
                    .withColumn('quarter', tracking['quarter'].cast(IntegerType())) \
                    .withColumn('down', tracking['down'].cast(IntegerType())) \
                    .withColumn('yardsToGo', tracking['yardsToGo'].cast(IntegerType())) \
                    .withColumn('preSnapHomeScore', tracking['preSnapHomeScore'].cast(IntegerType())) \
                    .withColumn('preSnapVisitorScore', tracking['preSnapVisitorScore'].cast(IntegerType())) \
                    .withColumn('preSnapHomeTeamWinProbability', tracking['preSnapHomeTeamWinProbability'].cast(DoubleType())) \
                    .withColumn('preSnapVisitorTeamWinProbability', tracking['preSnapVisitorTeamWinProbability'].cast(DoubleType())) \
                    .withColumn('secRemainHalf', tracking['secRemainHalf'].cast(IntegerType())) \
                    .withColumn('week', tracking['week'].cast(IntegerType())) \
                    .withColumn('preSnapPossessionTeamWinProbability', tracking['preSnapPossessionTeamWinProbability'].cast(DoubleType())) \
                    .withColumn('relativeDistLOS', tracking['relativeDistLOS'].cast(IntegerType()))

In [15]:
# designate target and features for the model
features = tracking.select(
                          # player specific metrics
                          'X_std','Y_std','dir_std','s',
                          # 'o_std','a','dis',
                          # '10_def_first(dist)','10_def_first(dir_std2)','10_def_first(s2)',
                          # #  '10_def_first(a2)','10_def_first(dis2)','10_def_first(o_std2)',
                          '10_off_first(dist)','10_off_first(dir_std2)','10_off_first(s2)',
                          # # '10_off_first(a2)','10_off_first(dis2)','10_off_first(o_std2)',
                          # '11_def_first(dist)','11_def_first(dir_std2)','11_def_first(s2)',
                          # '11_def_first(a2)','11_def_first(dis2)','11_def_first(o_std2)',
                          # '1_def_first(dist)','1_def_first(dir_std2)','1_def_first(s2)',
                          # '1_def_first(a2)','1_def_first(dis2)','1_def_first(o_std2)',
                          '1_off_first(dist)','1_off_first(dir_std2)','1_off_first(s2)',
                          # '1_off_first(a2)','1_off_first(dis2)','1_off_first(o_std2)',
                          # '2_def_first(dist)','2_def_first(dir_std2)','2_def_first(s2)',
                          # '2_def_first(a2)','2_def_first(dis2)','2_def_first(o_std2)',
                          '2_off_first(dist)','2_off_first(dir_std2)','2_off_first(s2)',
                          #  '2_off_first(a2)','2_off_first(dis2)','2_off_first(o_std2)',
                          # '3_def_first(dist)','3_def_first(dir_std2)','3_def_first(s2)',
                          # '3_def_first(a2)', '3_def_first(dis2)','3_def_first(o_std2)',
                          '3_off_first(dist)','3_off_first(dir_std2)','3_off_first(s2)',
                          # '3_off_first(a2)','3_off_first(dis2)','3_off_first(o_std2)',
                          #  '4_def_first(dist)','4_def_first(dir_std2)','4_def_first(s2)',
                          # '4_def_first(a2)','4_def_first(dis2)','4_def_first(o_std2)',
                          '4_off_first(dist)','4_off_first(dir_std2)','4_off_first(s2)',
                          # '4_off_first(a2)','4_off_first(dis2)','4_off_first(o_std2)',
                          # '5_def_first(dist)','5_def_first(dir_std2)','5_def_first(s2)',
                          # '5_def_first(a2)','5_def_first(dis2)','5_def_first(o_std2)',
                          '5_off_first(dist)','5_off_first(dir_std2)','5_off_first(s2)',
                          # '5_off_first(a2)','5_off_first(dis2)','5_off_first(o_std2)',
                          # '6_def_first(dist)','6_def_first(dir_std2)','6_def_first(s2)',
                          # # '6_def_first(a2)','6_def_first(dis2)','6_def_first(o_std2)',
                          '6_off_first(dist)','6_off_first(dir_std2)','6_off_first(s2)',
                          # # '6_off_first(a2)','6_off_first(dis2)','6_off_first(o_std2)',
                          # '7_def_first(dist)','7_def_first(dir_std2)','7_def_first(s2)',
                          # # '7_def_first(a2)','7_def_first(dis2)','7_def_first(o_std2)',
                          '7_off_first(dist)','7_off_first(dir_std2)','7_off_first(s2)',
                          # # '7_off_first(a2)','7_off_first(dis2)','7_off_first(o_std2)',
                          # '8_def_first(dist)','8_def_first(dir_std2)','8_def_first(s2)',
                          # # '8_def_first(a2)','8_def_first(dis2)','8_def_first(o_std2)',
                          '8_off_first(dist)','8_off_first(dir_std2)','8_off_first(s2)',
                          # #  '8_off_first(a2)','8_off_first(dis2)','8_off_first(o_std2)',
                          # '9_def_first(dist)','9_def_first(dir_std2)','9_def_first(s2)',
                          # # '9_def_first(a2)','9_def_first(dis2)','9_def_first(o_std2)',
                          '9_off_first(dist)','9_off_first(dir_std2)','9_off_first(s2)',
                          # '9_off_first(a2)','9_off_first(dis2)','9_off_first(o_std2)',
                           # game context
                          'quarter','down','yardsToGo','preSnapPossessionTeamWinProbability',#'preSnapHomeScore','preSnapVisitorScore'
                          'secRemainHalf','relativeDistLOS',
                          # one-hot vectors
                          'offenseVec','formationVec','receiverAlignmentVec'
                           ).columns
vector = VectorAssembler(inputCols=features, outputCol='features')
vector_df = vector.transform(tracking)
# vector_df.show(5)
# reduced feature set from 172 to 65

In [16]:
# convert routeRan to integer for the model
indexer = StringIndexer(inputCol="routeRan", outputCol="routeRanId").fit(vector_df)
vector_df = indexer.transform(vector_df)

In [17]:
# keep game and play to get EPA
# keep nflId to do player analysis
# keep frameId to determine how the important feautures act
# train = vector_df.filter((vector_df.week <= 7) & (vector_df.gameId == '2022102310') & (vector_df.playId == '1008') & (vector_df.frameId == '100')).select('gameId','playId','frameId','nflId','features','routeRanId')
train = vector_df.filter((vector_df.week <= 7)).select('gameId','playId','frameId','nflId','features','routeRanId','routeRan')
# test = vector_df.filter((vector_df.week > 7) & (vector_df.gameId == '2022110609') & (vector_df.playId == '1000') & (vector_df.frameId == '10')).select('gameId','playId','frameId','nflId','features','routeRanId')
test = vector_df.filter((vector_df.week > 7)).select('gameId','playId','frameId','nflId','features','routeRanId','routeRan')

In [18]:
vector_df.filter((vector_df.week > 7)).write.option("header",True).mode('overwrite').parquet(f"./data/test_df_underCenter.parquet")

In [None]:
# in the sample training set, we're looking at 1 frame
# we should see 5 rows at most because 5 players are covered up on the line and one must throw the ball
# train.show()

# Baseline Model

Insert random numbers between 0 and 1 into the rows and use a condition based on the baseline rates that each route is ran to predict the routes. Compare all models to this route to make sure they are more accurate than random guessing.

In [None]:
baselineModel = vector_df.select('routeRan').withColumn('randomPrediction', rand(seed=23))
# use route ran rate to create a baseline model
baselineModel = baselineModel.withColumn('randomPredictionLabel', when(col("randomPrediction") <= 0.18, "GO"). \
                                              when((col("randomPrediction") > 0.18) & (col("randomPrediction") <= 0.33), "HITCH"). \
                                              when((col("randomPrediction") > 0.33) & (col("randomPrediction") <= 0.48), "FLAT"). \
                                              when((col("randomPrediction") > 0.48) & (col("randomPrediction") <= 0.58), "OUT"). \
                                              when((col("randomPrediction") > 0.58) & (col("randomPrediction") <= 0.68), "CROSS"). \
                                              when((col("randomPrediction") > 0.68) & (col("randomPrediction") <= 0.76), "IN"). \
                                              when((col("randomPrediction") > 0.76) & (col("randomPrediction") <= 0.82), "POST"). \
                                              when((col("randomPrediction") > 0.82) & (col("randomPrediction") <= 0.87), "SLANT"). \
                                              when((col("randomPrediction") > 0.87) & (col("randomPrediction") <= 0.91), "CORNER"). \
                                              when((col("randomPrediction") > 0.91) & (col("randomPrediction") <= 0.95), "SCREEN"). \
                                              when((col("randomPrediction") > 0.95) & (col("randomPrediction") <= 0.99), "ANGLE"). \
                                              when((col("randomPrediction") > 0.99) & (col("randomPrediction") <= 1), "WHEEL"))
baselineModel = baselineModel.withColumn('correct', when(col("routeRan") == col("randomPredictionLabel"), 1).otherwise(0))
# baselineModel.agg({"correct": "avg"}).show()

In [None]:
# convert routeRan to integer for the model
baselineModelIndexer = StringIndexer(inputCol="routeRan", outputCol="routeRanId").fit(baselineModel)
baselineModel = baselineModelIndexer.transform(baselineModel)

baselineModelIndexer = StringIndexer(inputCol="randomPredictionLabel", outputCol="randomPredictionLabelId").fit(baselineModel)
baselineModel = baselineModelIndexer.transform(baselineModel)

evaluator = MulticlassClassificationEvaluator(predictionCol='randomPredictionLabelId', labelCol='routeRanId')

print('F1 Score:',evaluator.evaluate(baselineModel,{evaluator.metricName: 'f1'}))

F1 Score: 0.11487468973542875


Baseline model F1 score is 11.49, this is the number to beat.


# Random Forest

In [None]:
forest_model = RandomForestClassifier(featuresCol='features', labelCol='routeRanId',
                    predictionCol='prediction',maxDepth=30,
                    impurity='gini', subsamplingRate= .5).fit(train)
# 8 minutes on sample

rf_predictions = forest_model.transform(test)
#

Py4JJavaError: An error occurred while calling o1749.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 4 in stage 58.0 failed 1 times, most recent failure: Lost task 4.0 in stage 58.0 (TID 1331) (20b2117a6429 executor driver): java.lang.OutOfMemoryError: Java heap space
	at java.base/java.io.ObjectInputStream$HandleTable.grow(ObjectInputStream.java:4094)
	at java.base/java.io.ObjectInputStream$HandleTable.assign(ObjectInputStream.java:3900)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2219)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2414)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2433)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2458)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1049)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1048)
	at org.apache.spark.rdd.PairRDDFunctions.$anonfun$collectAsMap$1(PairRDDFunctions.scala:738)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.PairRDDFunctions.collectAsMap(PairRDDFunctions.scala:737)
	at org.apache.spark.ml.tree.impl.RandomForest$.findBestSplits(RandomForest.scala:663)
	at org.apache.spark.ml.tree.impl.RandomForest$.runBagged(RandomForest.scala:208)
	at org.apache.spark.ml.tree.impl.RandomForest$.run(RandomForest.scala:302)
	at org.apache.spark.ml.classification.RandomForestClassifier.$anonfun$train$1(RandomForestClassifier.scala:168)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:139)
	at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:47)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:114)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:78)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: java.lang.OutOfMemoryError: Java heap space
	at java.base/java.io.ObjectInputStream$HandleTable.grow(ObjectInputStream.java:4094)
	at java.base/java.io.ObjectInputStream$HandleTable.assign(ObjectInputStream.java:3900)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2219)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)
	at java.base/java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2496)
	at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2390)
	at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2228)
	at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1687)


In [None]:
# rf_predictions.show()
# route runners ran 0, 1, and 8
# model predicted 8 and 0
# this is good because the model is only predicting what it's learning
# also only seeing 5 rows in the test set which is good

+----------+------+-------+-----+--------------------+----------+--------------------+--------------------+----------+
|    gameId|playId|frameId|nflId|            features|routeRanId|       rawPrediction|         probability|prediction|
+----------+------+-------+-----+--------------------+----------+--------------------+--------------------+----------+
|2022110609|  1000|     10|41290|[31.5,32.1,24.99,...|       4.0|[8.0,1.0,0.0,0.0,...|[0.44444444444444...|       8.0|
|2022110609|  1000|     10|43399|[30.96,30.74,50.2...|       0.0|[8.0,1.0,0.0,0.0,...|[0.44444444444444...|       8.0|
|2022110609|  1000|     10|44881|[32.29,23.74,161....|       7.0|[10.0,0.0,0.0,0.0...|[0.55555555555555...|       0.0|
|2022110609|  1000|     10|42837|[31.29,31.26,-3.5...|       0.0|[8.0,1.0,0.0,0.0,...|[0.44444444444444...|       8.0|
|2022110609|  1000|     10|52465|[33.29,24.35,156....|       6.0|[9.0,2.0,0.0,0.0,...|[0.5,0.1111111111...|       0.0|
+----------+------+-------+-----+---------------

In [None]:
# convert predictions back to labels for further analysis
labelConverter = IndexToString(inputCol="prediction", outputCol="predictedLabel",labels=indexer.labels)
rfpredictionLabels = labelConverter.transform(rf_predictions)

In [None]:
# rfpredictionLabels.show(5)

+----------+------+-------+-----+--------------------+----------+--------------------+--------------------+----------+--------------+
|    gameId|playId|frameId|nflId|            features|routeRanId|       rawPrediction|         probability|prediction|predictedLabel|
+----------+------+-------+-----+--------------------+----------+--------------------+--------------------+----------+--------------+
|2022110609|  1000|     10|41290|[31.5,32.1,24.99,...|       4.0|[8.0,1.0,0.0,0.0,...|[0.44444444444444...|       8.0|        SCREEN|
|2022110609|  1000|     10|43399|[30.96,30.74,50.2...|       0.0|[8.0,1.0,0.0,0.0,...|[0.44444444444444...|       8.0|        SCREEN|
|2022110609|  1000|     10|44881|[32.29,23.74,161....|       7.0|[10.0,0.0,0.0,0.0...|[0.55555555555555...|       0.0|            GO|
|2022110609|  1000|     10|42837|[31.29,31.26,-3.5...|       0.0|[8.0,1.0,0.0,0.0,...|[0.44444444444444...|       8.0|        SCREEN|
|2022110609|  1000|     10|52465|[33.29,24.35,156....|       6

In [None]:
evaluator = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='routeRanId')

print('F1 Score:',evaluator.evaluate(rf_predictions,{evaluator.metricName: 'f1'}))

In [None]:
print(len(features))
print(len(forest_model.featureImportances))
# one-hot encoding of the teams is increasing the number of feature importance scores

142
172


In [None]:
# feature importance
feature_importances = forest_model.featureImportances
print("Feature Importances:")
for feature_idx in range(len(feature_importances)-31):
  if feature_importances[feature_idx] > 0:
    print(f"Feature {features[feature_idx]}: {feature_importances[feature_idx]}")

Feature Importances:
Feature X_std: 0.22727272727272727
Feature Y_std: 0.09090909090909091
Feature 10_off_first(dir_std2): 0.09090909090909091
Feature 1_off_first(o_std2): 0.09090909090909091
Feature 4_off_first(dis2): 0.04545454545454546
Feature 6_def_first(dist): 0.04545454545454546
Feature 7_def_first(o_std2): 0.09090909090909091
Feature 7_off_first(s2): 0.04545454545454544
Feature 8_def_first(dir_std2): 0.09090909090909091
Feature 8_def_first(o_std2): 0.09090909090909091
Feature 9_def_first(a2): 0.09090909090909091


# Logistic Regression

In [None]:
lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8)
lr_train = train.withColumn('label', train['routeRanId'].cast(IntegerType()))

# Fit the model
lrModel = lr.fit(lr_train)
lr_predictions = lrModel.transform(test)
# lr_predictions.show()

+----------+------+-------+-----+--------------------+----------+--------------------+--------------------+----------+
|    gameId|playId|frameId|nflId|            features|routeRanId|       rawPrediction|         probability|prediction|
+----------+------+-------+-----+--------------------+----------+--------------------+--------------------+----------+
|2022110609|  1000|     10|41290|[31.5,32.1,24.99,...|       4.0|[-56.094256741778...|[6.83011841328299...|       8.0|
|2022110609|  1000|     10|43399|[30.96,30.74,50.2...|       0.0|[-56.131330524725...|[5.18508011646320...|       8.0|
|2022110609|  1000|     10|44881|[32.29,23.74,161....|       7.0|[-61.287510941645...|[4.30985559462618...|       8.0|
|2022110609|  1000|     10|42837|[31.29,31.26,-3.5...|       0.0|[-56.131330524725...|[5.99355099426969...|       8.0|
|2022110609|  1000|     10|52465|[33.29,24.35,156....|       6.0|[-48.358976661903...|[3.36499613582510...|       8.0|
+----------+------+-------+-----+---------------

In [None]:
# convert predictions back to labels for further analysis
labelConverter = IndexToString(inputCol="prediction", outputCol="predictedLabel",labels=indexer.labels)
lrpredictionLabels = labelConverter.transform(lr_predictions)

In [None]:
evaluator = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='routeRanId')

print('F1 Score:',evaluator.evaluate(lr_predictions,{evaluator.metricName: 'f1'}))

F1 Score: 0.0


# Neural Network

[Multilayer Perceptron](https://spark.apache.org/docs/latest/ml-classification-regression.html#multilayer-perceptron-classifier)

In [19]:
start = time.time()

# specify layers for the neural network:
# input layer of size (features), two intermediate of size 5 and 4
# and output of size (classes)

###########################
## WR and 5 nearest defensive/offensive players
###########################
# layers = [len(features), 5, 4, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.115
# layers = [len(features), 30, 20, 25, 15, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.11677356218578241
# layers = [len(features), 30, 20, 25, 25, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.124831036863954
# layers = [len(features), 30, 20, 25, 25, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.11349247702144448
# 4286.269951581955 seconds
# layers = [len(features), 30, 20, 25, 25, 30, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.052985702879573854
# 4569.943071365356
# layers = [len(features), 30, 20, 25, 25, 20, 25, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.07277808983057245
# 4865.416504383087
# layers = [len(features), 30, 20, 25, 25, 20, 25, 15, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.10149990900055764
# 5599.684297800064

###################
## with 91 features, all 11 players on both sides of the ball
###################
# layers = [len(features), 70, 50, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.12128216070864331
# 12503.670677185059

##################################
## with formation, alignment, and team one-hot encodings as well as pre snap win probability
##################################
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 70, 50, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.1203704452481865
# 8125.783495426178
# add layers
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 70, 50, 60, 65, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.12222390801000574
# 18002.280100107193

###########################
## with one-hot encodings for formation, alignment, and team as well as spatial data for all 22 players (188 features)
###########################
# add features to best performing model above, don't add hidden layers/nodes
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, 25, 25, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.11405496414013641
# 9223.362916469574
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 90, 70, 60, 80, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.1139854620169429
# 24632.97947692871

#####################
## WR in question spatial data, game context, one-hot features (55 features)
#####################
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.07162395370718125
# 1849.9823768138885
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, 25, 25, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.09835853805696759
# 2606.8532614707947
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 40, 50, 45, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.05293037753514732
# 11160.735531330109

######################
## all offenseive players, game context, and one-hot encoding (85 features)
######################
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.12124986352306948
# 5551.74765920639
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 25, 25, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.12148048362445743
# 5417.08841252327

#####################
## all offenseive players, game context, one-hot encoding, and relative distance to LOS (86 features)
#####################
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.12451741660950025
# 6378.500684499741
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, 25, 25, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.11771552913935171
# 7888.775723934174

#####################
## all offenseive players, game context, one-hot encoding, and relative distance to LOS (86 features)
## filter out outlier formations (WILDCAT & JUMBO)
#####################
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.12392962113905438
# 5216.91431927681

#####################
## all offenseive players, game context, one-hot encoding, and relative distance to LOS (86 features)
## filter out outlier receiver alignment (1x0, 3x0, 3x3)
#####################
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.12423352133512806
# 5759.737510919571

#####################
## all offenseive players, game context, one-hot encoding, and relative distance to LOS (86 features)
## filter out outlier receiver alignment and formation
#####################
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.1241107083751171
# 12670.27545261383

#####################
## all offenseive players, game context, one-hot encoding, and relative distance to LOS (86 features)
## filter out last 2 minutes of each half
#####################
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.12060441632533488
# 5095.0004506111145

#####################
## all offenseive players, game context, one-hot encoding, and relative distance to LOS (86 features)
## filter out no huddle
#####################
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.12369144612155548
# 2577.95259308815

#####################
## all offenseive players, game context, one-hot encoding, and relative distance to LOS (86 features)
## only shotgun plays
#####################
# layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.13224386449916575
# 4356.726322412491

#####################
## all offenseive players, game context, one-hot encoding, and relative distance to LOS (86 features)
## only under center plays
#####################
layers = [len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2), 30, 20, len(player_play.filter(player_play.wasRunningRoute=='1').select('routeRan').distinct().collect())]
# 0.15309871102350278
# 531.6402590274811

# create the trainer and set its parameters
nntrainer = MultilayerPerceptronClassifier(maxIter=100, layers=layers, blockSize=128, seed=1234)

# train the model
nn_train = train.withColumn('label', train['routeRanId'].cast(IntegerType()))
nnmodel = nntrainer.fit(nn_train)

# make predictions
nnpredictions = nnmodel.transform(test)

end = time.time()
print(f'total time elapsed: {end-start}')

total time elapsed: 798.2586584091187


In [None]:
len(features) + (teamCount-2) + (formationCount-2) + (alignmentCount-2)

86

In [None]:
# nnpredictions.show()

+----------+------+-------+-----+--------------------+----------+--------------------+--------------------+----------+
|    gameId|playId|frameId|nflId|            features|routeRanId|       rawPrediction|         probability|prediction|
+----------+------+-------+-----+--------------------+----------+--------------------+--------------------+----------+
|2022110609|  1000|     10|41290|[31.5,32.1,24.99,...|       4.0|[22.0520654771062...|[0.49994054974623...|       8.0|
|2022110609|  1000|     10|43399|[30.96,30.74,50.2...|       0.0|[22.0520626812761...|[0.49993897775921...|       8.0|
|2022110609|  1000|     10|44881|[32.29,23.74,161....|       7.0|[20.0217543265752...|[0.66665205617430...|       0.0|
|2022110609|  1000|     10|42837|[31.29,31.26,-3.5...|       0.0|[22.0478511252248...|[0.49757274709153...|       8.0|
|2022110609|  1000|     10|52465|[33.29,24.35,156....|       6.0|[20.0217543265752...|[0.66665205617430...|       0.0|
+----------+------+-------+-----+---------------

In [20]:
# convert predictions back to labels for further analysis
labelConverter = IndexToString(inputCol="prediction", outputCol="predictedLabel",labels=indexer.labels)
nnpredictionLabels = labelConverter.transform(nnpredictions)

In [21]:
evaluator = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='routeRanId')

print('F1 Score:',evaluator.evaluate(nnpredictionLabels,{evaluator.metricName: 'f1'}))

F1 Score: 0.13573223121586217


In [None]:
# nnpredictionLabels.show(5)

+----------+------+-------+-----+--------------------+----------+--------+--------------------+--------------------+----------+--------------+
|    gameId|playId|frameId|nflId|            features|routeRanId|routeRan|       rawPrediction|         probability|prediction|predictedLabel|
+----------+------+-------+-----+--------------------+----------+--------+--------------------+--------------------+----------+--------------+
|2022103006|  1001|      1|44879|[78.2,28.81,175.4...|       9.0|  SCREEN|[0.45701019800710...|[0.12118411115369...|       2.0|          FLAT|
|2022103006|  1001|     10|45244|[79.28,28.67,173....|       6.0|    POST|[0.81439116712676...|[0.17514325899909...|       0.0|            GO|
|2022103006|  1001|    100|46160|[82.35,16.41,167....|       0.0|      GO|[0.98702107115736...|[0.20549969187963...|       0.0|            GO|
|2022103006|  1001|    100|53098|[83.24,38.59,-5.3...|       8.0|  CORNER|[1.40003752774375...|[0.28893568838793...|       0.0|            GO|

# Write Out Results

Write the results from the best model to a parquet file.

In [22]:
# export nnpredictionLabels to make sure the labels are exported for the predictions
nnpredictionLabels.write.option("header",True).mode('overwrite').parquet(f"./data/nnpredictions_UnderCenter.parquet")

# Save Model

In [23]:
nnmodel.save("./modelUnderCenter")