In [52]:
from ibis.interactive import *
from snowflake.snowpark import Session, functions as F, types as T
from snowflake.snowpark.functions import when, lit
from snowflake.ml.modeling.preprocessing import OneHotEncoder

import os

con = ibis.snowflake.connect()

In [53]:
def combine_stats(con, table_name: str, schema: str):
    t = con.table(table_name, schema=schema)
    t = t.mutate(W1COLUMN=_.WSCORE)
    t = t.mutate(W2COLUMN=_.LSCORE)
    winning = t.select("SEASON", "DAYNUM", s.startswith("W")).mutate(WON=1)
    winning = winning.rename(
        {
            c[1:]: c
            for c in winning.columns
            if c.startswith("W") and c not in ["WON", "WLOC", "W1COLUMN", "W2COLUMN"]
        }
    )

    losing = t.select(
        "SEASON", "DAYNUM", "WLOC", "W1COLUMN", "W2COLUMN", s.startswith("L")
    ).mutate(WON=0)
    losing = losing.rename(
        {
            c[1:]: c
            for c in losing.columns
            if c.startswith("L") and c not in ["WON", "WLOC"]
        }
    )
    return winning.union(losing)


def flatten_regions(con, table_name: str, schema: str):
    flattened_regions = (
        con.table(table_name, schema=schema)
        .pivot_longer(s.startswith("Region"))
        .rename({"Region": "name", "RegionName": "value"})
        .mutate(Region=_.Region.replace("Region", ""))
        .drop("DayZero")
    )
    return flattened_regions

In [54]:
m_reg = combine_stats(con, "M_REGULAR_SEASON_DETAILED_RESULTS", schema="MEN")

In [55]:
m_reg

In [56]:
w_margin = (
    m_reg.filter(_.WON == 1)
    .mutate(SCOREDIFF=_.W1COLUMN - _.W2COLUMN)
    .group_by(["SEASON", "TEAMID"])
    .agg(WINMARGINMEDIAN=_.SCOREDIFF.median(), WINMARGINMEAN=_.SCOREDIFF.mean())
)

l_margin = (
    m_reg.filter(_.WON == 0)
    .mutate(SCOREDIFF=_.W1COLUMN - _.W2COLUMN)
    .group_by(["SEASON", "TEAMID"])
    .agg(LOSEMARGINMEDIAN=_.SCOREDIFF.median(), LOSEMARGINMEAN=_.SCOREDIFF.mean())
)

m_season_margin = w_margin.join(l_margin, (["SEASON", "TEAMID"]))
m_reg = m_reg.drop(['W1COLUMN','W2COLUMN'])

In [57]:
season_stats = (
    m_reg.drop("DAYNUM")
    .group_by(["SEASON", "TEAMID"])
    .agg(s.across(s.numeric(), dict(MEAN=_.mean(), MEDIAN=_.median(), STDDEV=_.std())))
    .drop(s.startswith("WON_"), s.startswith("SEASON_"), s.startswith("TEAMID_"))
)

In [58]:
hna = (
    m_reg.group_by(["SEASON", "TEAMID", "WLOC"])
    .agg(WINCOUNT=_.WON.sum())
    .mutate(WLOC="WLOC" + _.WLOC)
    .pivot_wider(names_from="WLOC", values_from="WINCOUNT")
    .mutate(s.across(s.startswith("WLOC"), ibis.coalesce(_, 0)))
)

In [59]:
season_joined = (
    season_stats.join(hna, ["SEASON", "TEAMID"])
    .join(m_season_margin, ["SEASON", "TEAMID"])
).drop(s.endswith("_right")).distinct()

season_joined =season_joined.fillna(0)

In [60]:
season_joined.filter(_.SEASON == 2021, _.TEAMID == 1288)

In [61]:
conf_wins = (
    con.table("M_CONFERENCE_TOURNEY_GAMES")
    .mutate(
        ROWNUM=ibis.row_number().over(
            group_by=["SEASON", "CONFABBREV"], order_by=_.DAYNUM.desc()
        )
    )
    .filter(_.ROWNUM == 0)
    .drop(["DAYNUM", "ROWNUM", "LTEAMID", "CONFABBREV"])
    .mutate(WON_CONFERENCE=1)
    .rename({"TEAMID": "WTEAMID"})
)

final = (
    season_joined.join(conf_wins, ["SEASON", "TEAMID"], how="left")
    .mutate(WON_CONFERENCE=_.WON_CONFERENCE.fillna(0))
    .drop(s.endswith("_right"))
    .mutate(TOTAL_WINS=_.WLOCN + _.WLOCH + _.WLOCA)
)

In [62]:
final.filter(_.SEASON == 2024, _.TEAMID == 1314)

In [63]:
# This is super hacky, but I need to be able to use the same session to share cached tables.
# I also want to avoid this message: SnowparkSessionException: (1409): More than one active session is detected. When you call...

@classmethod
def from_ibis(self, con) -> Session:
    return Session.builder.config("connection", con.con).getOrCreate()

Session.from_ibis = from_ibis

session = Session.from_ibis(con)

In [64]:
season = session.sql(ibis.to_sql(final))

In [65]:
seeds = session.table('MEN.M_NCAATOURNEY_SEEDS')
seeds.show()

--------------------------------
|"SEASON"  |"SEED"  |"TEAMID"  |
--------------------------------
|1985      |W01     |1207      |
|1985      |W02     |1210      |
|1985      |W03     |1228      |
|1985      |W04     |1260      |
|1985      |W05     |1374      |
|1985      |W06     |1208      |
|1985      |W07     |1393      |
|1985      |W08     |1396      |
|1985      |W09     |1439      |
|1985      |W10     |1177      |
--------------------------------



### Kaggle had the play in games wrongso lets replace them

In [67]:
seeds = seeds.withColumn(
    "TEAMID",
    when((seeds["TEAMID"] == 1129) & (seeds["SEASON"] == 2024), 1160).otherwise(seeds["TEAMID"]),
)

In [68]:
seed_value = (
    seeds
    .with_column("REGION", F.substring(F.col("SEED"), 1, 1))
    .with_column(
        "SEED", F.substring(F.col("SEED"), 2, F.length(F.col("SEED")) - 1)
    )
    .select("SEASON", "TEAMID", "REGION", "SEED")
    .with_column(
        "SEED",
        F.cast(F.regexp_replace(F.col("SEED"), "[a-z]", ""), T.IntegerType()),
    )
)

seed_value.show()

-------------------------------------------
|"SEASON"  |"TEAMID"  |"REGION"  |"SEED"  |
-------------------------------------------
|1985      |1207      |W         |1       |
|1985      |1210      |W         |2       |
|1985      |1228      |W         |3       |
|1985      |1260      |W         |4       |
|1985      |1374      |W         |5       |
|1985      |1208      |W         |6       |
|1985      |1393      |W         |7       |
|1985      |1396      |W         |8       |
|1985      |1439      |W         |9       |
|1985      |1177      |W         |10      |
-------------------------------------------



In [69]:
tourney = session.table('MEN.M_NCAATOURNEY_COMPACT_RESULTS')
tourney = tourney.select('SEASON','WTEAMID','LTEAMID','WSCORE','LSCORE','DAYNUM')
tourney.show()

---------------------------------------------------------------------
|"SEASON"  |"WTEAMID"  |"LTEAMID"  |"WSCORE"  |"LSCORE"  |"DAYNUM"  |
---------------------------------------------------------------------
|1985      |1116       |1234       |63        |54        |136       |
|1985      |1120       |1345       |59        |58        |136       |
|1985      |1207       |1250       |68        |43        |136       |
|1985      |1229       |1425       |58        |55        |136       |
|1985      |1242       |1325       |49        |38        |136       |
|1985      |1246       |1449       |66        |58        |136       |
|1985      |1256       |1338       |78        |54        |136       |
|1985      |1260       |1233       |59        |58        |136       |
|1985      |1314       |1292       |76        |57        |136       |
|1985      |1323       |1333       |79        |70        |136       |
---------------------------------------------------------------------



In [70]:
tourney_round = tourney.with_column(
    "ROUND",
    when((tourney.daynum >= 134) & (tourney.daynum <= 135), lit(0))
    .when((tourney.daynum >= 136) & (tourney.daynum <= 137), lit(1))
    .when((tourney.daynum >= 138) & (tourney.daynum <= 139), lit(2))
    .when((tourney.daynum >= 143) & (tourney.daynum <= 144), lit(3))
    .when(tourney.daynum == 145, lit(4))
    .when(tourney.daynum == 152, lit(5))
    .otherwise(lit(6)),
).drop("DAYNUM")

tourney_round.show()

--------------------------------------------------------------------
|"SEASON"  |"WTEAMID"  |"LTEAMID"  |"WSCORE"  |"LSCORE"  |"ROUND"  |
--------------------------------------------------------------------
|1985      |1116       |1234       |63        |54        |1        |
|1985      |1120       |1345       |59        |58        |1        |
|1985      |1207       |1250       |68        |43        |1        |
|1985      |1229       |1425       |58        |55        |1        |
|1985      |1242       |1325       |49        |38        |1        |
|1985      |1246       |1449       |66        |58        |1        |
|1985      |1256       |1338       |78        |54        |1        |
|1985      |1260       |1233       |59        |58        |1        |
|1985      |1314       |1292       |76        |57        |1        |
|1985      |1323       |1333       |79        |70        |1        |
--------------------------------------------------------------------



In [71]:
## Add in conference names, uppercase column headers and values and one hot encode
conf = session.table('MEN.M_TEAM_CONFERENCES')

def fix_values(column):
    return F.upper(F.regexp_replace(F.col(column), "[^a-zA-Z0-9]+", "_"))

conf = conf.with_column("CONFABBREV", fix_values("CONFABBREV"))
conf = conf.with_column_renamed("SEASON", "C_SEASON")
conf = conf.with_column_renamed("TEAMID", "C_TEAMID")

conf.show()

------------------------------------------
|"C_SEASON"  |"C_TEAMID"  |"CONFABBREV"  |
------------------------------------------
|1985        |1102        |WAC           |
|1985        |1103        |OVC           |
|1985        |1104        |SEC           |
|1985        |1106        |SWAC          |
|1985        |1108        |SWAC          |
|1985        |1109        |IND           |
|1985        |1110        |ECACS         |
|1985        |1111        |SOUTHERN      |
|1985        |1112        |PAC_TEN       |
|1985        |1113        |PAC_TEN       |
------------------------------------------



In [72]:
tourney_conf_w = (
    tourney_round.join(
        conf,
        (tourney_round.col("WTEAMID") == conf.col("C_TEAMID"))
        & (tourney_round.col("SEASON") == conf.col("C_SEASON")),
    )
    .drop("C_SEASON", "C_TEAMID")
    .with_column_renamed("CONFABBREV", "W_CONF")
)
tourney_conf_w.show()

-------------------------------------------------------------------------------
|"SEASON"  |"WTEAMID"  |"LTEAMID"  |"WSCORE"  |"LSCORE"  |"ROUND"  |"W_CONF"  |
-------------------------------------------------------------------------------
|1985      |1104       |1112       |50        |41        |1        |SEC       |
|1985      |1104       |1433       |63        |59        |2        |SEC       |
|1985      |1116       |1234       |63        |54        |1        |SWC       |
|1985      |1120       |1345       |59        |58        |1        |SEC       |
|1985      |1120       |1242       |66        |64        |2        |SEC       |
|1985      |1130       |1403       |55        |53        |1        |BIG_EAST  |
|1985      |1130       |1181       |74        |73        |2        |BIG_EAST  |
|1985      |1181       |1337       |75        |62        |1        |ACC       |
|1985      |1207       |1250       |68        |43        |1        |BIG_EAST  |
|1985      |1207       |1396       |63  

In [73]:
tourney_conf_round = tourney_conf_w.join(
    conf,
    (tourney_round.col("LTEAMID") == conf.col("C_TEAMID"))
    & (tourney_round.col("SEASON") == conf.col("C_SEASON"))).drop("C_SEASON","C_TEAMID").with_column_renamed("CONFABBREV", "L_CONF")
tourney_conf_round.show()

-------------------------------------------------------------------------------------------
|"SEASON"  |"WTEAMID"  |"LTEAMID"  |"WSCORE"  |"LSCORE"  |"ROUND"  |"W_CONF"  |"L_CONF"   |
-------------------------------------------------------------------------------------------
|1985      |1104       |1112       |50        |41        |1        |SEC       |PAC_TEN    |
|1985      |1104       |1433       |63        |59        |2        |SEC       |SUN_BELT   |
|1985      |1116       |1234       |63        |54        |1        |SWC       |BIG_TEN    |
|1985      |1120       |1345       |59        |58        |1        |SEC       |BIG_TEN    |
|1985      |1120       |1242       |66        |64        |2        |SEC       |BIG_EIGHT  |
|1985      |1130       |1403       |55        |53        |1        |BIG_EAST  |SWC        |
|1985      |1130       |1181       |74        |73        |2        |BIG_EAST  |ACC        |
|1985      |1181       |1337       |75        |62        |1        |ACC       |W

In [74]:
tourney_conf_round.show()

-------------------------------------------------------------------------------------------
|"SEASON"  |"WTEAMID"  |"LTEAMID"  |"WSCORE"  |"LSCORE"  |"ROUND"  |"W_CONF"  |"L_CONF"   |
-------------------------------------------------------------------------------------------
|1985      |1104       |1112       |50        |41        |1        |SEC       |PAC_TEN    |
|1985      |1104       |1433       |63        |59        |2        |SEC       |SUN_BELT   |
|1985      |1116       |1234       |63        |54        |1        |SWC       |BIG_TEN    |
|1985      |1120       |1345       |59        |58        |1        |SEC       |BIG_TEN    |
|1985      |1120       |1242       |66        |64        |2        |SEC       |BIG_EIGHT  |
|1985      |1130       |1403       |55        |53        |1        |BIG_EAST  |SWC        |
|1985      |1130       |1181       |74        |73        |2        |BIG_EAST  |ACC        |
|1985      |1181       |1337       |75        |62        |1        |ACC       |W

In [75]:
w_t = (
    tourney_conf_round.join(
        seed_value,
        (
            (tourney_conf_round.SEASON == seed_value.SEASON)
            & (tourney_conf_round.WTEAMID == seed_value.TEAMID)
        ),
        rsuffix="_W",
    )
    .drop(["SEASON_W", "TEAMID"])
    .with_column_renamed("REGION", "W_REGION")
    .with_column_renamed("SEED", "W_SEED")
    .cache_result()
)

tourney_conf_round = (
    w_t.join(
        seed_value,
        ((w_t.SEASON == seed_value.SEASON) & (w_t.LTEAMID == seed_value.TEAMID)),
        rsuffix="_L",
    )
    .drop(["SEASON_L", "TEAMID"])
    .with_column_renamed("REGION", "L_REGION")
    .with_column_renamed("SEED", "L_SEED")
)

In [76]:
tourney_conf_round.show()

------------------------------------------------------------------------------------------------------------------------------------------
|"SEASON"  |"WTEAMID"  |"LTEAMID"  |"WSCORE"  |"LSCORE"  |"ROUND"  |"W_CONF"  |"L_CONF"  |"W_REGION"  |"W_SEED"  |"L_REGION"  |"L_SEED"  |
------------------------------------------------------------------------------------------------------------------------------------------
|1985      |1437       |1207       |66        |64        |6        |BIG_EAST  |BIG_EAST  |Z           |8         |W           |1         |
|1985      |1207       |1210       |60        |54        |4        |BIG_EAST  |ACC       |W           |1         |W           |2         |
|1985      |1210       |1228       |61        |53        |3        |ACC       |BIG_TEN   |W           |2         |W           |3         |
|1985      |1207       |1260       |65        |53        |3        |BIG_EAST  |MW_CITY   |W           |1         |W           |4         |
|1985      |1260       |137

In [77]:
season_w = season.select(
    *[F.col(col).alias(f"W_{col}") for col in season.columns]
)


season_l = season.select(
    *[F.col(col).alias(f"L_{col}") for col in season.columns]
)

In [78]:
final = (
    tourney_conf_round.join(
        season_w,
        on=(
            (tourney_conf_round.WTEAMID == season_w.W_TEAMID)
            & (tourney_conf_round.SEASON == season_w.W_SEASON)
        ),
    )
    .drop("W_TEAMID", "W_SEASON")
    .join(
        season_l,
        on=(
            (tourney_conf_round.LTEAMID == season_l.L_TEAMID)
            & (tourney_conf_round.SEASON == season_l.L_SEASON)
        ),
    )
    .drop("L_TEAMID", "L_SEASON")
)

final.show()

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [79]:
OHE = OneHotEncoder(
    input_cols=["W_CONF","L_CONF"],
    output_cols=["W_CONF","L_CONF"],
    drop_input_cols=True,
    drop="first",
    handle_unknown="ignore",
)

final = OHE.fit(final).transform(final)

  success, nchunks, nrows, ci_output = write_pandas(
  success, nchunks, nrows, ci_output = write_pandas(


In [80]:
final = final.with_columns(
    ["W_WLOCN", "W_WLOCH", "W_WLOCA", "L_WLOCN", "L_WLOCH", "L_WLOCA"],
    [
        F.col("W_WLOCN").cast(T.LongType()),
        F.col("W_WLOCH").cast(T.LongType()),
        F.col("W_WLOCA").cast(T.LongType()),
        F.col("L_WLOCN").cast(T.LongType()),
        F.col("L_WLOCH").cast(T.LongType()),
        F.col("L_WLOCA").cast(T.LongType()),
    ],
)

### This table is all season data joined with historic tournament data

In [81]:
final.write.save_as_table(
    "MEN.FINAL_FEATURES", mode="overwrite"
)

### Create season table for predicting 2024

In [82]:
season = (
    season.join(
        conf,
        (season.col("teamid") == conf.col("C_teamid"))
        & (season.col("season") == conf.col("C_season")),
    )
    .drop("C_SEASON", "C_TEAMID")
    .with_column_renamed("CONFABBREV", "CONF")
)

OHE = OneHotEncoder(
    input_cols=["CONF"],
    output_cols=["CONF"],
    drop_input_cols=True,
    drop="first",
    handle_unknown="ignore",
)

season = OHE.fit(season).transform(season)

  success, nchunks, nrows, ci_output = write_pandas(
  success, nchunks, nrows, ci_output = write_pandas(


In [83]:
region = seed_value.select(
    F.col("SEASON").alias("SEASON_1"), F.col("TEAMID").alias("TEAMID_1"), "REGION"
)

season = season.join(
    region, on=((season.season == region.season_1) & (season.teamid == region.teamid_1))
).drop("TEAMID_1", "SEASON_1")

season.write.save_as_table(
    "MEN.FINAL_SEASON_STATS", mode="overwrite"
)