In [1]:
from fastai.tabular.all import *
import numpy as np
import pandas as pd
from pathlib import Path

In [2]:
DATA_PATH = Path('../../data')

In [3]:
df = pd.read_parquet(DATA_PATH.joinpath('processed', 'game_logs_standings_v2.parquet'))

In [4]:
df.columns.tolist()

['Gm#',
 'Year',
 'Date',
 'Tm',
 'Opp',
 'W/L',
 'R',
 'RA',
 'W-L',
 'HomeTeam_Rank',
 'Win',
 'Loss',
 'Time',
 'D/N',
 'HomeTeam_cLI',
 'Streak',
 'VisitingTeam',
 'HomeTeam',
 'HomeTeam_W',
 'HomeTeam_L',
 'HomeTeam_Streak_count',
 'VisitingTeam_W',
 'VisitingTeam_L',
 'VisitingTeam_Rank',
 'VisitingTeam_cLI',
 'VisitingTeam_Streak_count',
 'Attendance_TRUTH_y',
 'NumberofGames',
 'DayofWeek',
 'VisitingTeamLeague',
 'VisitingTeamGameNumber',
 'HomeTeamLeague',
 'HomeTeamGameNumber',
 'VistingTeamScore',
 'HomeTeamScore',
 'NumberofOuts',
 'DayNight',
 'BallParkID',
 'LengthofGame',
 'VisitingTeam_LineScore',
 'HomeTeam_LineScore',
 'VisitingTeamOffense_AtBats',
 'VisitingTeamOffense_Hits',
 'VisitingTeamOffense_Doubles',
 'VisitingTeamOffense_Triples',
 'VisitingTeamOffense_Homeruns',
 'VisitingTeamOffense_RBIs',
 'VisitingTeamOffense_SacrificeHits',
 'VisitingTeamOffense_SacrificeFlies',
 'VisitingTeamOffense_HitbyPitch',
 'VisitingTeamOffense_Walks',
 'VisitingTeamOffense_Inten

In [5]:
CONT_FEATURES = [
        "avg_attendance_1_yr_ago",
        "avg_attendance_2_yr_ago",
        "avg_attendance_3_yr_ago",
        "is_holiday",
        "Year",
        "Month",
        "Week",
        "DayNight",
        "Dayofyear",
        "Is_month_end",
        "Is_month_start",
        "Is_quarter_end",
        "Is_quarter_start",
        "Is_year_end",
        "Is_year_start",
        "Stadium_Capacity",
        "HomeTeam_cLI",
        "HomeTeam_Rank",
        "HomeTeam_W",
        "HomeTeam_Streak_count",
        "HomeTeamGameNumber",
        "VisitingTeam_cLI",
        "VisitingTeam_Rank",
        "VisitingTeam_L",
        "VisitingTeam_Streak_count",
        "VisitingTeamGameNumber",
]
# also look at WAR
CAT_FEATURES = ["BallParkID",
                "Dayofweek"]

In [6]:
splits = RandomSplitter(valid_pct=0.2)(range_of(df))

In [7]:
to = TabularPandas(df, procs=[Categorify, FillMissing, Normalize],
                   cat_names = CAT_FEATURES,
                   cont_names = CONT_FEATURES,
                   y_names="Attendance_TRUTH_y",
                   splits=splits)

In [8]:
to.xs.iloc[:2]

Unnamed: 0,BallParkID,Dayofweek,HomeTeam_Player9_Position,VisitingTeam_Player9_Position,avg_attendance_1_yr_ago,avg_attendance_2_yr_ago,avg_attendance_3_yr_ago,is_holiday,Year,Month,...,HomeTeam_cLI,HomeTeam_Rank,HomeTeam_W,HomeTeam_Streak_count,HomeTeamGameNumber,VisitingTeam_cLI,VisitingTeam_Rank,VisitingTeam_L,VisitingTeam_Streak_count,VisitingTeamGameNumber
49213,30,7,8,4,-0.675772,-0.879641,-1.369568,-0.132412,-1.644811,-1.48385,...,0.604392,-1.329915,-1.469444,1.092704,-1.619126,0.440247,0.029698,-1.510981,-1.079063,-1.619114
23725,19,4,1,1,-0.239227,-0.045054,-0.013027,-0.132412,-1.644811,-0.325162,...,-0.667521,0.738687,-0.529344,0.307292,-0.442371,0.733894,-0.65971,-0.650072,-0.300195,-0.506505


In [9]:
dls = to.dataloaders(bs=64)

In [10]:
dls.show_batch()

Unnamed: 0,BallParkID,Dayofweek,HomeTeam_Player9_Position,VisitingTeam_Player9_Position,avg_attendance_1_yr_ago,avg_attendance_2_yr_ago,avg_attendance_3_yr_ago,is_holiday,Year,Month,Week,DayNight,Dayofyear,Is_month_end,Is_month_start,Is_quarter_end,Is_quarter_start,Is_year_end,Is_year_start,Stadium_Capacity,HomeTeam_cLI,HomeTeam_Rank,HomeTeam_W,HomeTeam_Streak_count,HomeTeamGameNumber,VisitingTeam_cLI,VisitingTeam_Rank,VisitingTeam_L,VisitingTeam_Streak_count,VisitingTeamGameNumber,Attendance_TRUTH_y
0,SAN02,4,1,1,38093.499875,24892.75023,23915.666123,-2.830768e-10,2006.0,8.0,31.0,1.0,216.0,1.054686e-09,1.305402e-09,-1.262277e-10,1.807785e-10,0.0,0.0,40845.000063,1.9,1.0,56.0,-1.0,109.0,0.23,5.0,59.999999,1.0,109.000001,36538.0
1,NYC20,6,1,1,30082.234359,30384.818363,32547.810593,-2.830768e-10,2011.0,6.0,24.0,-2.352776e-08,170.0,1.054686e-09,1.305402e-09,-1.262277e-10,1.807785e-10,0.0,0.0,41921.999978,0.81,3.0,35.0,-1.0,72.0,0.87,3.0,38.0,1.0,73.0,36213.0
2,CHI11,2,1,1,38247.109144,39258.199326,34288.499886,-2.830768e-10,2006.0,5.0,22.0,1.0,150.999999,1.0,1.305402e-09,-1.262277e-10,1.807785e-10,0.0,0.0,41915.000075,0.26,5.0,20.0,-1.0,52.000001,1.21,2.0,24.000001,1.0,52.999999,39810.0
3,KAN06,2,6,6,22732.000001,23615.000174,18503.11143,-2.830768e-10,2013.0,8.0,32.0,1.0,218.999998,1.054686e-09,1.305402e-09,-1.262277e-10,1.807785e-10,0.0,0.0,37903.000112,0.72,3.0,58.000001,1.0,110.999999,0.01,4.0,62.0,-1.0,110.999999,20198.0
4,BAL12,4,6,2,44191.109064,42749.800318,41973.777289,-2.830768e-10,2008.0,8.0,34.0,1.0,235.0,1.054686e-09,1.305402e-09,-1.262277e-10,1.807785e-10,0.0,0.0,45970.999974,0.02,5.0,62.0,-1.0,128.000001,0.52,3.0,59.999999,1.0,128.0,43543.0
5,DET05,1,8,6,38976.687421,33835.890571,36230.24976,-2.830768e-10,2014.0,9.0,37.0,1.0,252.000003,1.054686e-09,1.305402e-09,-1.262277e-10,1.807785e-10,0.0,0.0,41083.000017,4.2,1.0,80.000001,3.0,145.000002,4.11,2.0,64.999999,-2.0,144.000003,32603.0
6,NYC21,3,9,2,35328.624844,36410.4451,43853.554459,-2.830768e-10,2016.0,9.0,36.0,1.0,252.000003,1.054686e-09,1.305402e-09,-1.262277e-10,1.807785e-10,0.0,0.0,54251.000122,1.27,4.0,74.0,5.0,139.000003,3.027732e-08,5.0,79.999999,-1.0,139.000003,27631.0
7,ARL02,3,8,2,23340.000259,29138.999975,19497.428079,-2.830768e-10,2011.0,9.0,37.0,1.0,257.999999,1.054686e-09,1.305402e-09,-1.262277e-10,1.807785e-10,0.0,0.0,49114.999913,1.83,1.0,85.999999,4.0,150.000001,3.027732e-08,2.0,74.999999,-3.0,147.000001,44242.0
8,DET05,2,2,4,19508.999556,14534.999748,21604.000211,-2.830768e-10,2004.0,4.0,18.0,1.0,118.999999,1.054686e-09,1.305402e-09,-1.262277e-10,1.807785e-10,0.0,0.0,41083.000017,1.13,3.0,12.0,1.0,21.0,1.24,1.0,8.999999,-1.0,21.000002,17175.0
9,PHO01,1,1,1,32547.810604,32810.867188,31671.187567,-2.830768e-10,2009.0,6.0,26.0,1.0,174.0,1.054686e-09,1.305402e-09,-1.262277e-10,1.807785e-10,0.0,0.0,48518.999864,0.11,5.0,30.0,1.0,71.0,1.45,1.0,32.0,-5.0,69.000001,21379.0


In [11]:
learn = tabular_learner(dls, metrics=mae)

In [12]:
learn.fit_one_cycle(10)

epoch,train_loss,valid_loss,mae,time
0,1057051456.0,1048770240.0,30617.255859,00:08
1,1030241920.0,1045777216.0,30586.882812,00:07
2,1010356416.0,1037157120.0,30484.558594,00:07
3,1014460032.0,1027110208.0,30367.849609,00:07
4,1012610624.0,1013784512.0,30203.681641,00:07
5,986152960.0,997486848.0,29984.294922,00:07
6,980194240.0,990828096.0,29917.976562,00:07
7,981999296.0,986917184.0,29869.458984,00:07
8,969039232.0,982042560.0,29799.226562,00:07
9,990008960.0,986383232.0,29871.601562,00:07


In [13]:
#function to embed features ,obtained from fastai forums
def embed_features(learner, xs):
    xs = xs.copy()
    for i, feature in enumerate(learner.dls.cat_names):
        emb = learner.model.embeds[i]
        new_feat = pd.DataFrame(emb(tensor(xs[feature], dtype=torch.int64, device='mps')), index=xs.index, columns=[f'{feature}_{j}' for j in range(emb.embedding_dim)])
        xs.drop(columns=feature, inplace=True)
        xs = xs.join(new_feat)
    return xs

embeddings = embed_features(learn, to.all_cols)

In [14]:
embeddings.to_parquet(DATA_PATH.joinpath('processed', 'train_v2.parquet'))