# Model
In this notebook, we:
- Define the structure of our prediction model.
- Try different models and assess their performance.
- Predict on the 2024 March Madness bracket.

## Imports

In [13]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.neighbors import KNeighborsRegressor, KNeighborsClassifier
from sklearn.metrics import mean_squared_error, r2_score, log_loss, accuracy_score, confusion_matrix, classification_report
import xgboost as xgb

# display 100 rows and 100 columns
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 100)

# global random seed
SEED = 9

## Load Data

In [5]:
# load in features df
fcomp = pd.read_csv('data/processed/features_compact.csv')
# fdet = pd.read_csv('data/processed/features_detailed.csv')

## ML Model

In [9]:
fcomp

Unnamed: 0,A_1_pos_game_ratio,A_1_pos_losses,A_1_pos_win_ratio,A_1_pos_wins,A_Loc,A_PlayIn,A_Region,A_Seed,A_TeamID,A_avg_loss_diff,A_avg_pts_against_loss,A_avg_pts_against_win,A_avg_pts_for_loss,A_avg_pts_for_win,A_avg_win_diff,A_away_losses,A_away_win_ratio,A_away_wins,A_home_losses,A_home_win_ratio,A_home_wins,A_max_loss_diff,A_max_win_diff,A_neutral_losses,A_neutral_win_ratio,A_neutral_wins,A_num_losses,A_num_wins,A_ot_losses,A_ot_ratio,A_ot_win_ratio,A_ot_wins,A_recent_avg_pts_against,A_recent_avg_pts_for,A_recent_avg_score_diff,A_recent_losses,A_recent_pts_against_loss,A_recent_pts_against_win,A_recent_pts_for_loss,A_recent_pts_for_win,A_recent_win_ratio,A_recent_wins,A_std_loss_diff,A_std_pts_against_loss,A_std_pts_against_win,A_std_pts_for_loss,A_std_pts_for_win,A_std_win_diff,A_win_ratio,B_1_pos_game_ratio,...,B_1_pos_win_ratio,B_1_pos_wins,B_PlayIn,B_Region,B_Seed,B_TeamID,B_avg_loss_diff,B_avg_pts_against_loss,B_avg_pts_against_win,B_avg_pts_for_loss,B_avg_pts_for_win,B_avg_win_diff,B_away_losses,B_away_win_ratio,B_away_wins,B_home_losses,B_home_win_ratio,B_home_wins,B_max_loss_diff,B_max_win_diff,B_neutral_losses,B_neutral_win_ratio,B_neutral_wins,B_num_losses,B_num_wins,B_ot_losses,B_ot_ratio,B_ot_win_ratio,B_ot_wins,B_recent_avg_pts_against,B_recent_avg_pts_for,B_recent_avg_score_diff,B_recent_losses,B_recent_pts_against_loss,B_recent_pts_against_win,B_recent_pts_for_loss,B_recent_pts_for_win,B_recent_win_ratio,B_recent_wins,B_std_loss_diff,B_std_pts_against_loss,B_std_pts_against_win,B_std_pts_for_loss,B_std_pts_for_win,B_std_win_diff,B_win_ratio,NumOT,Season,score_diff,win
0,0.454545,12,0.200000,3,N,0,X,9,1116,8.083333,67.083333,58.619048,59.000000,68.952381,10.333333,8,0.333333,4,1,0.909091,10,20,35,3,0.700000,7,12,21,0,0.000000,0.0,0,61.500000,74.166667,12.666667,2.0,138.0,231.0,132.0,313.0,0.666667,4.0,7.501010,9.848473,8.570159,12.526336,10.589977,7.558659,0.636364,0.366667,...,0.090909,1,0,X,8,1234,5.500000,62.200000,57.800000,56.700000,76.250000,18.450000,6,0.333333,3,3,0.833333,15,20,49,1,0.666667,2,10,20,0,0.000000,0.0,0,63.833333,62.500000,-1.333333,4.0,251.0,132.0,218.0,157.0,0.333333,2.0,5.562773,8.024961,12.547174,6.147267,14.519950,13.359306,0.666667,0,1985,9,1
1,0.586207,11,0.352941,6,N,0,Z,11,1120,9.636364,76.181818,60.833333,66.545455,72.666667,11.833333,6,0.454545,5,4,0.666667,8,19,42,1,0.833333,5,11,18,1,0.068966,0.5,1,57.000000,63.166667,6.166667,1.0,78.0,264.0,73.0,306.0,0.833333,5.0,7.046598,13.242494,10.325582,11.129813,14.796661,12.015921,0.620690,0.400000,...,0.200000,2,0,Z,6,1345,14.750000,81.875000,57.529412,67.125000,70.058824,12.529412,4,0.600000,6,4,0.714286,10,43,31,0,1.000000,1,8,17,0,0.000000,0.0,0,68.833333,66.333333,-2.500000,2.0,172.0,241.0,111.0,287.0,0.666667,4.0,12.395276,8.288331,9.083761,12.368595,9.555565,8.397303,0.680000,0,1985,1,1
2,0.793103,18,0.217391,5,N,0,W,16,1250,10.833333,71.111111,68.727273,60.277778,74.727273,6.000000,7,0.363636,4,8,0.333333,4,31,18,3,0.500000,3,18,11,0,0.034483,1.0,1,70.000000,69.500000,-0.500000,1.0,75.0,345.0,58.0,359.0,0.833333,5.0,10.205247,9.310897,6.987001,7.168604,9.498325,6.115554,0.379310,0.148148,...,0.500000,2,0,W,1,1207,1.500000,65.500000,59.640000,64.000000,76.680000,17.040000,1,0.857143,6,1,0.923077,12,2,41,0,1.000000,7,2,25,0,0.000000,0.0,0,64.333333,83.666667,19.333333,0.0,0.0,386.0,0.0,502.0,1.000000,6.0,0.707107,0.707107,10.934959,1.414214,11.918893,10.043572,0.925926,0,1985,-25,0
3,0.407407,7,0.363636,4,N,0,Y,9,1229,9.428571,73.571429,62.850000,64.142857,74.200000,11.350000,4,0.600000,6,3,0.785714,11,13,31,0,1.000000,3,7,20,0,0.000000,0.0,0,64.833333,72.000000,7.166667,2.0,148.0,241.0,130.0,302.0,0.666667,4.0,4.353433,8.263517,11.202796,7.174691,11.445983,8.273452,0.740741,0.535714,...,0.400000,6,0,Y,8,1425,8.111111,66.888889,63.526316,58.777778,72.947368,9.421053,3,0.727273,8,5,0.642857,9,21,33,1,0.666667,2,9,19,0,0.000000,0.0,0,68.166667,68.833333,0.666667,3.0,217.0,192.0,193.0,220.0,0.500000,3.0,7.474029,8.950481,8.656202,9.162120,9.324244,8.591441,0.678571,0,1985,3,1
4,0.333333,7,0.300000,3,N,0,Z,3,1242,8.857143,76.285714,68.608696,67.428571,78.652174,10.043478,4,0.600000,6,0,1.000000,14,19,27,3,0.500000,3,7,23,0,0.000000,0.0,0,72.166667,76.833333,4.666667,1.0,75.0,358.0,59.0,402.0,0.833333,5.0,6.669047,16.660475,8.907296,14.842025,8.637175,7.003105,0.766667,0.370370,...,0.300000,3,0,Z,14,1325,7.571429,75.428571,58.650000,67.857143,67.450000,8.800000,5,0.583333,7,0,1.000000,10,17,23,2,0.600000,3,7,20,0,0.000000,0.0,0,58.166667,68.333333,10.166667,1.0,67.0,282.0,66.0,344.0,0.833333,5.0,5.028490,8.676734,8.845308,5.984106,10.625070,6.152449,0.740741,0,1985,11,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4029,0.290323,6,0.333333,3,A,0,W,2,3268,13.833333,83.333333,65.440000,69.500000,81.280000,15.840000,2,0.846154,11,2,0.846154,11,25,37,2,0.600000,3,6,25,0,0.000000,0.0,0,70.166667,79.500000,9.333333,1.0,89.0,332.0,84.0,393.0,0.833333,5.0,8.447879,10.308572,10.116323,11.256109,9.484725,10.922759,0.806452,0.000000,...,0.000000,0,0,W,1,3376,0.000000,0.000000,51.093750,0.000000,81.437500,30.343750,0,1.000000,13,0,1.000000,15,0,70,0,1.000000,4,0,32,0,0.062500,1.0,2,59.166667,76.166667,17.000000,0.0,0.0,355.0,0.0,457.0,1.000000,6.0,0.000000,0.000000,13.059070,0.000000,11.123754,18.836494,1.000000,0,2023,-11,0
4030,0.258065,4,0.500000,4,N,0,Z,1,3439,9.500000,67.500000,55.518519,58.000000,74.666667,19.148148,3,0.727273,8,1,0.933333,14,11,61,0,1.000000,5,4,27,0,0.000000,0.0,0,53.166667,67.000000,13.833333,0.0,0.0,319.0,0.0,402.0,1.000000,6.0,3.000000,6.454972,13.232740,6.055301,10.114727,16.280483,0.870968,0.281250,...,0.222222,2,0,Z,3,3326,18.142857,84.000000,63.840000,65.857143,84.920000,21.080000,2,0.833333,10,4,0.733333,11,36,56,1,0.800000,4,7,25,0,0.031250,1.0,1,75.166667,74.500000,-0.666667,2.0,181.0,270.0,146.0,301.0,0.666667,4.0,13.005493,10.801234,12.116380,7.425824,9.962095,16.247872,0.781250,0,2023,10,1
4031,0.000000,0,0.000000,0,N,0,W,1,3376,0.000000,0.000000,51.093750,0.000000,81.437500,30.343750,0,1.000000,13,0,1.000000,15,0,70,0,1.000000,4,0,32,0,0.062500,1.0,2,59.166667,76.166667,17.000000,0.0,0.0,355.0,0.0,457.0,1.000000,6.0,0.000000,0.000000,13.059070,0.000000,11.123754,18.836494,1.000000,0.250000,...,0.250000,2,0,X,2,3234,10.333333,89.500000,67.153846,79.166667,89.461538,22.307692,4,0.636364,7,1,0.937500,15,28,54,1,0.800000,4,6,26,0,0.062500,1.0,2,75.833333,82.833333,7.000000,1.0,96.0,359.0,68.0,429.0,0.833333,5.0,9.584710,4.722288,12.533770,6.177918,12.063932,16.600649,0.812500,0,2023,-4,0
4032,0.100000,2,0.333333,1,N,0,Y,3,3261,13.000000,78.500000,56.250000,65.500000,85.428571,29.178571,1,0.888889,8,0,1.000000,15,24,75,1,0.833333,5,2,28,0,0.033333,1.0,1,66.000000,77.500000,11.500000,1.0,69.0,327.0,67.0,398.0,0.833333,5.0,15.556349,13.435029,12.219065,2.121320,14.260613,21.060853,0.933333,0.258065,...,0.500000,4,0,Z,1,3439,9.500000,67.500000,55.518519,58.000000,74.666667,19.148148,3,0.727273,8,1,0.933333,14,11,61,0,1.000000,5,4,27,0,0.000000,0.0,0,53.166667,67.000000,13.833333,0.0,0.0,319.0,0.0,402.0,1.000000,6.0,3.000000,6.454972,13.232740,6.055301,10.114727,16.280483,0.870968,0,2023,7,1


In [11]:
# X and y
X = fcomp.drop(['A_Region', 'B_Region', 'A_TeamID', 'B_TeamID', 'A_away_losses', 'A_away_wins', 'A_home_losses', 'A_home_wins', 'A_neutral_losses', 'A_neutral_wins', 'B_away_losses', 
                'B_away_wins', 'B_home_losses', 'B_home_wins', 'B_neutral_losses', 'B_neutral_wins', 'A_num_losses', 'A_num_wins', 'B_num_losses', 'B_num_wins', 'A_ot_losses', 
                'A_ot_wins', 'B_ot_losses', 'B_ot_wins', 'A_recent_losses', 'A_recent_wins', 'B_recent_losses', 'B_recent_wins', 'NumOT', 'Season', 'score_diff', 'win'], axis=1)
y_regression = fcomp['score_diff']
y_binary = fcomp['win']

In [16]:
X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4034 entries, 0 to 4033
Data columns (total 69 columns):
 #   Column                     Non-Null Count  Dtype  
---  ------                     --------------  -----  
 0   A_1_pos_game_ratio         4034 non-null   float64
 1   A_1_pos_losses             4034 non-null   int64  
 2   A_1_pos_win_ratio          4034 non-null   float64
 3   A_1_pos_wins               4034 non-null   int64  
 4   A_Loc                      4034 non-null   object 
 5   A_PlayIn                   4034 non-null   int64  
 6   A_Seed                     4034 non-null   int64  
 7   A_avg_loss_diff            4034 non-null   float64
 8   A_avg_pts_against_loss     4034 non-null   float64
 9   A_avg_pts_against_win      4034 non-null   float64
 10  A_avg_pts_for_loss         4034 non-null   float64
 11  A_avg_pts_for_win          4034 non-null   float64
 12  A_avg_win_diff             4034 non-null   float64
 13  A_away_win_ratio           4034 non-null   float

In [15]:
# split data
X_train, X_test, y_train, y_test = train_test_split(X, y_regression, test_size=0.2, random_state=SEED)

# scale data
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# linear regression
lr = LinearRegression()
lr.fit(X_train, y_train)
y_pred = lr.predict(X_test)

# metrics
print('Linear Regression')
print('MSE:', mean_squared_error(y_test, y_pred))
print('R2:', r2_score(y_test, y_pred))

ValueError: could not convert string to float: 'N'