In [None]:
import numpy as np
from utils import get_data, get_table, get_predictions, other_stats, add_intervals_to_test
from plots import plot_rmse, plot_finish_groups, plot_interval_checks, plot_finish_age_gender
np.random.seed(2025)

size = 125
size_test = 1000
save_val = True
train_yr, test_yr = [2021, 2022, 2023], [2024]
train_bos, test_bos = get_data(racename="bos", size_train=size, size_test=size_test, train_lis=train_yr, test_lis=test_yr, save=save_val)
train_nyc, test_nyc = get_data(racename="nyc", size_train=size, size_test=size_test, train_lis=train_yr, test_lis=test_yr, save=save_val)
train_chi, test_chi = get_data(racename="chi", size_train=size, size_test=size_test, train_lis=train_yr, test_lis=test_yr, save=save_val)
data = {"bos": (train_bos, test_bos), "nyc": (train_nyc, test_nyc), "chi": (train_chi, test_chi)}
# test_nyc = pd.read_csv("processed_data/test_nyc.csv")

In [2]:
race = "bos"
test = data[race][1]

model_info = [
    ("model1", f"stan_results/model1/params_{race}.csv", ["alpha", "total_pace"]),
    ("model2", f"stan_results/model2/params_{race}.csv", ["alpha", "total_pace", "curr_pace"]),
    ("model3", f"stan_results/model3/params_{race}.csv", ["alpha", "total_pace", "curr_pace", "male", "age"]),
]
mpreds = {name: get_predictions(test, path, feats_lis=feats, full=False) for (name, path, feats) in model_info}
test2 = get_table(test, mpreds)
test2

Unnamed: 0,id,dist,curr_pace,total_pace,finish,age,gender,year,prop,propleft,male,propxcurr,malexage,alpha,lvl,extrap,model1,model2,model3
0,74170,5K,3.720238,3.720238,3.401725,46,M,2024,0.118497,0.881503,1,0.440839,46,1,1,-17.699733,-9.536638,-9.521755,-8.430469
1,81272,5K,3.134796,3.134796,2.956903,69,M,2024,0.118497,0.881503,1,0.371465,69,1,1,-13.496583,1.008615,1.088788,6.100313
2,69740,5K,3.958828,3.958828,3.752001,47,M,2024,0.118497,0.881503,1,0.469111,47,1,1,-9.792383,-3.319702,-3.320852,-2.147468
3,72944,5K,4.022526,4.022526,3.480861,65,M,2024,0.118497,0.881503,1,0.476659,65,1,1,-27.205383,-21.124080,-21.128830,-18.212685
4,80672,5K,3.322259,3.322259,2.998508,42,F,2024,0.118497,0.881503,0,0.393679,0,1,1,-22.855083,-10.819678,-10.765733,-12.057475
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7995,81229,40K,2.598753,2.998501,2.959806,46,F,2024,0.947980,0.052020,0,2.463564,0,1,8,-3.066125,-2.636781,-1.614243,-1.664243
7996,87296,40K,1.921599,2.492833,2.425697,72,M,2024,0.947980,0.052020,1,1.821636,72,1,8,-7.807929,-7.250377,-4.907041,-4.519448
7997,68535,40K,3.921569,3.902058,3.901165,45,M,2024,0.947980,0.052020,1,3.717567,45,1,8,-0.041273,0.258678,-0.025340,0.033353
7998,75593,40K,2.893519,3.339734,3.314611,26,F,2024,0.947980,0.052020,0,2.742997,0,1,8,-1.596035,-1.226014,-0.199021,-0.399686


In [3]:
models = ["model1", "model2", "model3"]
tbl = plot_rmse(test2, models + ["extrap"], save_name=race, bar=True)
other_stats(test2[["extrap"] + models], test2["finish"])

File saved: analysis/bos_rmse_bar.png
File saved: analysis/bos_rmse.csv
File saved: analysis/bos_rmse2.csv


Unnamed: 0,extrap,model1,model2,model3,pcnt_model1,pcnt_model2,pcnt_model3
Overall RMSE,21.402621,15.017223,13.967748,13.808904,-,-,-
Overall R-squared,0.79423,0.898696,0.91236,0.914342,-,-,-


In [4]:
plot_finish_groups(test2, label_pair=["extrap", "model2"], num=4, overall=True, save_name=race, palette="inferno")
plot_finish_age_gender(test2, label_pair=["extrap", "model2"], num=4, overall=True, save_name=race, palette="crest", grouping="age")

File saved: analysis/bos_rmse_groups.png
4 [0.0, 25.0, 50.0, 75.0] [18. 35. 44. 53.]
File saved: analysis/bos_rmse_gender_age.png


In [5]:
mpreds2 = {name: (42195 / 60) / get_predictions(test, path, feats_lis=feats, full=True) for (name, path, feats) in model_info}
intervals_tbl = add_intervals_to_test(test2, mpreds2, models)
i_check, i_sizes = plot_interval_checks(intervals_tbl, models, save_name=race)

File saved: analysis/bos_intervals.png
File saved: analysis/bos_int_sizes.csv
File saved: analysis/bos_int_checks.csv


In [6]:
i_check

Unnamed: 0_level_0,model1-size50,model2-size50,model3-size50,model1-size80,model2-size80,model3-size80,model1-size95,model2-size95,model3-size95
dist,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
5K,21.189292,21.193302,21.396361,40.620346,40.635551,41.049152,63.210318,63.243721,63.862231
10K,18.115772,16.364198,16.608186,34.657957,31.276176,31.803099,53.672321,48.387133,49.246647
15K,15.616947,13.75496,13.980408,29.846815,26.279763,26.688313,46.150954,40.581615,41.229922
20K,13.090103,11.885769,12.03701,24.987816,22.699642,22.955536,38.547961,34.95846,35.374462
25K,10.970173,8.615306,8.656534,20.927885,16.42544,16.50307,32.22264,25.290574,25.389192
30K,8.259072,5.849842,5.849874,15.759574,11.143147,11.143921,24.223662,17.11837,17.133255
35K,4.763873,2.644605,2.676472,9.070228,5.037936,5.098952,13.94273,7.743107,7.830663
40K,1.410892,0.875051,0.868875,2.687596,1.668031,1.654435,4.127912,2.5617,2.538813
