# START TRAINING

In [98]:
import ast
import json
import warnings

import pandas as pd
from snowflake.ml.modeling.impute import SimpleImputer
from snowflake.ml.modeling.metrics import accuracy_score
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.preprocessing import OneHotEncoder
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.registry import Registry
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session
from snowflake.snowpark import types as T
from snowflake.snowpark.functions import col
from snowflake.snowpark import functions as F
from snowflake.snowpark.functions import when, lit

import numpy as np

warnings.simplefilter(action="ignore", category=UserWarning)

In [99]:
session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()

In [100]:
final = (
    session.table("MEN.FINAL_TRAIN")
    .with_column_renamed("WTEAMID", "W_TEAMID")
    .with_column_renamed("LTEAMID", "L_TEAMID")
    .with_column_renamed("WSCORE", "W_SCORE")
    .with_column_renamed("LSCORE", "L_SCORE")
    .drop('ROUND_NUMBER')
)

In [101]:
# Columns with null values and their respective counts
{
    k: v
    for k, v in {
        col_name: final.where(F.col(col_name).is_null()).count()
        for col_name in final.columns
    }.items()
    if v > 0
}

{'W_CTWINS': 1302,
 'W_AVERAGECTSCORE': 1302,
 'L_CTWINS': 1302,
 'L_AVERAGECTSCORE': 1302}

In [102]:
final = final.drop(['W_CTWINS','W_AVERAGECTSCORE','L_CTWINS','L_AVERAGECTSCORE'])
final = final.drop(['W_WLOCN','W_WLOCH','W_WLOCA','L_WLOCN','L_WLOCH','L_WLOCA']) #variants

In [103]:
final = final.drop(['W_WLOCN','W_WLOCH','W_WLOCA','L_WLOCN','L_WLOCH','L_WLOCA']) #variants

In [104]:
parameters = {
    "n_estimators": [100, 200, 300, 400, 500],
    "learning_rate": [0.1, 0.2, 0.3, 0.4, 0.5],
    "max_depth": list(range(3, 6, 1)),
    "min_child_weight": list(range(1, 6, 1)),
}

In [105]:
train = final.filter(F.col('SEASON') <= 2021).filter(F.col('SEASON') >= 2010).drop('WIN_INDICATOR').cache_result()
test = final.filter(F.col('SEASON') > 2021).cache_result()

### Swap the W and L teams

In [106]:
df = train.to_pandas()
cols = df.columns
for index, row in df.iterrows():
    if np.random.rand() > 0.5:
        for col in cols:
            if col.startswith("W_"):
                df.at[index, col] = row["L_"+col[2:]]
            elif col.startswith("L_"):
                df.at[index, col] = row["W_"+col[2:]]

df = session.create_dataframe(df)
train = df.withColumn("WIN_INDICATOR", F.when(df["W_SCORE"] > df["L_SCORE"], 1).otherwise(0))


In [107]:
session.use_warehouse('MM_L')

In [108]:
all_rounds = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
all_rounds.fit(train)



<snowflake.ml.modeling.model_selection.grid_search_cv.GridSearchCV at 0x7fea698018d0>

In [109]:
#session.use_warehouse('wh_xs')

In [110]:
result = all_rounds.predict(test).filter(F.col("ROUND") == 1)
result.count()

64

In [111]:
accuracy_2022 = accuracy_score(
    df=result.filter(result.season == 2022), y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)
print(f"Accuracy 2022: {accuracy_2022}")

accuracy_2023 = accuracy_score(
    df=result.filter(result.season == 2023), y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)
print(f"Accuracy 2023: {accuracy_2023}")

accuracy = accuracy_score(
    df=result, y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)

print(f"Accuracy total: {accuracy}")

Accuracy 2022: 0.625
Accuracy 2023: 0.6875
Accuracy total: 0.65625


In [112]:
df = test.to_pandas()
cols = df.columns
for index, row in df.iterrows():
    if np.random.rand() > 0.5:
        for col in cols:
            if col.startswith("W_"):
                df.at[index, col] = row["L_"+col[2:]]
            elif col.startswith("L_"):
                df.at[index, col] = row["W_"+col[2:]]

df = session.create_dataframe(df)
test2 = df.withColumn("WIN_INDICATOR", F.when(df["W_SCORE"] > df["L_SCORE"], 1).otherwise(0))
result = all_rounds.predict(test2).filter(F.col("ROUND") == 1)
accuracy_2022 = accuracy_score(
    df=result.filter(result.season == 2022), y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)
print(f"Accuracy 2022: {accuracy_2022}")

accuracy_2023 = accuracy_score(
    df=result.filter(result.season == 2023), y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)
print(f"Accuracy 2023: {accuracy_2023}")

accuracy = accuracy_score(
    df=result, y_true_col_names="WIN_INDICATOR", y_pred_col_names="PRED_WIN_INDICATOR"
)

print(f"Accuracy total: {accuracy}")


Accuracy 2022: 0.59375
Accuracy 2023: 0.6875
Accuracy total: 0.640625


#### Register Models that are good

In [15]:
optimal_model = all_rounds.to_sklearn().best_estimator_

def check_and_update(df, model_name):
    if df.empty:
        return "V_1"
    elif df[df["name"] == model_name].empty:
        return "V_1"
    else:
        # Increment model_version if df is not a pandas Series
        lst = sorted(ast.literal_eval(df["versions"][0]))
        last_value = lst[-1]
        prefix, num = last_value.rsplit("_", 1)
        new_last_value = f"{prefix}_{int(num)+1}"
        lst[-1] = new_last_value
        return new_last_value
# Get sample input data to pass into the registry logging function
X = train.drop(['SEASON','WIN_INDICATOR','l_teamid','w_teamid','w_score','l_score','round','l_region','w_region']).limit(100)

# Create a registry and log the model
# You can specify a different DB and Schema if you'd like
# otherwise it uses the session context
reg = Registry(session=session)

reg_df = reg.show_models()

# Define model name and version (use uppercase for name)
model_name = "MARCHMADNESS"

model_version = check_and_update(reg_df, model_name)

mm_model = reg.log_model(
    model_name=model_name,
    version_name=model_version,
    model=optimal_model,
    sample_input_data=X,
)

# Add evaluation metric
mm_model.set_metric(
    metric_name="accuracy_2023",
    value= accuracy_2023
)

mm_model.set_metric(
    metric_name="accuracy_2022",
    value= accuracy_2022,
)
reg.get_model(model_name).show_versions()

Unnamed: 0,created_on,name,comment,database_name,schema_name,module_name,is_default_version,functions,metadata,user_data
0,2024-03-19 05:33:27.323000-07:00,V_1,,MARCHMADNESS,MEN,MARCHMADNESS,False,"[""PREDICT_PROBA"",""PREDICT"",""APPLY""]","{""metrics"": {""accuracy_2023"": 0.78125, ""accura...","{""snowpark_ml_data"":{""functions"":[{""name"":""APP..."
1,2024-03-19 07:53:12.886000-07:00,V_2,,MARCHMADNESS,MEN,MARCHMADNESS,True,"[""PREDICT_PROBA"",""PREDICT"",""APPLY""]","{""metrics"": {""accuracy_2023"": 0.8125, ""accurac...","{""snowpark_ml_data"":{""functions"":[{""name"":""APP..."
2,2024-03-19 08:04:18.334000-07:00,V_3,,MARCHMADNESS,MEN,MARCHMADNESS,False,"[""PREDICT_PROBA"",""PREDICT"",""APPLY""]","{""metrics"": {""accuracy_2023"": 0.59375, ""accura...","{""snowpark_ml_data"":{""functions"":[{""name"":""APP..."
3,2024-03-19 13:47:49.766000-07:00,V_4,,MARCHMADNESS,MEN,MARCHMADNESS,False,"[""PREDICT_PROBA"",""PREDICT"",""APPLY""]","{""metrics"": {""accuracy_2023"": 0.78125, ""accura...","{""snowpark_ml_data"":{""functions"":[{""name"":""APP..."


# Predicting the Bracket & final four (we want as many final 4 as possible)

## Results play in games

In [16]:
result = all_rounds.predict(test).filter(F.col('ROUND') == 0).filter(F.col("season") == 2023)

teams = session.table('mteams')
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = result.select('season','l_teamID','l_seed','l_region','w_teamid','w_seed','w_region','win_indicator','pred_win_indicator')

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

res_teams = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select('season','l_teamid','w_teamid','L_TEAM_NAME','l_seed','l_region','W_TEAM_NAME','w_seed','w_region','WIN_INDICATOR','PRED_WIN_INDICATOR')
)

res_teams.sort(F.col("w_team_name")).show()

-----------------------------------------------------------------------------------------------------------------------------------------------------------------
|"SEASON"  |"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"   |"L_SEED"  |"L_REGION"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |"WIN_INDICATOR"  |"PRED_WIN_INDICATOR"  |
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
|2023      |1305        |1113        |Nevada          |11        |Z           |Arizona St      |11        |Z           |1                |1.0                   |
|2023      |1411        |1192        |TX Southern     |16        |W           |F Dickinson     |16        |W           |1                |0.0                   |
|2023      |1280        |1338        |Mississippi St  |11        |Y           |Pittsburgh      |11        |Y           |1                |1.0                   |
|2023      |1369        |139

## Results round 1

In [17]:
result = all_rounds.predict(test.filter(F.col('ROUND') == 1)).filter(F.col("season") == 2023)

teams = session.table('mteams')
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = result.select('season','l_teamID','l_seed','l_region','w_teamid','w_seed','w_region','win_indicator','pred_win_indicator')

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

res_teams = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select('season','l_teamid','w_teamid','L_TEAM_NAME','l_seed','l_region','W_TEAM_NAME','w_seed','w_region','WIN_INDICATOR','PRED_WIN_INDICATOR')
)

print('The round of 32')
res_teams.count()

The round of 32


32

In [18]:
result = all_rounds.predict(test.filter(F.col("ROUND") == 1)).filter(F.col("season") == 2023)

teams = session.table("mteams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = (
    result.select(
        "season",
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "win_indicator",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_1_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "season",
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "WIN_INDICATOR",
        "PRED_WIN_INDICATOR",
    )
)

In [19]:
round_1_results.sort('W_REGION').show(32)

------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"SEASON"  |"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"     |"L_SEED"  |"W_SEED"  |"L_REGION"  |"W_TEAM_NAME"  |"W_REGION"  |"WIN_INDICATOR"  |"PRED_WIN_INDICATOR"  |
------------------------------------------------------------------------------------------------------------------------------------------------------------------
|2023      |1272        |1194        |Memphis           |8         |9         |W           |FL Atlantic    |W           |1                |1.0                   |
|2023      |1418        |1397        |Louisiana         |13        |4         |W           |Tennessee      |W           |1                |1.0                   |
|2023      |1344        |1246        |Providence        |11        |6         |W           |Kentucky       |W           |1                |0.0                   |
|2023      |1331      

In [20]:
round_1_winners = round_1_results.with_column(
    "W_TEAM_ID",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_teamid)
    .otherwise(round_1_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_team_name)
    .otherwise(round_1_results.l_team_name),
).with_column(
    "W_SEED",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_seed)
    .otherwise(round_1_results.l_seed),
).with_column(
    "W_REGION",
    when((round_1_results.PRED_WIN_INDICATOR == 1), round_1_results.w_region)
    .otherwise(round_1_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# Winners round 1

In [21]:
round_1_winners.sort('W_REGION').show(32)

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1194         |FL Atlantic    |9         |W           |
|1397         |Tennessee      |4         |W           |
|1344         |Providence     |11        |W           |
|1181         |Duke           |5         |W           |
|1425         |USC            |10        |W           |
|1266         |Marquette      |2         |W           |
|1345         |Purdue         |1         |W           |
|1243         |Kansas St      |3         |W           |
|1112         |Arizona        |2         |X           |
|1124         |Baylor         |3         |X           |
|1268         |Maryland       |8         |X           |
|1202         |Furman         |13        |X           |
|1104         |Alabama        |1         |X           |
|1361         |San Diego St   |5         |X           |
|1166         |Creighton      |6         |X     

In [22]:
df1 = round_1_winners.cache_result()
df2 = round_1_winners.cache_result()

second_round_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed == 8)
        | (df1.w_seed == 1) & (df2.w_seed == 9)
        | (df1.w_seed == 16) & (df2.w_seed == 8)
        | (df1.w_seed == 16) & (df2.w_seed == 9)
        | (df1.w_seed == 4) & (df2.w_seed == 5)
        | (df1.w_seed == 4) & (df2.w_seed == 12)
        | (df1.w_seed == 13) & (df2.w_seed == 5)
        | (df1.w_seed == 13) & (df2.w_seed == 12)
        | (df1.w_seed == 3) & (df2.w_seed == 6)
        | (df1.w_seed == 3) & (df2.w_seed == 11)
        | (df1.w_seed == 14) & (df2.w_seed == 6)
        | (df1.w_seed == 14) & (df2.w_seed == 11)
        | (df1.w_seed == 2) & (df2.w_seed == 7)
        | (df1.w_seed == 2) & (df2.w_seed == 10)
        | (df1.w_seed == 15) & (df2.w_seed == 7)
        | (df1.w_seed == 15) & (df2.w_seed == 10)
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

second_round_matchups.sort('W_REGION').show(16)

-----------------------------------------------------------------------------------------------------------
|"W_TEAM_NAME"  |"W_REGION"  |"WTEAMID2"  |"W_SEED"  |"L_TEAM_NAME"  |"L_REGION"  |"LTEAMID2"  |"L_SEED"  |
-----------------------------------------------------------------------------------------------------------
|Kansas St      |W           |1243        |3         |Providence     |W           |1344        |11        |
|Purdue         |W           |1345        |1         |FL Atlantic    |W           |1194        |9         |
|Tennessee      |W           |1397        |4         |Duke           |W           |1181        |5         |
|Marquette      |W           |1266        |2         |USC            |W           |1425        |10        |
|Alabama        |X           |1104        |1         |Maryland       |X           |1268        |8         |
|Furman         |X           |1202        |13        |San Diego St   |X           |1361        |5         |
|Arizona        |X          

# ROUND OF 32!!!

In [23]:
second_round_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

---------------------------------------------------------------
|"TEAM_1"     |"TEAM_2"      |"SEED_1"  |"SEED_2"  |"REGION"  |
---------------------------------------------------------------
|Kansas St    |Providence    |3         |11        |W         |
|Purdue       |FL Atlantic   |1         |9         |W         |
|Tennessee    |Duke          |4         |5         |W         |
|Marquette    |USC           |2         |10        |W         |
|Alabama      |Maryland      |1         |8         |X         |
|Furman       |San Diego St  |13        |5         |X         |
|Arizona      |Missouri      |2         |7         |X         |
|Baylor       |Creighton     |3         |6         |X         |
|Houston      |Auburn        |1         |9         |Y         |
|Xavier       |Pittsburgh    |3         |11        |Y         |
|Texas        |Texas A&M     |2         |7         |Y         |
|Indiana      |Miami FL      |4         |5         |Y         |
|Kansas       |Illinois      |1         

In [24]:
second_round = second_round_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2023).drop("W_SEASON")
)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = second_round.join(season_w, season_w.w_teamid == second_round.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == second_round.lTEAMID2)

games_32 = df2.drop('WTEAMID2','lTEAMID2')

games_32.count()

16

In [25]:
result = all_rounds.predict(games_32)

teams = session.table("mteams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_2_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [26]:
round_2_winners = round_2_results.with_column(
    "W_TEAM_ID",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_teamid)
    .otherwise(round_2_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_team_name)
    .otherwise(round_2_results.l_team_name),
).with_column(
    "W_SEED",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_seed)
    .otherwise(round_2_results.l_seed),
).with_column(
    "W_REGION",
    when((round_2_results.PRED_WIN_INDICATOR == 1), round_2_results.w_region)
    .otherwise(round_2_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# Round 2 winners

In [27]:
round_2_winners.sort('W_REGION').show(20)

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1397         |Tennessee      |4         |W           |
|1344         |Providence     |11        |W           |
|1266         |Marquette      |2         |W           |
|1345         |Purdue         |1         |W           |
|1124         |Baylor         |3         |X           |
|1112         |Arizona        |2         |X           |
|1104         |Alabama        |1         |X           |
|1361         |San Diego St   |5         |X           |
|1231         |Indiana        |4         |Y           |
|1400         |Texas          |2         |Y           |
|1222         |Houston        |1         |Y           |
|1462         |Xavier         |3         |Y           |
|1163         |Connecticut    |4         |Z           |
|1211         |Gonzaga        |3         |Z           |
|1417         |UCLA           |2         |Z     

In [28]:
df1 = round_2_winners.cache_result()
df2 = round_2_winners.cache_result()

third_round_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 16) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 8) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 9) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 6) & (df2.w_seed.isin([7,10,2,15])) |
        (df1.w_seed == 11) & (df2.w_seed.isin([7,10,2,15])) |
        (df1.w_seed == 3) & (df2.w_seed.isin([7,10,2,15]))|
        (df1.w_seed == 14) & (df2.w_seed.isin([7,10,2,15]))
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

third_round_matchups.sort("W_REGION").show(20)

-----------------------------------------------------------------------------------------------------------
|"W_TEAM_NAME"  |"W_REGION"  |"WTEAMID2"  |"W_SEED"  |"L_TEAM_NAME"  |"L_REGION"  |"LTEAMID2"  |"L_SEED"  |
-----------------------------------------------------------------------------------------------------------
|Providence     |W           |1344        |11        |Marquette      |W           |1266        |2         |
|Purdue         |W           |1345        |1         |Tennessee      |W           |1397        |4         |
|Baylor         |X           |1124        |3         |Arizona        |X           |1112        |2         |
|Alabama        |X           |1104        |1         |San Diego St   |X           |1361        |5         |
|Houston        |Y           |1222        |1         |Indiana        |Y           |1231        |4         |
|Xavier         |Y           |1462        |3         |Texas          |Y           |1400        |2         |
|Gonzaga        |Z          

# SWEET SIXTEEN!!!

In [29]:
third_round_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

--------------------------------------------------------------
|"TEAM_1"    |"TEAM_2"      |"SEED_1"  |"SEED_2"  |"REGION"  |
--------------------------------------------------------------
|Providence  |Marquette     |11        |2         |W         |
|Purdue      |Tennessee     |1         |4         |W         |
|Baylor      |Arizona       |3         |2         |X         |
|Alabama     |San Diego St  |1         |5         |X         |
|Houston     |Indiana       |1         |4         |Y         |
|Xavier      |Texas         |3         |2         |Y         |
|Gonzaga     |UCLA          |3         |2         |Z         |
|Kansas      |Connecticut   |1         |4         |Z         |
--------------------------------------------------------------



In [30]:
third_round = third_round_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2023).drop("W_SEASON")
)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = third_round.join(season_w, season_w.w_teamid == third_round.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == third_round.lTEAMID2)

games_16 = df2.drop('WTEAMID2','lTEAMID2')

games_16.count()

8

In [31]:
result = all_rounds.predict(games_16)

teams = session.table("mteams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

sweet_16_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [32]:
sweet_16_winners = sweet_16_results.with_column(
    "W_TEAM_ID",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_teamid)
    .otherwise(sweet_16_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_team_name)
    .otherwise(sweet_16_results.l_team_name),
).with_column(
    "W_SEED",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_seed)
    .otherwise(sweet_16_results.l_seed),
).with_column(
    "W_REGION",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_region)
    .otherwise(sweet_16_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# Sweet 16 winners!

In [33]:
sweet_16_winners.sort('W_REGION').show(20)

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1266         |Marquette      |2         |W           |
|1345         |Purdue         |1         |W           |
|1112         |Arizona        |2         |X           |
|1104         |Alabama        |1         |X           |
|1222         |Houston        |1         |Y           |
|1400         |Texas          |2         |Y           |
|1417         |UCLA           |2         |Z           |
|1242         |Kansas         |1         |Z           |
-------------------------------------------------------



In [34]:
df1 = sweet_16_winners.cache_result()
df2 = sweet_16_winners.cache_result()

elite_8_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 16) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 8) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 9) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 5) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 12) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 4) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 13) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) 
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

elite_8_matchups.sort("W_REGION").show(20)

-----------------------------------------------------------------------------------------------------------
|"W_TEAM_NAME"  |"W_REGION"  |"WTEAMID2"  |"W_SEED"  |"L_TEAM_NAME"  |"L_REGION"  |"LTEAMID2"  |"L_SEED"  |
-----------------------------------------------------------------------------------------------------------
|Purdue         |W           |1345        |1         |Marquette      |W           |1266        |2         |
|Alabama        |X           |1104        |1         |Arizona        |X           |1112        |2         |
|Houston        |Y           |1222        |1         |Texas          |Y           |1400        |2         |
|Kansas         |Z           |1242        |1         |UCLA           |Z           |1417        |2         |
-----------------------------------------------------------------------------------------------------------



# Elite 8

In [35]:
elite_8_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

---------------------------------------------------------
|"TEAM_1"  |"TEAM_2"   |"SEED_1"  |"SEED_2"  |"REGION"  |
---------------------------------------------------------
|Purdue    |Marquette  |1         |2         |W         |
|Alabama   |Arizona    |1         |2         |X         |
|Houston   |Texas      |1         |2         |Y         |
|Kansas    |UCLA       |1         |2         |Z         |
---------------------------------------------------------



In [36]:
elite_8 = elite_8_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2023).drop("W_SEASON")
)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = elite_8.join(season_w, season_w.w_teamid == elite_8.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == elite_8.lTEAMID2)

elite_8 = df2.drop('WTEAMID2','lTEAMID2')

elite_8.count()

4

In [37]:
result = all_rounds.predict(elite_8)

teams = session.table("mteams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

elite_8_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [38]:
elite_8_winners = elite_8_results.with_column(
    "W_TEAM_ID",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_teamid)
    .otherwise(elite_8_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_team_name)
    .otherwise(elite_8_results.l_team_name),
).with_column(
    "W_SEED",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_seed)
    .otherwise(elite_8_results.l_seed),
).with_column(
    "W_REGION",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_region)
    .otherwise(elite_8_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# Elite 8 winners

In [39]:
elite_8_winners.sort('W_REGION').show(20)

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1345         |Purdue         |1         |W           |
|1104         |Alabama        |1         |X           |
|1222         |Houston        |1         |Y           |
|1417         |UCLA           |2         |Z           |
-------------------------------------------------------



In [40]:
df1 = elite_8_winners.filter(F.col("W_REGION") == "W")

df2 = elite_8_winners.filter(F.col("W_REGION") == "X").select(
    F.col("W_TEAM_ID").alias("TEAM_2"),
    F.col("W_TEAM_NAME").alias("W_TEAM_NAME_2"),
    F.col("W_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION_2"),
)

df3 = df1.join(df2)

df4 = elite_8_winners.filter(F.col("W_REGION") == "Y")

df5 = elite_8_winners.filter(F.col("W_REGION") == "Z").select(
    F.col("W_TEAM_ID").alias("TEAM_2"),
    F.col("W_TEAM_NAME").alias("W_TEAM_NAME_2"),
    F.col("W_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION_2"),
)

df6 = df4.join(df5)

final_four_matchups = df3.union(df6).select(
    F.col('W_TEAM_NAME'),
    F.col('W_REGION'),
    F.col("W_TEAM_ID").alias("WTEAMID2"),
    F.col('W_SEED'),
    F.col("W_TEAM_NAME_2").alias("L_TEAM_NAME"),
    F.col("REGION_2").alias("L_REGION"),
    F.col("TEAM_2").alias("LTEAMID2"),
    F.col("SEED_2").alias("L_SEED"),
)

# FINAL FOUR!

In [41]:
final_four_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

--------------------------------------------------------
|"TEAM_1"  |"TEAM_2"  |"SEED_1"  |"SEED_2"  |"REGION"  |
--------------------------------------------------------
|Purdue    |Alabama   |1         |1         |W         |
|Houston   |UCLA      |1         |2         |Y         |
--------------------------------------------------------



In [42]:
final_four = final_four_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2023).drop("W_SEASON")
)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = final_four.join(season_w, season_w.w_teamid == final_four.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == final_four.lTEAMID2)

final_four = df2.drop('WTEAMID2','lTEAMID2')

final_four.count()

2

In [43]:
result = all_rounds.predict(final_four)

teams = session.table("mteams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

final_four_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [44]:
final_four_winners = final_four_results.with_column(
    "W_TEAM_ID",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_teamid)
    .otherwise(final_four_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_team_name)
    .otherwise(final_four_results.l_team_name),
).with_column(
    "W_SEED",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_seed)
    .otherwise(final_four_results.l_seed),
).with_column(
    "W_REGION",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_region)
    .otherwise(final_four_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

In [45]:
final_four_winners.sort('W_REGION').show(20)

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1104         |Alabama        |1         |X           |
|1222         |Houston        |1         |Y           |
-------------------------------------------------------



In [46]:
df1 = final_four_winners.cache_result()
df2 = final_four_winners.cache_result()

championship_matchup = df1.join(df2).select(
    df1.w_team_name.alias('W_TEAM_NAME'),
    df1.w_region.alias('W_REGION'),
    df1.w_team_id.alias('WTEAMID2'),
    df1.w_seed.alias('W_SEED'),
    df2.w_team_name.alias('L_TEAM_NAME'),
    df2.w_region.alias('L_REGION'),
    df2.w_team_id.alias('LTEAMID2'),
    df2.w_seed.alias('L_SEED'),
).filter(F.col("WTEAMID2") != F.col("LTEAMID2")).limit(1)

# CHAMPIONSHIP!

In [47]:
championship_matchup.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

--------------------------------------------------------
|"TEAM_1"  |"TEAM_2"  |"SEED_1"  |"SEED_2"  |"REGION"  |
--------------------------------------------------------
|Houston   |Alabama   |1         |1         |Y         |
--------------------------------------------------------



In [48]:
championship = championship_matchup.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = (
    session.table("FINAL_SEASON_STATS").filter(F.col("season") == 2023).drop("W_SEASON")
)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = championship.join(season_w, season_w.w_teamid == championship.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == championship.lTEAMID2)

championship = df2.drop('WTEAMID2','lTEAMID2')

championship.count()

1

In [49]:
result = all_rounds.predict(championship)

teams = session.table("mteams")
for col in teams.columns:
    teams = teams.withColumnRenamed(col, col.upper())

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

championship_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [50]:
championship_results.show()

----------------------------------------------------------------------------------------------------------------------------------
|"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"  |"L_SEED"  |"W_SEED"  |"L_REGION"  |"W_TEAM_NAME"  |"W_REGION"  |"PRED_WIN_INDICATOR"  |
----------------------------------------------------------------------------------------------------------------------------------
|1104        |1222        |Alabama        |1         |1         |X           |Houston        |Y           |1.0                   |
----------------------------------------------------------------------------------------------------------------------------------



In [51]:
champion = championship_results.with_column(
    "W_TEAM_ID",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_teamid)
    .otherwise(championship_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_team_name)
    .otherwise(championship_results.l_team_name),
).with_column(
    "W_SEED",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_seed)
    .otherwise(championship_results.l_seed),
).with_column(
    "W_REGION",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_region)
    .otherwise(championship_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# CHAMPION!

In [52]:
champion.show()

-------------------------------------------------------
|"W_TEAM_ID"  |"W_TEAM_NAME"  |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1222         |Houston        |1         |Y           |
-------------------------------------------------------



# Predict Play in games today

In [66]:
howard = (
    session.table("MEN.SEASON_STATS_2024").drop('N','H','A')
).filter(F.col('SEASON') == 2024).filter(F.col('TEAMID') == 1224).drop('SEASON').limit(1)

wagner = (
    session.table("MEN.SEASON_STATS_2024").drop('N','H','A')
).filter(F.col('SEASON') == 2024).filter(F.col('TEAMID') == 1447).drop('SEASON').limit(1)

colst = (
    session.table("MEN.SEASON_STATS_2024").drop('N','H','A')
).filter(F.col('SEASON') == 2024).filter(F.col('TEAMID') == 1161).drop('SEASON').limit(1)

va = (
    session.table("MEN.SEASON_STATS_2024").drop('N','H','A')
).filter(F.col('SEASON') == 2024).filter(F.col('TEAMID') == 1438).drop('SEASON').limit(1)

for column in va.columns:
        va = va.withColumnRenamed(column, f"W_{column}")

for column in colst.columns:
        colst = colst.withColumnRenamed(column, f"L_{column}")

play_in = va.join(colst)

In [67]:
columns_to_add = [col for col in test.columns if col not in play_in.columns]

# For each column to add, add it to df1 with a default value of 0
for col in columns_to_add:
    play_in = play_in.withColumn(col, lit(0))

In [68]:
columns_to_add

['SEASON',
 'W_SCORE',
 'L_SCORE',
 'W_SEED',
 'L_SEED',
 'W_REGION',
 'L_REGION',
 'WIN_INDICATOR',
 'ROUND']

In [69]:
play_in = play_in.drop(['W_CTWINS','W_AVERAGECTSCORE','L_CTWINS','L_AVERAGECTSCORE'])
play_in = play_in.drop(['W_WLOCN','W_WLOCH','W_WLOCA','L_WLOCN','L_WLOCH','L_WLOCA']) #variants
play_in = play_in.withColumn('L_SEED', lit(10))
play_in = play_in.withColumn('W_SEED', lit(10))

In [56]:
# Point to the registry

reg = Registry(session=session)
m = reg.get_model("MARCHMADNESS")
m.default = 'V_4'

# Get the default version of your model (Model with best accuracy in our case)

mv = reg.get_model("MARCHMADNESS").default

remote_prediction = mv.run(play_in, function_name="predict_proba")
remote_prediction.drop('"output_feature_0"').with_column_renamed(
    '"output_feature_1"', "va_win_prob"
).select('va_win_prob').show()

----------------------
|"VA_WIN_PROB"       |
----------------------
|0.7778103947639465  |
----------------------



In [72]:
for column in wagner.columns:
        wagner = wagner.withColumnRenamed(column, f"W_{column}")

for column in howard.columns:
        howard = howard.withColumnRenamed(column, f"L_{column}")

play_in = wagner.join(howard)

columns_to_add = [col for col in test.columns if col not in play_in.columns]

# For each column to add, add it to df1 with a default value of 0
for col in columns_to_add:
    play_in = play_in.withColumn(col, lit(0))

play_in = play_in.drop(['W_CTWINS','W_AVERAGECTSCORE','L_CTWINS','L_AVERAGECTSCORE'])
play_in = play_in.drop(['W_WLOCN','W_WLOCH','W_WLOCA','L_WLOCN','L_WLOCH','L_WLOCA']) #variants
play_in = play_in.withColumn('L_SEED', lit(16))
play_in = play_in.withColumn('W_SEED', lit(16))

# Point to the registry

reg = Registry(session=session)
m = reg.get_model("MARCHMADNESS")
m.default = 'V_4'

# Get the default version of your model (Model with best accuracy in our case)

mv = reg.get_model("MARCHMADNESS").default

remote_prediction = mv.run(play_in, function_name="predict_proba")
remote_prediction.drop('"output_feature_0"').with_column_renamed(
    '"output_feature_1"', "wagner_win_prob"
).select('wagner_win_prob').show()

----------------------
|"WAGNER_WIN_PROB"   |
----------------------
|0.5396665334701538  |
----------------------



In [71]:
df = play_in.to_pandas()
cols = df.columns
for index, row in df.iterrows():
    if np.random.rand() > 0.5:
        for col in cols:
            if col.startswith("W_"):
                df.at[index, col] = row["L_"+col[2:]]
            elif col.startswith("L_"):
                df.at[index, col] = row["W_"+col[2:]]

df = session.create_dataframe(df)

remote_prediction = mv.run(df, function_name="predict_proba")
remote_prediction.drop('"output_feature_0"').with_column_renamed(
    '"output_feature_1"', "howard_win_prob"
).select('howard_win_prob').show()

----------------------
|"HOWARD_WIN_PROB"   |
----------------------
|0.5555117130279541  |
----------------------

