In [1]:
import os
from pathlib import Path
import pandas as pd
import plotly.express as px
from src.utils import *
from src.features.base import *
from src.features.swe import *
from src.data.base import *
from src.config import *
from src.models.postprocess import *
from src.models.lgb import *
from src.models.ensemble import *
from src.forecast import *

import warnings
warnings.filterwarnings("ignore")

In [2]:
df_meta = read_meta()
df_snotel_sites_chosen = pd.concat([
    pd.read_feather("data/meta/snotel_sites_basins/snotel_sites_basins_chosen.feather"),
    pd.read_feather("data/meta/cdec_sites_chosen.feather")
])
smr_snotel_sites = df_snotel_sites_chosen.query("method not in ['around','huc']").groupby('site_id', as_index=False).agg(n_sites = ('snotel_id','nunique'))
ngm_sites = list(set(df_meta.site_id) - set(smr_snotel_sites.site_id))
ngm_sites

['colville_r_at_kettle_falls', 'skagit_ross_reservoir']

In [3]:
config_main = "afs_base_syc"
config_main_2 = "afs_ngm_syc"
dir_main = "main"

# Ensemble

In [4]:
val_years=[x for x in range(2004, 2024)]

exp_dict = {
    config_main_2: [
        "ngm_ua",
        "ngm_pdsi_ua_s51",
        "ngm_pdsi_era5_s51",
        "ngm_pdsi_ua_era5_s51",
    ],
    config_main: [
        "base_swe",
        "swe_ua",
        "pdsi_swe_s51",
        "pdsi_swe_era5",
        "pdsi_swe_era5_s51",
    ]
}

groupby_cols=["site_id", "year", "month", "day", "md_id"]

df_pred_all = []
for model_type, exp_list in exp_dict.items():
    for exp_name in exp_list:
        try:
            exp_name_final = exp_name + "_" + model_type
            df_pred = pd.read_csv(f"runs/main/{model_type}/{exp_name}/pred.csv").query('(site_id=="detroit_lake_inflow" & md_id>=24)==False')
            df_pred['pred_volume_50'] = df_pred['pred_volume_reg']
            df_pred_all.append(df_pred.assign(exp_name=exp_name))
        except:
            print(exp_name_final)
            continue
df_pred_all = pd.concat(df_pred_all)
df_pred_all = rearrange_prediction(df_pred_all)

In [5]:
ens_all = [
    "base_swe",
    "swe_ua",
    "pdsi_swe_s51",
    "pdsi_swe_era5",
    "pdsi_swe_era5_s51",
    "ngm_ua",
    "ngm_pdsi_ua_s51",
    "ngm_pdsi_era5_s51",
    "ngm_pdsi_ua_era5_s51"
]
ens_gm = [
    "base_swe",
    "swe_ua",
    "pdsi_swe_s51",
    "pdsi_swe_era5",
    "pdsi_swe_era5_s51",
]
ens_ngm = [
    "ngm_ua",
    "ngm_pdsi_ua_s51",
    "ngm_pdsi_era5_s51",
    "ngm_pdsi_ua_era5_s51"
]

groupby_cols = ["site_id", "year", "month", "day", "md_id"]
df_pred_all_ens_reg = pd.concat([
    df_pred_all, 
    custom_ensemble(df_pred_all[df_pred_all["exp_name"].isin(ens_all)], groupby_cols=groupby_cols).assign(exp_name="ens_all"),
    custom_ensemble(df_pred_all[df_pred_all["exp_name"].isin(ens_ngm)], groupby_cols=groupby_cols).assign(exp_name="ens_ngm"),
    custom_ensemble(df_pred_all[df_pred_all["exp_name"].isin(ens_gm)], groupby_cols=groupby_cols).assign(exp_name="ens_gm"),
])
df_pred_all_ens_reg = pd.concat([
    df_pred_all_ens_reg,
    pd.concat([
        df_pred_all_ens_reg[(df_pred_all_ens_reg['site_id'].isin(ngm_sites)) & (df_pred_all_ens_reg['exp_name']=='ens_ngm')],
        df_pred_all_ens_reg[(~df_pred_all_ens_reg['site_id'].isin(ngm_sites)) & (df_pred_all_ens_reg['exp_name']=='ens_all')]
    ]).assign(exp_name='ens_adj')
])
df_test = read_train(is_forecast=True)
df_pred_all_ens_reg = df_pred_all_ens_reg.assign(
    cat = lambda x: np.where(x["year"].isin(val_years), "val", "test")
)
df_pred_all_ens_reg = get_lead_cat(df_pred_all_ens_reg, df_meta)
df_pred_all_ens_reg = get_period_cat(df_pred_all_ens_reg)
df_pred_all_ens_reg = get_area_cat(df_pred_all_ens_reg)
df_pred_all_ens_reg = get_year_type(df_pred_all_ens_reg, df_test)
df_pred_all_ens_reg = rearrange_prediction(df_pred_all_ens_reg)

eval_agg(df_pred_all_ens_reg, ["exp_name"], is_include_mean_std=False).sort_values('mpl')

Unnamed: 0_level_0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
exp_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
ens_adj,14480.0,79.351025,49.429904,125.79845,62.824719,0.834323,255.595657,0.959053,20.394378,17.462082,901.393558,883.931476
ens_all,14480.0,79.486898,49.299923,126.197195,62.963575,0.837362,255.955869,0.958937,20.416237,17.994911,901.393558,883.398646
ens_gm,14480.0,80.594944,48.647901,128.892574,64.244357,0.802348,259.675101,0.957735,20.584534,13.373511,901.393558,888.020047
ens_ngm,14480.0,80.981901,49.770969,127.869535,65.3052,0.78529,259.525869,0.957784,20.774277,23.771662,901.393558,877.621895
pdsi_swe_era5,14480.0,85.536804,50.293253,129.902407,76.414754,0.616851,263.256683,0.956561,20.64104,19.376683,901.393558,882.016874
swe_ua,14480.0,85.927078,52.975001,135.173703,69.63253,0.675414,273.450725,0.953132,21.770751,6.13714,901.393558,895.256418
base_swe,14480.0,86.249608,52.318277,137.777392,68.653156,0.707597,274.433584,0.952794,22.441449,1.062784,901.393558,900.330774
pdsi_swe_era5_s51,14480.0,86.344326,50.824255,129.089388,79.119334,0.59482,261.169415,0.957247,20.388905,20.299893,901.393558,881.093665
pdsi_swe_s51,14480.0,86.604656,50.385383,130.646639,78.781945,0.596616,262.21208,0.956905,20.422836,19.991053,901.393558,881.402504
ngm_pdsi_ua_era5_s51,14480.0,88.554521,51.728206,129.508899,84.426458,0.588881,263.885326,0.956353,20.784324,30.240926,901.393558,871.152631


In [6]:
df_pred_final = df_pred_all_ens_reg.query('exp_name=="ens_all"')
eval_all(df_pred_final, [["cat"], ["area"], ["period"], ["year_type"], ["year"], ["month"]])

['cat']


Unnamed: 0_level_0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
cat,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
val,14480.0,79.486898,49.299923,126.197195,62.963575,0.837362,255.955869,0.958937,20.416237,17.994911,901.393558,883.398646




['area']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
california,1680.0,118.679646,78.234055,191.505802,86.299082,0.809524,302.811421,0.849066,28.402478,15.853396,930.035,914.181604
cascades,2160.0,51.416984,30.961699,81.364402,41.924853,0.817593,124.597295,0.919275,14.274299,0.844445,696.170804,695.326359
colorado,2240.0,16.52637,9.195258,26.08964,14.294212,0.858929,44.876084,0.88996,15.379047,1.555497,186.419813,184.864315
hard,1680.0,23.553933,11.329212,37.3382,21.994389,0.844643,88.489436,0.808243,28.335298,4.200007,139.839267,135.63926
others,6720.0,113.681267,70.821719,179.864518,90.357562,0.841667,332.600675,0.958113,20.093169,32.971471,1388.910808,1355.939337
0,2896.0,64.77164,40.108388,103.232512,50.97402,0.834471,178.674982,0.884931,21.296858,11.084963,668.275138,657.190175
0,2153.620208,48.741863,32.651186,78.159989,35.588053,0.020376,130.444383,0.058583,6.815702,13.645867,524.414242,512.556443




['period']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
long,5720.0,120.228767,70.445091,194.070853,96.170358,0.77972,343.254317,0.92579,33.282128,15.748029,899.162381,883.414352
short,8760.0,52.883759,35.492805,81.877867,41.280607,0.875,177.077868,0.980408,12.015221,19.462054,902.850445,883.388391
0,7240.0,86.556263,52.968948,137.97436,68.725483,0.82736,260.166093,0.953099,22.648675,17.605041,901.006413,883.401372
0,2149.604615,47.620112,24.714999,79.332421,38.812915,0.067373,117.504494,0.038621,15.037974,2.626213,2.607855,0.018358




['year_type']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
dry,3620.0,74.474101,38.34285,122.005745,63.073709,0.769613,238.864823,0.950603,29.282326,-78.593542,655.242564,733.836106
normal,6516.0,75.107457,47.132815,116.124282,62.065273,0.871087,253.886764,0.957624,18.559419,34.868226,875.331719,840.463494
wet,4344.0,90.233389,61.681479,144.79944,64.219249,0.843232,272.29265,0.962225,15.813056,73.175318,1145.612144,1072.436825
0,4826.666667,79.938316,49.052381,127.643156,63.119411,0.827977,255.014746,0.956817,21.218267,9.816667,892.062142,882.245475
0,1507.126184,8.921417,11.787131,15.146005,1.077715,0.052428,16.742436,0.005853,7.117402,78.924852,245.612523,173.123979




['year']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
2004,724.0,67.638508,25.378155,113.318734,64.218636,0.856354,230.137384,0.93237,23.312178,-87.135587,639.101182,726.23677
2005,724.0,74.919481,54.653155,114.928768,55.176519,0.867403,222.215392,0.958508,14.698369,57.312413,869.363586,812.051173
2006,724.0,86.973494,51.112902,139.559808,70.247771,0.870166,274.163442,0.961423,18.204559,25.9961,1093.336039,1067.339939
2007,724.0,66.655273,30.350688,112.268601,57.346531,0.857735,215.781102,0.970216,24.104076,-72.620368,695.840254,768.460622
2008,724.0,65.527336,41.218423,101.214338,54.149249,0.935083,187.67405,0.974608,12.122879,6.815862,939.324271,932.508408
2009,724.0,72.435623,50.630098,124.552495,42.124276,0.921271,236.379061,0.948462,13.597914,55.458766,886.712448,831.253682
2010,724.0,70.094339,53.037845,110.481842,46.763331,0.893646,201.272127,0.954057,13.587808,83.962376,836.593094,752.630717
2011,724.0,136.618216,92.711145,230.851537,86.291967,0.73895,401.988946,0.943823,18.394358,201.057631,1438.805785,1237.748154
2012,724.0,130.371512,59.372732,186.976696,144.765107,0.741713,490.395634,0.917223,28.511185,103.929361,955.773917,851.844556
2013,724.0,63.704584,30.885284,104.702051,55.526417,0.859116,218.208593,0.972615,26.46479,-41.481065,728.644657,770.125722




['month']


Unnamed: 0,n,mpl,mpl10,mpl50,mpl90,int_cvr,rmse,r2,mape,bias,actual_mean,pred_mean
1,2080.0,134.903311,78.461713,219.540277,106.707943,0.788462,381.187847,0.908482,38.124928,10.908673,899.162381,888.253708
2,2080.0,118.23501,69.024754,191.574232,94.106044,0.769231,335.212448,0.929227,32.482121,16.068741,899.162381,883.093639
3,2080.0,100.52277,60.374068,158.738752,82.455489,0.789904,289.730669,0.947129,27.079018,22.591542,899.162381,876.570839
4,2080.0,76.660443,50.261825,119.903637,59.815867,0.856731,221.260286,0.969166,19.426019,24.489926,899.162381,874.672455
5,2080.0,62.333253,43.120733,96.763763,47.115264,0.871635,200.400556,0.974705,13.593574,33.626497,899.162381,865.535884
6,2080.0,41.467183,27.898301,63.347564,33.155685,0.894231,139.675004,0.987712,7.929403,14.541731,899.162381,884.62065
7,2000.0,19.99829,14.624393,29.804738,15.56574,0.8935,77.887109,0.996309,3.63309,3.166965,915.316102,912.149137
0,2068.571429,79.160037,49.109398,125.667566,62.703147,0.83767,235.05056,0.958961,20.324022,17.913439,901.470055,883.556616
0,30.237158,41.434254,22.588312,68.546329,33.314224,0.053572,107.677037,0.032005,12.836411,9.944274,6.105533,14.682289






In [7]:
smr_site_score_main = eval_agg(df_pred_final, ["site_id"], is_include_mean_std=False).assign(
    nmpl = lambda x: x['mpl'] / x['actual_mean']
)[['mpl','mpl10','mpl50','mpl90','mape','rmse','r2','int_cvr','nmpl','actual_mean','pred_mean']].sort_values("nmpl")
smr_site_score_rel = eval_relative_skill(
    df_pred_final,
    pd.read_feather("data/sub/pred_median_dp.feather"),
    grouper=["site_id"],
).sort_values('mplss', ascending=False)

smr_site_score = pd.concat([smr_site_score_main, 
                            smr_site_score_rel,
                            ], axis=1)
smr_site_score

Unnamed: 0_level_0,mpl,mpl10,mpl50,mpl90,mape,rmse,r2,int_cvr,nmpl,actual_mean,pred_mean,mpl_x,mpl_y,mplss
site_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
stehekin_r_at_stehekin,34.377176,21.263052,52.333761,29.534715,7.611364,77.491467,0.651065,0.878571,0.048746,705.2284,696.202834,34.377176,63.295294,0.456876
snake_r_nr_heise,203.278933,138.338402,333.436153,138.062243,10.772078,459.648564,0.773824,0.876786,0.062584,3248.1116,3178.864376,203.278933,418.697679,0.514497
libby_reservoir_inflow,375.968581,210.111996,570.899326,346.89442,10.449976,796.958521,0.589975,0.8375,0.068934,5454.07105,5371.421298,375.968581,565.418845,0.335062
skagit_ross_reservoir,90.347845,52.774478,141.238127,77.030929,12.090274,194.313956,0.584021,0.817857,0.069168,1306.2125,1338.419546,90.347845,147.333714,0.386781
hungry_horse_reservoir_inflow,142.496102,99.344535,224.22142,103.922352,11.403805,311.609177,0.628811,0.785714,0.070679,2016.10255,1957.684111,142.496102,233.226471,0.389023
weber_r_nr_oakley,8.425528,5.186299,13.62922,6.461065,14.316023,19.351835,0.814058,0.864286,0.078053,107.9468,107.426193,8.425528,18.691249,0.549226
ruedi_reservoir_inflow,10.401733,6.123653,16.801267,8.28028,13.848113,22.473889,0.695657,0.855357,0.080575,129.09345,127.820442,10.401733,18.730874,0.444674
boise_r_nr_boise,102.810388,65.428401,168.732034,74.270728,16.106005,249.380432,0.722831,0.867857,0.084993,1209.6298,1189.948203,102.810388,184.919538,0.444026
merced_river_yosemite_at_pohono_bridge,33.625292,22.764895,52.637606,25.473375,18.896333,85.155766,0.880563,0.841071,0.08755,384.07,379.105402,33.625292,93.480355,0.640296
animas_r_at_durango,33.142912,18.043907,51.04369,30.341139,15.362057,78.994561,0.694187,0.873214,0.0895,370.31165,366.874451,33.142912,60.734567,0.454299


In [10]:
generate_hindcast_submission(
    df_pred_final,
    pd.read_csv("data/raw/cross_validation_submission_format.csv"),
    fname="ens_final", 
    dirname="data/sub"
)

Unnamed: 0,site_id,issue_date,volume_10,volume_50,volume_90
0,hungry_horse_reservoir_inflow,2004-01-01,1368.863554,1985.589929,2602.540525
1,hungry_horse_reservoir_inflow,2004-01-08,1490.711493,1971.696788,2543.073908
2,hungry_horse_reservoir_inflow,2004-01-15,1472.042871,1880.979839,2435.534682
3,hungry_horse_reservoir_inflow,2004-01-22,1429.614784,1822.297442,2321.815187
4,hungry_horse_reservoir_inflow,2004-02-01,1607.775817,2080.638612,2619.052885
...,...,...,...,...,...
14475,owyhee_r_bl_owyhee_dam,2023-06-22,470.270205,513.294309,557.398878
14476,owyhee_r_bl_owyhee_dam,2023-07-01,517.700994,528.056522,535.756656
14477,owyhee_r_bl_owyhee_dam,2023-07-08,517.855931,527.486629,535.774409
14478,owyhee_r_bl_owyhee_dam,2023-07-15,517.242458,526.202038,536.172115
