# START TRAINING

In [1]:
import ast
import json
import warnings

import pandas as pd
from snowflake.ml.modeling.impute import SimpleImputer
from snowflake.ml.modeling.metrics import accuracy_score
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.preprocessing import OneHotEncoder
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session
from snowflake.snowpark import types as T
from snowflake.snowpark.functions import col
from snowflake.snowpark import functions as F
from snowflake.snowpark.functions import when, lit

import numpy as np

warnings.simplefilter(action="ignore", category=UserWarning)

In [2]:
session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()

SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


In [3]:
final = (
    session.table("MEN.FINAL_FEATURES")
    .with_column_renamed("WTEAMID", "W_TEAMID")
    .with_column_renamed("LTEAMID", "L_TEAMID")
    .with_column_renamed("WSCORE", "W_SCORE")
    .with_column_renamed("LSCORE", "L_SCORE")
)

In [4]:
OHE = OneHotEncoder(
    input_cols=["W_CONF","L_CONF"],
    output_cols=["W_CONF","L_CONF"],
    drop_input_cols=True,
    drop="first",
    handle_unknown="ignore",
)

final = OHE.fit(final).transform(final)
final.show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [5]:
parameters = {
    "n_estimators": [100, 200, 300, 400, 500],
    "learning_rate": [0.1, 0.2, 0.3, 0.4, 0.5],
    "max_depth": list(range(3, 6, 1)),
    "min_child_weight": list(range(1, 6, 1)),
}

In [6]:
def add_losing_matches(df):
    lose_rename = {
        "W_TEAMID": "L_TEAMID", 
        "W_SCORE" : "L_SCORE", 
        "L_TEAMID" : "W_TEAMID",
        "L_SCORE": "W_SCORE",
    }
    lose_rename.update({c : "L_" + c[2:] for c in df.columns if c.startswith('W_')})
    lose_rename.update({c : "W_" + c[2:] for c in df.columns if c.startswith('L_')})
    
    win_df = df.copy()
    lose_df = df.copy()
    
    lose_df = lose_df.rename(columns=lose_rename)
    
    return pd.concat([win_df, lose_df], axis=0, sort=False)

In [7]:
train = final.filter(F.col('SEASON') <= 2023)
test = final.filter(F.col('SEASON') > 2021)

In [8]:
df_pd = train.to_pandas()
train_2 = add_losing_matches(df_pd)
train_2 = session.create_dataframe(train_2)

In [9]:
train_2 = train_2.withColumn(
    "WIN_INDICATOR", F.when(train_2["W_SCORE"] > train_2["L_SCORE"], 1).otherwise(0)
)

train_2 = train_2.with_columns(
    ["W_WLOCN", "W_WLOCH", "W_WLOCA", "L_WLOCN", "L_WLOCH", "L_WLOCA"],
    [
        F.col("W_WLOCN").cast(T.LongType()),
        F.col("W_WLOCH").cast(T.LongType()),
        F.col("W_WLOCA").cast(T.LongType()),
        F.col("L_WLOCN").cast(T.LongType()),
        F.col("L_WLOCH").cast(T.LongType()),
        F.col("L_WLOCA").cast(T.LongType()),
    ],
)

In [10]:
train_r0 = train_2.filter(F.col('ROUND') == 0)
train_r1 = train_2.filter(F.col('ROUND') == 1)
train_r2 = train_2.filter(F.col('ROUND') == 2)
train_r3 = train_2.filter(F.col('ROUND') == 3)
train_r4 = train_2.filter(F.col('ROUND') == 4)
train_r5 = train_2.filter(F.col('ROUND') == 5)
train_r6 = train_2.filter(F.col('ROUND') == 6)

In [11]:
session.use_warehouse('MM_L')

In [12]:
r1_model = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train_r1.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
r1_model.fit(train_r1)

r2_model = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train_r2.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
r2_model.fit(train_2)

r3_model = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train_r3.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
r3_model.fit(train_r3)

r4_model = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train_r4.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
r4_model.fit(train_r4)

r5_model = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train_r5.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
r5_model.fit(train_r5)

r6_model = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train_r6.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
r6_model.fit(train_r6)

The version of package 'snowflake-snowpark-python' in the local environment is 1.13.0, which does not fit the criteria for the requirement 'snowflake-snowpark-python<2'. Your UDF might not work when the package version is different between the server and your local environment.
Package 'fastparquet' is not installed in the local environment. Your UDF might not work when the package is installed on the server but not on your local environment.
The version of package 'pyarrow' in the local environment is 15.0.1, which does not fit the criteria for the requirement 'pyarrow<14'. Your UDF might not work when the package version is different between the server and your local environment.
The version of package 'cachetools' in the local environment is 5.3.3, which does not fit the criteria for the requirement 'cachetools<6'. Your UDF might not work when the package version is different between the server and your local environment.
The version of package 'snowflake-snowpark-python' in the loc

<snowflake.ml.modeling.model_selection.grid_search_cv.GridSearchCV at 0x7fdc50e092a0>

# Predicting the Bracket & final four

## 2024 Data

In [13]:
season_2024 = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024).drop('REGION')

In [14]:
tourney_2024 = session.table('COMMON.TOURNEY_SEEDS_2024').filter(F.col('TOURNAMENT') == 'M')

In [15]:
seed_value = (
    tourney_2024
    .with_column("REGION", F.substring(F.col("SEED"), 1, 1))
    .with_column(
        "SEED", F.substring(F.col("SEED"), 2, F.length(F.col("SEED")) - 1)
    )
    .select("TEAMID", "REGION", "SEED")
    .with_column(
        "SEED",
        F.cast(F.regexp_replace(F.col("SEED"), "[a-z]", ""), T.IntegerType()),
    )
).select(F.col('TEAMID').alias('TEAMID1'), 'REGION','SEED')

In [16]:
round_of_64 = season_2024.join(seed_value, season_2024.TEAMID == seed_value.TEAMID1).drop('TEAMID1')


In [17]:
df1 = round_of_64.select(
    *[F.col(col).alias(f"W_{col}") for col in round_of_64.columns]
)

df2 = round_of_64.select(
    *[F.col(col).alias(f"L_{col}") for col in round_of_64.columns]
)

In [18]:
first_round = df1.join(
    df2,
    (df1.w_region == df2.l_region)
    & (
        (df1.w_seed == 1) & (df2.l_seed == 16) |
        (df1.w_seed == 8) & (df2.l_seed == 9) |
        (df1.w_seed == 5) & (df2.l_seed == 12) |
        (df1.w_seed == 4) & (df2.l_seed == 13) |
        (df1.w_seed == 6) & (df2.l_seed == 11) |
        (df1.w_seed == 3) & (df2.l_seed == 14) |
        (df1.w_seed == 7) & (df2.l_seed == 10) |
        (df1.w_seed == 2) & (df2.l_seed == 15) 

    ),
    "inner",
)

first_round.show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [19]:
result = r1_model.predict(first_round)

teams = session.table('m_teams')

result = result.select('l_teamID','l_seed','l_region','w_teamid','w_seed','w_region','pred_win_indicator')

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_1_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select('l_teamid','w_teamid','L_TEAM_NAME','l_seed','l_region','W_TEAM_NAME','w_seed','w_region','PRED_WIN_INDICATOR')
)

In [20]:
round_1_winners = round_1_results.with_column(
    "W_TEAM_ID",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_teamid)
    .otherwise(round_1_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_team_name)
    .otherwise(round_1_results.l_team_name),
).with_column(
    "W_SEED",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_seed)
    .otherwise(round_1_results.l_seed),
).with_column(
    "W_REGION",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_region)
    .otherwise(round_1_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

round_1_winners.sort('W_REGION').show(32)

--------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |
--------------------------------------------------------
|1140         |BYU             |6         |W           |
|1321         |Northwestern    |9         |W           |
|1361         |San Diego St    |5         |W           |
|1120         |Auburn          |4         |W           |
|1163         |Connecticut     |1         |W           |
|1450         |Washington St   |7         |W           |
|1235         |Iowa St         |2         |W           |
|1228         |Illinois        |3         |W           |
|1277         |Michigan St     |9         |X           |
|1173         |Dayton          |7         |X           |
|1124         |Baylor          |3         |X           |
|1155         |Clemson         |6         |X           |
|1104         |Alabama         |4         |X           |
|1112         |Arizona         |2         |X           |
|1388         |St Mary's CA    

In [21]:
df1 = round_1_winners.cache_result()
df2 = round_1_winners.cache_result()

second_round_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([8,9]))
        | (df1.w_seed == 16) & (df2.w_seed.isin([8,9]))
        | (df1.w_seed == 4) & (df2.w_seed.isin([5,12]))
        | (df1.w_seed == 13) & (df2.w_seed.isin([5,12]))
        | (df1.w_seed == 3) & (df2.w_seed.isin([6,11]))
        | (df1.w_seed == 14) & (df2.w_seed.isin([6,11]))
        | (df1.w_seed == 2) & (df2.w_seed.isin([7,10]))
        | (df1.w_seed == 15) & (df2.w_seed.isin([7,10]))
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

second_round_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

-------------------------------------------------------------------
|"TEAM_1"        |"TEAM_2"       |"SEED_1"  |"SEED_2"  |"REGION"  |
-------------------------------------------------------------------
|Connecticut     |Northwestern   |1         |9         |W         |
|Illinois        |BYU            |3         |6         |W         |
|Iowa St         |Washington St  |2         |7         |W         |
|Auburn          |San Diego St   |4         |5         |W         |
|North Carolina  |Michigan St    |1         |9         |X         |
|Baylor          |Clemson        |3         |6         |X         |
|Alabama         |St Mary's CA   |4         |5         |X         |
|Arizona         |Dayton         |2         |7         |X         |
|Tennessee       |Texas          |2         |7         |Y         |
|Kansas          |Gonzaga        |4         |5         |Y         |
|Purdue          |TCU            |1         |9         |Y         |
|Creighton       |Oregon         |3         |11 

In [22]:
second_round = second_round_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = second_round.join(season_w, season_w.w_teamid == second_round.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == second_round.lTEAMID2)

games_32 = df2.drop('WTEAMID2','lTEAMID2')

games_32.count()

16

In [23]:
result = r2_model.predict(games_32)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_2_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [24]:
round_2_winners = round_2_results.with_column(
    "W_TEAM_ID",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_teamid)
    .otherwise(round_2_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_team_name)
    .otherwise(round_2_results.l_team_name),
).with_column(
    "W_SEED",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_seed)
    .otherwise(round_2_results.l_seed),
).with_column(
    "W_REGION",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_region)
    .otherwise(round_2_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# SWEET SIXTEEN

In [25]:
round_2_winners.sort('W_REGION').show(20)

--------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |
--------------------------------------------------------
|1120         |Auburn          |4         |W           |
|1235         |Iowa St         |2         |W           |
|1163         |Connecticut     |1         |W           |
|1140         |BYU             |6         |W           |
|1314         |North Carolina  |1         |X           |
|1112         |Arizona         |2         |X           |
|1388         |St Mary's CA    |5         |X           |
|1124         |Baylor          |3         |X           |
|1211         |Gonzaga         |5         |Y           |
|1166         |Creighton       |3         |Y           |
|1345         |Purdue          |1         |Y           |
|1397         |Tennessee       |2         |Y           |
|1458         |Wisconsin       |5         |Z           |
|1246         |Kentucky        |3         |Z           |
|1266         |Marquette       

In [26]:
df1 = round_2_winners.cache_result()
df2 = round_2_winners.cache_result()

third_round_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 16) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 8) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 9) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 6) & (df2.w_seed.isin([7,10,2,15])) |
        (df1.w_seed == 11) & (df2.w_seed.isin([7,10,2,15])) |
        (df1.w_seed == 3) & (df2.w_seed.isin([7,10,2,15]))|
        (df1.w_seed == 14) & (df2.w_seed.isin([7,10,2,15]))
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

third_round_matchups.sort("W_REGION").show(20)

------------------------------------------------------------------------------------------------------------
|"W_TEAM_NAME"   |"W_REGION"  |"WTEAMID2"  |"W_SEED"  |"L_TEAM_NAME"  |"L_REGION"  |"LTEAMID2"  |"L_SEED"  |
------------------------------------------------------------------------------------------------------------
|BYU             |W           |1140        |6         |Iowa St        |W           |1235        |2         |
|Connecticut     |W           |1163        |1         |Auburn         |W           |1120        |4         |
|Baylor          |X           |1124        |3         |Arizona        |X           |1112        |2         |
|North Carolina  |X           |1314        |1         |St Mary's CA   |X           |1388        |5         |
|Creighton       |Y           |1166        |3         |Tennessee      |Y           |1397        |2         |
|Purdue          |Y           |1345        |1         |Gonzaga        |Y           |1211        |5         |
|Houston         |Z

In [27]:
third_round = third_round_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = third_round.join(season_w, season_w.w_teamid == third_round.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == third_round.lTEAMID2)

games_16 = df2.drop('WTEAMID2','lTEAMID2')

games_16.count()

8

In [45]:
result = r3_model.predict(games_16)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

sweet_16_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [46]:
sweet_16_winners = sweet_16_results.with_column(
    "W_TEAM_ID",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_teamid)
    .otherwise(sweet_16_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_team_name)
    .otherwise(sweet_16_results.l_team_name),
).with_column(
    "W_SEED",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_seed)
    .otherwise(sweet_16_results.l_seed),
).with_column(
    "W_REGION",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_region)
    .otherwise(sweet_16_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

## ELITE 8

In [47]:
sweet_16_winners.sort('W_REGION').show(20)

--------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |
--------------------------------------------------------
|1120         |Auburn          |4         |W           |
|1235         |Iowa St         |2         |W           |
|1112         |Arizona         |2         |X           |
|1314         |North Carolina  |1         |X           |
|1397         |Tennessee       |2         |Y           |
|1345         |Purdue          |1         |Y           |
|1222         |Houston         |1         |Z           |
|1266         |Marquette       |2         |Z           |
--------------------------------------------------------



In [48]:
df1 = sweet_16_winners.cache_result()
df2 = sweet_16_winners.cache_result()

elite_8_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 16) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 8) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 9) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 5) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 12) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 4) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 13) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) 
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

# ELITE 8

In [49]:
elite_8_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

---------------------------------------------------------------
|"TEAM_1"        |"TEAM_2"   |"SEED_1"  |"SEED_2"  |"REGION"  |
---------------------------------------------------------------
|Auburn          |Iowa St    |4         |2         |W         |
|North Carolina  |Arizona    |1         |2         |X         |
|Purdue          |Tennessee  |1         |2         |Y         |
|Houston         |Marquette  |1         |2         |Z         |
---------------------------------------------------------------



In [50]:
elite_8 = elite_8_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = elite_8.join(season_w, season_w.w_teamid == elite_8.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == elite_8.lTEAMID2)

elite_8 = df2.drop('WTEAMID2','lTEAMID2')

In [51]:
result = r4_model.predict(elite_8)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

elite_8_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [52]:
elite_8_winners = elite_8_results.with_column(
    "W_TEAM_ID",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_teamid)
    .otherwise(elite_8_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_team_name)
    .otherwise(elite_8_results.l_team_name),
).with_column(
    "W_SEED",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_seed)
    .otherwise(elite_8_results.l_seed),
).with_column(
    "W_REGION",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_region)
    .otherwise(elite_8_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

In [53]:
df1 = elite_8_winners.filter(F.col("W_REGION") == "W")

df2 = elite_8_winners.filter(F.col("W_REGION") == "X").select(
    F.col("W_TEAM_ID").alias("TEAM_2"),
    F.col("W_TEAM_NAME").alias("W_TEAM_NAME_2"),
    F.col("W_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION_2"),
)

df3 = df1.join(df2)

df4 = elite_8_winners.filter(F.col("W_REGION") == "Y")

df5 = elite_8_winners.filter(F.col("W_REGION") == "Z").select(
    F.col("W_TEAM_ID").alias("TEAM_2"),
    F.col("W_TEAM_NAME").alias("W_TEAM_NAME_2"),
    F.col("W_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION_2"),
)

df6 = df4.join(df5)

final_four_matchups = df3.union(df6).select(
    F.col('W_TEAM_NAME'),
    F.col('W_REGION'),
    F.col("W_TEAM_ID").alias("WTEAMID2"),
    F.col('W_SEED'),
    F.col("W_TEAM_NAME_2").alias("L_TEAM_NAME"),
    F.col("REGION_2").alias("L_REGION"),
    F.col("TEAM_2").alias("LTEAMID2"),
    F.col("SEED_2").alias("L_SEED"),
)

# FINAL FOUR

In [54]:
final_four_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show()

--------------------------------------------------------
|"TEAM_1"  |"TEAM_2"  |"SEED_1"  |"SEED_2"  |"REGION"  |
--------------------------------------------------------
|Iowa St   |Arizona   |2         |2         |W         |
|Purdue    |Houston   |1         |1         |Y         |
--------------------------------------------------------



In [55]:
final_four = final_four_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = final_four.join(season_w, season_w.w_teamid == final_four.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == final_four.lTEAMID2)

final_four = df2.drop('WTEAMID2','lTEAMID2')

final_four.count()

2

In [56]:
result = r5_model.predict(final_four)

teams = session.table("m_teams")


result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

final_four_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [57]:
final_four_winners = final_four_results.with_column(
    "W_TEAM_ID",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_teamid)
    .otherwise(final_four_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_team_name)
    .otherwise(final_four_results.l_team_name),
).with_column(
    "W_SEED",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_seed)
    .otherwise(final_four_results.l_seed),
).with_column(
    "W_REGION",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_region)
    .otherwise(final_four_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

In [58]:
df1 = final_four_winners.cache_result()
df2 = final_four_winners.cache_result()

championship_matchup = df1.join(df2).select(
    df1.w_team_name.alias('W_TEAM_NAME'),
    df1.w_region.alias('W_REGION'),
    df1.w_team_id.alias('WTEAMID2'),
    df1.w_seed.alias('W_SEED'),
    df2.w_team_name.alias('L_TEAM_NAME'),
    df2.w_region.alias('L_REGION'),
    df2.w_team_id.alias('LTEAMID2'),
    df2.w_seed.alias('L_SEED'),
).filter(F.col("WTEAMID2") != F.col("LTEAMID2")).limit(1)

# CHAMPIONSHIP

In [59]:
championship_matchup.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

--------------------------------------------------------
|"TEAM_1"  |"TEAM_2"  |"SEED_1"  |"SEED_2"  |"REGION"  |
--------------------------------------------------------
|Iowa St   |Houston   |2         |1         |W         |
--------------------------------------------------------



In [60]:
championship = championship_matchup.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = championship.join(season_w, season_w.w_teamid == championship.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == championship.lTEAMID2)

championship = df2.drop('WTEAMID2','lTEAMID2')

championship.count()

1

In [61]:
result = r6_model.predict(championship)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

championship_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [64]:
champion = championship_results.with_column(
    "W_TEAM_ID",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_teamid)
    .otherwise(championship_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_team_name)
    .otherwise(championship_results.l_team_name),
).with_column(
    "W_SEED",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_seed)
    .otherwise(championship_results.l_seed),
).with_column(
    "W_REGION",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_region)
    .otherwise(championship_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# CHAMPION

In [65]:
champion.show()

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1235         |Iowa St        |2         |W           |
-------------------------------------------------------



# Use Average or Median for projected score

In [66]:
r6_model.predict(championship).select(
    "L_SCORE_MEDIAN", "W_SCORE_MEDIAN", "L_SCORE_MEAN", "W_SCORE_MEAN"
).show()

-------------------------------------------------------------------------
|"L_SCORE_MEDIAN"  |"W_SCORE_MEDIAN"  |"L_SCORE_MEAN"  |"W_SCORE_MEAN"  |
-------------------------------------------------------------------------
|75.0              |71.5              |73.029412       |75.558824       |
-------------------------------------------------------------------------

