# START TRAINING

In [49]:
import ast
import json
import warnings

import pandas as pd
from snowflake.ml.modeling.impute import SimpleImputer
from snowflake.ml.modeling.metrics import accuracy_score
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.preprocessing import OneHotEncoder
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session
from snowflake.snowpark import types as T
from snowflake.snowpark.functions import col
from snowflake.snowpark import functions as F
from snowflake.snowpark.functions import when, lit

import numpy as np

warnings.simplefilter(action="ignore", category=UserWarning)

In [50]:
session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()

In [51]:
final = (
    session.table("MEN.FINAL_FEATURES")
    .with_column_renamed("WTEAMID", "W_TEAMID")
    .with_column_renamed("LTEAMID", "L_TEAMID")
    .with_column_renamed("WSCORE", "W_SCORE")
    .with_column_renamed("LSCORE", "L_SCORE")
)

In [52]:
final.show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [53]:
OHE = OneHotEncoder(
    input_cols=["W_CONF","L_CONF"],
    output_cols=["W_CONF","L_CONF"],
    drop_input_cols=True,
    drop="first",
    handle_unknown="ignore",
)

final = OHE.fit(final).transform(final)
final.show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [54]:
parameters = {
    "n_estimators": [100, 200, 300, 400, 500],
    # "learning_rate": [0.1, 0.2, 0.3, 0.4, 0.5],
    # "max_depth": list(range(3, 6, 1)),
    # "min_child_weight": list(range(1, 6, 1)),
}

In [55]:
def add_losing_matches(df):
    lose_rename = {
        "W_TEAMID": "L_TEAMID", 
        "W_SCORE" : "L_SCORE", 
        "L_TEAMID" : "W_TEAMID",
        "L_SCORE": "W_SCORE",
    }
    lose_rename.update({c : "L_" + c[2:] for c in df.columns if c.startswith('W_')})
    lose_rename.update({c : "W_" + c[2:] for c in df.columns if c.startswith('L_')})
    
    win_df = df.copy()
    lose_df = df.copy()
    
    lose_df = lose_df.rename(columns=lose_rename)
    
    return pd.concat([win_df, lose_df], axis=0, sort=False)

In [56]:
train = final.filter(F.col('SEASON') <= 2021)
test = final.filter(F.col('SEASON') > 2021)

In [57]:
df_pd = train.to_pandas()
train_2 = add_losing_matches(df_pd)
train_2 = session.create_dataframe(train_2)

In [58]:
train_2 = train_2.withColumn(
    "WIN_INDICATOR", F.when(train_2["W_SCORE"] > train_2["L_SCORE"], 1).otherwise(0)
)

In [72]:
test = test.withColumn(
    "WIN_INDICATOR", F.lit(1))

In [60]:
session.use_warehouse('MM_L')

In [67]:
train_2 = train_2.with_columns(
    ["W_WLOCN", "W_WLOCH", "W_WLOCA", "L_WLOCN", "L_WLOCH", "L_WLOCA"],
    [
        F.col("W_WLOCN").cast(T.LongType()),
        F.col("W_WLOCH").cast(T.LongType()),
        F.col("W_WLOCA").cast(T.LongType()),
        F.col("L_WLOCN").cast(T.LongType()),
        F.col("L_WLOCH").cast(T.LongType()),
        F.col("L_WLOCA").cast(T.LongType()),
    ],
)


In [68]:
all_rounds = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train_2.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
all_rounds.fit(train_2)

The version of package 'snowflake-snowpark-python' in the local environment is 1.13.0, which does not fit the criteria for the requirement 'snowflake-snowpark-python<2'. Your UDF might not work when the package version is different between the server and your local environment.
Package 'fastparquet' is not installed in the local environment. Your UDF might not work when the package is installed on the server but not on your local environment.
The version of package 'pyarrow' in the local environment is 15.0.1, which does not fit the criteria for the requirement 'pyarrow<14'. Your UDF might not work when the package version is different between the server and your local environment.
The version of package 'cachetools' in the local environment is 5.3.3, which does not fit the criteria for the requirement 'cachetools<6'. Your UDF might not work when the package version is different between the server and your local environment.


<snowflake.ml.modeling.model_selection.grid_search_cv.GridSearchCV at 0x7fbc31df06a0>

In [None]:
#session.use_warehouse('wh_xs')

In [74]:
result = all_rounds.predict(test).filter(F.col("ROUND") == 1)
result.count()

64

In [75]:
result = all_rounds.predict(test).filter(F.col("ROUND") == 1)

accuracy_2022 = accuracy_score(
    df=result.filter(result.season == 2022),
    y_true_col_names="WIN_INDICATOR",
    y_pred_col_names="PRED_WIN_INDICATOR",
)
print(f"Accuracy 2022: {accuracy_2022}")

accuracy_2023 = accuracy_score(
    df=result.filter(result.season == 2023),
    y_true_col_names="WIN_INDICATOR",
    y_pred_col_names="PRED_WIN_INDICATOR",
)
print(f"Accuracy 2023: {accuracy_2023}")

accuracy = accuracy_score(
    df=result, y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)

print(f"Accuracy total: {accuracy}")

Accuracy 2022: 0.8125
Accuracy 2023: 0.6875
Accuracy total: 0.75


#### Register Models that are good

In [77]:
optimal_model = all_rounds.to_sklearn().best_estimator_

def check_and_update(df, model_name):
    if df.empty:
        return "V_1"
    elif df[df["name"] == model_name].empty:
        return "V_1"
    else:
        # Increment model_version if df is not a pandas Series
        lst = sorted(ast.literal_eval(df["versions"][0]))
        last_value = lst[-1]
        prefix, num = last_value.rsplit("_", 1)
        new_last_value = f"{prefix}_{int(num)+1}"
        lst[-1] = new_last_value
        return new_last_value
# Get sample input data to pass into the registry logging function
X = train_2.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).limit(100)

# Create a registry and log the model
# You can specify a different DB and Schema if you'd like
# otherwise it uses the session context
reg = Registry(session=session)

reg_df = reg.show_models()

# Define model name and version (use uppercase for name)
model_name = "MARCHMADNESS"

model_version = check_and_update(reg_df, model_name)

mm_model = reg.log_model(
    model_name=model_name,
    version_name=model_version,
    model=optimal_model,
    sample_input_data=X,
)

# Add evaluation metric
mm_model.set_metric(
    metric_name="accuracy_2023",
    value= accuracy_2023
)

mm_model.set_metric(
    metric_name="accuracy_2022",
    value= accuracy_2022,
)
reg.get_model(model_name).show_versions()

Unnamed: 0,created_on,name,comment,database_name,schema_name,module_name,is_default_version,functions,metadata,user_data
0,2024-03-20 14:41:44.382000-07:00,V_1,,MARCHMADNESS,MEN,MARCHMADNESS,True,"[""PREDICT_PROBA"",""PREDICT"",""APPLY""]","{""metrics"": {""accuracy_2023"": 0.6875, ""accurac...","{""snowpark_ml_data"":{""functions"":[{""name"":""APP..."


# Predicting the Bracket & final four (we want as many final 4 as possible)

## Results play in games

In [79]:
result = all_rounds.predict(test).filter(F.col('ROUND') == 0).filter(F.col("season") == 2023)

teams = session.table('m_teams')

result = result.select('season','l_teamID','l_seed','l_region','w_teamid','w_seed','w_region','win_indicator','pred_win_indicator')

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

res_teams = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select('season','l_teamid','w_teamid','L_TEAM_NAME','l_seed','l_region','W_TEAM_NAME','w_seed','w_region','WIN_INDICATOR','PRED_WIN_INDICATOR')
)

res_teams.sort(F.col("w_team_name")).show()

-----------------------------------------------------------------------------------------------------------------------------------------------------------------
|"SEASON"  |"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"   |"L_SEED"  |"L_REGION"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |"WIN_INDICATOR"  |"PRED_WIN_INDICATOR"  |
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
|2023      |1305        |1113        |Nevada          |11        |Z           |Arizona St      |11        |Z           |1                |1.0                   |
|2023      |1411        |1192        |TX Southern     |16        |W           |F Dickinson     |16        |W           |1                |1.0                   |
|2023      |1280        |1338        |Mississippi St  |11        |Y           |Pittsburgh      |11        |Y           |1                |1.0                   |
|2023      |1369        |139

## Results round 1

In [80]:
result = all_rounds.predict(test.filter(F.col('ROUND') == 1)).filter(F.col("season") == 2023)

teams = session.table('m_teams')

result = result.select('season','l_teamID','l_seed','l_region','w_teamid','w_seed','w_region','win_indicator','pred_win_indicator')

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

res_teams = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select('season','l_teamid','w_teamid','L_TEAM_NAME','l_seed','l_region','W_TEAM_NAME','w_seed','w_region','WIN_INDICATOR','PRED_WIN_INDICATOR')
)

print('The round of 32')
res_teams.count()

The round of 32


32

In [81]:
result = all_rounds.predict(test.filter(F.col("ROUND") == 1)).filter(F.col("season") == 2023)

teams = session.table("m_teams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = (
    result.select(
        "season",
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "win_indicator",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_1_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "season",
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "WIN_INDICATOR",
        "PRED_WIN_INDICATOR",
    )
)

In [82]:
round_1_results.sort('W_REGION').show(32)

------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"SEASON"  |"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"     |"L_SEED"  |"W_SEED"  |"L_REGION"  |"W_TEAM_NAME"  |"W_REGION"  |"WIN_INDICATOR"  |"PRED_WIN_INDICATOR"  |
------------------------------------------------------------------------------------------------------------------------------------------------------------------
|2023      |1272        |1194        |Memphis           |8         |9         |W           |FL Atlantic    |W           |1                |1.0                   |
|2023      |1418        |1397        |Louisiana         |13        |4         |W           |Tennessee      |W           |1                |1.0                   |
|2023      |1344        |1246        |Providence        |11        |6         |W           |Kentucky       |W           |1                |1.0                   |
|2023      |1331      

In [83]:
round_1_winners = round_1_results.with_column(
    "W_TEAM_ID",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_teamid)
    .otherwise(round_1_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_team_name)
    .otherwise(round_1_results.l_team_name),
).with_column(
    "W_SEED",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_seed)
    .otherwise(round_1_results.l_seed),
).with_column(
    "W_REGION",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_region)
    .otherwise(round_1_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# Winners round 1

In [84]:
round_1_winners.sort('W_REGION').show(32)

--------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |
--------------------------------------------------------
|1194         |FL Atlantic     |9         |W           |
|1397         |Tennessee       |4         |W           |
|1246         |Kentucky        |6         |W           |
|1331         |Oral Roberts    |12        |W           |
|1425         |USC             |10        |W           |
|1266         |Marquette       |2         |W           |
|1345         |Purdue          |1         |W           |
|1243         |Kansas St       |3         |W           |
|1112         |Arizona         |2         |X           |
|1124         |Baylor          |3         |X           |
|1268         |Maryland        |8         |X           |
|1438         |Virginia        |4         |X           |
|1104         |Alabama         |1         |X           |
|1158         |Col Charleston  |12        |X           |
|1166         |Creighton       

In [85]:
df1 = round_1_winners.cache_result()
df2 = round_1_winners.cache_result()

second_round_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed == 8)
        | (df1.w_seed == 1) & (df2.w_seed == 9)
        | (df1.w_seed == 16) & (df2.w_seed == 8)
        | (df1.w_seed == 16) & (df2.w_seed == 9)
        | (df1.w_seed == 4) & (df2.w_seed == 5)
        | (df1.w_seed == 4) & (df2.w_seed == 12)
        | (df1.w_seed == 13) & (df2.w_seed == 5)
        | (df1.w_seed == 13) & (df2.w_seed == 12)
        | (df1.w_seed == 3) & (df2.w_seed == 6)
        | (df1.w_seed == 3) & (df2.w_seed == 11)
        | (df1.w_seed == 14) & (df2.w_seed == 6)
        | (df1.w_seed == 14) & (df2.w_seed == 11)
        | (df1.w_seed == 2) & (df2.w_seed == 7)
        | (df1.w_seed == 2) & (df2.w_seed == 10)
        | (df1.w_seed == 15) & (df2.w_seed == 7)
        | (df1.w_seed == 15) & (df2.w_seed == 10)
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

second_round_matchups.sort('W_REGION').show(16)

------------------------------------------------------------------------------------------------------------
|"W_TEAM_NAME"  |"W_REGION"  |"WTEAMID2"  |"W_SEED"  |"L_TEAM_NAME"   |"L_REGION"  |"LTEAMID2"  |"L_SEED"  |
------------------------------------------------------------------------------------------------------------
|Kansas St      |W           |1243        |3         |Kentucky        |W           |1246        |6         |
|Purdue         |W           |1345        |1         |FL Atlantic     |W           |1194        |9         |
|Tennessee      |W           |1397        |4         |Oral Roberts    |W           |1331        |12        |
|Marquette      |W           |1266        |2         |USC             |W           |1425        |10        |
|Alabama        |X           |1104        |1         |Maryland        |X           |1268        |8         |
|Virginia       |X           |1438        |4         |Col Charleston  |X           |1158        |12        |
|Arizona        |X 

# ROUND OF 32!!!

In [86]:
second_round_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

-----------------------------------------------------------------
|"TEAM_1"     |"TEAM_2"        |"SEED_1"  |"SEED_2"  |"REGION"  |
-----------------------------------------------------------------
|Kansas St    |Kentucky        |3         |6         |W         |
|Purdue       |FL Atlantic     |1         |9         |W         |
|Tennessee    |Oral Roberts    |4         |12        |W         |
|Marquette    |USC             |2         |10        |W         |
|Alabama      |Maryland        |1         |8         |X         |
|Virginia     |Col Charleston  |4         |12        |X         |
|Arizona      |Missouri        |2         |7         |X         |
|Baylor       |Creighton       |3         |6         |X         |
|Houston      |Iowa            |1         |8         |Y         |
|Xavier       |Iowa St         |3         |6         |Y         |
|Texas        |Penn St         |2         |10        |Y         |
|Kent         |Miami FL        |13        |5         |Y         |
|Kansas   

In [90]:
session.table('FINAL_SEASON_STATS').show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"SEASON"  |"TEAMID"  |"SCORE_MEAN"  |"FGM_MEAN"  |"FGA_MEAN"  |"FGM3_MEAN"  |"FGA3_MEAN"  |"FTM_MEAN"  |"FTA_MEAN"  |"OR_MEAN"  |"DR_MEAN"  |"AST_ME

In [93]:
second_round = second_round_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2023).drop("W_SEASON")
)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = second_round.join(season_w, season_w.w_teamid == second_round.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == second_round.lTEAMID2)

games_32 = df2.drop('WTEAMID2','lTEAMID2')

games_32.count()

16

In [94]:
result = all_rounds.predict(games_32)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_2_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [95]:
round_2_winners = round_2_results.with_column(
    "W_TEAM_ID",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_teamid)
    .otherwise(round_2_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_team_name)
    .otherwise(round_2_results.l_team_name),
).with_column(
    "W_SEED",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_seed)
    .otherwise(round_2_results.l_seed),
).with_column(
    "W_REGION",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_region)
    .otherwise(round_2_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# Round 2 winners

In [96]:
round_2_winners.sort('W_REGION').show(20)

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1397         |Tennessee      |4         |W           |
|1266         |Marquette      |2         |W           |
|1246         |Kentucky       |6         |W           |
|1345         |Purdue         |1         |W           |
|1438         |Virginia       |4         |X           |
|1112         |Arizona        |2         |X           |
|1104         |Alabama        |1         |X           |
|1124         |Baylor         |3         |X           |
|1245         |Kent           |13        |Y           |
|1400         |Texas          |2         |Y           |
|1222         |Houston        |1         |Y           |
|1235         |Iowa St        |6         |Y           |
|1388         |St Mary's CA   |5         |Z           |
|1242         |Kansas         |1         |Z           |
|1417         |UCLA           |2         |Z     

In [97]:
df1 = round_2_winners.cache_result()
df2 = round_2_winners.cache_result()

third_round_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 16) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 8) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 9) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 6) & (df2.w_seed.isin([7,10,2,15])) |
        (df1.w_seed == 11) & (df2.w_seed.isin([7,10,2,15])) |
        (df1.w_seed == 3) & (df2.w_seed.isin([7,10,2,15]))|
        (df1.w_seed == 14) & (df2.w_seed.isin([7,10,2,15]))
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

third_round_matchups.sort("W_REGION").show(20)

-----------------------------------------------------------------------------------------------------------
|"W_TEAM_NAME"  |"W_REGION"  |"WTEAMID2"  |"W_SEED"  |"L_TEAM_NAME"  |"L_REGION"  |"LTEAMID2"  |"L_SEED"  |
-----------------------------------------------------------------------------------------------------------
|Kentucky       |W           |1246        |6         |Marquette      |W           |1266        |2         |
|Purdue         |W           |1345        |1         |Tennessee      |W           |1397        |4         |
|Alabama        |X           |1104        |1         |Virginia       |X           |1438        |4         |
|Baylor         |X           |1124        |3         |Arizona        |X           |1112        |2         |
|Houston        |Y           |1222        |1         |Kent           |Y           |1245        |13        |
|Iowa St        |Y           |1235        |6         |Texas          |Y           |1400        |2         |
|Kansas         |Z          

# SWEET SIXTEEN!!!

In [98]:
third_round_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

------------------------------------------------------------
|"TEAM_1"  |"TEAM_2"      |"SEED_1"  |"SEED_2"  |"REGION"  |
------------------------------------------------------------
|Kentucky  |Marquette     |6         |2         |W         |
|Purdue    |Tennessee     |1         |4         |W         |
|Alabama   |Virginia      |1         |4         |X         |
|Baylor    |Arizona       |3         |2         |X         |
|Houston   |Kent          |1         |13        |Y         |
|Iowa St   |Texas         |6         |2         |Y         |
|Kansas    |St Mary's CA  |1         |5         |Z         |
|Gonzaga   |UCLA          |3         |2         |Z         |
------------------------------------------------------------



In [99]:
third_round = third_round_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2023).drop("W_SEASON")
)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = third_round.join(season_w, season_w.w_teamid == third_round.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == third_round.lTEAMID2)

games_16 = df2.drop('WTEAMID2','lTEAMID2')

games_16.count()

8

In [100]:
result = all_rounds.predict(games_16)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

sweet_16_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [101]:
sweet_16_winners = sweet_16_results.with_column(
    "W_TEAM_ID",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_teamid)
    .otherwise(sweet_16_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_team_name)
    .otherwise(sweet_16_results.l_team_name),
).with_column(
    "W_SEED",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_seed)
    .otherwise(sweet_16_results.l_seed),
).with_column(
    "W_REGION",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_region)
    .otherwise(sweet_16_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# Sweet 16 winners!

In [102]:
sweet_16_winners.sort('W_REGION').show(20)

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1266         |Marquette      |2         |W           |
|1345         |Purdue         |1         |W           |
|1124         |Baylor         |3         |X           |
|1104         |Alabama        |1         |X           |
|1222         |Houston        |1         |Y           |
|1400         |Texas          |2         |Y           |
|1211         |Gonzaga        |3         |Z           |
|1388         |St Mary's CA   |5         |Z           |
-------------------------------------------------------



In [103]:
df1 = sweet_16_winners.cache_result()
df2 = sweet_16_winners.cache_result()

elite_8_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 16) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 8) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 9) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 5) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 12) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 4) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 13) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) 
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

elite_8_matchups.sort("W_REGION").show(20)

-----------------------------------------------------------------------------------------------------------
|"W_TEAM_NAME"  |"W_REGION"  |"WTEAMID2"  |"W_SEED"  |"L_TEAM_NAME"  |"L_REGION"  |"LTEAMID2"  |"L_SEED"  |
-----------------------------------------------------------------------------------------------------------
|Purdue         |W           |1345        |1         |Marquette      |W           |1266        |2         |
|Alabama        |X           |1104        |1         |Baylor         |X           |1124        |3         |
|Houston        |Y           |1222        |1         |Texas          |Y           |1400        |2         |
|St Mary's CA   |Z           |1388        |5         |Gonzaga        |Z           |1211        |3         |
-----------------------------------------------------------------------------------------------------------



# Elite 8

In [104]:
elite_8_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

-------------------------------------------------------------
|"TEAM_1"      |"TEAM_2"   |"SEED_1"  |"SEED_2"  |"REGION"  |
-------------------------------------------------------------
|Purdue        |Marquette  |1         |2         |W         |
|Alabama       |Baylor     |1         |3         |X         |
|Houston       |Texas      |1         |2         |Y         |
|St Mary's CA  |Gonzaga    |5         |3         |Z         |
-------------------------------------------------------------



In [105]:
elite_8 = elite_8_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2023).drop("W_SEASON")
)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = elite_8.join(season_w, season_w.w_teamid == elite_8.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == elite_8.lTEAMID2)

elite_8 = df2.drop('WTEAMID2','lTEAMID2')

elite_8.count()

4

In [106]:
result = all_rounds.predict(elite_8)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

elite_8_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [107]:
elite_8_winners = elite_8_results.with_column(
    "W_TEAM_ID",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_teamid)
    .otherwise(elite_8_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_team_name)
    .otherwise(elite_8_results.l_team_name),
).with_column(
    "W_SEED",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_seed)
    .otherwise(elite_8_results.l_seed),
).with_column(
    "W_REGION",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_region)
    .otherwise(elite_8_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# Elite 8 winners

In [108]:
elite_8_winners.sort('W_REGION').show(20)

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1266         |Marquette      |2         |W           |
|1104         |Alabama        |1         |X           |
|1222         |Houston        |1         |Y           |
|1211         |Gonzaga        |3         |Z           |
-------------------------------------------------------



In [109]:
df1 = elite_8_winners.filter(F.col("W_REGION") == "W")

df2 = elite_8_winners.filter(F.col("W_REGION") == "X").select(
    F.col("W_TEAM_ID").alias("TEAM_2"),
    F.col("W_TEAM_NAME").alias("W_TEAM_NAME_2"),
    F.col("W_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION_2"),
)

df3 = df1.join(df2)

df4 = elite_8_winners.filter(F.col("W_REGION") == "Y")

df5 = elite_8_winners.filter(F.col("W_REGION") == "Z").select(
    F.col("W_TEAM_ID").alias("TEAM_2"),
    F.col("W_TEAM_NAME").alias("W_TEAM_NAME_2"),
    F.col("W_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION_2"),
)

df6 = df4.join(df5)

final_four_matchups = df3.union(df6).select(
    F.col('W_TEAM_NAME'),
    F.col('W_REGION'),
    F.col("W_TEAM_ID").alias("WTEAMID2"),
    F.col('W_SEED'),
    F.col("W_TEAM_NAME_2").alias("L_TEAM_NAME"),
    F.col("REGION_2").alias("L_REGION"),
    F.col("TEAM_2").alias("LTEAMID2"),
    F.col("SEED_2").alias("L_SEED"),
)

# FINAL FOUR!

In [110]:
final_four_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

---------------------------------------------------------
|"TEAM_1"   |"TEAM_2"  |"SEED_1"  |"SEED_2"  |"REGION"  |
---------------------------------------------------------
|Marquette  |Alabama   |2         |1         |W         |
|Houston    |Gonzaga   |1         |3         |Y         |
---------------------------------------------------------



In [111]:
final_four = final_four_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2023).drop("W_SEASON")
)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = final_four.join(season_w, season_w.w_teamid == final_four.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == final_four.lTEAMID2)

final_four = df2.drop('WTEAMID2','lTEAMID2')

final_four.count()

2

In [112]:
result = all_rounds.predict(final_four)

teams = session.table("m_teams")


result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

final_four_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [113]:
final_four_winners = final_four_results.with_column(
    "W_TEAM_ID",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_teamid)
    .otherwise(final_four_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_team_name)
    .otherwise(final_four_results.l_team_name),
).with_column(
    "W_SEED",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_seed)
    .otherwise(final_four_results.l_seed),
).with_column(
    "W_REGION",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_region)
    .otherwise(final_four_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

In [114]:
final_four_winners.sort('W_REGION').show(20)

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1266         |Marquette      |2         |W           |
|1222         |Houston        |1         |Y           |
-------------------------------------------------------



In [115]:
df1 = final_four_winners.cache_result()
df2 = final_four_winners.cache_result()

championship_matchup = df1.join(df2).select(
    df1.w_team_name.alias('W_TEAM_NAME'),
    df1.w_region.alias('W_REGION'),
    df1.w_team_id.alias('WTEAMID2'),
    df1.w_seed.alias('W_SEED'),
    df2.w_team_name.alias('L_TEAM_NAME'),
    df2.w_region.alias('L_REGION'),
    df2.w_team_id.alias('LTEAMID2'),
    df2.w_seed.alias('L_SEED'),
).filter(F.col("WTEAMID2") != F.col("LTEAMID2")).limit(1)

# CHAMPIONSHIP!

In [116]:
championship_matchup.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

---------------------------------------------------------
|"TEAM_1"  |"TEAM_2"   |"SEED_1"  |"SEED_2"  |"REGION"  |
---------------------------------------------------------
|Houston   |Marquette  |1         |2         |Y         |
---------------------------------------------------------



In [117]:
championship = championship_matchup.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2023).drop("W_SEASON")
)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = championship.join(season_w, season_w.w_teamid == championship.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == championship.lTEAMID2)

championship = df2.drop('WTEAMID2','lTEAMID2')

championship.count()

1

In [118]:
result = all_rounds.predict(championship)

teams = session.table("m_teams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

championship_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [119]:
championship_results.show()

----------------------------------------------------------------------------------------------------------------------------------
|"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"  |"L_SEED"  |"W_SEED"  |"L_REGION"  |"W_TEAM_NAME"  |"W_REGION"  |"PRED_WIN_INDICATOR"  |
----------------------------------------------------------------------------------------------------------------------------------
|1266        |1222        |Marquette      |2         |1         |W           |Houston        |Y           |1.0                   |
----------------------------------------------------------------------------------------------------------------------------------



In [120]:
champion = championship_results.with_column(
    "W_TEAM_ID",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_teamid)
    .otherwise(championship_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_team_name)
    .otherwise(championship_results.l_team_name),
).with_column(
    "W_SEED",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_seed)
    .otherwise(championship_results.l_seed),
).with_column(
    "W_REGION",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_region)
    .otherwise(championship_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# CHAMPION!

In [121]:
champion.show()

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1222         |Houston        |1         |Y           |
-------------------------------------------------------



# Predict Play in games today

In [123]:
season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2024).drop("W_SEASON")
)
season.show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [124]:
grambling = (
    session.table("MEN.FINAL_SEASON_STATS").drop('N','H','A')
).filter(F.col('SEASON') == 2024).filter(F.col('TEAMID') == 1212).drop('SEASON').limit(1)

montana_st = (
    session.table("MEN.FINAL_SEASON_STATS").drop('N','H','A')
).filter(F.col('SEASON') == 2024).filter(F.col('TEAMID') == 1286).drop('SEASON').limit(1)

boise_st = (
    session.table("MEN.FINAL_SEASON_STATS").drop('N','H','A')
).filter(F.col('SEASON') == 2024).filter(F.col('TEAMID') == 1129).drop('SEASON').limit(1)

colorado = (
    session.table("MEN.FINAL_SEASON_STATS").drop('N','H','A')
).filter(F.col('SEASON') == 2024).filter(F.col('TEAMID') == 1160).drop('SEASON').limit(1)

for column in grambling.columns:
        grambling = grambling.withColumnRenamed(column, f"W_{column}")

for column in montana_st.columns:
        montana_st = montana_st.withColumnRenamed(column, f"L_{column}")

play_in = grambling.join(montana_st)

In [125]:
columns_to_add = [col for col in test.columns if col not in play_in.columns]
columns_to_add

['SEASON', 'W_SCORE', 'L_SCORE', 'ROUND', 'W_SEED', 'L_SEED', 'WIN_INDICATOR']

In [None]:
columns_to_add = [col for col in test.columns if col not in play_in.columns]

# For each column to add, add it to df1 with a default value of 0
for col in columns_to_add:
    play_in = play_in.withColumn(col, lit(0))

In [126]:
play_in = play_in.withColumn('L_SEED', lit(16))
play_in = play_in.withColumn('W_SEED', lit(16))

In [129]:
play_in_results = all_rounds.predict_proba(play_in)
play_in_results.with_column_renamed(
    '"predict_proba_1"', "grambling_win_prob"
).select('grambling_win_prob').show()

------------------------
|"GRAMBLING_WIN_PROB"  |
------------------------
|0.027495671063661575  |
------------------------



In [135]:
boise_st = (
    session.table("MEN.FINAL_SEASON_STATS").drop('N','H','A')
).filter(F.col('SEASON') == 2024).filter(F.col('TEAMID') == 1129).drop('SEASON').limit(1)

colorado = (
    session.table("MEN.FINAL_SEASON_STATS").drop('N','H','A')
).filter(F.col('SEASON') == 2024).filter(F.col('TEAMID') == 1160).drop('SEASON').limit(1)

In [136]:
for column in boise_st.columns:
        boise_st = boise_st.withColumnRenamed(column, f"W_{column}")

for column in colorado.columns:
        colorado = colorado.withColumnRenamed(column, f"L_{column}")

play_in = boise_st.join(colorado)

columns_to_add = [col for col in test.columns if col not in play_in.columns]


play_in = play_in.withColumn('L_SEED', lit(10))
play_in = play_in.withColumn('W_SEED', lit(10))

In [137]:
play_in_results = all_rounds.predict_proba(play_in)
play_in_results.with_column_renamed(
    '"predict_proba_1"', "boise_st_win_prob"
).select('boise_st_win_prob').show()

-----------------------
|"BOISE_ST_WIN_PROB"  |
-----------------------
|0.39441338181495667  |
-----------------------



In [127]:
# # Point to the registry

# reg = Registry(session=session)
# m = reg.get_model("MARCHMADNESS")
# m.default = 'V_1'

# # Get the default version of your model (Model with best accuracy in our case)

# mv = reg.get_model("MARCHMADNESS").default

# remote_prediction = mv.run(play_in, function_name="predict_proba")
# remote_prediction.drop('"output_feature_0"').with_column_renamed(
#     '"output_feature_1"', "grambling_win_prob"
# ).select('grambling_win_prob').show()

ValueError: (2112) 
Data Validation Error when validating your Snowpark DataFrame.
If using the normalized names from model signatures, there are the following errors:
[ValueError('(2112) Data Validation Error in feature W_WLOCN: Feature type DataType.INT64 is not met by column W_WLOCN because of its original type VariantType()'), ValueError('(2112) Data Validation Error in feature W_WLOCH: Feature type DataType.INT64 is not met by column W_WLOCH because of its original type VariantType()'), ValueError('(2112) Data Validation Error in feature W_WLOCA: Feature type DataType.INT64 is not met by column W_WLOCA because of its original type VariantType()'), ValueError('(2112) Data Validation Error in feature L_WLOCN: Feature type DataType.INT64 is not met by column L_WLOCN because of its original type VariantType()'), ValueError('(2112) Data Validation Error in feature L_WLOCH: Feature type DataType.INT64 is not met by column L_WLOCH because of its original type VariantType()'), ValueError('(2112) Data Validation Error in feature L_WLOCA: Feature type DataType.INT64 is not met by column L_WLOCA because of its original type VariantType()')]

If using the inferred names from model signatures, there are the following errors:
[ValueError('(2112) Data Validation Error in feature W_WLOCN: Feature type DataType.INT64 is not met by column W_WLOCN because of its original type VariantType()'), ValueError('(2112) Data Validation Error in feature W_WLOCH: Feature type DataType.INT64 is not met by column W_WLOCH because of its original type VariantType()'), ValueError('(2112) Data Validation Error in feature W_WLOCA: Feature type DataType.INT64 is not met by column W_WLOCA because of its original type VariantType()'), ValueError('(2112) Data Validation Error in feature L_WLOCN: Feature type DataType.INT64 is not met by column L_WLOCN because of its original type VariantType()'), ValueError('(2112) Data Validation Error in feature L_WLOCH: Feature type DataType.INT64 is not met by column L_WLOCH because of its original type VariantType()'), ValueError('(2112) Data Validation Error in feature L_WLOCA: Feature type DataType.INT64 is not met by column L_WLOCA because of its original type VariantType()')]


In [None]:
for column in boise_st.columns:
        boise_st = boise_st.withColumnRenamed(column, f"W_{column}")

for column in colorado.columns:
        colorado = colorado.withColumnRenamed(column, f"L_{column}")

play_in = boise_st.join(colorado)

columns_to_add = [col for col in test.columns if col not in play_in.columns]

# For each column to add, add it to df1 with a default value of 0
for col in columns_to_add:
    play_in = play_in.withColumn(col, lit(0))

play_in = play_in.drop(['W_CTWINS','W_AVERAGECTSCORE','L_CTWINS','L_AVERAGECTSCORE'])
play_in = play_in.drop(['W_WLOCN','W_WLOCH','W_WLOCA','L_WLOCN','L_WLOCH','L_WLOCA']) #variants
play_in = play_in.withColumn('L_SEED', lit(10))
play_in = play_in.withColumn('W_SEED', lit(10))

# Point to the registry

reg = Registry(session=session)
m = reg.get_model("MARCHMADNESS")
m.default = 'V_5'

# Get the default version of your model (Model with best accuracy in our case)

mv = reg.get_model("MARCHMADNESS").default

remote_prediction = mv.run(play_in, function_name="predict_proba")
remote_prediction.drop('"output_feature_0"').with_column_renamed(
    '"output_feature_1"', "boise_st_win_prob"
).select('boise_st_win_prob').show()