# START TRAINING

In [3]:
import ast
import json
import warnings

import pandas as pd
from snowflake.ml.modeling.impute import SimpleImputer
from snowflake.ml.modeling.metrics import accuracy_score
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.modeling.preprocessing import OneHotEncoder
from snowflake.ml.modeling.xgboost import XGBRegressor
from snowflake.ml.registry import Registry
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
from snowflake.snowpark import Session
from snowflake.snowpark import types as T
from snowflake.snowpark.functions import col
from snowflake.snowpark import functions as F
from snowflake.snowpark.functions import when, lit
from snowflake.ml.modeling.metrics import mean_absolute_error

import numpy as np

warnings.simplefilter(action="ignore", category=UserWarning)

In [4]:
session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()

In [5]:
final = (
    session.table("MEN.FINAL_FEATURES")
    .with_column_renamed("WTEAMID", "W_TEAMID")
    .with_column_renamed("LTEAMID", "L_TEAMID")
    .with_column_renamed("WSCORE", "W_SCORE")
    .with_column_renamed("LSCORE", "L_SCORE")
)

In [6]:
parameters = {
    "n_estimators": [100, 200, 300, 400, 500],
    "learning_rate": [0.1, 0.2, 0.3, 0.4, 0.5],
    "max_depth": list(range(3, 6, 1)),
    "min_child_weight": list(range(1, 6, 1)),
}

In [7]:
def add_losing_matches(df):
    lose_rename = {
        "W_TEAMID": "L_TEAMID", 
        "W_SCORE" : "L_SCORE", 
        "L_TEAMID" : "W_TEAMID",
        "L_SCORE": "W_SCORE",
    }
    lose_rename.update({c : "L_" + c[2:] for c in df.columns if c.startswith('W_')})
    lose_rename.update({c : "W_" + c[2:] for c in df.columns if c.startswith('L_')})
    
    win_df = df.copy()
    lose_df = df.copy()
    
    lose_df = lose_df.rename(columns=lose_rename)
    
    return pd.concat([win_df, lose_df], axis=0, sort=False)

In [8]:
final.write.save_as_table(
    "MEN.FINAL_MODEL_TRAIN", mode="overwrite"
)

In [9]:
train = final.filter(F.col('SEASON') <= 2021)
test = final.filter(F.col('SEASON') > 2021)

In [10]:
df_pd = train.to_pandas()
train_2 = add_losing_matches(df_pd)
train_2 = session.create_dataframe(train_2)

In [11]:
train_2 = train_2.withColumn("SPREAD", col("W_SCORE") - col("L_SCORE"))
test = test.withColumn("SPREAD", col("W_SCORE") - col("L_SCORE"))

In [12]:
session.use_warehouse('MM_L')

In [93]:
all_rounds = GridSearchCV(
    estimator=XGBRegressor(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="neg_mean_absolute_percentage_error",
    input_cols=train_2.drop(
        [
            "SEASON",
            "SPREAD",
            "l_teamid",
            "w_teamid",
            "w_score",
            "l_score",
            "round",
            "l_region",
            "w_region",
            "win_indicator"
        ]
    ).columns,
    label_cols="SPREAD",
    output_cols="PRED_SPREAD",
)

# Train
all_rounds.fit(train_2.filter(F.col("ROUND") == 3))

Package 'fastparquet' is not installed in the local environment. Your UDF might not work when the package is installed on the server but not on your local environment.
The version of package 'pyarrow' in the local environment is 15.0.1, which does not fit the criteria for the requirement 'pyarrow<14'. Your UDF might not work when the package version is different between the server and your local environment.
The version of package 'cachetools' in the local environment is 5.3.3, which does not fit the criteria for the requirement 'cachetools<6'. Your UDF might not work when the package version is different between the server and your local environment.


<snowflake.ml.modeling.model_selection.grid_search_cv.GridSearchCV at 0x7fb2f1ced270>

In [94]:
result = all_rounds.predict(test).filter(F.col("ROUND") == 1)
result.count()

64

In [95]:
mae_2022 = mean_absolute_error(
    df=result.filter(result.season == 2022),
    y_true_col_names="SPREAD",
    y_pred_col_names="PRED_SPREAD",
)
print(f"MAE 2022: {mae_2022}")

mae_2023 = mean_absolute_error(
    df=result.filter(result.season == 2023),
    y_true_col_names="SPREAD",
    y_pred_col_names="PRED_SPREAD",
)
print(f"MAE 2023: {mae_2023}")

MAE = mean_absolute_error(
    df=result, y_true_col_names="SPREAD", y_pred_col_names="PRED_SPREAD"
)

print(f"MAE total: {MAE}")

MAE 2022: 13.131337739992887
MAE 2023: 10.176728605758399
MAE total: 11.654033172875643


# Predicting the Bracket & final four

## 2024 Data

In [16]:
season_2024 = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024).drop('REGION')

In [17]:
tourney_2024 = session.table('COMMON.TOURNEY_SEEDS_2024').filter(F.col('TOURNAMENT') == 'M')

In [18]:
tourney_2024 = session.table('MEN.M_NCAATOURNEY_SEEDS').filter(F.col('season') == 2024)

### Fix for the play in game error from Kaggle

In [19]:
tourney_2024 = tourney_2024.filter(~col("TEAMID").isin([1129, 1286, 1224, 1438]))

In [20]:
tourney_2024.count()

64

In [21]:
seed_value = (
    tourney_2024
    .with_column("REGION", F.substring(F.col("SEED"), 1, 1))
    .with_column(
        "SEED", F.substring(F.col("SEED"), 2, F.length(F.col("SEED")) - 1)
    )
    .select("TEAMID", "REGION", "SEED")
    .with_column(
        "SEED",
        F.cast(F.regexp_replace(F.col("SEED"), "[a-z]", ""), T.IntegerType()),
    )
).select(F.col('TEAMID').alias('TEAMID1'), 'REGION','SEED')

In [22]:
round_of_64 = season_2024.join(seed_value, season_2024.TEAMID == seed_value.TEAMID1).drop('TEAMID1')


In [23]:
df1 = round_of_64.select(
    *[F.col(col).alias(f"W_{col}") for col in round_of_64.columns]
)

df2 = round_of_64.select(
    *[F.col(col).alias(f"L_{col}") for col in round_of_64.columns]
)

In [24]:
first_round = df1.join(
    df2,
    (df1.w_region == df2.l_region)
    & (
        (df1.w_seed == 1) & (df2.l_seed == 16) |
        (df1.w_seed == 8) & (df2.l_seed == 9) |
        (df1.w_seed == 5) & (df2.l_seed == 12) |
        (df1.w_seed == 4) & (df2.l_seed == 13) |
        (df1.w_seed == 6) & (df2.l_seed == 11) |
        (df1.w_seed == 3) & (df2.l_seed == 14) |
        (df1.w_seed == 7) & (df2.l_seed == 10) |
        (df1.w_seed == 2) & (df2.l_seed == 15) 

    ),
    "inner",
)

In [25]:
result = all_rounds.predict(first_round).select(
    "l_teamID", "l_seed", "l_region", "w_teamid", "w_seed", "w_region", "PRED_SPREAD"
)

In [26]:
teams = session.table("m_teams")
res_teamsl = result.join(
    teams, result.col("L_TEAMID") == teams.col("TEAMID")
).with_column_renamed("teamname", "l_team_name")
res_teamsl = res_teamsl.cache_result()

round_1_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_seed",
        "w_region",
        "PRED_SPREAD",
    )
).distinct()
print(round_1_results.count())

round_1_results.select(
    F.col("w_region").alias("REGION"),
    F.col('L_SEED').alias('SEED2'),
    F.col("L_TEAM_NAME").alias("TEAM1"),
    F.col('W_SEED').alias('SEED1'),
    F.col("W_TEAM_NAME").alias("TEAM2"),
    "PRED_SPREAD",
).sort("REGION").show(32)

32
----------------------------------------------------------------------------------------
|"REGION"  |"SEED2"  |"TEAM1"         |"SEED1"  |"TEAM2"         |"PRED_SPREAD"        |
----------------------------------------------------------------------------------------
|W         |13       |Yale            |4        |Auburn          |11.611824989318848   |
|W         |14       |Morehead St     |3        |Illinois        |2.011875867843628    |
|W         |15       |S Dakota St     |2        |Iowa St         |17.530385971069336   |
|W         |9        |Northwestern    |8        |FL Atlantic     |-4.187599182128906   |
|W         |16       |Stetson         |1        |Connecticut     |22.989835739135746   |
|W         |11       |Duquesne        |6        |BYU             |7.984871864318848    |
|W         |10       |Drake           |7        |Washington St   |1.2472617626190186   |
|W         |12       |UAB             |5        |San Diego St    |3.252487897872925    |
|X         |13    

In [27]:
w_teams = [
    "Connecticut",
    "San Diego St",
    "Washington St",
    "Iowa St",
    "Illinois",
    "Dayton",
    "Arizona",
    "North Carolina",
    "Creighton",
    "Tennessee",
    "Texas",
    "Gonzaga",
    "Kansas",
    "Northwestern",
    "Baylor",
    "Clemson",
    "Alabama",
    "Purdue",
    "Texas",
    "Gonzaga",
    "Utah St ",
    "Duke",
    "Boise St",
    "Houston",
    "Marquette",
    "Utah St",
    "Washington St"
]  # Example list of teams

# Create the new column "winner"
round_1_results = round_1_results.withColumn(
    "WINNER", when(col("W_TEAM_NAME").isin(w_teams), 1).otherwise(0)
)

round_1_results.show(32)

----------------------------------------------------------------------------------------------------------------------------------------------
|"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"   |"L_SEED"  |"L_REGION"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |"PRED_SPREAD"        |"WINNER"  |
----------------------------------------------------------------------------------------------------------------------------------------------
|1158        |1104        |Col Charleston  |13        |X           |Alabama         |4         |X           |5.892940998077393    |1         |
|1463        |1120        |Yale            |13        |W           |Auburn          |4         |W           |11.611824989318848   |0         |
|1159        |1124        |Colgate         |14        |X           |Baylor          |3         |X           |11.083738327026367   |1         |
|1182        |1140        |Duquesne        |11        |W           |BYU             |6         |W           |7.984871864318848    |0         |

In [28]:
round_1_winners = round_1_results.with_column(
    "W_TEAMID",
    when((round_1_results.WINNER == 1), round_1_results.w_teamid)
    .otherwise(round_1_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((round_1_results.WINNER == 1), round_1_results.w_team_name)
    .otherwise(round_1_results.l_team_name),
).with_column(
    "W_SEED",
    when((round_1_results.WINNER == 1), round_1_results.w_seed)
    .otherwise(round_1_results.l_seed),
).with_column(
    "W_REGION",
    when((round_1_results.WINNER == 1), round_1_results.w_region)
    .otherwise(round_1_results.l_region)
).select("W_TEAMID","W_TEAM_NAME","W_SEED","W_REGION")

round_1_winners.sort('W_REGION').show(32)

-------------------------------------------------------
|"W_TEAMID"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1463        |Yale            |13        |W           |
|1228        |Illinois        |3         |W           |
|1235        |Iowa St         |2         |W           |
|1321        |Northwestern    |9         |W           |
|1163        |Connecticut     |1         |W           |
|1182        |Duquesne        |11        |W           |
|1450        |Washington St   |7         |W           |
|1361        |San Diego St    |5         |W           |
|1104        |Alabama         |4         |X           |
|1112        |Arizona         |2         |X           |
|1155        |Clemson         |6         |X           |
|1173        |Dayton          |7         |X           |
|1213        |Grand Canyon    |12        |X           |
|1314        |North Carolina  |1         |X           |
|1124        |Baylor          |3         |X     

In [29]:
df1 = round_1_winners.cache_result()
df2 = round_1_winners.cache_result()

second_round_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([8,9]))
        | (df1.w_seed == 16) & (df2.w_seed.isin([8,9]))
        | (df1.w_seed == 4) & (df2.w_seed.isin([5,12]))
        | (df1.w_seed == 13) & (df2.w_seed.isin([5,12]))
        | (df1.w_seed == 3) & (df2.w_seed.isin([6,11]))
        | (df1.w_seed == 14) & (df2.w_seed.isin([6,11]))
        | (df1.w_seed == 2) & (df2.w_seed.isin([7,10]))
        | (df1.w_seed == 15) & (df2.w_seed.isin([7,10]))
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAMID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAMID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

print(second_round_matchups.count())

second_round_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(32)

16
-------------------------------------------------------------------
|"TEAM_1"        |"TEAM_2"       |"SEED_1"  |"SEED_2"  |"REGION"  |
-------------------------------------------------------------------
|Yale            |San Diego St   |13        |5         |W         |
|Iowa St         |Washington St  |2         |7         |W         |
|Illinois        |Duquesne       |3         |11        |W         |
|Connecticut     |Northwestern   |1         |9         |W         |
|Arizona         |Dayton         |2         |7         |X         |
|Alabama         |Grand Canyon   |4         |12        |X         |
|North Carolina  |Michigan St    |1         |9         |X         |
|Baylor          |Clemson        |3         |6         |X         |
|Kansas          |Gonzaga        |4         |5         |Y         |
|Tennessee       |Texas          |2         |7         |Y         |
|Creighton       |Oregon         |3         |11        |Y         |
|Purdue          |Utah St        |1         |

In [30]:
second_round = second_round_matchups.select("WTEAMID2", "LTEAMID2", "W_SEED", "L_SEED")
print(second_round.count())

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col("SEASON") == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = second_round.join(season_w, season_w.w_teamid == second_round.wTEAMID2)
print(df1.count())

df2 = df1.join(season_l, season_l.l_teamid == second_round.lTEAMID2).distinct()
print(df2.count())
games_32 = df2.drop("WTEAMID2", "lTEAMID2")

games_32.count()

16
16
16


16

In [31]:
result = all_rounds.predict(games_32)

In [32]:
teams = session.table('m_teams')
res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
)
res_teamsl = res_teamsl.cache_result()

round_1_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select('l_teamid','w_teamid','L_TEAM_NAME','l_seed','l_region','W_TEAM_NAME','w_seed','w_region','PRED_SPREAD')
)

# Round 2 game predicted spreads

In [33]:
round_1_results.select(
    F.col("w_region").alias("REGION"),
    "L_TEAM_NAME",
    "l_seed",
    "W_TEAM_NAME",
    "w_seed",
    "PRED_SPREAD",
).sort("REGION").show(32)

------------------------------------------------------------------------------------------
|"REGION"  |"L_TEAM_NAME"  |"L_SEED"  |"W_TEAM_NAME"   |"W_SEED"  |"PRED_SPREAD"         |
------------------------------------------------------------------------------------------
|W         |San Diego St   |5         |Yale            |13        |-5.607921123504639    |
|W         |Duquesne       |11        |Illinois        |3         |-0.08841568976640701  |
|W         |Northwestern   |9         |Connecticut     |1         |9.200010299682615     |
|W         |Washington St  |7         |Iowa St         |2         |3.20786190032959      |
|X         |Clemson        |6         |Baylor          |3         |4.083593368530273     |
|X         |Grand Canyon   |12        |Alabama         |4         |6.477607727050781     |
|X         |Dayton         |7         |Arizona         |2         |6.6120781898498535    |
|X         |Michigan St    |9         |North Carolina  |1         |8.157055854797363     |

# Round 2 predict wins

In [34]:
train_2 = train_2.withColumn(
    "WIN_INDICATOR", F.when(train_2["W_SCORE"] > train_2["L_SCORE"], 1).otherwise(0)
)

In [71]:
all_rounds_class = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid=parameters,
    n_jobs=-1,
    scoring="accuracy",
    input_cols=train_2.drop(
        [
            "SEASON",
            'SPREAD',
            "WIN_INDICATOR",
            "l_teamid",
            "w_teamid",
            "w_score",
            "l_score",
            "round",
            "l_region",
            "w_region",
        ]
    ).columns,
    label_cols="WIN_INDICATOR",
    output_cols="PRED_WIN_INDICATOR",
)

# Train
all_rounds_class.fit(train_2.filter(F.col("ROUND") == 3))

Package 'fastparquet' is not installed in the local environment. Your UDF might not work when the package is installed on the server but not on your local environment.
The version of package 'pyarrow' in the local environment is 15.0.1, which does not fit the criteria for the requirement 'pyarrow<14'. Your UDF might not work when the package version is different between the server and your local environment.
The version of package 'cachetools' in the local environment is 5.3.3, which does not fit the criteria for the requirement 'cachetools<6'. Your UDF might not work when the package version is different between the server and your local environment.


<snowflake.ml.modeling.model_selection.grid_search_cv.GridSearchCV at 0x7fb2f1b38610>

In [None]:
session.use_warehouse('wh_xs')

In [72]:
result = all_rounds_class.predict(games_32)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_2_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [73]:
result = all_rounds_class.predict_proba(games_32)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        F.col('"predict_proba_1"').alias('W_PROB')
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_2_results_proba = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        F.col("l_teamid").alias('teaml'),
        F.col("w_teamid").alias('teamw'),
        'W_PROB',
    )
)

# JOE START ANALYZINF HERE
Here are the results of the classification cross check them with the spreads and see if there are any differences in who they chose as the winner

In [74]:
round_2_results.sort('W_REGION').show(16)

-----------------------------------------------------------------------------------------------------------------------------------
|"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"  |"L_SEED"  |"W_SEED"  |"L_REGION"  |"W_TEAM_NAME"   |"W_REGION"  |"PRED_WIN_INDICATOR"  |
-----------------------------------------------------------------------------------------------------------------------------------
|1361        |1463        |San Diego St   |5         |13        |W           |Yale            |W           |1.0                   |
|1182        |1228        |Duquesne       |11        |3         |W           |Illinois        |W           |1.0                   |
|1321        |1163        |Northwestern   |9         |1         |W           |Connecticut     |W           |1.0                   |
|1450        |1235        |Washington St  |7         |2         |W           |Iowa St         |W           |1.0                   |
|1155        |1124        |Clemson        |6         |3         |X          

### Here are the Spread predictions, with classification probabilities (2 different models)

In [75]:
round_2_results = round_1_results.select(
    "w_region",
    "l_region",
    "l_teamid",
    "w_teamid",
    "L_TEAM_NAME",
    "l_seed",
    "W_TEAM_NAME",
    "w_seed",
    "PRED_SPREAD",
).withColumn(
    "PRED_WIN_INDICATOR", F.when(round_1_results["PRED_SPREAD"] > 0, 1).otherwise(0)
).join(
    round_2_results_proba,
    (
        round_1_results.col('l_teamid') == round_2_results_proba.col('teaml')) & (round_1_results.col('w_teamid') == round_2_results_proba.col('teamw')
    )
).drop('teamw','teaml').sort(
    "W_REGION"
).sort(F.col('W_PROB').desc())

In [76]:
round_2_results.show()

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"W_REGION"  |"L_REGION"  |"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"  |"L_SEED"  |"W_TEAM_NAME"   |"W_SEED"  |"PRED_SPREAD"         |"PRED_WIN_INDICATOR"  |"W_PROB"            |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|W           |W           |1321        |1163        |Northwestern   |9         |Connecticut     |1         |9.200010299682615     |1                     |0.9622525572776794  |
|Y           |Y           |1400        |1397        |Texas          |7         |Tennessee       |2         |1.66803777217865      |1                     |0.9568915367126464  |
|X           |X           |1173        |1112        |Dayton         |7         |Arizona         |2         |6.6120781898

I only see Creighton and Oregon are different in the two models, I like Creighton and <br>
honestly the spread model if you look at what it predicted actually outperformed the classification.  Take this and work your magic.

## STOP HERE FOR NOW

In [77]:
w_teams = [
    "Houston",
    "Connecticut",
    "Alabama",
    "Arizona",
    "Clemson",
    "Iowa St",
    "Purdue",
    "North Carolina",
    "Creighton",
    "Tennessee",
    "Marquette",
    "Duke",
    "Gonzaga",
    "San Diego St",
    "Illinois",
    "NC State",
]  # Example list of teams

# Create the new column "winner"
round_2_results = round_2_results.withColumn(
    "WINNER", when(col("W_TEAM_NAME").isin(w_teams), 1).otherwise(0)
)

round_2_results.show(32)

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"W_REGION"  |"L_REGION"  |"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"  |"L_SEED"  |"W_TEAM_NAME"   |"W_SEED"  |"PRED_SPREAD"         |"PRED_WIN_INDICATOR"  |"W_PROB"            |"WINNER"  |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|W           |W           |1321        |1163        |Northwestern   |9         |Connecticut     |1         |9.200010299682615     |1                     |0.9622525572776794  |1         |
|Y           |Y           |1400        |1397        |Texas          |7         |Tennessee       |2         |1.66803777217865      |1                     |0.9568915367126464  |1         |
|X           |X           |1173        |1112        |Dayton      

# SWEET SIXTEEN

In [78]:
round_2_winners = round_2_results.with_column(
    "W_TEAMID",
    when((round_2_results.WINNER == 1), round_2_results.w_teamid)
    .otherwise(round_2_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((round_2_results.WINNER == 1), round_2_results.w_team_name)
    .otherwise(round_2_results.l_team_name),
).with_column(
    "W_SEED",
    when((round_2_results.WINNER == 1), round_2_results.w_seed)
    .otherwise(round_2_results.l_seed),
).with_column(
    "W_REGION",
    when((round_2_results.WINNER == 1), round_2_results.w_region)
    .otherwise(round_2_results.l_region)
).select("W_TEAMID","W_TEAM_NAME","W_SEED","W_REGION")

round_2_winners.sort('W_REGION').show(32)

-------------------------------------------------------
|"W_TEAMID"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1163        |Connecticut     |1         |W           |
|1361        |San Diego St    |5         |W           |
|1235        |Iowa St         |2         |W           |
|1228        |Illinois        |3         |W           |
|1104        |Alabama         |4         |X           |
|1112        |Arizona         |2         |X           |
|1314        |North Carolina  |1         |X           |
|1155        |Clemson         |6         |X           |
|1345        |Purdue          |1         |Y           |
|1397        |Tennessee       |2         |Y           |
|1166        |Creighton       |3         |Y           |
|1211        |Gonzaga         |5         |Y           |
|1301        |NC State        |11        |Z           |
|1266        |Marquette       |2         |Z           |
|1181        |Duke            |4         |Z     

In [79]:
round_2_winners.sort('W_REGION').show(20)

-------------------------------------------------------
|"W_TEAMID"  |"W_TEAM_NAME"   |"W_SEED"  |"W_REGION"  |
-------------------------------------------------------
|1228        |Illinois        |3         |W           |
|1163        |Connecticut     |1         |W           |
|1361        |San Diego St    |5         |W           |
|1235        |Iowa St         |2         |W           |
|1104        |Alabama         |4         |X           |
|1314        |North Carolina  |1         |X           |
|1112        |Arizona         |2         |X           |
|1155        |Clemson         |6         |X           |
|1397        |Tennessee       |2         |Y           |
|1345        |Purdue          |1         |Y           |
|1166        |Creighton       |3         |Y           |
|1211        |Gonzaga         |5         |Y           |
|1181        |Duke            |4         |Z           |
|1266        |Marquette       |2         |Z           |
|1222        |Houston         |1         |Z     

In [86]:
df1 = round_2_winners.cache_result()
df2 = round_2_winners.cache_result()

third_round_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 16) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 8) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 9) & (df2.w_seed.isin([5, 12, 4, 13])) |
        (df1.w_seed == 6) & (df2.w_seed.isin([7,10,2,15])) |
        (df1.w_seed == 11) & (df2.w_seed.isin([7,10,2,15])) |
        (df1.w_seed == 3) & (df2.w_seed.isin([7,10,2,15]))|
        (df1.w_seed == 14) & (df2.w_seed.isin([7,10,2,15]))
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAMID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAMID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

third_round_matchups.sort("W_REGION").show(20)

------------------------------------------------------------------------------------------------------------
|"W_TEAM_NAME"   |"W_REGION"  |"WTEAMID2"  |"W_SEED"  |"L_TEAM_NAME"  |"L_REGION"  |"LTEAMID2"  |"L_SEED"  |
------------------------------------------------------------------------------------------------------------
|Connecticut     |W           |1163        |1         |San Diego St   |W           |1361        |5         |
|Illinois        |W           |1228        |3         |Iowa St        |W           |1235        |2         |
|North Carolina  |X           |1314        |1         |Alabama        |X           |1104        |4         |
|Clemson         |X           |1155        |6         |Arizona        |X           |1112        |2         |
|Purdue          |Y           |1345        |1         |Gonzaga        |Y           |1211        |5         |
|Creighton       |Y           |1166        |3         |Tennessee      |Y           |1397        |2         |
|NC State        |Z

In [90]:
third_round = third_round_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = third_round.join(season_w, season_w.w_teamid == third_round.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == third_round.lTEAMID2)

games_16 = df2.drop('WTEAMID2','lTEAMID2')

games_16 = games_16.cache_result()
games_16.count()

8

In [96]:
result = all_rounds.predict(games_16)
result.show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [97]:
result = all_rounds.predict(games_16)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "PRED_SPREAD",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_2_results_m2 = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_SPREAD",
    )
)

In [98]:
round_2_results_m2.show()

----------------------------------------------------------------------------------------------------------------------------------
|"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"  |"L_SEED"  |"W_SEED"  |"L_REGION"  |"W_TEAM_NAME"   |"W_REGION"  |"PRED_SPREAD"        |
----------------------------------------------------------------------------------------------------------------------------------
|1112        |1155        |Arizona        |2         |6         |X           |Clemson         |X           |-11.12165069580078   |
|1361        |1163        |San Diego St   |5         |1         |W           |Connecticut     |W           |6.550592422485352    |
|1397        |1166        |Tennessee      |2         |3         |Y           |Creighton       |Y           |-6.730685710906982   |
|1181        |1222        |Duke           |4         |1         |Z           |Houston         |Z           |5.990898609161377    |
|1235        |1228        |Iowa St        |2         |3         |W           |Illin

In [99]:
result = all_rounds_class.predict_proba(games_16)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        F.col('"predict_proba_1"').alias('W_PROB')
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

round_2_results_proba_m2 = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        F.col("l_teamid").alias('teaml'),
        F.col("w_teamid").alias('teamw'),
        'W_PROB',
    )
)

In [100]:
round_2_results_proba.show()

------------------------------------------
|"TEAML"  |"TEAMW"  |"W_PROB"            |
------------------------------------------
|1213     |1104     |0.6433817744255066  |
|1173     |1112     |0.9471816420555116  |
|1155     |1124     |0.5533120036125183  |
|1321     |1163     |0.9622525572776794  |
|1332     |1166     |0.804239809513092   |
|1241     |1181     |0.2349315583705902  |
|1401     |1222     |0.6673325300216675  |
|1182     |1228     |0.9383527636528016  |
|1450     |1235     |0.6643685102462769  |
|1211     |1242     |0.6586493253707886  |
------------------------------------------



In [101]:
round_2_results_m2 = round_2_results_m2.select(
    "w_region",
    "l_region",
    "l_teamid",
    "w_teamid",
    "L_TEAM_NAME",
    "l_seed",
    "W_TEAM_NAME",
    "w_seed",
    "PRED_SPREAD",
).withColumn(
    "PRED_WIN_INDICATOR", F.when(round_1_results["PRED_SPREAD"] > 0, 1).otherwise(0)
).join(
    round_2_results_proba_m2,
    (
        round_1_results.col('l_teamid') == round_2_results_proba.col('teaml')) & (round_1_results.col('w_teamid') == round_2_results_proba.col('teamw')
    )
).drop('teamw','teaml').sort(
    "W_REGION"
).sort(F.col('W_PROB').desc())

In [105]:
round_2_results.show()

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"W_REGION"  |"L_REGION"  |"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"  |"L_SEED"  |"W_TEAM_NAME"   |"W_SEED"  |"PRED_SPREAD"         |"PRED_WIN_INDICATOR"  |"W_PROB"            |"WINNER"  |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|W           |W           |1321        |1163        |Northwestern   |9         |Connecticut     |1         |9.200010299682615     |1                     |0.9622525572776794  |1         |
|Y           |Y           |1400        |1397        |Texas          |7         |Tennessee       |2         |1.66803777217865      |1                     |0.9568915367126464  |1         |
|X           |X           |1173        |1112        |Dayton      

In [102]:
round_2_results_m2 = round_2_results_m2.select(*[F.col(col).alias(f"{col}_model2") for col in round_2_results_m2.columns])

In [106]:
round_2_results_m2.sort(F.col('W_PROB').desc()).show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"W_REGION_MODEL2"  |"L_REGION_MODEL2"  |"L_TEAMID_MODEL2"  |"W_TEAMID_MODEL2"  |"L_TEAM_NAME_MODEL2"  |"L_SEED_MODEL2"  |"W_TEAM_NAME_MODEL2"  |"W_SEED_MODEL2"  |"PRED_SPREAD_MODEL2"  |"PRED_WIN_INDICATOR_MODEL2"  |"W_PROB_MODEL2"      |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|W                  |W                  |1361               |1163               |San Diego St          |5                |Connecticut           |1                |6.550592422485352     |1                            |0.937250316143036    |
|X                  |X                  |110

In [103]:
round_2_results.join(
    round_2_results_m2,
    (round_2_results.col("l_teamid") == round_2_results_m2.col("l_teamid_model2"))
    & (round_2_results.col("w_teamid") == round_2_results_m2.col("w_teamid_model2")),
).show()

-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"W_REGION"  |"L_REGION"  |"L_TEAMID"  |"W_TEAMID"  |"L_TEAM_NAME"  |"L_SEED"  |"W_TEAM_NAME"  |"W_SEED"  |"PRED_SPREAD"  |"PRED_WIN_INDICATOR"  |"W_PROB"  |"WINNER"  |"W_REGION_MODEL2"  |"L_REGION_MODEL2"  |"L_TEAMID_MODEL2"  |"W_TEAMID_MODEL2"  |"L_TEAM_NAME_MODEL2"  |"L_SEED_MODEL2"  |"W_TEAM_NAME_MODEL2"  |"W_SEED_MODEL2"  |"PRED_SPREAD_MODEL2"  |"PRED_WIN_INDICATOR_MODEL2"  |"W_PROB_MODEL2"  |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
result = all_rounds.predict(games_16)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

sweet_16_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [None]:
sweet_16_winners = sweet_16_results.with_column(
    "W_TEAM_ID",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_teamid)
    .otherwise(sweet_16_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_team_name)
    .otherwise(sweet_16_results.l_team_name),
).with_column(
    "W_SEED",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_seed)
    .otherwise(sweet_16_results.l_seed),
).with_column(
    "W_REGION",
    when((sweet_16_results.PRED_WIN_INDICATOR == 1), sweet_16_results.w_region)
    .otherwise(sweet_16_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

## ELITE 8

In [None]:
sweet_16_winners.sort('W_REGION').show(20)

In [None]:
df1 = sweet_16_winners.cache_result()
df2 = sweet_16_winners.cache_result()

elite_8_matchups = df1.join(
    df2,
    (df1.w_region == df2.w_region)
    & (
        (df1.w_seed == 1) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 16) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 8) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 9) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 5) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 12) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 4) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) |
        (df1.w_seed == 13) & (df2.w_seed.isin([6,11,3,14,7,10,2,15])) 
    ),
    "inner",
).select(
    (df1.W_TEAM_NAME).alias("W_team_name"),
    (df1.W_REGION).alias("w_region"),
    (df1.W_TEAM_ID).alias("WTeamID2"),
    (df1.W_SEED).alias("w_seed"),
    (df2.W_TEAM_NAME).alias("l_Team_name"),
    (df2.W_REGION).alias("l_region"),
    (df2.W_TEAM_ID).alias("lteamID2"),
    (df2.W_SEED).alias("l_seed"),
)

# ELITE 8

In [None]:
elite_8_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

In [None]:
elite_8 = elite_8_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = elite_8.join(season_w, season_w.w_teamid == elite_8.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == elite_8.lTEAMID2)

elite_8 = df2.drop('WTEAMID2','lTEAMID2')

In [None]:
result = all_rounds.predict(elite_8)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

elite_8_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [None]:
elite_8_winners = elite_8_results.with_column(
    "W_TEAM_ID",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_teamid)
    .otherwise(elite_8_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_team_name)
    .otherwise(elite_8_results.l_team_name),
).with_column(
    "W_SEED",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_seed)
    .otherwise(elite_8_results.l_seed),
).with_column(
    "W_REGION",
    when((elite_8_results.PRED_WIN_INDICATOR == 1), elite_8_results.w_region)
    .otherwise(elite_8_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

In [None]:
df1 = elite_8_winners.filter(F.col("W_REGION") == "W")

df2 = elite_8_winners.filter(F.col("W_REGION") == "X").select(
    F.col("W_TEAM_ID").alias("TEAM_2"),
    F.col("W_TEAM_NAME").alias("W_TEAM_NAME_2"),
    F.col("W_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION_2"),
)

df3 = df1.join(df2)

df4 = elite_8_winners.filter(F.col("W_REGION") == "Y")

df5 = elite_8_winners.filter(F.col("W_REGION") == "Z").select(
    F.col("W_TEAM_ID").alias("TEAM_2"),
    F.col("W_TEAM_NAME").alias("W_TEAM_NAME_2"),
    F.col("W_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION_2"),
)

df6 = df4.join(df5)

final_four_matchups = df3.union(df6).select(
    F.col('W_TEAM_NAME'),
    F.col('W_REGION'),
    F.col("W_TEAM_ID").alias("WTEAMID2"),
    F.col('W_SEED'),
    F.col("W_TEAM_NAME_2").alias("L_TEAM_NAME"),
    F.col("REGION_2").alias("L_REGION"),
    F.col("TEAM_2").alias("LTEAMID2"),
    F.col("SEED_2").alias("L_SEED"),
)

# FINAL FOUR

In [None]:
final_four_matchups.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show()

In [None]:
final_four = final_four_matchups.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = final_four.join(season_w, season_w.w_teamid == final_four.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == final_four.lTEAMID2)

final_four = df2.drop('WTEAMID2','lTEAMID2')

final_four.count()

In [None]:
result = all_rounds.predict(final_four)

teams = session.table("m_teams")


result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

final_four_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [None]:
final_four_winners = final_four_results.with_column(
    "W_TEAM_ID",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_teamid)
    .otherwise(final_four_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_team_name)
    .otherwise(final_four_results.l_team_name),
).with_column(
    "W_SEED",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_seed)
    .otherwise(final_four_results.l_seed),
).with_column(
    "W_REGION",
    when((final_four_results.PRED_WIN_INDICATOR == 1), final_four_results.w_region)
    .otherwise(final_four_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

In [None]:
df1 = final_four_winners.cache_result()
df2 = final_four_winners.cache_result()

championship_matchup = df1.join(df2).select(
    df1.w_team_name.alias('W_TEAM_NAME'),
    df1.w_region.alias('W_REGION'),
    df1.w_team_id.alias('WTEAMID2'),
    df1.w_seed.alias('W_SEED'),
    df2.w_team_name.alias('L_TEAM_NAME'),
    df2.w_region.alias('L_REGION'),
    df2.w_team_id.alias('LTEAMID2'),
    df2.w_seed.alias('L_SEED'),
).filter(F.col("WTEAMID2") != F.col("LTEAMID2")).limit(1)

# CHAMPIONSHIP

In [None]:
championship_matchup.select(
    F.col("W_TEAM_NAME").alias("TEAM_1"),
    F.col("L_TEAM_NAME").alias("TEAM_2"),
    F.col("W_SEED").alias("SEED_1"),
    F.col("L_SEED").alias("SEED_2"),
    F.col("W_REGION").alias("REGION"),
).sort("REGION").show(20)

In [None]:
championship = championship_matchup.select('WTEAMID2','LTEAMID2','W_SEED','L_SEED')

season = session.table("MEN.FINAL_SEASON_STATS").filter(F.col('SEASON') == 2024)

season_w = season.select(*[F.col(col).alias(f"W_{col}") for col in season.columns])

season_l = season.select(*[F.col(col).alias(f"L_{col}") for col in season.columns])

df1 = championship.join(season_w, season_w.w_teamid == championship.wTEAMID2)

df2 = df1.join(season_l, season_l.l_teamid == championship.lTEAMID2)

championship = df2.drop('WTEAMID2','lTEAMID2')

championship.count()

In [None]:
result = all_rounds.predict(championship)

teams = session.table("m_teams")

result = (
    result.select(
        "l_teamID",
        "l_seed",
        "w_seed",
        "l_region",
        "w_teamid",
        "w_region",
        "pred_win_indicator",
    )
)

res_teamsl = (
    result.join(teams, result.col("L_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "l_team_name")
    .drop("firstd1season", "lastd1season")
)
res_teamsl = res_teamsl.cache_result()

championship_results = (
    res_teamsl.join(teams, result.col("W_TEAMID") == teams.col("TEAMID"))
    .with_column_renamed("teamname", "w_team_name")
    .select(
        "l_teamid",
        "w_teamid",
        "L_TEAM_NAME",
        "l_seed",
        "w_seed",
        "l_region",
        "W_TEAM_NAME",
        "w_region",
        "PRED_WIN_INDICATOR",
    )
)

In [None]:
all_rounds.predict(championship).select(
    "L_SCORE_MEDIAN", "W_SCORE_MEDIAN", "L_SCORE_MEAN", "W_SCORE_MEAN"
).show()

In [None]:
champion = championship_results.with_column(
    "W_TEAM_ID",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_teamid)
    .otherwise(championship_results.l_teamID),
).with_column(
    "W_TEAM_NAME",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_team_name)
    .otherwise(championship_results.l_team_name),
).with_column(
    "W_SEED",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_seed)
    .otherwise(championship_results.l_seed),
).with_column(
    "W_REGION",
    when((championship_results.PRED_WIN_INDICATOR == 1), championship_results.w_region)
    .otherwise(championship_results.l_region)
).select("W_TEAM_ID","W_TEAM_NAME","W_SEED","W_REGION")

# CHAMPION

In [None]:
champion.show()

# Use Average and Median for predicting score

In [None]:
all_rounds.predict(championship).select(
    "L_SCORE_MEDIAN", "W_SCORE_MEDIAN", "L_SCORE_MEAN", "W_SCORE_MEAN"
).show()