In [1]:
import pandas as pd
import plotly.express as px

from src.utils import *
from src.features.base import *
from src.features.volume_obs import *
from src.features.swe import *
from src.data.base import *
from src.config import *
from src.models.postprocess import *
from src.models.lgb import *
from src.models.ensemble import *

import importlib

EXP_NAME = 'swe_k9_s1_dp'
cfg = importlib.import_module(f"configs.forecast.{EXP_NAME}").cfg

val_years=[2020,2021,2022]
test_years=[2023]

# Evaluation

In [2]:
exp_list = [
    "swe_k9_s1_dp",
    "swe_k9_s1_dp_s1024",
    "swe_k9_s1_dp_s3024",
    "swe_k5_s3_dp",
    "swe_k5_s3_dp_s1024",
    "swe_k5_s3_dp_s3024",
    "pdsi_swe_k9_s1_dp",
    "pdsi_swe_k9_s1_dp_s1024",
    "pdsi_swe_k9_s1_dp_s3024",
    "pdsi_swe_k5_s3_dp",
    "pdsi_swe_k5_s3_dp_s1024",
    "pdsi_swe_k5_s3_dp_s3024",
]
groupby_cols=["site_id", "year", "month", "day", "md_id"]

df_pred_val_all = []
df_pred_test_all = []
for exp_name in exp_list:
    df_pred = pd.read_csv(f"runs/forecast_v2_prd/{exp_name}/pred.csv").query('(site_id=="detroit_lake_inflow" & md_id>=24)==False')
    df_pred['pred_volume_50'] = df_pred['pred_volume_reg']

    df_pred_val = df_pred[df_pred["year"].isin(val_years)]
    df_pred_test = df_pred[df_pred["year"].isin(test_years)]


    df_pred_val_all.append(df_pred_val.assign(model_id=exp_name.replace("_s1024", "").replace("_s3024", "")))
    df_pred_test_all.append(df_pred_test.assign(model_id=exp_name.replace("_s1024", "").replace("_s3024", "")))

df_pred_val_all = pd.concat(df_pred_val_all)
df_pred_test_all = pd.concat(df_pred_test_all)
df_pred_val_all = mean_ensemble(df_pred_val_all, groupby_cols=groupby_cols + ["model_id"])
df_pred_test_all = mean_ensemble(df_pred_test_all, groupby_cols=groupby_cols + ["model_id"])
df_pred_val_all = clip_prediction(rearrange_prediction(df_pred_val_all))
df_pred_test_all = clip_prediction(rearrange_prediction(df_pred_test_all))

df_pred_val_ens = custom_ensemble(df_pred_val_all, groupby_cols=groupby_cols)
df_pred_test_ens = custom_ensemble(df_pred_test_all, groupby_cols=groupby_cols)

df_test = read_train(is_forecast=True)
df_test = df_test[df_test["year"].isin(cfg["test_years"])]
df_pred_test_ens = pd.merge(
    df_pred_test_ens.drop(columns=["volume"]),
    df_test
)

df_pred_val_ens["cat"] = "val"
df_pred_test_ens["cat"] = "test"

df_pred_val_ens_pp = use_previous_forecast_sites(df_pred_val_ens,
                                                 months=[5,6,7],
                                                 cols=["pred_volume_10", "pred_volume_50"])
df_pred_test_ens_pp = use_previous_forecast_sites(df_pred_test_ens,
                                                  months=[5,6,7],
                                                  cols=["pred_volume_10", "pred_volume_50"])

In [3]:
eval_all(pd.concat([df_pred_val_ens, df_pred_test_ens]), [["cat"], ["year"], ["month"]])

['cat']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
test,724.0,81.133305,57.712249,129.91247,55.775196,0.819061,230.858854,0.954501,12.226386,40.25374,1039.918597,999.664857
val,2172.0,68.857836,40.667781,108.691744,57.213984,0.756446,202.070527,0.970596,27.960806,-35.227464,753.240214,788.467678
0,1448.0,74.995571,49.190015,119.302107,56.49459,0.787753,216.464691,0.962548,20.093596,2.513138,896.579405,894.066267
0,1023.890619,8.680067,12.052258,15.005319,1.017377,0.044276,20.356421,0.011381,11.125915,53.373272,202.712229,149.338957




['year']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
2020,724.0,63.163739,47.153755,97.313219,45.024244,0.803867,179.513837,0.980041,20.906127,47.369437,851.73058,804.361143
2021,724.0,73.624789,36.003935,121.09008,63.780352,0.691989,223.412734,0.944705,32.947493,-108.265007,599.50863,707.773637
2022,724.0,69.78498,38.845654,107.671933,62.837354,0.773481,200.895561,0.97497,30.028798,-44.786823,808.481431,853.268254
2023,724.0,81.133305,57.712249,129.91247,55.775196,0.819061,230.858854,0.954501,12.226386,40.25374,1039.918597,999.664857
0,724.0,71.926703,44.928898,113.996925,56.854287,0.772099,208.670246,0.963554,24.027201,-16.357163,824.909809,841.266973
0,0.0,7.506031,9.747219,14.398715,8.65798,0.056669,23.239457,0.016727,9.391328,74.209549,180.761091,121.677759




['month']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
1,416.0,118.923634,61.027925,201.317118,94.425859,0.814904,335.029585,0.916425,43.782025,-69.017448,823.096404,892.113852
2,416.0,94.001075,59.207199,146.846065,75.949961,0.783654,231.992667,0.959926,34.986797,-39.817803,823.096404,862.914207
3,416.0,89.263346,58.39356,142.058864,67.337614,0.747596,229.873016,0.960655,31.98484,-25.343237,823.096404,848.439641
4,416.0,70.269334,48.8398,103.64165,58.326552,0.735577,168.564017,0.978844,24.382405,-18.827776,823.096404,841.92418
5,416.0,60.299406,40.681903,92.844881,47.371433,0.786058,168.878863,0.978765,17.126902,25.014484,823.096404,798.08192
6,416.0,46.636112,30.577211,73.139814,36.19131,0.745192,148.014176,0.983688,10.168333,19.570324,823.096404,803.52608
7,400.0,22.18071,14.60852,35.095412,16.838198,0.7925,92.340663,0.993872,5.028381,-5.667548,836.22546,841.893008
0,413.714286,71.653374,44.762303,113.563401,56.634418,0.772212,196.384712,0.967454,23.922812,-16.298429,824.971983,841.270413
0,6.047432,32.270328,17.333877,54.680027,25.839787,0.029538,77.789491,0.02558,13.990337,32.982323,4.962317,32.648641






In [4]:
eval_all(pd.concat([df_pred_val_ens_pp, df_pred_test_ens_pp]), [["cat"], ["year"], ["month"]])

['cat']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
test,724.0,79.530989,55.262987,127.554784,55.775196,0.812155,224.522698,0.956964,12.202903,33.60921,1039.918597,1006.309387
val,2172.0,68.824243,40.21433,109.044415,57.213984,0.751381,201.490988,0.970765,28.310692,-38.337803,753.240214,791.578017
0,1448.0,74.177616,47.738659,118.299599,56.49459,0.781768,213.006843,0.963864,20.256797,-2.364297,896.579405,898.943702
0,1023.890619,7.570813,10.641008,13.088808,1.017377,0.042973,16.285878,0.009759,11.389927,50.874221,202.712229,151.838008




['year']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
2020,724.0,62.522392,46.851819,95.691111,45.024244,0.798343,175.598337,0.980902,21.314246,44.623446,851.73058,807.107134
2021,724.0,73.79633,35.481094,122.127543,63.780352,0.69337,223.835717,0.944496,33.24938,-115.113277,599.50863,714.621907
2022,724.0,70.154007,38.310077,109.314589,62.837354,0.762431,202.135229,0.97466,30.36845,-44.52358,808.481431,853.005011
2023,724.0,79.530989,55.262987,127.554784,55.775196,0.812155,224.522698,0.956964,12.202903,33.60921,1039.918597,1006.309387
0,724.0,71.500929,43.976494,113.672007,56.854287,0.766575,206.522995,0.964255,24.283745,-20.35105,824.909809,845.26086
0,0.0,7.122244,8.943015,14.218852,8.65798,0.053112,23.089006,0.016623,9.524896,74.604704,180.761091,121.817956




['month']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
1,416.0,118.923634,61.027925,201.317118,94.425859,0.814904,335.029585,0.916425,43.782025,-69.017448,823.096404,892.113852
2,416.0,94.001075,59.207199,146.846065,75.949961,0.783654,231.992667,0.959926,34.986797,-39.817803,823.096404,862.914207
3,416.0,89.263346,58.39356,142.058864,67.337614,0.747596,229.873016,0.960655,31.98484,-25.343237,823.096404,848.439641
4,416.0,70.269334,48.8398,103.64165,58.326552,0.735577,168.564017,0.978844,24.382405,-18.827776,823.096404,841.92418
5,416.0,60.066366,37.632091,95.195575,47.371433,0.771635,163.985485,0.979977,18.360304,8.3434,823.096404,814.753004
6,416.0,44.33366,28.190209,68.619462,36.19131,0.723558,132.428604,0.986942,10.741456,7.724318,823.096404,815.372086
7,400.0,21.735016,13.367405,34.999446,16.838198,0.79,90.961503,0.994054,5.006973,-4.925514,836.22546,841.150974
0,413.714286,71.22749,43.808313,113.23974,56.634418,0.766703,193.262125,0.968118,24.177829,-20.266294,824.971983,845.238278
0,6.047432,32.703522,18.167764,55.147952,25.839787,0.032591,80.178631,0.026069,13.808518,27.761213,4.962317,26.984679






In [5]:
eval_agg(df_pred_val_ens_pp, ["site_id"], is_include_mean_std=False).assign(
    nmpl = lambda x: x['mpl'] / x['actual_mean']
)[['mpl','mpl10','mpl50','mpl90','int_cvr','nmpl','rmse']].sort_values("nmpl")

Unnamed: 0_level_0,mpl,mpl10,mpl50,mpl90,int_cvr,nmpl,rmse
site_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
stehekin_r_at_stehekin,24.811914,21.456396,34.20376,18.775587,0.952381,0.034568,44.795357
libby_reservoir_inflow,274.201484,193.67989,424.182592,204.741971,0.940476,0.049805,531.614701
hungry_horse_reservoir_inflow,112.527225,88.166432,182.117735,67.297508,0.821429,0.052958,229.071366
green_r_bl_howard_a_hanson_dam,17.30562,11.636293,26.532763,13.747803,0.988095,0.069771,34.120816
snake_r_nr_heise,194.939166,105.745689,330.002123,149.069686,0.797619,0.073446,449.29786
skagit_ross_reservoir,109.204892,54.563422,180.372931,92.678322,0.797619,0.074694,223.725725
boise_r_nr_boise,80.854649,36.705581,129.033969,76.824396,0.880952,0.097233,189.494617
weber_r_nr_oakley,8.309057,4.967133,13.063954,6.896083,0.797619,0.109638,17.546594
missouri_r_at_toston,163.114658,106.317794,235.01488,148.011301,0.714286,0.112651,308.706909
fontenelle_reservoir_inflow,55.502444,26.034315,84.143038,56.329978,0.97619,0.114677,116.133856


In [6]:
px.box(
    eval_agg(df_pred_val_ens, ["year","month"]).reset_index(),
    x='month',
    y=['mpl']
)

In [7]:
px.line(
    eval_agg(df_pred_val_ens, ["md_id"]).reset_index(),
    x='md_id',
    y=['mpl','mpl10','mpl50','mpl90']
)