In [1]:
import os
import pandas as pd
import numpy as np
from pathlib import Path
import xarray as xr
import math
import json
from multiprocessing import Pool

import geopandas as gpd
from shapely.geometry import Point
import xyzservices.providers as xyz
from scipy.stats import linregress
from collections import defaultdict

from bokeh.plotting import figure, output_file, save, show
from bokeh.layouts import gridplot, row, column

# from bokeh.models import ColumnDataSource, LinearAxis, Range1d, HoverTool, Div
from bokeh.io import output_notebook
from bokeh.models import Div
# from bokeh.palettes import Sunset10, Vibrant7, Category20, Bokeh6, Bokeh7, Bokeh8, Greys256, Blues256

# from shapely.geometry import Polygon, Point
# from shapely.ops import unary_union
# from scipy.spatial import Voronoi

from utils.kde_estimator import KDEEstimator
from utils.fdc_estimator_context import FDCEstimationContext
from utils.fdc_data import StationData
from utils.evaluation_metrics import EvaluationMetrics

import utils.data_processing_functions as dpf

import xyzservices.providers as xyz
tiles = xyz['USGS']['USTopo']

output_notebook()

In [2]:
BASE_DIR = Path(os.getcwd())

from utils.table_notes import notes_html
attr_fpath = 'data/BCUB_watershed_attributes_updated_20250227.csv'
attr_df = pd.read_csv(attr_fpath, dtype={'Official_ID': str})
station_ids = sorted(attr_df['official_id'].unique().tolist())

# streamflow folder from (updated) HYSETS
HYSETS_DIR = Path('/home/danbot/code/common_data/HYSETS')
hs_df = pd.read_csv('data/HYSETS_watershed_properties.txt', sep=';')
hs_df = hs_df[hs_df['Official_ID'].isin(station_ids)]
hs_df.head(2)

Unnamed: 0,Watershed_ID,Source,Name,Official_ID,Centroid_Lat_deg_N,Centroid_Lon_deg_E,Drainage_Area_km2,Drainage_Area_GSIM_km2,Flag_GSIM_boundaries,Flag_Artificial_Boundaries,...,Land_Use_Wetland_frac,Land_Use_Water_frac,Land_Use_Urban_frac,Land_Use_Shrubs_frac,Land_Use_Crops_frac,Land_Use_Snow_Ice_frac,Flag_Land_Use_Extraction,Permeability_logk_m2,Porosity_frac,Flag_Subsoil_Extraction
846,847,HYDAT,CROWSNEST RIVER AT FRANK,05AA008,49.59732,-114.4106,402.6522,,0,0,...,0.0103,0.0065,0.0328,0.0785,0.0015,0.0002,1,-15.543306,0.170479,1
849,850,HYDAT,CASTLE RIVER NEAR BEAVER MINES,05AA022,49.48866,-114.1444,820.651,,0,0,...,0.0058,0.0023,0.0105,0.1156,0.0246,0.0,1,-15.929747,0.150196,1


In [3]:
# retrieve LSTM ensemble predictions
lstm_result_base_folder = Path('/home/danbot/code/neuralhydrology/data/')
results_revisions = ['20250514', '20250627']
lstm_result_files = os.listdir(lstm_result_base_folder / f'ensemble_results_{results_revisions[0]}')
lstm_result_stns = [e.split('_')[0] for e in lstm_result_files]

# filter for the common stations between BCUB region and LSTM-compatible (i.e. 1980-)
daymet_concurrent_stations = list(set(station_ids) & set(lstm_result_stns))
# assert '012414900' in daymet_concurrent_stations
print(f'There are {len(daymet_concurrent_stations)} monitored basins concurrent with LSTM ensemble results.')


There are 723 monitored basins concurrent with LSTM ensemble results.


In [4]:
watershed_id_dict = {row['Watershed_ID']: row['Official_ID'] for _, row in hs_df.iterrows()}
# and the inverse
official_id_dict = {row['Official_ID']: row['Watershed_ID'] for _, row in hs_df.iterrows()}
# also for drainage areas
da_dict = {row['Official_ID']: row['Drainage_Area_km2'] for _, row in hs_df.iterrows()}

In [5]:
def load_and_filter_hysets_data(station_ids, hs_df):
    hs_df = hs_df[hs_df['Official_ID'].isin(station_ids)]

    # load the updated HYSETS data
    updated_filename = 'HYSETS_2023_update_QC_stations.nc'
    ds = xr.open_dataset(HYSETS_DIR / updated_filename)

    # Get valid IDs as a NumPy array
    selected_ids = hs_df['Watershed_ID'].values

    # Get boolean index where watershedID in selected_set
    # safely access watershedID as a variable first
    ws_ids = ds['watershedID'].data  # or .values if you prefer
    mask = np.isin(ws_ids, selected_ids)

    # Apply mask to data
    ds = ds.sel(watershed=mask)
    # Step 1: Promote 'watershedID' to a coordinate on the 'watershed' dimension
    ds = ds.assign_coords(watershedID=("watershed", ds["watershedID"].data))

    # Step 2: Set 'watershedID' as the index for the 'watershed' dimension
    return ds.set_index(watershed="watershedID")


ds = load_and_filter_hysets_data(station_ids, hs_df)

In [6]:
def set_grid(global_min, global_max, n_grid_points=2**12):
    baseline_log_grid = np.linspace(global_min, global_max, n_grid_points)
    baseline_lin_grid = np.exp(baseline_log_grid)
    log_dx = np.gradient(baseline_log_grid)
    max_step_size = baseline_lin_grid[-1] - baseline_lin_grid[-2]
    # print(f'    max step size = {max_step_size:.1f} L/s/km^2 for n={n_grid_points:.1e} grid points')
    min_step_size = baseline_lin_grid[1] - baseline_lin_grid[0]
    # print(f'    min step size = {min_step_size:.2e} L/s/km^2 for n={n_grid_points:.1e} grid points')
    return baseline_lin_grid, baseline_log_grid, log_dx

_, baseline_log_grid, log_dx = set_grid(np.log(1e-6), np.log(1e4), n_grid_points=2**12)
print(baseline_log_grid[0], baseline_log_grid[-1], log_dx[0], log_dx[-1])

-13.815510557964274 9.210340371976184 0.005622918420009171 0.005622918420009171


In [7]:
def split_knn_label_col(df):
    # Split the string column
    # Determine format based on length
    print(df.columns)
    split_labels = df['Label'].str.split('_')
    df['n_parts'] = split_labels.str.len()

    assert len(set(df['n_parts'])) == 1, "Not all labels have the same number of parts"

    # Define expected column structures
    # format_a_cols = ["Official_ID", "k", "NN", 'concurrent', 'tree_type', 'dist', 'weighting', 'ensemble_method']
    format_cols = ["Official_ID", "k", "NN", 'tree_type', 'dist', 'ensemble_weight', 'ensemble_method']

    # Subset by format
    df_a = df[df['n_parts'] == len(format_cols)].copy()

    # Split and join with suffix to avoid conflicts
    df_a_split = df_a['Label'].str.split('_', expand=True)
    df_a_split.columns = format_cols
    merged = pd.concat([df_a.reset_index(drop=True), df_a_split.reset_index(drop=True)], axis=1)

    # Drop duplicates (if any) and update
    merged.drop(columns=['NN', 'dist', 'n_parts', 'minYears', 'minOverlapPct'], errors='ignore', inplace=True)
    merged = merged.loc[:, ~merged.columns.duplicated()]
    return merged

In [8]:
# retrieve LSTM ensemble predictions
LSTM_ensemble_result_folder = '/home/danbot/code/neuralhydrology/data/ensemble_results_20250514'# uses NSE mean as loss function
# LSTM_ensemble_result_folder = '/home/danbot/code/neuralhydrology/data/ensemble_results_20250627'# uses NSE 95% as loss function
lstm_result_files = os.listdir(LSTM_ensemble_result_folder)
lstm_result_stns = [e.split('_')[0] for e in lstm_result_files]

# filter for the common stations between BCUB region and LSTM-compatible (i.e. 1980-)
daymet_concurrent_stations = list(set(station_ids) & set(lstm_result_stns))
# assert '012414900' in daymet_concurrent_stations
print(f'There are {len(daymet_concurrent_stations)} monitored basins concurrent with LSTM ensemble results.')


There are 723 monitored basins concurrent with LSTM ensemble results.


In [9]:
results_dfs = {}
lstm_rev_date = LSTM_ensemble_result_folder.split('_')[-1]
sub_folder = f'lstm_{lstm_rev_date}' 
# results_folder = '/media/danbot/Samsung_T5/fdc_estimation_results/'
results_folder = 'data/results/fdc_estimation_results'
completed_stns = [c.split('_')[0] for c in os.listdir(os.path.join(results_folder, sub_folder))]
print(f'Found {len(set(completed_stns))} completed stations in {sub_folder} results folder.')

for method in ['parametric', 'lstm', 'knn']:
    print(f'   Loading {method} results')
    method_results_fpath = os.path.join('data', 'results', f'{method}_all_results.csv')
    if method == 'lstm':
        rev_date = LSTM_ensemble_result_folder.split('_')[-1]
        method_results_fpath = os.path.join('data', 'results', f'{method}_all_results_{rev_date}.csv')
    if os.path.exists(method_results_fpath):
        results_dfs[method] = pd.read_csv(method_results_fpath, dtype={'Official_ID': str})
        print(f'   Loaded {len(results_dfs[method])} {method} results from {method_results_fpath}')
    else:
        print(f'   {method} results not found in {method_results_fpath}, loading from individual station files...')
        res_folder = os.path.join(results_folder, method)
        if method == 'lstm':
            res_folder = os.path.join(results_folder, f'{method}_{rev_date}')
        args = [(stn, res_folder, method) for stn in completed_stns]
        with Pool() as pool:
            results_list = pool.map(load_results, args)

        foo = pd.concat(results_list, ignore_index=True)
        bad_dkl = foo[foo['KLD'].isna() | (foo['KLD'] < 0)].copy()
        if not bad_dkl.empty:
            print(f'Warning: {len(bad_dkl)} {method} rows with NaN or negative DKL values.')
            bad_stns = bad_dkl['Official_ID'].values
            raise Exception(f'Results have {len(bad_stns)} NaN or negative DKL values: {bad_stns}')
        method_results = pd.concat(results_list, ignore_index=True)
        results_dfs[method] = method_results
        print(f'   Loaded {int(len(results_dfs[method])/len(set(completed_stns)))} station results for {method} results')
        method_results.to_csv(method_results_fpath, index=False)

Found 719 completed stations in lstm_20250514 results folder.
   Loading parametric results
   Loaded 2876 parametric results from data/results/parametric_all_results.csv
   Loading lstm results
   Loaded 1438 lstm results from data/results/lstm_all_results_20250514.csv
   Loading knn results
   Loaded 57520 knn results from data/results/knn_all_results.csv


In [10]:
fdc_df = pd.concat([results_dfs['parametric'], results_dfs['lstm']], axis=0)
# fdc_df = results_dfs['parametric'].copy()
np.unique(fdc_df['Label'].values)
results_dfs['parametric'].keys()
print(len(fdc_df))
model_labels = sorted(list(set(fdc_df['Label'])))
print(model_labels)
fdc_df.head()

4314
['MLE', 'PredictedLog', 'PredictedMOM', 'RandomDraw', 'frequency', 'time']


Unnamed: 0,Official_ID,Label,KLD,EMD,RMSE,MB,RB,MARE,NSE,KGE,VE,VB_PMF,VB_FDC,MEAN_FRAC_DIFF
0,08EE008,MLE,0.095325,3.1696,0.210544,0.118106,0.008756,0.175779,0.972858,0.973509,0.832637,-0.008756,-0.008756,0.049605
1,08EE008,PredictedLog,0.124773,7.2142,0.383349,5.601531,0.415272,0.423468,0.91002,0.827263,0.583876,-0.415272,-0.415272,0.049605
2,08EE008,PredictedMOM,0.365131,7.3755,0.746716,6.719347,0.498142,1.060132,0.658596,0.578968,0.501858,-0.498142,-0.498142,0.049605
3,08EE008,RandomDraw,0.40422,10.564,0.833516,9.627484,0.713738,1.274073,0.574611,0.540447,0.286262,-0.713738,-0.713738,0.049605
4,09AA013,MLE,0.180311,2.725,0.219398,-0.510697,-0.032154,0.182355,0.948784,0.961978,0.857567,0.032154,0.032154,0.035018


In [11]:
parametric_targets = list(set(results_dfs['parametric']['Label'].values))
results_dfs['knn'] = split_knn_label_col(results_dfs['knn'])

Index(['Official_ID', 'Label', 'KLD', 'EMD', 'RMSE', 'MB', 'RB', 'MARE', 'NSE',
       'KGE', 'VE', 'VB_PMF', 'VB_FDC', 'MEAN_FRAC_DIFF'],
      dtype='object')


In [12]:
def get_result_and_ids(label, metric):
    data = fdc_df[fdc_df['Label'] == label].copy()
    data = data.dropna(subset=[metric])
    values = data[metric].values
    if metric in ['NSE', 'KGE']:
        # for NSE and KGE, we want to plot the upper bound as the maximum value
        values = 1 - values
    return values, data['Official_ID']


def get_knn_group_results(tree_type='attribute', ensemble_type='freqEnsemble', weighting='ID2', k=7, which_set='knn'):
    data = results_dfs[which_set].copy()
    data = data[data['tree_type'] == 'attribute']
    data = data[data['ensemble_method'] == ensemble_type]
    data = data[data['ensemble_weight'] == weighting]
    data = data[data['k'] == str(k)]
    return data

In [13]:
main_result_vals = {}
all_metrics = ['KLD', 'EMD', 'RMSE', 'RB', 'MB', 'NSE', 'KGE', 'VE', 'VB_PMF', 'VB_FDC', 'MEAN_FRAC_DIFF']
tree_type = 'attribute'
# Define metrics that are naturally "higher is better" and need inversion to match "lower is better"
invert_metrics = {'NSE', 'KGE', 'VE'}

def invert_if_needed(values, dm):
    """Ensure metric values follow good = 0, bad = large positive"""
    values = np.asarray(values)
    if dm in invert_metrics:
        if np.nanmin(values) < 0.:  # assumes values ~ [-inf, 1] if not yet inverted
            return 1 - values
    return values

# Loop through each metric
for dm in all_metrics:    
    # Parametric LN MoM
    for label, name in [('PredictedMOM', 'LN MoM'), ('PredictedLog', 'LN Direct'), ('MLE', 'MLE')]:
        data, ids = get_result_and_ids(label, dm)
        data = invert_if_needed(data, dm)
        main_result_vals[f'{name} {dm}'] = pd.DataFrame({'ids': ids, 'values': data})

    # kNN group results
    for k in [2, 4, 8]:
        knn_df = get_knn_group_results(k=k)
        data = invert_if_needed(knn_df[dm].values, dm)
        ids = knn_df['Official_ID'].values
        main_result_vals[f'{k} kNN {dm}'] = pd.DataFrame({'ids': ids, 'values': data})

    # LSTM models
    for label, suffix in [('time', 'LSTM time'), ('frequency', 'LSTM dist.')]:
        subset = fdc_df[fdc_df['Label'] == label]
        data = invert_if_needed(subset[dm].values, dm)
        ids = subset['Official_ID'].values
        main_result_vals[f'{suffix} {dm}'] = pd.DataFrame({'ids': ids, 'values': data})

In [14]:
# create a dataframe with all the model results indexed by station
all_results = []
for m in main_result_vals.keys():
    df = main_result_vals[m].copy()
    df.rename(columns={'values': m}, inplace=True)
    df.set_index('ids', inplace=True)
    all_results.append(df)
all_results_df = pd.concat(all_results, axis=1)
all_results_df.describe()

Unnamed: 0,LN MoM KLD,LN Direct KLD,MLE KLD,2 kNN KLD,4 kNN KLD,8 kNN KLD,LSTM time KLD,LSTM dist. KLD,LN MoM EMD,LN Direct EMD,...,LSTM time VB_FDC,LSTM dist. VB_FDC,LN MoM MEAN_FRAC_DIFF,LN Direct MEAN_FRAC_DIFF,MLE MEAN_FRAC_DIFF,2 kNN MEAN_FRAC_DIFF,4 kNN MEAN_FRAC_DIFF,8 kNN MEAN_FRAC_DIFF,LSTM time MEAN_FRAC_DIFF,LSTM dist. MEAN_FRAC_DIFF
count,719.0,719.0,719.0,719.0,719.0,719.0,719.0,719.0,719.0,719.0,...,719.0,719.0,719.0,719.0,719.0,719.0,719.0,719.0,719.0,719.0
mean,0.38182,0.231857,0.124272,0.268939,0.212758,0.188623,0.445734,0.245748,8.908878,10.345866,...,0.050034,0.01125,0.056395,0.056395,0.056395,0.056395,0.056395,0.056395,0.056395,0.056395
std,0.606189,0.299205,0.089947,0.437871,0.356477,0.287667,0.707612,0.406313,11.960709,13.857253,...,0.328589,0.353904,0.025241,0.025241,0.025241,0.025241,0.025241,0.025241,0.025241,0.025241
min,0.010143,0.011476,0.001988,0.001748,0.003461,0.003564,0.006952,0.004348,0.3732,0.3335,...,-2.573342,-2.829599,0.018139,0.018139,0.018139,0.018139,0.018139,0.018139,0.018139,0.018139
25%,0.126949,0.107299,0.052504,0.047649,0.046567,0.048828,0.07184,0.046764,2.88025,2.7192,...,-0.058961,-0.092156,0.039629,0.039629,0.039629,0.039629,0.039629,0.039629,0.039629,0.039629
50%,0.221149,0.175871,0.115462,0.119703,0.107185,0.106227,0.174699,0.106895,5.4391,5.5281,...,0.087525,0.065352,0.050275,0.050275,0.050275,0.050275,0.050275,0.050275,0.050275,0.050275
75%,0.388692,0.250882,0.177244,0.284624,0.226452,0.203985,0.475624,0.259679,9.6911,11.784,...,0.208378,0.184598,0.065667,0.065667,0.065667,0.065667,0.065667,0.065667,0.065667,0.065667
max,8.196112,3.672594,0.685356,3.285333,3.350589,3.252448,5.709942,3.874971,106.7319,114.1914,...,0.828491,0.826106,0.203445,0.203445,0.203445,0.203445,0.203445,0.203445,0.203445,0.203445


In [15]:
# load the complete years previously processed
import json
with open('data/complete_years.json', 'r') as f:
    complete_year_dict = json.load(f)

In [16]:
def retrieve_timeseries_discharge(stn, ds):
    watershed_id = official_id_dict[stn]
    df = ds['discharge'].sel(watershed=str(watershed_id)).to_dataframe(name='discharge').reset_index()
    df = df.set_index('time')[['discharge']]
    df.dropna(inplace=True)
    # clip minimum flow to 1e-4
    df['discharge'] = np.clip(df['discharge'], 1e-4, None)
    df.rename(columns={'discharge': stn}, inplace=True)
    df[f'{stn}_uar'] = 1000 * df[stn] / da_dict[stn]
    return df


def plot_ensemble_results(stn, folder):
    """Plot the ensemble results for a given station."""
    p = figure(title=f'Ensemble results for {stn}', x_axis_type='datetime', 
               y_axis_type='log', width=800, height=400)

    for date, clr in zip(['20250514', '20250627'], ['black', 'red']):
        fpath = folder / f'ensemble_results_{date}' / f'{stn}_ensemble.csv'
        df = pd.read_csv(fpath)
        df.rename(columns={'Unnamed: 0': 'time'}, inplace=True)
        df['time'] = pd.to_datetime(df['time'])
        df.set_index('time', inplace=True)
        df = np.exp(df)
        if 'streamflow_obs' in df.columns:
            p.line(df.index, df['streamflow_obs'], color=clr, legend_label=f'{date} Obs', line_width=2)
        
        sim_cols = [c for c in df.columns if c.startswith('streamflow_sim')]
        mean_sim = df[sim_cols].mean(axis=1)
        # compute the 5% quantiles on the simulation columns
        lb = df[sim_cols].quantile(0.05, axis=1)
        ub = df[sim_cols].quantile(0.95, axis=1)

        p.varea(df.index, lb, ub, color=clr, alpha=0.2, legend_label=f'{date} 90% CI')
        p.line(df.index, mean_sim, color=clr, legend_label=f'{date} Mean', line_dash='dashed', line_width=2)

    p.legend.location = 'top_left'
    p.xaxis.axis_label = 'Time'
    p.yaxis.axis_label = 'Streamflow (L/s/km2)'
    p.legend.click_policy= 'hide'
    return p


def filter_by_complete_years(stn, folder):
    all_dfs = []
    for date, clr in zip(['20250514', '20250627'], ['black', 'red']):
        fpath = folder / f'ensemble_results_{date}' / f'{stn}_ensemble.csv'
        if not os.path.exists(fpath):
            return pd.DataFrame()
        df = pd.read_csv(fpath)
        df.rename(columns={'Unnamed: 0': 'time'}, inplace=True)
        df['time'] = pd.to_datetime(df['time'])
        df.set_index('time', inplace=True)
        df.columns = [f'{c}_{date}' for c in df.columns]
        df = np.exp(df)
        all_dfs.append(df)
    result = pd.concat(all_dfs, axis=1, join='inner')
    result = result.dropna(how='any', axis=0)
    complete_years = complete_year_dict.get(stn, None).get('complete_years', [])
    # print(f'    Found {len(complete_years)} complete years for {stn}: {complete_years}')
    return result[result.index.year.isin(complete_years)]


def get_original_timeseries(stn, ds):
    """Retrieve the original timeseries for a given station."""
    watershed_id = official_id_dict[stn]
    df = ds['discharge'].sel(watershed=str(watershed_id)).to_dataframe(name='discharge').reset_index()
    df = df.set_index('time')[['discharge']]
    df.dropna(inplace=True)
    # clip minimum flow to 1e-4
    df['zero_flow_flag'] = df['discharge'] < 1e-4
    df['discharge'] = np.clip(df['discharge'], 1e-4, None)
    df.rename(columns={'discharge': stn}, inplace=True)
    df[f'{stn}_uar'] = 1000 * df[stn] / da_dict[stn]
    return df


def compute_afdcs_from_sorted_flows(df, years, da, date):
    """Compute AFDCs from sorted daily flows, rather than PDFs."""
    
    obs_cols = [c for c in df.columns if c.startswith('streamflow_obs') and c.endswith(date)]
    assert len(obs_cols) == 1, f'Expected one observed column, found {len(obs_cols)}'
    sim_cols = [c for c in df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
    assert len(sim_cols) == 10, f'Expected 10 simulated columns, found {len(sim_cols)}'

    afdc_obs, afdc_sim = [], []

    for year in years:
        year_df = df[df.index.year == year]

        # Observed
        obs_values = year_df[obs_cols[0]].dropna().values
        if len(obs_values) > 0:
            sorted_obs = np.sort(obs_values)[::-1]  # descending
            afdc_obs.append(pd.Series(sorted_obs, name=f"{year}_obs"))

        # Simulated ensemble mean
        sim_ensemble = year_df[sim_cols].dropna(how='all')  # drop rows with all NaNs
        if not sim_ensemble.empty:
            sim_mean = sim_ensemble.mean(axis=1).dropna().values
            sorted_sim = np.sort(sim_mean)[::-1]
            afdc_sim.append(pd.Series(sorted_sim, name=f"{year}_sim"))

    # Align lengths: trim to shortest year
    min_len = min(len(s) for s in afdc_obs + afdc_sim)
    afdc_obs_trimmed = [s.iloc[:min_len].reset_index(drop=True) for s in afdc_obs]
    afdc_sim_trimmed = [s.iloc[:min_len].reset_index(drop=True) for s in afdc_sim]

    # Combine into DataFrames
    obs_df = pd.concat(afdc_obs_trimmed, axis=1)
    sim_df = pd.concat(afdc_sim_trimmed, axis=1)

    # Compute percentile summary
    afdc_summary = pd.DataFrame(index=np.arange(1, min_len + 1))
    afdc_summary[f'AFDC50_obs_{date}'] = obs_df.median(axis=1)
    afdc_summary[f'AFDC10_obs_{date}'] = obs_df.quantile(0.10, axis=1)
    afdc_summary[f'AFDC90_obs_{date}'] = obs_df.quantile(0.90, axis=1)

    afdc_summary[f'AFDC50_sim_{date}'] = sim_df.median(axis=1)
    afdc_summary[f'AFDC10_sim_{date}'] = sim_df.quantile(0.10, axis=1)
    afdc_summary[f'AFDC90_sim_{date}'] = sim_df.quantile(0.90, axis=1)

    afdc_summary.index.name = 'Rank'
    return afdc_summary


def compute_ensemble_pmfs(df, sim_cols, kde, da):
    """Compute the frequency mean PMF for the simulated ensemble."""
    sim_ensemble_pmfs = []
    for sim_col in sim_cols:
        sim_vals = df[sim_col].dropna().values
        assert len(sim_vals) > 0, f'No valid values found for {sim_col}'
        sim_pmf, _ = kde.compute(sim_vals, da=da)
        sim_ensemble_pmfs.append(pd.Series(sim_pmf, index=baseline_log_grid, name=sim_col))
    # concatenate all PMFs and compute the mean
    return pd.concat(sim_ensemble_pmfs, axis=1)

def compute_series_range(pmf, grid, threshold=1e-4):
    # Find indices where pmf is above the threshold
    valid_idx = np.where(pmf >= threshold)[0]
    if len(valid_idx) > 0:
        left_bound = grid[valid_idx[0]]
        right_bound = grid[valid_idx[-1]]
    else:
        # Fallback if all values are below threshold
        left_bound = grid[0]
        right_bound = grid[-1]
    return left_bound, right_bound


In [17]:
def plot_observed_and_simulated_pdf(stn, pmf_dfs, og_df, date, pdf_plots=[]):
    """Plot the observed and simulated PDFs for a given station."""

    baseline_lin_grid = np.exp(baseline_log_grid)
    lbs, rbs = [], []
    if date == '20250514':
        title = f'{stn}: LSTM PDFs Mean NSE Objective'
    elif date == '20250627':
        title = f'{stn}: LSTM PDFs 95% Quantile Objective'
    if len(pdf_plots) > 0:
        p = figure(title=title, x_axis_type='log', toolbar_location='above',
            width=800, height=350, x_range=pdf_plots[0].x_range,
            y_range=pdf_plots[0].y_range)
    else:
        p = figure(title=title, x_axis_type='log',toolbar_location='above',
            width=800, height=350)

    # plot the observed values as quad glyphs
    observed_vals = og_df[f'{stn}_uar'].dropna().values
    observed_log_vals = np.log(observed_vals)
    min_q, max_q = observed_log_vals.min(), observed_log_vals.max()
    obs_log_dx = np.linspace(min_q - 0.1, max_q + 0.1, num=128)
    hist, edges = np.histogram(observed_log_vals, bins=obs_log_dx, density=True)
    edges = np.exp(edges)  # convert edges back to linear space
    # convert to probbility mass function (PMF)
    p.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:],
            fill_color='dodgerblue', alpha=0.5, legend_label='Observed')
    
    df = pmf_dfs[date].copy()
    for i in range(10):
        sim_col = f'streamflow_sim_{i}_{date}'
        if sim_col in df.columns:
            vals = df[sim_col].values / log_dx
            lb, rb = compute_series_range(vals, baseline_lin_grid, threshold=1e-4)
            lbs.append(lb)
            rbs.append(rb)
            p.line(baseline_lin_grid, vals, color='grey', alpha=0.5, 
                    legend_label=f'LSTM Ensemble')

    # convert pmfs to pdfs
    obs_vals = df[f'POR_obs_{date}'].values / log_dx
    lb, rb = compute_series_range(obs_vals, baseline_lin_grid, threshold=1e-4)
    lbs.append(lb)
    rbs.append(rb)
    p.line(baseline_lin_grid, obs_vals, 
           color='black', line_width=2.5, legend_label=f'POR Observed', line_dash='dotted')
    sim_time_vals = df[f'POR_sim_timeEnsemble_{date}'].values / log_dx
    lb, rb = compute_series_range(sim_time_vals, baseline_lin_grid, threshold=1e-4)
    lbs.append(lb)
    rbs.append(rb)
    p.line(baseline_lin_grid, sim_time_vals, 
           line_width=2.5, color='green', legend_label=f'timeEnsemble', line_dash='dashed')
    sim_dist_vals = df[f'POR_sim_distEnsemble_{date}'].values / log_dx
    lb, rb = compute_series_range(sim_dist_vals, baseline_lin_grid, threshold=1e-4)
    lbs.append(lb)
    rbs.append(rb)
    p.line(baseline_lin_grid, sim_dist_vals, 
           line_width=2.5, color='green', legend_label=f'distEnsemble')
    
    clrs = ['orange', 'purple']
    lss = ['solid', 'dashed', 'dotted']
    cols = list(pmf_dfs.keys())
    
    for i, lb in enumerate(['LN', 'KNN']):
        model_cols = [c for c in cols if c.startswith(lb)]
        for j, mc in enumerate(model_cols):
            mc_pdf = pmf_dfs[mc] / log_dx
            lb, rb = compute_series_range(mc_pdf, baseline_lin_grid, threshold=1e-4)
            lbs.append(lb)
            rbs.append(rb)
            p.line(baseline_lin_grid, mc_pdf, line_color=clrs[i], line_dash=lss[j]
                    , legend_label=f'{mc}', line_width=2.5)
    
    p.x_range.start = np.min(lbs)
    p.x_range.end = np.max(rbs)

    p.xaxis.axis_label = 'Log Unit Area Runoff (L/s/km2)'
    p.xaxis.axis_label = 'Unit Area Runoff (L/s/km2)'
    p.yaxis.axis_label = 'Probability Density'
    p.legend.location = 'top_left'
    p.legend.click_policy = 'hide'
    p.add_layout(p.legend[0], 'left')
    return p, baseline_log_grid, baseline_lin_grid, log_dx


def plot_observed_and_simulated_fdc(stn, pmf_dfs, baseline_lin_grid, lstm_df, og_df, date, fdc_plots=[]):
    """Plot the observed and simulated FDCs for a given station."""

    # baseline_lin_grid = np.exp(baseline_log_grid)
    if date == '20250514':
        title = f'{stn}: LSTM FDCs Mean NSE Objective'
    elif date == '20250627':
        title = f'{stn}: LSTM FDCs 95% Quantile Objective'

    if len(fdc_plots) > 0:
        fdc_plot = figure(title=title, y_axis_type='log',
            width=600, height=350, x_range=fdc_plots[0].x_range,
            y_range=fdc_plots[0].y_range, toolbar_location='above')
    else:
        fdc_plot = figure(title=title, toolbar_location='above', 
                          width=600, height=350, y_axis_type='log')
        
    # plot the observed duration curve
    pcts = np.linspace(0.01, 0.99, 99)
    observed_vals = og_df[f'{stn}_uar'].dropna().values
    obs_fdc = np.percentile(observed_vals, pcts * 100)[::-1]
    fdc_plot.line(pcts, obs_fdc, color='dodgerblue', line_width=2.5, legend_label=f'Observed')

    sim_cols = [c for c in lstm_df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
    assert len(sim_cols) == 10, f'Expected 10 simulated columns, found {len(sim_cols)}'

    sim_fdcs = pd.DataFrame(index=pcts)
    for i, sim_col in enumerate(sim_cols):
        sim_vals = lstm_df[sim_col].dropna().values
        sim_fdc = np.percentile(sim_vals, pcts * 100)[::-1]
        sim_fdcs[f'LSTM Simulation {i+1}'] = sim_fdc
        fdc_plot.line(pcts, sim_fdc, color='grey', alpha=0.5, 
                      legend_label=f'LSTM Simulation')

    # compute the temporal ensemble mean FDC
    temporal_mean_fdc = lstm_df[sim_cols].mean(axis=1).dropna().values
    time_ensemble_fdc = np.percentile(temporal_mean_fdc, pcts * 100)[::-1]
    fdc_plot.line(pcts, time_ensemble_fdc, color='green', 
                  line_width=2.5, line_dash='dashed',
                  legend_label=f'Temporal Ensemble Mean')
    
    # compute the frequency ensemble mean FDC
    freq_ensemble_fdc = sim_fdcs.mean(axis=1).values
    fdc_plot.line(pcts, freq_ensemble_fdc, color='green', line_width=2.5, 
                  legend_label=f'Dist. Ensemble Mean')
    
    clrs = ['orange', 'purple']
    lss = ['solid', 'dashed', 'dotted']
    cols = list(pmf_dfs.keys())
    fdc_metrics = {}
    fdc_metrics['distEnsemble'] = float((freq_ensemble_fdc - obs_fdc).sum())
    fdc_metrics['timeEnsemble'] = float((time_ensemble_fdc - obs_fdc).sum())
    for i, lb in enumerate(['LN', 'KNN']):
        model_cols = [c for c in cols if c.startswith(lb)]
        for j, mc in enumerate(model_cols):
            mc_pmf = pmf_dfs[mc]
            # compute the cdf
            mc_cdf = np.cumsum(mc_pmf) 
            # interpolate between the percentiles to get the 1, 99 percentile values
            model_vals = np.interp(pcts, mc_cdf, baseline_lin_grid)
            fdc_metrics[f'{mc}'] = float((model_vals - obs_fdc).sum())
            fdc_plot.line(pcts, model_vals[::-1], line_color=clrs[i], line_dash=lss[j],
                    legend_label=f'{mc}', line_width=2.5)
    
    fdc_plot.xaxis.axis_label = 'Exceedance Probability (%)'
    fdc_plot.legend.background_fill_alpha = 0.5
    fdc_plot.add_layout(fdc_plot.legend[0], 'right')
    fdc_plot.yaxis.axis_label = 'Unit Area Runoff (L/s/km2)'
    fdc_plot.legend.location = 'top_right'
    fdc_plot.legend.click_policy = 'hide'
    
    return fdc_plot, fdc_metrics

In [18]:
import sqlite3

def check_if_station_in_hydat(stn, conn):
    # Query to check if the station exists
    check_query = """
    SELECT STATION_NUMBER, STATION_NAME
    FROM STATIONS
    WHERE STATION_NUMBER = ?
    """
    # Run the query
    station_check = pd.read_sql_query(check_query, conn, params=(stn,))

    # Test if any result was returned
    if station_check.empty:
        return False
    else:
        return True
    
    
def query_data_symbols(conn):
    # Query all data symbols
    query = "SELECT SYMBOL_ID, SYMBOL_EN FROM DATA_SYMBOLS ORDER BY SYMBOL_ID"
    df = pd.read_sql_query(query, conn)
    return df.set_index('SYMBOL_ID')['SYMBOL_EN']


def reshape_hydat_wide(df):
    # First, ensure all FLOW_SYMBOL columns exist and are named correctly
    id_vars = ["STATION_NUMBER", "YEAR", "MONTH", "NO_DAYS"]
    
    # Melt flows
    flow_df = df.melt(id_vars=id_vars, 
                      value_vars=[f"FLOW{i}" for i in range(1, 32)],
                      var_name="day", value_name="flow")

    # Melt symbols
    sym_df = df.melt(id_vars=id_vars, 
                     value_vars=[f"FLOW_SYMBOL{i}" for i in range(1, 32)],
                     var_name="day", value_name="flow_symbol")

    # Extract day number
    flow_df["day"] = flow_df["day"].str.extract(r"(\d+)$").astype(int)
    sym_df["day"] = sym_df["day"].str.extract(r"(\d+)$").astype(int)

    # Merge on ID columns + day
    merged = pd.merge(flow_df, sym_df, on=id_vars + ["day"])

    # Construct date
    merged["date"] = pd.to_datetime(dict(year=merged["YEAR"], 
                                         month=merged["MONTH"], 
                                         day=merged["day"]), errors='coerce')

    # Filter out invalid days (e.g., day > NO_DAYS)
    merged = merged[merged["day"] <= merged["NO_DAYS"]]
    formatted_df = merged[["STATION_NUMBER", "date", "flow", "flow_symbol"]].dropna(subset=["flow"])
    formatted_df.set_index('date', inplace=True)
    return formatted_df


def query_hydat_database(stn):
    """Query the HYDAT database for a given station and date range."""
    hydat_path = Path('/home/danbot/code/common_data/HYDAT') / 'Hydat_20250415.sqlite3'
    # Connect to the database
    conn = sqlite3.connect(hydat_path)

    quality_symbols = query_data_symbols(conn)

    station_in_hydat = check_if_station_in_hydat(stn, conn)
    if station_in_hydat is False:
        print(f'Station {stn} not found in HYDAT database.')
        return pd.DataFrame(), quality_symbols
    
    base_columns = ["STATION_NUMBER", "YEAR", "MONTH"]
    flow_columns = [f"FLOW{i}, FLOW_SYMBOL{i}" for i in range(1, 32)]
    end_columns = ["NO_DAYS"]

    all_columns = base_columns + flow_columns + end_columns
    column_str = ",\n    ".join(all_columns)

    query = f"""
    SELECT
        {column_str}
    FROM DLY_FLOWS
    WHERE STATION_NUMBER = ?
    ORDER BY YEAR, MONTH;
    """
    df = pd.read_sql_query(query, conn, params=(stn,))

    if df.empty:
        print(f'No data found for {stn} in HYDAT.')
        return pd.DataFrame(), quality_symbols
    df = reshape_hydat_wide(df)
    return df, quality_symbols


def find_symbol_segments(symbol_df, target_symbol):
    """Return (start, end) date pairs for each continuous period of target_symbol."""
    # Filter for matching symbol only
    mask = (symbol_df['flow_symbol'] == target_symbol)
    dates = symbol_df['flow_symbol'].index[mask]

    if dates.empty:
        return []

    # Compute gaps in days between successive dates
    gaps = dates.to_series().diff().gt(pd.Timedelta(days=1)).fillna(True)

    # Group by contiguous regions (cumsum creates a new group after each gap)
    group_ids = gaps.cumsum()

    # Group by group ID and extract start and end of each contiguous block
    segments = [(group.min(), group.max()) for _, group in dates.to_series().groupby(group_ids)]

    return segments

In [19]:
def plot_quality_flag_periods(stn, df, hydat_df, quality_symbols, runoff_plot, obs_col):
    symbol_dict = quality_symbols.to_dict()
    symbol_colors = {
        'B': 'dodgerblue',  # Baseflow
        'D': 'firebrick',  # Dry weather flow
        'E': 'orange',  # Estimated
    }
    df['flow_symbol'] = hydat_df['flow_symbol'].reindex(df.index, method=None)
    uar_cols = obs_col + [c for c in df.columns if c.startswith('streamflow_sim')]

    for symbol in ['B', 'D', 'E']:
        description = symbol_dict.get(symbol, {})
        color = symbol_colors.get(symbol, 'gray')
        n_symbols = df['flow_symbol'].eq(symbol).sum()
        if n_symbols == 0:
            continue

        segments = find_symbol_segments(df[['flow_symbol']].copy(), symbol)

        for start, end in segments:
            runoff_plot.varea(
                x=pd.date_range(start, end),
                y1=0.98 * df[uar_cols].min().min(),  # get the min of the dataframe for the lower bound
                y2=1.02 * df[uar_cols].max().max(),  # a bit above max for visibility
                fill_color=color, fill_alpha=0.3,
                legend_label=f"{description} ({symbol})"
            )
    
    df['flow'] = hydat_df['flow'].reindex(df.index, method=None)
    df['hydat_uar'] = 1000 * df['flow'] / da_dict[stn]  # convert to unit area runoff (L/s/km2)
    runoff_plot.line(df.index, df['hydat_uar'],
                     color='dodgerblue', legend_label='HYDAT UAR', 
                     line_width=2, line_dash='dotted')
    return runoff_plot

In [20]:
def plot_zero_flow_periods(stn, df, runoff_plot, obs_col):

    color = 'salmon'
    n_symbols = (df['zero_flow_flag'] == True).sum()
    if n_symbols == 0:
        return runoff_plot

    df.rename({'zero_flow_flag': 'flow_symbol'}, inplace=True, axis=1)

    segments = find_symbol_segments(df[['flow_symbol']].copy(), True)
    for start, end in segments:
        runoff_plot.varea(
            x=pd.date_range(start, end),
            y1=0.98 * df[f'{stn}_uar'].min(),  # get the min of the dataframe for the lower bound
            y2=1.02 * df[f'{stn}_uar'].max(),  # a bit above max for visibility
            fill_color=color, fill_alpha=0.3,
            legend_label=f"Q=0 replaced"
        )
    runoff_plot.line(df.index, df[f'{stn}_uar'],
                     color='purple', legend_label='UAR', 
                     line_width=2, line_dash='dotted')
    return runoff_plot

In [21]:
def process_FDCs(df, stn, og_df, output_folder, result_folder):
    pmf_dfs, por_metrics, other_metrics = {}, {}, {}
    # load the parametric result
    parametric_fpath = result_folder / 'parametric' / f'{stn}_fdc_results.json'
    # if not os.path.exists(parametric_fpath):
    #     raise xception(f'Parametric results for {stn} not found at {parametric_fpath}')
    with open(parametric_fpath, 'r') as f:
        parametric_results = json.load(f)
    other_metrics[f'LN_MLE'] = parametric_results['MLE']['eval']
    other_metrics[f'LN_PredictedLog'] = parametric_results['PredictedLog']['eval']
    other_metrics[f'LN_PredictedMOM'] = parametric_results['PredictedMOM']['eval']
    pmf_dfs['LN_MLE'] = parametric_results['MLE']['pmf']
    pmf_dfs['LN_PredictedLog'] = parametric_results['PredictedLog']['pmf']
    pmf_dfs['LN_PredictedMOM'] = parametric_results['PredictedMOM']['pmf']

    knn_fpath = result_folder / 'knn' / f'{stn}_fdc_results.json'
    with open(knn_fpath, 'r') as f:
        knn_results = json.load(f)
    knn_cols = list(knn_results.keys())
    nn2 = [c for c in knn_cols if '2_NN_attribute_dist_ID2_freqEnsemble' in c]
    nn4 = [c for c in knn_cols if '4_NN_attribute_dist_ID2_freqEnsemble' in c]
    nn8 = [c for c in knn_cols if '8_NN_attribute_dist_ID2_freqEnsemble' in c]
    nn2_pmf = knn_results[nn2[0]]['pmf']
    nn4_pmf = knn_results[nn4[0]]['pmf']
    nn8_pmf = knn_results[nn8[0]]['pmf']
    pmf_dfs['KNN_2_NN'] = nn2_pmf
    pmf_dfs['KNN_4_NN'] = nn4_pmf
    pmf_dfs['KNN_8_NN'] = nn8_pmf
    # for m in ['nse', 'kld', 'rmse', 'relative_error', 'emd', 'kge']:
    other_metrics[f'KNN_2_NN'] = knn_results[nn2[0]]['eval']
    other_metrics[f'KNN_4_NN'] = knn_results[nn4[0]]['eval']
    other_metrics[f'KNN_8_NN'] = knn_results[nn8[0]]['eval']

    kde = KDEEstimator(baseline_log_grid, log_dx)
    eval_object = EvaluationMetrics(baseline_log_grid, log_dx)
    print(f'    Processing FDCs for {stn}')
    
    years = []
    for date in ['20250514', '20250627']:
        pmf_columns = []
        years = df.index.year.unique()
        da = da_dict[stn]

        pmf_fpath = f'data/results/lstm_pmfs/POR_{stn}_pmfs_{len(years)}_years_{date}.csv'
        # if os.path.exists(pmf_fpath):
        #     print(f'    PMFs for {stn} already exist, skipping.')
        #     continue

        # compute observed POR PMFs 
        obs_cols = [c for c in df.columns if c.startswith('streamflow_obs') and c.endswith(date)]
        assert len(obs_cols) == 1, f'Expected exactly one observed column for {stn} on {date}, found {len(obs_cols)}'
        por_obs_vals = df[obs_cols[0]].dropna().values
        # print(f'{stn} POR obs: {min(por_obs_vals):.2f} - {max(por_obs_vals):.2f}')
        por_obs_pmf, por_obs_pdf = kde.compute(por_obs_vals, da=da)
        pmf_columns.append(pd.Series(por_obs_pmf, index=baseline_log_grid, name=f'POR_obs_{date}'))

        # compute simulated POR PMFs
        sim_cols = [c for c in df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
        # temporal mean of the simulated ensemble
        sim_vals = df[sim_cols].mean(axis=1).dropna().values
        sim_pmf, sim_pdf = kde.compute(sim_vals, da=da) # temporal mean ensemble PMF
        pmf_columns.append(pd.Series(sim_pmf, index=baseline_log_grid, name=f'POR_sim_timeEnsemble_{date}'))

        frequency_sim_pmfs = compute_ensemble_pmfs(df, sim_cols, kde, da)
        mean_pmf = frequency_sim_pmfs.mean(axis=1).rename('sim_distEnsemble_mean')
        mean_pmf /= mean_pmf.sum()  # renormalize the PMF
        pmf_columns.append(pd.Series(mean_pmf, index=baseline_log_grid, name=f'POR_sim_distEnsemble_{date}'))

        og_vals = og_df[f'{stn}_uar'].dropna().values
        # print(f'{stn}_uar: {min(og_vals):.2f} - {max(og_vals):.2f}')
        og_pmf, og_pdf = kde.compute(og_vals, da=da)
        pmf_columns.append(pd.Series(og_pmf, index=baseline_log_grid, name=f'{stn}_uar'))
        por_pmfs = pd.concat(pmf_columns, axis=1)    # do the same for the original timeseries
        pmfs = pd.concat([por_pmfs, frequency_sim_pmfs], axis=1)  # initialize PMFs DataFrame with observed PMFs
        pmfs.index.name = 'log_uar'  # set the index name for clarity
        # save the PMFs to a CSV file
        pmfs.to_csv(pmf_fpath)
        pmf_dfs[date] = pmfs

        # evaluate metrics
        for col in [f'POR_sim_timeEnsemble_{date}', f'POR_sim_distEnsemble_{date}']:
            pmf = pmfs[col].values
            por_metrics[col] = eval_object._evaluate_fdc_metrics_from_pmf(pmf, por_obs_pmf)

        for col in frequency_sim_pmfs.columns:
            pmf = frequency_sim_pmfs[col].values
            por_metrics[col] = eval_object._evaluate_fdc_metrics_from_pmf(pmf, por_obs_pmf)

    mdf = pd.DataFrame(por_metrics).T
    odf = pd.DataFrame(other_metrics).T
    # save to csv
    metric_fpath = f'data/results/lstm_metrics/{stn}_{len(years)}_years_metrics.csv'
    mdf.to_csv(metric_fpath)
    return mdf, odf, pmf_dfs


### Dams

* 12398000 - Sullivan Lake upstream is a reservoir
* 12058800 - Dam!

### Other anomalous conditions

* 12143700 -  *Small catchment adjacent to dam* (USGS.gov) No regulation or diversion upstream from station. **Flow is mostly seepage from Chester Morse Lake.** U.S. Geological Survey satellite telemeter at station.  This represnts a very specific scenario that depends upon hydraulic conditions in neighbouring catchments but also the operation policy if the neighbouring catchment contains a large reservoir.


In [25]:
model_label_styles = {
    "timeEnsemble":       ("#2ca25f", "white"),  # green
    "distEnsemble":       ("#2ca25f", "white"),
    "LN MLE":             ("#fdae6b", "black"),  # orange
    "LN_PredictedLog":    ("#fdae6b", "black"),
    "LN_PredictedMOM":    ("#fdae6b", "black"),
    "KNN_2_NN":           ("#756bb1", "white"),  # purple
    "KNN_4_NN":           ("#756bb1", "white"),  # purple
    "KNN_8_NN":           ("#756bb1", "white"),
}


def key_mapper(mod,met):
    
    if met == 'pct_vol_bias':
        met = 'RB'
    elif met == 'mean_error':
        met = 'MB'
    elif met == 'pmf':
        met = 'VB_PMF'
    elif met == 'fdc':
        met = 'VB_FDC'
    elif met == 'diff':
        met = 'mean_frac_diff'

    met = met.upper()
    if mod.startswith('time'):
        key = f'LSTM time {met}'
    elif mod.startswith('dist'):
        key = f'LSTM dist. {met}'
    elif 'MOM' in mod:
        key = f'LN MoM {met}'
    elif 'PredictedLog' in mod:
        key = f'LN Direct {met}'
    elif 'MLE' in mod:
        key = f'MLE {met}'
    elif mod.endswith('_NN'):
        k = mod.split('_')[1]
        key = f'{k} kNN {met}'
    else:
        raise Exception(f'{mod} {met} combo not found')
    return key


def plot_runoff_timeseries(stn, lstm_df, og_df, date):

    obs_col = [c for c in lstm_df.columns if '_obs_' in c and c.endswith(date)]
    assert len(obs_col) == 1, f'Expected exactly one observed column for {stn}, found {len(obs_col)} {obs_col}'
    sim_cols = [c for c in lstm_df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
    assert len(sim_cols) == 10, f'Expected ten simulation columns for {stn}, found {len(sim_cols)}'
    # plot th time series of the observed values
    runoff_plot = figure(title=f'{stn} Observed Unit Area Runoff', x_axis_type='datetime',
                         width=800, height=300, y_axis_type='log', toolbar_location='above')
    
    hydat_df, quality_symbols = query_hydat_database(stn)

    # reindex to daily frequency and keep nans
    df = lstm_df.copy().reindex(pd.date_range(start=lstm_df.index.min(), end=lstm_df.index.max(), freq='D'))

    if not hydat_df.empty:
        runoff_plot = plot_quality_flag_periods(stn, df, hydat_df, quality_symbols, runoff_plot, obs_col)

    # plot zero flow segments on the runoff_plot
    runoff_plot = plot_zero_flow_periods(stn, og_df, runoff_plot, obs_col)
    # plot the adjusted line
    runoff_plot.line(df.index, df[obs_col[0]], color='dodgerblue',
                     legend_label='Observed UAR', line_width=2.)
    for col in sim_cols:
        runoff_plot.line(df.index, df[col], color='grey', alpha=0.5,
                         legend_label=f'LSTM ensemble')
    # compute the temporal mean of the simulated ensemble
    mean_sim = df[sim_cols].mean(axis=1)
    runoff_plot.line(df.index, mean_sim, color='green', legend_label='Ensemble Mean', line_width=3, line_dash='dashed')
    runoff_plot.xaxis.axis_label = 'Date'
    runoff_plot.yaxis.axis_label = 'Unit Area Runoff (L/s/km2)'
    runoff_plot.legend.location = 'top_left'
    runoff_plot.legend.click_policy = 'hide'
    runoff_plot.add_layout(runoff_plot.legend[0], 'left')
    runoff_plot.legend.background_fill_alpha = 0.65
    return runoff_plot

def format_metrics_table(metric_df, odf, stn, date, fdc_metrics):
    # Prep LSTM metrics
    metric_df.index.name = 'series'
    metric_df = metric_df.reset_index()
    # reverse_cols = ['kge', 'nse', 've']
    
    lstm_df = metric_df[metric_df['series'].str.endswith(date)].copy()
    metric_cols = [c for c in lstm_df.columns if c not in ['series', 'date']]    
    time_df = lstm_df[lstm_df['series'].str.contains('timeEnsemble')][metric_cols]
    freq_df = lstm_df[lstm_df['series'].str.contains('distEnsemble')][metric_cols]
    sim_df = lstm_df[lstm_df['series'].str.startswith('streamflow_sim_')][metric_cols]
    # sim_perc = pd.DataFrame(np.percentile(sim_df.values, [5, 95], axis=0), index=['5%', '95%'], columns=metric_cols)
    
    table = pd.concat([time_df, freq_df])#, sim_perc])
    table.index = ['timeEnsemble', 'distEnsemble']#, '5%', '95%']
    # table = table.loc[~table.index.isin(['5%', '95%'])]  # drop percentiles
    # table.columns = [c.split('_')[-1] for c in table.columns]
    
    # Prep ODF
    mapper = {'mean_frac_diff': 'diff', 'vb_pmf': 'pmf', 'vb_fdc': 'fdc',
             'pct_vol_bias': 'rb', 'mean_error': 'mb'}
    odf = odf.rename(columns=mapper)
    table = table.rename(columns=mapper)
    max_cols = ['nse', 'kge', 've']
    min_abs_cols = ['pmf', 'fdc', 'diff', 'vb_pmf', 'vb_fdc', 'error', 'pct_vol_bias', 'mean_error', 
                   'mean_abs_rel_error', 'rb', 'mb']

    # Combine and compute percentiles
    df = pd.concat([table, odf], axis=0)
    df.drop(labels=['median_error'], axis=1, inplace=True)
        
    colors = [
        "#2166ac", "#4393c3", "#92c5de", "#d1e5f0", "#f7f7f7",
        "#fddbc7", "#f4a582", "#d6604d", "#b2182b", "#67001f"
    ]
    # Build HTML table with row labels
    html = '<table style="width:100%; border-collapse:collapse; margin-top:25px;">'
    html += '<thead><tr>'
    table_cols = {'pmf': 'pmfBias', 'fdc': 'fdcBias', 'diff': 'fdcpErr', 'rb': 'RB', 'mb': 'MB'}
    col_order = ['rb', 'mb', 'rmse', 'kld', 'emd', 'nse', 'kge', 'pmf', 'fdc', 'diff']
    df = df[col_order]
    for c in df.columns:
        label = c
        if c in table_cols:
            label = table_cols[c]
        html += f'<th>{label}</th>'

    html += '<th style="text-align:left;">Model</th></tr></thead><tbody>'      

    for model, row in df.iterrows():
        html += f'<tr>'
        for metric in df.columns:
            key = key_mapper(model, metric)
            val = df.at[model, metric]
            global_vals = all_results_df[key].dropna().values
            minv, maxv = np.min(global_vals), np.max(global_vals)
            
            if metric in min_abs_cols:
                rank = np.searchsorted(np.sort(np.abs(global_vals)), abs(val), side="right")
            else:
                global_vals = np.sort(global_vals)
                rank = np.searchsorted(global_vals, val, side="right")
            
            pct = rank / len(global_vals)
            if metric in max_cols:
                pct = 1 - pct
            idx = min(math.floor(pct * 10), 9)
            bg = colors[idx]
            font = "white" if idx in {0, 1, 8, 9} else "black"
            
            # html += f'<td style="background:{color};text-align:center;"><span style="color:{font};">{row[c]:.2f}</span></td>'
            html += f'<td style="background:{bg}; text-align:center; padding:4px 6px;"><span style="color:{font};">{row[metric]:.2f}</span></td>'
        
        row_bg, row_font = model_label_styles.get(model, ("#ffffff", "black"))
        html += f'<td style="background:{row_bg}; color:{row_font}; text-align:left; font-weight:bold;padding:4px 4px;">{model}</td>'
        # html += f'<td style="text-align:left; font-weight:bold; padding:4px 4px;">{model}</td>'
    html += '</tbody></table>'
    # Add color legend
    legend_html = '<table style="border-collapse:collapse; margin-bottom:4px;"><caption>Rank percentile colour map (N=720):</caption><tr>'
    for i, color in enumerate(colors):
        pct_label = f"{(i+1)*10}%"
        font = "white" if i in {0, 1, 8, 9} else "black"
        legend_html += f'<td style="background:{color}; padding:4px 4px; text-align:center;"><span style="color:{font}">{pct_label}</span></td>'
    legend_html += '</tr></table>'

    return Div(text=html), Div(text=legend_html)

In [27]:
# retrieve LSTM ensemble predictions
# filter for the common stations
common_stations = list(set(station_ids) & set(lstm_result_stns))
stations_to_process = ['08AA008', '12090500', '12102190', '12115700', '12115800',
       '12157025', '12201960', '10CD003', '10CD004', '12447383',
       '08HB075', '10ED009', '08HA026', '12202310', '10CD005', '12202420',
       '12210900', '12193500', '12036650', '07BB003']

# print(f'There are {len(common_stations)} monitored basins with LSTM ensemble results.')
attr_df = attr_df[attr_df['official_id'].isin(stations_to_process)]

output_folder = BASE_DIR / 'data' / 'results' /  'lstm_plots'
result_folder = BASE_DIR / 'data' / 'results' /'fdc_estimation_results'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# plots = []
# excluded = ['10AA002'] 
dam_sites = ['12398000', '12058800', '12143700', '12323760'] 
process_plots = True
if process_plots:
    for stn in common_stations:
        if stn in dam_sites:
            continue
        output_fname = output_folder / f'{stn}_fdc.html'
        # if stn in excluded:
        #     continue
        # if os.path.exists(output_fname):
        #     print(output_fname)
        #     continue
    
        og_df = get_original_timeseries(stn, ds)
        lstm_ensemble_df = filter_by_complete_years(stn, lstm_result_base_folder)

        if lstm_ensemble_df.empty:
            print(f'No complete years found for {stn}. Skipping.')
            continue
        og_df = og_df[og_df.index.isin(lstm_ensemble_df.index)]
        mdf, odf, pmf_dfs = process_FDCs(lstm_ensemble_df, stn, og_df, output_folder, result_folder)

        dates = list(pmf_dfs.keys())
        pdf_plots, fdc_plots, metric_tables, other_tables = [], [], [], []
        date = '20250514'
        
        pdf_plot, baseline_log_grid, baseline_lin_grid, log_dx = plot_observed_and_simulated_pdf(stn, pmf_dfs, og_df, date, pdf_plots=pdf_plots)
        pdf_plots.append(pdf_plot)
        
        fdc_plot, fdc_metrics = plot_observed_and_simulated_fdc(stn, pmf_dfs, baseline_lin_grid, lstm_ensemble_df, og_df, date, fdc_plots=fdc_plots)
        fdc_plots.append(fdc_plot)
        
        ts_plot = plot_runoff_timeseries(stn, lstm_ensemble_df, og_df, dates[-1])
        
        metric_table, legend_html = format_metrics_table(mdf, odf, stn, date, fdc_metrics)

        table_div = column(
            metric_table,
            legend_html,
        )
        notes_div = Div(text=notes_html, width=1000)
        layout = column(
            row([ts_plot, table_div]), 
            row(pdf_plots[0], fdc_plots[0]),
            row(notes_div),
            )
        # save the plot to an HTML file
        output_fname = output_folder / f'{stn}_fdc.html'
        output_file(output_fname, title=f'{stn} FDCs')
        save(layout)
        print(f'    Saved plot for {stn} to {output_fname}')
    

    Processing FDCs for 10FA002
    Saved plot for 10FA002 to /home/danbot/code/distribution_estimation/docs/notebooks/data/results/lstm_plots/10FA002_fdc.html
    Processing FDCs for 08HE009
    Saved plot for 08HE009 to /home/danbot/code/distribution_estimation/docs/notebooks/data/results/lstm_plots/08HE009_fdc.html
    Processing FDCs for 08NA012
    Saved plot for 08NA012 to /home/danbot/code/distribution_estimation/docs/notebooks/data/results/lstm_plots/08NA012_fdc.html
    Processing FDCs for 15081580
Station 15081580 not found in HYDAT database.
    Saved plot for 15081580 to /home/danbot/code/distribution_estimation/docs/notebooks/data/results/lstm_plots/15081580_fdc.html
    Processing FDCs for 12010000
Station 12010000 not found in HYDAT database.
    Saved plot for 12010000 to /home/danbot/code/distribution_estimation/docs/notebooks/data/results/lstm_plots/12010000_fdc.html
    Processing FDCs for 10AA004
    Saved plot for 10AA004 to /home/danbot/code/distribution_estimatio