In [1]:
import pickle
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from evaluation_util import sensitivity, load_and_prepare_usgs, calculate_diff, prepare_data_for_plotting
warnings.filterwarnings("ignore")

In [2]:
number = 11
multi_models = [f'short_range_{i}' for i in range(1, number+1)]

In [3]:
with open('bridge_gua.pkl', 'rb') as file:
    bridge_gua = pickle.load(file)

with open('gua_together_raw.pkl', 'rb') as file:
    raw = pickle.load(file)

with open('gua_together_open.pkl', 'rb') as file:
    all_outputs_open = pickle.load(file)

all_outputs_da = {}
for model_name in multi_models:
    with open(f'gua_together_da_usgs_{model_name}.pkl', 'rb') as file:
        model_outputs = pickle.load(file)
    all_outputs_da[model_name] = model_outputs

In [4]:
filtered_bridge = {key: value for key, value in bridge_gua.items() if value['overflow']}

In [5]:
GAUGES_USGS = {
    # san antonio
    "10840818":'USGS-08178050',
    "10840824":'USGS-08178500',
    "10840502":'USGS-08178565',
    "10840230":'USGS-08178593',
    "10840232":'USGS-08178700',
    "10840488":'USGS-08178800',
    "10834470":'USGS-0817887350', 
    "10833740":'USGS-08178880',
    "10834916":'USGS-08178980', 
    "10835018":'USGS-08180586', 
    "10836382":'USGS-08180640', 
    "10835982":'USGS-08180700',
    "10836104":'USGS-08180800', 
    "10835030":'USGS-08181400', 
    "10836388":'USGS-08181480',
    "10836092":'USGS-08181500', 
    "10840558":'USGS-08181725',
    "10840572":'USGS-08181800',
    "3836053":'USGS-08183200',
    "3838221":'USGS-08183500', 
    "7850579":'USGS-08183900', 
    "7850611":'USGS-08183978', 
    "7850687":'USGS-08185065', 
    "7851629":'USGS-08185100', 
    "7851771":'USGS-08185500',
    "7852265":'USGS-08186000',
    "3838999":'USGS-08186500',
    "3839263":'USGS-08187500',
    "3839167":'USGS-08188060', 
    "3840125":'USGS-08188500',
    "3840137":'USGS-08188570',
    #Guadalupe
    "3585678":"USGS-08165300",
    "3585620":"USGS-08165500",
    "3585554":"USGS-08166000",
    "3585626":"USGS-08166140",
    "3585724":"USGS-08166200",
    "3587616":"USGS-08166250",
    "3589508":"USGS-08167000",
    "3589062":"USGS-08167200",
    "3589120":"USGS-08167500",
    "1619595":"USGS-08167800",
    "1619637":"USGS-08168000",
    "1620031":"USGS-08168500",
    "1619663":"USGS-08168797",
    "1619647":"USGS-08168932", 
    "1619649":"USGS-08169000",
    "1620877":"USGS-08169792",
    "1622735":"USGS-08169845",
    "1631099":"USGS-08170500",
    "1628253":"USGS-08170950",
    "1628219":"USGS-08170990",
    "1630223":"USGS-08171000",
    "1629555":"USGS-08171290",
    "1631129":"USGS-08171350",
    "1631195":"USGS-08171400",
    "1631387":"USGS-08172000",
    "1631087":"USGS-08172400",
    "1631587":"USGS-08173000",
    "1622713":"USGS-08173900", 
    "1620735":"USGS-08174200",
    "1620703":"USGS-08174550",
    "1622763":"USGS-08174600",
    "1623075":"USGS-08174700",
    "1623207":"USGS-08175000",
    "1637437":"USGS-08175800",
    "1639225":"USGS-08176500",
    "1638559":"USGS-08176900",
    "1638907":"USGS-08177500",
}

In [6]:
date_range = pd.date_range('20230501', '20230502', freq='1H').strftime('%Y%m%d%H')[:-1]

In [7]:
usgs, rename_dict = load_and_prepare_usgs('usgs_gages_nudging.csv', GAUGES_USGS)
gage_list = list(rename_dict.values())

In [8]:
factor = 0.9

In [9]:
gage_list = list(rename_dict.values())

In [10]:
results = []

for site in gage_list:
    ensemble_weights_history, overtop_flow, timelagged_das, timelagged_raws, timelagged_opens, crps_das, crps_raws, crps_opens, timelagged_da_probs, timelagged_raw_probs, timelagged_open_probs, brier_raws, brier_das, brier_opens = sensitivity(site, factor, raw, all_outputs_open, all_outputs_da, usgs, date_range, multi_models)
    results.append({
        'site': site,
        'ensemble_weights_history': ensemble_weights_history,
        'overtop_flow': overtop_flow,
        'timelagged_das': timelagged_das,
        'timelagged_raws': timelagged_raws,
        'timelagged_opens': timelagged_opens,

        'crps_das': crps_das,
        'crps_raws': crps_raws,
        'crps_opens': crps_opens,

        'timelagged_da_probs': timelagged_da_probs,
        'timelagged_raw_probs': timelagged_raw_probs,
        'timelagged_open_probs': timelagged_open_probs,

        'brier_raws': brier_raws,
        'brier_das': brier_das,
        'brier_opens': brier_opens
    })

In [11]:
diff_da, diff_da_mean = calculate_diff(results, gage_list, 'crps_das', 'crps_opens')
diff_raw, diff_raw_mean = calculate_diff(results, gage_list, 'crps_raws', 'crps_opens')

melted_data = prepare_data_for_plotting(diff_da, diff_raw)

In [12]:
median_values = melted_data.groupby(['Index', 'Skill Type'])['Skill Value'].median().unstack()
mean_values = melted_data.groupby(['Index', 'Skill Type'])['Skill Value'].mean().unstack()

In [None]:
plt.figure(figsize=(10, 8))
skill_types = median_values.columns

color_map = {
    'DA Skill': plt.cm.tab10.colors[0], 
    'Raw Skill': plt.cm.tab10.colors[1]
}

for skill_type in median_values.columns:
    for subindex in melted_data['SubIndex'].unique():
        individual_values = melted_data[(melted_data['Skill Type'] == skill_type) & (melted_data['SubIndex'] == subindex)]
        plt.plot(individual_values['Index'], individual_values['Skill Value'], linestyle=':', color=color_map[skill_type], alpha=0.2)

    latex_label = r'$\text{CRPSS}_{\text{KF}}$' if skill_type == 'DA Skill' else r'$\text{CRPSS}_{\text{NWM}}$'
    plt.plot(median_values.index, median_values[skill_type], linestyle='-', marker='o', color=color_map[skill_type], label=f'Median {latex_label}', linewidth=3)
    plt.plot(mean_values.index, mean_values[skill_type], linestyle=':', marker='o', color=color_map[skill_type], label=f'Mean {latex_label}', linewidth=3)

plt.xticks(melted_data['Index'].unique())
plt.xlabel(r'Lead time (hr)')
plt.ylabel(r'CRPSS')

plt.legend(fontsize=20, loc='best', ncol=2, frameon=False)

plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.ylim(-0.1, 1.01)

# Optimize layout
plt.tight_layout()
plt.show()