# START TRAINING

In [None]:
import ast
import json
import warnings

import pandas as pd
from snowflake.ml.modeling.impute import SimpleImputer
from snowflake.ml.modeling.metrics import accuracy_score
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.preprocessing import OneHotEncoder
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session
from snowflake.snowpark import types as T
from snowflake.snowpark.functions import col
from snowflake.snowpark import functions as F
from snowflake.snowpark.functions import when, lit

import numpy as np

warnings.simplefilter(action="ignore", category=UserWarning)

In [None]:
session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()

In [None]:
final = (
    session.table("MEN.FINAL_FEATURES")
    .with_column_renamed("WTEAMID", "W_TEAMID")
    .with_column_renamed("LTEAMID", "L_TEAMID")
    .with_column_renamed("WSCORE", "W_SCORE")
    .with_column_renamed("LSCORE", "L_SCORE")
)

In [None]:
OHE = OneHotEncoder(
    input_cols=["W_CONF","L_CONF"],
    output_cols=["W_CONF","L_CONF"],
    drop_input_cols=True,
    drop="first",
    handle_unknown="ignore",
)

final = OHE.fit(final).transform(final)
final.show()

In [None]:
parameters = {
    "n_estimators": [100, 200, 300, 400, 500],
    "learning_rate": [0.1, 0.2, 0.3, 0.4, 0.5],
    "max_depth": list(range(3, 6, 1)),
    "min_child_weight": list(range(1, 6, 1)),
}

In [None]:
def add_losing_matches(df):
    lose_rename = {
        "W_TEAMID": "L_TEAMID", 
        "W_SCORE" : "L_SCORE", 
        "L_TEAMID" : "W_TEAMID",
        "L_SCORE": "W_SCORE",
    }
    lose_rename.update({c : "L_" + c[2:] for c in df.columns if c.startswith('W_')})
    lose_rename.update({c : "W_" + c[2:] for c in df.columns if c.startswith('L_')})
    
    win_df = df.copy()
    lose_df = df.copy()
    
    lose_df = lose_df.rename(columns=lose_rename)
    
    return pd.concat([win_df, lose_df], axis=0, sort=False)

In [None]:
train = final.filter(F.col('SEASON') <= 2023)
test = final.filter(F.col('SEASON') > 2021)

In [None]:
df_pd = train.to_pandas()
train_2 = add_losing_matches(df_pd)
train_2 = session.create_dataframe(train_2)

In [None]:
train_2 = train_2.withColumn(
    "WIN_INDICATOR", F.when(train_2["W_SCORE"] > train_2["L_SCORE"], 1).otherwise(0)
)

In [None]:
test = test.withColumn(
    "WIN_INDICATOR", F.lit(1))

In [None]:
session.use_warehouse('MM_L')

In [None]:
train_2 = train_2.with_columns(
    ["W_WLOCN", "W_WLOCH", "W_WLOCA", "L_WLOCN", "L_WLOCH", "L_WLOCA"],
    [
        F.col("W_WLOCN").cast(T.LongType()),
        F.col("W_WLOCH").cast(T.LongType()),
        F.col("W_WLOCA").cast(T.LongType()),
        F.col("L_WLOCN").cast(T.LongType()),
        F.col("L_WLOCH").cast(T.LongType()),
        F.col("L_WLOCA").cast(T.LongType()),
    ],
)


In [None]:
all_rounds = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train_2.drop(
        [
            "SEASON",
            "WIN_INDICATOR",
            "l_teamid",
            "w_teamid",
            "w_score",
            "l_score",
            "round",
            "l_region",
            "w_region",
            "l_seed",
            "w_seed"
        ]
    ).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
all_rounds.fit(train_2)

In [None]:
#session.use_warehouse('wh_xs')

In [None]:
result = all_rounds.predict(test).filter(F.col("ROUND") == 1)
result.count()

In [None]:
accuracy_2022 = accuracy_score(
    df=result.filter(result.season == 2022),
    y_true_col_names="WIN_INDICATOR",
    y_pred_col_names="PRED_WIN_INDICATOR",
)
print(f"Accuracy 2022: {accuracy_2022}")

accuracy_2023 = accuracy_score(
    df=result.filter(result.season == 2023),
    y_true_col_names="WIN_INDICATOR",
    y_pred_col_names="PRED_WIN_INDICATOR",
)
print(f"Accuracy 2023: {accuracy_2023}")

accuracy = accuracy_score(
    df=result, y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)

print(f"Accuracy total: {accuracy}")

#### Register Models that are good

In [None]:
optimal_model = all_rounds.to_sklearn().best_estimator_

def check_and_update(df, model_name):
    if df.empty:
        return "V_1"
    elif df[df["name"] == model_name].empty:
        return "V_1"
    else:
        # Increment model_version if df is not a pandas Series
        lst = sorted(ast.literal_eval(df["versions"][0]))
        last_value = lst[-1]
        prefix, num = last_value.rsplit("_", 1)
        new_last_value = f"{prefix}_{int(num)+1}"
        lst[-1] = new_last_value
        return new_last_value
# Get sample input data to pass into the registry logging function
X = train_2.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).limit(100)

# Create a registry and log the model
# You can specify a different DB and Schema if you'd like
# otherwise it uses the session context
reg = Registry(session=session)

reg_df = reg.show_models()

# Define model name and version (use uppercase for name)
model_name = "MARCHMADNESS"

model_version = check_and_update(reg_df, model_name)

mm_model = reg.log_model(
    model_name=model_name,
    version_name=model_version,
    model=optimal_model,
    sample_input_data=X,
)

# Add evaluation metric
mm_model.set_metric(
    metric_name="accuracy_2023",
    value= accuracy_2023
)

mm_model.set_metric(
    metric_name="accuracy_2022",
    value= accuracy_2022,
)
reg.get_model(model_name).show_versions()

# Predicting the Bracket & final four

## 2024 Data

In [None]:
season_2024 = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024).drop('REGION')

In [None]:
tourney_2024 = session.table('COMMON.TOURNEY_SEEDS_2024').filter(F.col('TOURNAMENT') == 'M')

In [None]:
seed_value = (
    tourney_2024
    .with_column("REGION", F.substring(F.col("SEED"), 1, 1))
    .with_column(
        "SEED", F.substring(F.col("SEED"), 2, F.length(F.col("SEED")) - 1)
    )
    .select("TEAMID", "REGION", "SEED")
    .with_column(
        "SEED",
        F.cast(F.regexp_replace(F.col("SEED"), "[a-z]", ""), T.IntegerType()),
    )
).select(F.col('TEAMID').alias('TEAMID1'), 'REGION','SEED')

In [None]:
round_of_64 = season_2024.join(seed_value, season_2024.TEAMID == seed_value.TEAMID1).drop('TEAMID1')


In [None]:
df1 = round_of_64.select(
    *[F.col(col).alias(f"W_{col}") for col in round_of_64.columns]
)

df2 = round_of_64.select(
    *[F.col(col).alias(f"L_{col}") for col in round_of_64.columns]
)

In [None]:
first_round = df1.join(
    df2,
    (df1.w_region == df2.l_region)
    & (
        (df1.w_seed == 1) & (df2.l_seed == 16) |
        (df1.w_seed == 8) & (df2.l_seed == 9) |
        (df1.w_seed == 5) & (df2.l_seed == 12) |
        (df1.w_seed == 4) & (df2.l_seed == 13) |
        (df1.w_seed == 6) & (df2.l_seed == 11) |
        (df1.w_seed == 3) & (df2.l_seed == 14) |
        (df1.w_seed == 7) & (df2.l_seed == 10) |
        (df1.w_seed == 2) & (df2.l_seed == 15) 

    ),
    "inner",
)

first_round.show()

In [24]:
result = all_rounds.predict(first_round)

teams = session.table('m_teams')

result = result.select('l_teamID','l_seed','l_region','w_teamid','w_seed','w_region','pred_win_indicator')

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_1_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select('l_teamid','w_teamid','L_TEAM_NAME','l_seed','l_region','W_TEAM_NAME','w_seed','w_region','PRED_WIN_INDICATOR')
)

In [25]:
result.show()

--------------------------------------------------------------------------------------------------
|"L_TEAMID"  |"L_SEED"  |"L_REGION"  |"W_TEAMID"  |"W_SEED"  |"W_REGION"  |"PRED_WIN_INDICATOR"  |
--------------------------------------------------------------------------------------------------
|1389        |15        |Y           |1397        |2         |Y           |1.0                   |
|1270        |12        |Y           |1211        |5         |Y           |0.0                   |
|1103        |14        |Y           |1166        |3         |Y           |1.0                   |
|1179        |10        |W           |1450        |7         |W           |0.0                   |
|1253        |15        |X           |1112        |2         |X           |1.0                   |
|1224        |16        |X           |1314        |1         |X           |1.0                   |
|1355        |15        |W           |1235        |2         |W           |1.0                   |
|1129     

In [26]:
round_1_winners = round_1_results.with_column(
    "W_TEAM_ID",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_teamid)
    .otherwise(round_1_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_team_name)
    .otherwise(round_1_results.l_team_name),
).with_column(
    "W_SEED",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_seed)
    .otherwise(round_1_results.l_seed),
).with_column(
    "W_REGION",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_region)
    .otherwise(round_1_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

round_1_winners.sort('W_REGION').show(32)

--------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |
--------------------------------------------------------
|1140         |BYU             |6         |W           |
|1321         |Northwestern    |9         |W           |
|1412         |UAB             |12        |W           |
|1120         |Auburn          |4         |W           |
|1163         |Connecticut     |1         |W           |
|1179         |Drake           |10        |W           |
|1235         |Iowa St         |2         |W           |
|1228         |Illinois        |3         |W           |
|1277         |Michigan St     |9         |X           |
|1305         |Nevada          |10        |X           |
|1124         |Baylor          |3         |X           |
|1307         |New Mexico      |11        |X           |
|1104         |Alabama         |4         |X           |
|1112         |Arizona         |2         |X           |
|1213         |Grand Canyon    

In [27]:
df1 = round_1_winners.cache_result()
df2 = round_1_winners.cache_result()

second_round_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([8,9]))
        | (df1.w_seed == 16) & (df2.w_seed.isin([8,9]))
        | (df1.w_seed == 4) & (df2.w_seed.isin([5,12]))
        | (df1.w_seed == 13) & (df2.w_seed.isin([5,12]))
        | (df1.w_seed == 3) & (df2.w_seed.isin([6,11]))
        | (df1.w_seed == 14) & (df2.w_seed.isin([6,11]))
        | (df1.w_seed == 2) & (df2.w_seed.isin([7,10]))
        | (df1.w_seed == 15) & (df2.w_seed.isin([7,10]))
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

second_round_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

--------------------------------------------------------------------
|"TEAM_1"        |"TEAM_2"        |"SEED_1"  |"SEED_2"  |"REGION"  |
--------------------------------------------------------------------
|Connecticut     |Northwestern    |1         |9         |W         |
|Illinois        |BYU             |3         |6         |W         |
|Iowa St         |Drake           |2         |10        |W         |
|Auburn          |UAB             |4         |12        |W         |
|North Carolina  |Michigan St     |1         |9         |X         |
|Baylor          |New Mexico      |3         |11        |X         |
|Alabama         |Grand Canyon    |4         |12        |X         |
|Arizona         |Nevada          |2         |10        |X         |
|Tennessee       |Texas           |2         |7         |Y         |
|Samford         |McNeese St      |13        |12        |Y         |
|Purdue          |TCU             |1         |9         |Y         |
|Creighton       |South Carolina  

In [28]:
second_round = second_round_matchups.select("WTEAMID2", "LTEAMID2", "W_SEED", "L_SEED")

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col("SEASON") == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = second_round.join(season_w, season_w.w_teamid == second_round.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == second_round.lTEAMID2)

games_32 = df2.drop("WTEAMID2", "lTEAMID2")

games_32.count()

16

In [29]:
result = all_rounds.predict(games_32)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_2_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [30]:
round_2_winners = round_2_results.with_column(
    "W_TEAM_ID",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_teamid)
    .otherwise(round_2_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_team_name)
    .otherwise(round_2_results.l_team_name),
).with_column(
    "W_SEED",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_seed)
    .otherwise(round_2_results.l_seed),
).with_column(
    "W_REGION",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_region)
    .otherwise(round_2_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# SWEET SIXTEEN

In [31]:
round_2_winners.sort('W_REGION').show(20)

--------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |
--------------------------------------------------------
|1163         |Connecticut     |1         |W           |
|1235         |Iowa St         |2         |W           |
|1228         |Illinois        |3         |W           |
|1120         |Auburn          |4         |W           |
|1277         |Michigan St     |9         |X           |
|1104         |Alabama         |4         |X           |
|1112         |Arizona         |2         |X           |
|1124         |Baylor          |3         |X           |
|1376         |South Carolina  |6         |Y           |
|1270         |McNeese St      |12        |Y           |
|1345         |Purdue          |1         |Y           |
|1400         |Texas           |7         |Y           |
|1458         |Wisconsin       |5         |Z           |
|1266         |Marquette       |2         |Z           |
|1222         |Houston         

In [32]:
df1 = round_2_winners.cache_result()
df2 = round_2_winners.cache_result()

third_round_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 16) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 8) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 9) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 6) & (df2.w_seed.isin([7,10,2,15])) |
        (df1.w_seed == 11) & (df2.w_seed.isin([7,10,2,15])) |
        (df1.w_seed == 3) & (df2.w_seed.isin([7,10,2,15]))|
        (df1.w_seed == 14) & (df2.w_seed.isin([7,10,2,15]))
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

third_round_matchups.sort("W_REGION").show(20)

------------------------------------------------------------------------------------------------------------
|"W_TEAM_NAME"   |"W_REGION"  |"WTEAMID2"  |"W_SEED"  |"L_TEAM_NAME"  |"L_REGION"  |"LTEAMID2"  |"L_SEED"  |
------------------------------------------------------------------------------------------------------------
|Illinois        |W           |1228        |3         |Iowa St        |W           |1235        |2         |
|Connecticut     |W           |1163        |1         |Auburn         |W           |1120        |4         |
|Michigan St     |X           |1277        |9         |Alabama        |X           |1104        |4         |
|Baylor          |X           |1124        |3         |Arizona        |X           |1112        |2         |
|Purdue          |Y           |1345        |1         |McNeese St     |Y           |1270        |12        |
|South Carolina  |Y           |1376        |6         |Texas          |Y           |1400        |7         |
|Houston         |Z

In [33]:
third_round = third_round_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = third_round.join(season_w, season_w.w_teamid == third_round.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == third_round.lTEAMID2)

games_16 = df2.drop('WTEAMID2','lTEAMID2')

games_16.count()

8

In [34]:
result = all_rounds.predict(games_16)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

sweet_16_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [35]:
sweet_16_winners = sweet_16_results.with_column(
    "W_TEAM_ID",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_teamid)
    .otherwise(sweet_16_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_team_name)
    .otherwise(sweet_16_results.l_team_name),
).with_column(
    "W_SEED",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_seed)
    .otherwise(sweet_16_results.l_seed),
).with_column(
    "W_REGION",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_region)
    .otherwise(sweet_16_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

## ELITE 8

In [36]:
sweet_16_winners.sort('W_REGION').show(20)

--------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |
--------------------------------------------------------
|1163         |Connecticut     |1         |W           |
|1235         |Iowa St         |2         |W           |
|1112         |Arizona         |2         |X           |
|1104         |Alabama         |4         |X           |
|1345         |Purdue          |1         |Y           |
|1376         |South Carolina  |6         |Y           |
|1222         |Houston         |1         |Z           |
|1246         |Kentucky        |3         |Z           |
--------------------------------------------------------



In [37]:
df1 = sweet_16_winners.cache_result()
df2 = sweet_16_winners.cache_result()

elite_8_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 16) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 8) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 9) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 5) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 12) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 4) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 13) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) 
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

# ELITE 8

In [38]:
elite_8_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

-----------------------------------------------------------------
|"TEAM_1"     |"TEAM_2"        |"SEED_1"  |"SEED_2"  |"REGION"  |
-----------------------------------------------------------------
|Connecticut  |Iowa St         |1         |2         |W         |
|Alabama      |Arizona         |4         |2         |X         |
|Purdue       |South Carolina  |1         |6         |Y         |
|Houston      |Kentucky        |1         |3         |Z         |
-----------------------------------------------------------------



In [39]:
elite_8 = elite_8_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = elite_8.join(season_w, season_w.w_teamid == elite_8.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == elite_8.lTEAMID2)

elite_8 = df2.drop('WTEAMID2','lTEAMID2')

In [40]:
result = all_rounds.predict(elite_8)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

elite_8_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [41]:
elite_8_winners = elite_8_results.with_column(
    "W_TEAM_ID",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_teamid)
    .otherwise(elite_8_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_team_name)
    .otherwise(elite_8_results.l_team_name),
).with_column(
    "W_SEED",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_seed)
    .otherwise(elite_8_results.l_seed),
).with_column(
    "W_REGION",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_region)
    .otherwise(elite_8_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

In [42]:
df1 = elite_8_winners.filter(F.col("W_REGION") == "W")

df2 = elite_8_winners.filter(F.col("W_REGION") == "X").select(
    F.col("W_TEAM_ID").alias("TEAM_2"),
    F.col("W_TEAM_NAME").alias("W_TEAM_NAME_2"),
    F.col("W_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION_2"),
)

df3 = df1.join(df2)

df4 = elite_8_winners.filter(F.col("W_REGION") == "Y")

df5 = elite_8_winners.filter(F.col("W_REGION") == "Z").select(
    F.col("W_TEAM_ID").alias("TEAM_2"),
    F.col("W_TEAM_NAME").alias("W_TEAM_NAME_2"),
    F.col("W_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION_2"),
)

df6 = df4.join(df5)

final_four_matchups = df3.union(df6).select(
    F.col('W_TEAM_NAME'),
    F.col('W_REGION'),
    F.col("W_TEAM_ID").alias("WTEAMID2"),
    F.col('W_SEED'),
    F.col("W_TEAM_NAME_2").alias("L_TEAM_NAME"),
    F.col("REGION_2").alias("L_REGION"),
    F.col("TEAM_2").alias("LTEAMID2"),
    F.col("SEED_2").alias("L_SEED"),
)

# FINAL FOUR

In [43]:
final_four_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show()

--------------------------------------------------------------
|"TEAM_1"        |"TEAM_2"  |"SEED_1"  |"SEED_2"  |"REGION"  |
--------------------------------------------------------------
|Connecticut     |Arizona   |1         |2         |W         |
|South Carolina  |Kentucky  |6         |3         |Y         |
--------------------------------------------------------------



In [44]:
final_four = final_four_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = final_four.join(season_w, season_w.w_teamid == final_four.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == final_four.lTEAMID2)

final_four = df2.drop('WTEAMID2','lTEAMID2')

final_four.count()

2

In [45]:
result = all_rounds.predict(final_four)

teams = session.table("m_teams")


result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

final_four_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [46]:
final_four_winners = final_four_results.with_column(
    "W_TEAM_ID",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_teamid)
    .otherwise(final_four_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_team_name)
    .otherwise(final_four_results.l_team_name),
).with_column(
    "W_SEED",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_seed)
    .otherwise(final_four_results.l_seed),
).with_column(
    "W_REGION",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_region)
    .otherwise(final_four_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

In [47]:
df1 = final_four_winners.cache_result()
df2 = final_four_winners.cache_result()

championship_matchup = df1.join(df2).select(
    df1.w_team_name.alias('W_TEAM_NAME'),
    df1.w_region.alias('W_REGION'),
    df1.w_team_id.alias('WTEAMID2'),
    df1.w_seed.alias('W_SEED'),
    df2.w_team_name.alias('L_TEAM_NAME'),
    df2.w_region.alias('L_REGION'),
    df2.w_team_id.alias('LTEAMID2'),
    df2.w_seed.alias('L_SEED'),
).filter(F.col("WTEAMID2") != F.col("LTEAMID2")).limit(1)

# CHAMPIONSHIP

In [48]:
championship_matchup.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

-----------------------------------------------------------
|"TEAM_1"  |"TEAM_2"     |"SEED_1"  |"SEED_2"  |"REGION"  |
-----------------------------------------------------------
|Kentucky  |Connecticut  |3         |1         |Z         |
-----------------------------------------------------------



In [49]:
championship = championship_matchup.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = championship.join(season_w, season_w.w_teamid == championship.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == championship.lTEAMID2)

championship = df2.drop('WTEAMID2','lTEAMID2')

championship.count()

1

In [50]:
result = all_rounds.predict(championship)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

championship_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [52]:
champion = championship_results.with_column(
    "W_TEAM_ID",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_teamid)
    .otherwise(championship_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_team_name)
    .otherwise(championship_results.l_team_name),
).with_column(
    "W_SEED",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_seed)
    .otherwise(championship_results.l_seed),
).with_column(
    "W_REGION",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_region)
    .otherwise(championship_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# CHAMPION

In [53]:
champion.show()

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1163         |Connecticut    |1         |W           |
-------------------------------------------------------



# Use Average and Median for predicting score

In [54]:
all_rounds.predict(championship).select(
    "L_SCORE_MEDIAN", "W_SCORE_MEDIAN", "L_SCORE_MEAN", "W_SCORE_MEAN"
).show()

-------------------------------------------------------------------------
|"L_SCORE_MEDIAN"  |"W_SCORE_MEDIAN"  |"L_SCORE_MEAN"  |"W_SCORE_MEAN"  |
-------------------------------------------------------------------------
|80.5              |90.0              |81.470588       |89.4375         |
-------------------------------------------------------------------------

