### Generate Diagnostic Pages 

In this section:

1. We assemble prediction outputs, performance metrics, and metadata needed to generate station-level diagnostic pages.

2. We structure these data into standardized formats for plotting flow duration curves, prediction–observation comparisons, and model-specific diagnostics.

3. We create and export visual summaries for each catchment, highlighting where models align, where they diverge, and how these patterns relate to observed streamflow behaviour.

4. We prepare these diagnostic products for inclusion in the broader reporting framework and for use in station-level interpretation.

In [None]:
import os
import pandas as pd
import numpy as np
from pathlib import Path
from multiprocessing import Pool

# import geopandas as gpd
# from shapely.geometry import Point
import math
import xyzservices.providers as xyz
import json
import sqlite3

from bokeh.plotting import figure, show, output_file, save
from bokeh.layouts import gridplot, row, column
from utils.fdc_estimator_context import FDCEstimationContext
# from utils.knn_estimator import kNNEstimator
from utils.kde_estimator import KDEEstimator
from utils.fdc_data import StationData
from utils.evaluation_metrics import EvaluationMetrics

from bokeh.io import output_notebook
from bokeh.models import Div

import utils.data_processing_functions as dpf

import xyzservices.providers as xyz
tiles = xyz['USGS']['USTopo']

output_notebook()

In [None]:
BASE_DIR = Path(os.getcwd())

from utils.table_notes import notes_html
attr_fpath = 'data/BCUB_watershed_attributes_updated_20250227.csv'
attr_df = pd.read_csv(attr_fpath, dtype={'Official_ID': str})
station_ids = sorted(attr_df['official_id'].unique().tolist())

# streamflow folder from (updated) HYSETS
HYSETS_DIR = Path('/home/danbot/code/common_data/HYSETS')
hs_df = pd.read_csv('data/HYSETS_watershed_properties.txt', sep=';')
hs_df = hs_df[hs_df['Official_ID'].isin(station_ids)]
hs_df.head(2)

In [None]:
# retrieve LSTM ensemble predictions
lstm_result_base_folder = Path('/home/danbot/code/neuralhydrology/data/')
results_revisions = ['20250514', '20250627']
lstm_result_files = os.listdir(lstm_result_base_folder / f'ensemble_results_{results_revisions[0]}')
lstm_result_stns = [e.split('_')[0] for e in lstm_result_files]

# filter for the common stations between BCUB region and LSTM-compatible (i.e. 1980-)
daymet_concurrent_stations = list(set(station_ids) & set(lstm_result_stns))
# assert '012414900' in daymet_concurrent_stations
print(f'There are {len(daymet_concurrent_stations)} monitored basins concurrent with LSTM ensemble results.')


In [None]:
# see notebook 1 for details on how these were identified
exclude_stations = ['08FA009', '08GA037', '08NC003', '12052500', '12090480', '12107950', '12108450', '12119300', 
                    '12119450', '12200684', '12200762', '12203000', '12409500', '15056070', '15081510',
                    '12323760', '12143700', '12143900', '12398000', '12058800', '12137800', '12100000']

In [None]:
watershed_id_dict = {row['Watershed_ID']: row['Official_ID'] for _, row in hs_df.iterrows()}
# and the inverse
official_id_dict = {row['Official_ID']: row['Watershed_ID'] for _, row in hs_df.iterrows()}
# also for drainage areas
da_dict = {row['Official_ID']: row['Drainage_Area_km2'] for _, row in hs_df.iterrows()}

In [None]:
regularization_type = 'discrete'
bitrate = 8

baseline_folder = BASE_DIR / 'data' / 'baseline_distributions' / f'{bitrate:02d}_bits'
fname = f'pmf_obs.csv'
if regularization_type == 'kde':
    print('Using KDE regularization for PMF estimation.')
    fname = 'pmf_kde.csv'

pmf_path = baseline_folder / fname
pmf_obs_df = pd.read_csv(pmf_path)
daymet_concurrent_stations = [s for s in daymet_concurrent_stations if s not in exclude_stations and s in pmf_obs_df.columns]
print(f'There are {len(daymet_concurrent_stations)} monitored basins concurrent with LSTM ensemble results.')

# log_edges = np.concatenate([pmf_obs_df['left_log_edges'].values[:1], pmf_obs_df['right_log_edges'].values])


In [None]:
# def load_and_filter_hysets_data(station_ids, hs_df):
#     hs_df = hs_df[hs_df['Official_ID'].isin(station_ids)]

#     # load the updated HYSETS data
#     updated_filename = 'HYSETS_2023_update_QC_stations.nc'
#     ds = xr.open_dataset(HYSETS_DIR / updated_filename)

#     # Get valid IDs as a NumPy array
#     selected_ids = hs_df['Watershed_ID'].values

#     # Get boolean index where watershedID in selected_set
#     # safely access watershedID as a variable first
#     ws_ids = ds['watershedID'].data  # or .values if you prefer
#     mask = np.isin(ws_ids, selected_ids)

#     # Apply mask to data
#     ds = ds.sel(watershed=mask)
#     # Step 1: Promote 'watershedID' to a coordinate on the 'watershed' dimension
#     ds = ds.assign_coords(watershedID=("watershed", ds["watershedID"].data))

#     # Step 2: Set 'watershedID' as the index for the 'watershed' dimension
#     return ds.set_index(watershed="watershedID")


# ds = load_and_filter_hysets_data(station_ids, hs_df)

In [None]:
def split_knn_label_col(df):
    # Split the string column
    # Determine format based on length
    split_labels = df['Label'].str.split('_')
    df['n_parts'] = split_labels.str.len()

    assert len(set(df['n_parts'])) == 1, "Not all labels have the same number of parts"

    # Define expected column structures
    # format_a_cols = ["Official_ID", "k", "NN", 'concurrent', 'tree_type', 'dist', 'weighting', 'ensemble_method']
    format_cols = ["Official_ID", "k", "NN", 'tree_type', 'dist', 'ensemble_weight', 'ensemble_method']

    # Subset by format
    df_a = df[df['n_parts'] == len(format_cols)].copy()

    # Split and join with suffix to avoid conflicts
    df_a_split = df_a['Label'].str.split('_', expand=True)
    df_a_split.columns = format_cols
    merged = pd.concat([df_a.reset_index(drop=True), df_a_split.reset_index(drop=True)], axis=1)

    # Drop duplicates (if any) and update
    merged.drop(columns=['NN', 'dist', 'n_parts', 'minYears', 'minOverlapPct'], errors='ignore', inplace=True)
    merged = merged.loc[:, ~merged.columns.duplicated()]
    return merged

In [None]:
# retrieve LSTM ensemble predictions
LSTM_ensemble_result_folder = '/home/danbot/code/neuralhydrology/data/ensemble_results_20250514'# uses NSE mean as loss function
# LSTM_ensemble_result_folder = '/home/danbot/code/neuralhydrology/data/ensemble_results_20250627'# uses NSE 95% as loss function
lstm_result_files = os.listdir(LSTM_ensemble_result_folder)
lstm_result_stns = [e.split('_')[0] for e in lstm_result_files]


### Compute the NSE on the daily timeseries for evaluation over the sample

Compute the distribution of NSE values for the LSTM ensembles to report as a benchmark in the results.

In [None]:
def compute_NSE_for_lstm(obs, sim):
    mean_observed = np.mean(obs)
    numerator = np.sum((obs - sim) ** 2)
    denominator = np.sum((obs - mean_observed) ** 2)
    nse = 1 - (numerator / denominator)
    return nse

In [None]:
NSE_vals = []
for f in lstm_result_files:
    stn = f.split('_')[0]
    if stn in exclude_stations:
        continue
    ldf = pd.read_csv(Path(LSTM_ensemble_result_folder) / f)
    # compute the ensemble mean time sieres
    ldf['qsim_mean'] = ldf[[c for c in ldf.columns if '_sim_' in c]].mean(axis=1)
    ldf.dropna(subset=['streamflow_obs', 'qsim_mean'], inplace=True)
    # compute the NSE on the daily timeseries
    nse = compute_NSE_for_lstm(ldf['streamflow_obs'].values, ldf['qsim_mean'].values)
    NSE_vals.append(nse)

In [None]:
# plot the distribution of NSE values with bokeh
p = figure(title='Distribution of NSE values for LSTM ensemble', 
           x_axis_label='NSE', y_axis_label='Frequency', 
           width=600, height=400)
# compute the empirical cdf
x = np.sort(NSE_vals)
assert np.all(np.isfinite(x))
y = np.arange(1, len(x) + 1) / len(x)
p.line(x, y, line_width=2, color='dodgerblue', legend_label='LSTM Ensemble')
# show(p)

In [None]:
mean_nse, median_nse, ci_nse = np.mean(NSE_vals), np.median(NSE_vals), (np.percentile(NSE_vals, 2.5), np.percentile(NSE_vals, 97.5))
n_failures = np.sum(np.array(NSE_vals) <= 0)
pct_failures = n_failures / len(NSE_vals) * 100
print(f'Mean NSE: {mean_nse:.2f}, Median NSE: {median_nse:.2f}, 95% CI: ({ci_nse[0]:.2f}, {ci_nse[1]:.2f}), % Failures: {n_failures}/{len(NSE_vals)} {pct_failures:.2f}%')

In [None]:
results_dfs = {}
lstm_rev_date = LSTM_ensemble_result_folder.split('_')[-1]
sub_folder = f'lstm_{lstm_rev_date}'
# results_folder = '/media/danbot/Samsung_T5/fdc_estimation_results/'
results_folder = f'data/results/fdc_estimation_results_{bitrate:02d}_bits'
if regularization_type == 'kde':
    results_folder = f'data/results/fdc_estimation_results_kde'
    
completed_stns = [c.split('_')[0] for c in os.listdir(os.path.join(results_folder, sub_folder))]
print(f'Found {len(set(completed_stns))} completed stations in {sub_folder} results folder.')

for method in ['parametric', 'lstm', 'knn']:
    print(f'   Loading {method} results')
    method_results_fpath = os.path.join('data', 'results', 'additional_results', f'{method}_all_results_{bitrate:02d}_bits.csv')
    if regularization_type == 'kde':
        method_results_fpath = os.path.join('data', 'results', 'additional_results', f'{method}_all_results_kde.csv')
        
    if method == 'lstm':
        rev_date = LSTM_ensemble_result_folder.split('_')[-1]
        method_results_fpath = os.path.join('data', 'results', 'additional_results', f'{method}_all_results_{bitrate:02d}_bits_{rev_date}.csv')
        if regularization_type == 'kde':
            method_results_fpath = os.path.join('data', 'results', 'additional_results', f'{method}_all_results_kde_{rev_date}.csv')
        
    if os.path.exists(method_results_fpath):
        results_dfs[method] = pd.read_csv(method_results_fpath, dtype={'Official_ID': str})
        print(f'   Loaded {len(results_dfs[method])} {method} results from {method_results_fpath}')
    else:
        print(f'   {method} results not found in {method_results_fpath}, loading from individual station files...')
        res_folder = os.path.join(results_folder, method)
        if method == 'lstm':
            res_folder = os.path.join(results_folder, f'{method}_{rev_date}')
        args = [(stn, res_folder, method) for stn in completed_stns]
        with Pool() as pool:
            results_list = pool.map(dpf.load_results, args)

        foo = pd.concat(results_list, ignore_index=True)
        bad_dkl = foo[foo['KLD'].isna() | (foo['KLD'] < 0)].copy()
        if not bad_dkl.empty:
            print(f'Warning: {len(bad_dkl)} {method} rows with NaN or negative DKL values.')
            bad_stns = bad_dkl['Official_ID'].values
            raise Exception(f'Results have {len(bad_stns)} NaN or negative DKL values: {bad_stns}')
        method_results = pd.concat(results_list, ignore_index=True)
        results_dfs[method] = method_results
        print(f'   Loaded {int(len(results_dfs[method])/len(set(completed_stns)))} station results for {method} results')
        method_results.to_csv(method_results_fpath, index=False)

In [None]:
fdc_df = pd.concat([results_dfs['parametric'], results_dfs['lstm']], axis=0)
# fdc_df = results_dfs['parametric'].copy()
np.unique(fdc_df['Label'].values)
results_dfs['parametric'].keys()
model_labels = sorted(list(set(fdc_df['Label'])))

parametric_targets = list(set(results_dfs['parametric']['Label'].values))
results_dfs['knn'] = split_knn_label_col(results_dfs['knn'])

In [None]:
def get_result_and_ids(label, metric):
    data = fdc_df[fdc_df['Label'] == label].copy()
    data = data.dropna(subset=[metric])
    values = data[metric].values
    if metric in ['NSE', 'KGE']:
        # for NSE and KGE, we want to plot the upper bound as the maximum value
        values = 1 - values
    return values, data['Official_ID']


def get_knn_group_results(tree_type='attribute', ensemble_type='freqEnsemble', weighting='ID2', k=7, which_set='knn'):
    data = results_dfs[which_set].copy()
    data = data[data['tree_type'] == 'attribute']
    data = data[data['ensemble_method'] == ensemble_type]
    data = data[data['ensemble_weight'] == weighting]
    data = data[data['k'] == str(k)]
    return data

In [None]:
main_result_vals = {}
all_metrics = ['KLD', 'EMD', 'RMSE', 'PB', 'NSE', 'KGE', 'VE', 'NAE', 'MAPE']
tree_type = 'attribute'
# Define metrics that are naturally "higher is better" and need inversion to match "lower is better"
invert_metrics = {'NSE', 'KGE', 'VE'}

def invert_if_needed(values, dm):
    """Ensure metric values follow good = 0, bad = large positive"""
    values = np.asarray(values)
    if dm in invert_metrics:
        if np.nanmin(values) < 0.:  # assumes values ~ [-inf, 1] if not yet inverted
            return 1 - values
    return values

# Loop through each metric
for dm in all_metrics:
    # Parametric LN MoM
    dml = dm
    for label, name in [('PredictedMOM', 'LN MoM'), ('PredictedLog', 'LN Direct'), ('MLE', 'MLE')]:
        data, ids = get_result_and_ids(label, dm)
        data = invert_if_needed(data, dm)
        main_result_vals[f'{name} {dml}'] = pd.DataFrame({'ids': ids, 'values': data})

    # kNN group results
    for k in [2, 4, 8]:
        knn_df = get_knn_group_results(k=k)
        data = invert_if_needed(knn_df[dm].values, dm)
        ids = knn_df['Official_ID'].values
        main_result_vals[f'{k} kNN {dml}'] = pd.DataFrame({'ids': ids, 'values': data})

    # LSTM models
    for label, suffix in [('time', 'LSTM time'), ('frequency', 'LSTM dist.')]:
        subset = fdc_df[fdc_df['Label'] == label]
        data = invert_if_needed(subset[dm].values, dm)
        ids = subset['Official_ID'].values
        main_result_vals[f'{suffix} {dml}'] = pd.DataFrame({'ids': ids, 'values': data})

In [None]:
# create a dataframe with all the model results indexed by station
all_results = []
for m in main_result_vals.keys():
    df = main_result_vals[m].copy()
    df.rename(columns={'values': m}, inplace=True)
    df.set_index('ids', inplace=True)
    all_results.append(df)
all_results_df = pd.concat(all_results, axis=1)


In [None]:
# load the complete years previously processed
complete_year_dict = np.load('data/complete_year_stats.npy', allow_pickle=True).item()

In [None]:
# load baseline distributions to get global support grid
fname = 'pmf_obs.csv'
if regularization_type == 'kde':
    print('Using KDE regularization for PMF estimation.')
    fname = 'pmf_kde.csv'
pmf_path = Path(os.getcwd()) / 'data' / 'baseline_distributions' / f'{bitrate:02d}_bits' / fname
pmf_obs_df = pd.read_csv(pmf_path)
baseline_log_grid = pmf_obs_df['log_x_uar'].values
# baseline_log_edges = np.concatenate([pmf_obs_df['left_log_edges'].values[:1], pmf_obs_df['right_log_edges'].values])
# baseline_log_w = np.diff(baseline_log_edges)


In [None]:
def filter_LSTM_outputs(stn_data, folder):
    all_dfs = []
    stn = stn_data.target_stn
    da = stn_data.target_da
    obs_cols = []
    for date, clr in zip(['20250514', '20250627'], ['black', 'red']):
        fpath = folder / f'ensemble_results_{date}' / f'{stn}_ensemble.csv'
        if not os.path.exists(fpath):
            return pd.DataFrame()
        df = pd.read_csv(fpath)
        df.rename(columns={'Unnamed: 0': 'time'}, inplace=True)
        df['time'] = pd.to_datetime(df['time'])
        df.set_index('time', inplace=True)
        df.columns = [f'{c}_{date}' for c in df.columns]
        obs_cols += [c for c in df.columns if c.startswith('streamflow_obs')]
        df = np.exp(df)
        all_dfs.append(df)
    result = pd.concat(all_dfs, axis=1, join='inner')
    result = result.dropna(how='any', axis=0)
    result['discharge'] = result[obs_cols[0]]
    # complete_years = complete_year_dict.get(stn, None).get('hyd_years', [])
    cal_df, hyd_df = stn_data.filter_complete_hydrological_years(result, da, min_days=20)
    # print(f'    Found {len(complete_years)} complete years for {stn}: {complete_years}')
    return hyd_df


def plot_ensemble_results(stn, folder):
    """Plot the ensemble results for a given station."""
    p = figure(title=f'Ensemble results for {stn}', x_axis_type='datetime', 
               y_axis_type='log', width=800, height=400)

    for date, clr in zip(['20250514', '20250627'], ['black', 'red']):
        fpath = folder / f'ensemble_results_{date}' / f'{stn}_ensemble.csv'
        df = pd.read_csv(fpath)
        df.rename(columns={'Unnamed: 0': 'time'}, inplace=True)
        df['time'] = pd.to_datetime(df['time'])
        df.set_index('time', inplace=True)
        df = np.exp(df)
        if 'streamflow_obs' in df.columns:
            p.line(df.index, df['streamflow_obs'], color=clr, legend_label=f'{date} Obs', line_width=2)
        
        sim_cols = [c for c in df.columns if c.startswith('streamflow_sim')]
        mean_sim = df[sim_cols].mean(axis=1)
        # compute the 5% quantiles on the simulation columns
        lb = df[sim_cols].quantile(0.05, axis=1)
        ub = df[sim_cols].quantile(0.95, axis=1)

        p.varea(df.index, lb, ub, color=clr, alpha=0.2, legend_label=f'{date} 90% CI')
        p.line(df.index, mean_sim, color=clr, legend_label=f'{date} Mean', line_dash='dashed', line_width=2)

    p.legend.location = 'top_left'
    p.xaxis.axis_label = 'Time'
    p.yaxis.axis_label = 'Streamflow (L/s/km2)'
    p.legend.click_policy= 'hide'
    return p


def compute_afdcs_from_sorted_flows(df, years, da, date):
    """Compute AFDCs from sorted daily flows, rather than PDFs."""
    
    obs_cols = [c for c in df.columns if c.startswith('streamflow_obs') and c.endswith(date)]
    assert len(obs_cols) == 1, f'Expected one observed column, found {len(obs_cols)}'
    sim_cols = [c for c in df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
    assert len(sim_cols) == 10, f'Expected 10 simulated columns, found {len(sim_cols)}'

    afdc_obs, afdc_sim = [], []

    for year in years:
        year_df = df[df.index.year == year]

        # Observed
        obs_values = year_df[obs_cols[0]].dropna().values
        if len(obs_values) > 0:
            sorted_obs = np.sort(obs_values)[::-1]  # descending
            afdc_obs.append(pd.Series(sorted_obs, name=f"{year}_obs"))

        # Simulated ensemble mean
        sim_ensemble = year_df[sim_cols].dropna(how='all')  # drop rows with all NaNs
        if not sim_ensemble.empty:
            sim_mean = sim_ensemble.mean(axis=1).dropna().values
            sorted_sim = np.sort(sim_mean)[::-1]
            afdc_sim.append(pd.Series(sorted_sim, name=f"{year}_sim"))

    # Align lengths: trim to shortest year
    min_len = min(len(s) for s in afdc_obs + afdc_sim)
    afdc_obs_trimmed = [s.iloc[:min_len].reset_index(drop=True) for s in afdc_obs]
    afdc_sim_trimmed = [s.iloc[:min_len].reset_index(drop=True) for s in afdc_sim]

    # Combine into DataFrames
    obs_df = pd.concat(afdc_obs_trimmed, axis=1)
    sim_df = pd.concat(afdc_sim_trimmed, axis=1)

    # Compute percentile summary
    afdc_summary = pd.DataFrame(index=np.arange(1, min_len + 1))
    afdc_summary[f'AFDC50_obs_{date}'] = obs_df.median(axis=1)
    afdc_summary[f'AFDC10_obs_{date}'] = obs_df.quantile(0.10, axis=1)
    afdc_summary[f'AFDC90_obs_{date}'] = obs_df.quantile(0.90, axis=1)

    afdc_summary[f'AFDC50_sim_{date}'] = sim_df.median(axis=1)
    afdc_summary[f'AFDC10_sim_{date}'] = sim_df.quantile(0.10, axis=1)
    afdc_summary[f'AFDC90_sim_{date}'] = sim_df.quantile(0.90, axis=1)

    afdc_summary.index.name = 'Rank'
    return afdc_summary


def compute_ensemble_pmfs(df, sim_cols, stn_data):
    """Compute the frequency mean PMF for the simulated ensemble."""
    sim_ensemble_pmfs = []
    
    for sim_col in sim_cols:
        sim_vals = df[sim_col].dropna().values # should be uar
        assert len(sim_vals) > 0, f'No valid values found for {sim_col}'
        sim_pmf = stn_data.build_pmf_from_timeseries(sim_vals, stn_data.min_measurable_log_uar, stn_data.target_da)
        sim_ensemble_pmfs.append(pd.Series(sim_pmf, index=baseline_log_grid, name=sim_col))
    # concatenate all PMFs and compute the mean
    return pd.concat(sim_ensemble_pmfs, axis=1)


def compute_series_range(pmf, grid, threshold=1e-4):
    # Find indices where pmf is above the threshold
    valid_idx = np.where(pmf >= threshold)[0]
    if len(valid_idx) > 0:
        left_bound = grid[valid_idx[0]]
        right_bound = grid[valid_idx[-1]]
    else:
        # Fallback if all values are below threshold
        left_bound = grid[0]
        right_bound = grid[-1]
    return left_bound, right_bound


In [None]:
a = np.array([1, 1, 1, 0.5, 0.2, 0])
np.argmax(a < 1.0)

In [None]:
def adjust_pmf_for_zero_flows(log_uar_data, station_data):
    """Adjust PMF to account for zero flows."""
    log_edges_uar = station_data.log_edges_extended
    pmf_arr = np.zeros(len(log_edges_uar) - 1)

    # add the zero-flow bin counts as the first bin
    z_index = station_data.zero_bin_index
    pmf_arr[z_index] = (log_uar_data < np.log(station_data.min_measurable_uar)).sum()


    # Empirical (counts -> pmf/pdf)
    # log_edges count the positive flows only
    counts, edges = np.histogram(log_uar_data, bins=log_edges_uar, density=False)
    counts = counts.astype(int)

    # add the non-zero empirical counts to the arrays
    # if z_index > 0:
    #     print(z_index, (log_uar_data < np.log(station_data.min_measurable_uar)).sum())
    #     print(pmf_arr[z_index-2:z_index+3], counts[z_index-2:z_index+3])
    

    pmf_arr[z_index:] = counts[z_index:]
    # if z_index > 0:
    #     print(pmf_arr[z_index-2:z_index+3], counts[z_index-2:z_index+3])

    return pmf_arr / np.sum(pmf_arr)


def plot_observed_and_simulated_pdf(stn, pmf_dfs, og_df, date, stn_data, pdf_plots=[]):
    """Plot the observed and simulated PDFs for a given station."""

    log_w = stn_data.log_w
    da = stn_data.target_da
    max_p = 0

    baseline_lin_grid = np.exp(baseline_log_grid)
    lbs, rbs = [], []

    if date == '20250514':
        title = f'{stn} ({da} km²): LSTM PDFs Mean NSE Objective'
    elif date == '20250627':
        title = f'{stn} ({da} km²): LSTM PDFs 95% Quantile Objective'
    if len(pdf_plots) > 0:
        p = figure(title=title, x_axis_type='log', toolbar_location='above',
            width=1000, height=350, x_range=pdf_plots[0].x_range,
            y_range=pdf_plots[0].y_range)
    else:
        p = figure(title=title, x_axis_type='log',toolbar_location='above', 
            width=1000, height=350)

    # plot the observed values as quad glyphs
    observed_vals = og_df[f'{stn}_uar'].dropna().values
    observed_log_vals = np.log(observed_vals)
    edges = np.exp(stn_data.log_edges_extended)  # convert edges back to linear space
    
    hist = adjust_pmf_for_zero_flows(observed_log_vals, stn_data) / log_w    
    # convert to probbility mass function (PMF)
    
    p.quad(top=hist, bottom=0.99 * min(hist), left=edges[:-1], right=edges[1:],
            fill_color='dodgerblue', alpha=0.5, legend_label='Observed')
    max_p = max(max_p, np.max(hist))

    df = pmf_dfs[date].copy()

    for i in range(10):
        sim_col = f'streamflow_sim_{i}_{date}'
        if sim_col in df.columns:
            vals = df[sim_col].values / log_w
            lb, rb = compute_series_range(vals, baseline_lin_grid, threshold=1e-4)
            lbs.append(lb)
            rbs.append(rb)
            p.line(baseline_lin_grid, vals, color='grey', alpha=0.5, 
                    legend_label=f'LSTM Ensemble')
            max_p = max(max_p, np.max(vals))

    # convert pmfs to pdfs
    # obs_vals = df[f'POR_obs_{date}'].values / log_w
    # max_p = max(max_p, np.max(obs_vals))
    # lb, rb = compute_series_range(obs_vals, baseline_lin_grid, threshold=1e-4)
    # lbs.append(lb)
    # rbs.append(rb)
    # p.line(baseline_lin_grid, obs_vals, 
    #        color='black', line_width=2.5, legend_label=f'POR Observed', line_dash='dotted')
    sim_time_vals = df[f'POR_sim_timeEnsemble_{date}'].values
    sim_time_pmf = np.zeros_like(sim_time_vals)
    sim_time_pmf[stn_data.zero_bin_index] = sim_time_vals[:stn_data.zero_bin_index].sum()
    sim_time_pmf[stn_data.zero_bin_index:] = sim_time_vals[stn_data.zero_bin_index:]
    p.line(baseline_lin_grid, sim_time_pmf / log_w, 
           line_width=2.5, color='green', legend_label=f'timeEnsemble', line_dash='dashed')
    sim_dist_vals = df[f'POR_sim_distEnsemble_{date}'].values 
    sim_dist_pmf = np.zeros_like(sim_dist_vals)
    sim_dist_pmf[stn_data.zero_bin_index] = sim_dist_vals[:stn_data.zero_bin_index].sum()
    sim_dist_pmf[stn_data.zero_bin_index:] = sim_dist_vals[stn_data.zero_bin_index:]
    lb, rb = compute_series_range(sim_dist_pmf / log_w, baseline_lin_grid, threshold=1e-4)
    lbs.append(lb)
    rbs.append(rb)
    p.line(baseline_lin_grid, sim_dist_pmf / log_w, 
           line_width=2.5, color='green', legend_label=f'distEnsemble')
    max_p = max(max_p, np.max(vals), np.max(sim_time_vals), np.max(sim_dist_vals))
    le, re = np.exp(stn_data.log_edges_extended[0]), stn_data.zero_equiv_uar

    p.quad(top=max_p * 1.01, bottom=0, left=le, right=re, 
           fill_color='black', line_color=None, fill_alpha=0.1, legend_label='Zero-equiv.')
    
    clrs = ['orange', 'purple']
    lss = ['solid', 'dashed', 'dotted']
    cols = list(pmf_dfs.keys())
    for i, lb in enumerate(['LN', 'KNN']):
        model_cols = [c for c in cols if c.startswith(lb)]
        for j, mc in enumerate(model_cols):
            zf_index = stn_data.zero_bin_index
            pmf = np.array(pmf_dfs[mc].copy())
            low_p = float(np.sum(pmf[:zf_index]))
            pmf[:zf_index] = 0.
            pmf[zf_index] = low_p
            mc_pdf = pmf / log_w
            lb, rb = compute_series_range(mc_pdf, baseline_lin_grid, threshold=1e-4)
            lbs.append(lb)
            rbs.append(rb)
            p.line(baseline_lin_grid, mc_pdf, line_color=clrs[i], line_dash=lss[j], 
                   legend_label=f'{mc}', line_width=2.5)
    
    l = max(stn_data.zero_bin_index - 2, 0)
    p.x_range.start = baseline_lin_grid[l] 
    p.x_range.end = np.max(rbs)

    p.xaxis.axis_label = 'Log Unit Area Runoff (L/s/km2)'
    p.xaxis.axis_label = 'Unit Area Runoff (L/s/km2)'
    p.yaxis.axis_label = 'Probability Density'
    p.legend.location = 'top_left'
    p.legend.click_policy = 'hide'
    p.add_layout(p.legend[0], 'left')
    return p, baseline_log_grid, baseline_lin_grid, log_w


def plot_observed_and_simulated_fdc(stn, pmf_dfs, baseline_lin_grid, lstm_df, og_df, date, stn_data, fdc_plots=[]):
    """Plot the observed and simulated FDCs for a given station."""

    # baseline_lin_grid = np.exp(baseline_log_grid)
    zi = stn_data.zero_bin_index
    if date == '20250514':
        title = f'{stn}: FDCs'
    elif date == '20250627':
        title = f'{stn}: FDCs'

    if len(fdc_plots) > 0:
        fdc_plot = figure(title=title, y_axis_type='log',
            width=600, height=350, x_range=fdc_plots[0].x_range,
            y_range=fdc_plots[0].y_range, toolbar_location='above')
    else:
        fdc_plot = figure(title=title, toolbar_location='above', 
                          width=600, height=350, y_axis_type='log')
        
    # plot the observed duration curve
    # pcts = np.linspace(0.01, 0.99, 99) * 100
    observed_vals = og_df[f'{stn}_uar'].dropna().values
    obs_pmf = stn_data.build_pmf_from_timeseries(observed_vals, stn_data.min_measurable_log_uar, stn_data.target_da)
    z = 1e-9
    obs_cdf = np.cumsum(obs_pmf)
    obs_fdc = 1 - obs_cdf
    # diffs = np.r_[stn_data.zero_bin_index == 0, ~np.isclose(np.diff(obs_fdc), 0.0)]
    # get the index of the last value >= 1.0
    
    ci, ce = int(np.argmax(obs_fdc < 1.0)- 1), np.argmax(obs_fdc == np.min(obs_fdc)) + 1
    # obs_fdc = np.percentile(observed_vals, pcts)[::-1]
    # obs_fdc = np.clip(obs_fdc, a_min=stn_data.min_measurable_uar, a_max=None)
    fdc_plot.line(obs_fdc[ci:ce], baseline_lin_grid[ci:ce], color='dodgerblue', line_width=2.5, legend_label=f'Observed')

    sim_cols = [c for c in lstm_df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
    assert len(sim_cols) == 10, f'Expected 10 simulated columns, found {len(sim_cols)}'

    # sim_fdcs = pd.DataFrame(index=pcts)
    sim_fdcs = pd.DataFrame(index=baseline_lin_grid)
    for i, sim_col in enumerate(sim_cols):
        sim_vals = lstm_df[sim_col].dropna().values
        sim_pmf = stn_data.build_pmf_from_timeseries(sim_vals, stn_data.min_measurable_log_uar, stn_data.target_da)
        sim_cdf = np.cumsum(sim_pmf)
        sim_fdc = 1 - sim_cdf
        ci, ce = int(np.argmax(sim_fdc < 1.0)- 1), np.argmax(sim_fdc == np.min(sim_fdc)) + 1
        # get the percentile values of the simulated flows
        sim_fdcs[f'LSTM Simulation {i+1}'] = sim_fdc
        fdc_plot.line(sim_fdc[ci:ce], baseline_lin_grid[ci:ce], color='grey', alpha=0.5, 
                      legend_label=f'LSTM Simulation')

    # compute the temporal ensemble mean FDC
    temporal_mean_ts = lstm_df[sim_cols].mean(axis=1).dropna().values
    temporal_mean_pmf = stn_data.build_pmf_from_timeseries(temporal_mean_ts, stn_data.min_measurable_log_uar, stn_data.target_da)
    temporal_mean_cdf = np.cumsum(temporal_mean_pmf)
    time_ensemble_fdc = 1 - temporal_mean_cdf
    # get the index of the first value equal to 1.
    # keep = np.r_[True, ~np.isclose(np.diff(time_ensemble_fdc), 0.0)]
    ci, ce = int(np.argmax(time_ensemble_fdc < 1.0) - 1), np.argmax(time_ensemble_fdc == np.min(time_ensemble_fdc)) + 1
    # time_ensemble_fdc = np.percentile(temporal_mean_fdc, baseline_lin_grid)[::-1]
    fdc_plot.line(time_ensemble_fdc[ci:ce], 
                  baseline_lin_grid[ci:ce], color='green', 
                  line_width=2.5, line_dash='dashed',
                  legend_label=f'LSTM Time')
    
    # compute the frequency ensemble mean FDC
    freq_ensemble_fdc = sim_fdcs.mean(axis=1).values
    ci, ce = int(np.argmax(freq_ensemble_fdc < 1.0) - 1), np.argmax(freq_ensemble_fdc == np.min(freq_ensemble_fdc)) + 1
    fdc_plot.line(freq_ensemble_fdc[ci:ce], 
                  baseline_lin_grid[ci:ce], 
                  color='green', line_width=2.5, 
                  legend_label=f'LSTM Dist.')
    
    clrs = ['orange', 'purple']
    lss = ['solid', 'dashed', 'dotted']
    cols = list(pmf_dfs.keys())
    fdc_metrics = {}
    fdc_metrics['distEnsemble'] = float((freq_ensemble_fdc - obs_fdc).sum())
    fdc_metrics['timeEnsemble'] = float((time_ensemble_fdc - obs_fdc).sum())
    for i, lb in enumerate(['LN', 'KNN']):
        model_cols = [c for c in cols if c.startswith(lb)]
        for j, mc in enumerate(model_cols):
            mc_pmf = np.array(pmf_dfs[mc].copy())
            low_p = float(np.sum(mc_pmf[:stn_data.zero_bin_index]))
            mc_pmf[:stn_data.zero_bin_index] = 0.
            mc_pmf[stn_data.zero_bin_index] = low_p
            mc_pmf /= np.sum(mc_pmf)
            assert np.isclose(np.sum(mc_pmf), 1.0), f'Model PMF for {mc} does not sum to 1 (sums to {np.sum(mc_pmf)})'
            # compute the cdf
            mc_cdf = np.cumsum(mc_pmf) 
            # interpolate between the percentiles to get the 1, 99 percentile values
            fdc = 1 - mc_cdf
            # keep = np.r_[True, ~np.isclose(np.diff(obs_fdc), 0.0)]
            ci, ce = int(np.argmax(fdc < 1.0) - 1), np.argmax(fdc == np.min(fdc)) + 1
            # print(fdc[ci], fdc[min(len(fdc)-1, ce)])
            # assert ce < ci, f'Invalid indices for FDC plotting: ci={ci}, ce={ce}'
            # model_vals = np.interp(pcts, mc_cdf, baseline_lin_grid)
            # fdc_metrics[f'{mc}'] = float((model_vals - obs_fdc).sum())
            fdc_plot.line(fdc[ci:ce], baseline_lin_grid[ci:ce], 
                          line_color=clrs[i], line_dash=lss[j],
                          legend_label=f'{mc}', line_width=2.5)
            # fdc_plot.line(pcts, model_vals[::-1], line_color=clrs[i], line_dash=lss[j],
            #         legend_label=f'{mc}', line_width=2.5)
    
    fdc_plot.xaxis.axis_label = 'Exceedance Probability (%)'
    fdc_plot.legend.background_fill_alpha = 0.5
    fdc_plot.add_layout(fdc_plot.legend[0], 'left')
    fdc_plot.yaxis.axis_label = 'Unit Area Runoff (L/s/km2)'
    fdc_plot.legend.location = 'top_right'
    fdc_plot.legend.click_policy = 'hide'
    
    return fdc_plot, None#fdc_metrics

In [None]:
def check_if_station_in_hydat(stn, conn):
    # Query to check if the station exists
    check_query = """
    SELECT STATION_NUMBER, STATION_NAME
    FROM STATIONS
    WHERE STATION_NUMBER = ?
    """
    # Run the query
    station_check = pd.read_sql_query(check_query, conn, params=(stn,))

    # Test if any result was returned
    if station_check.empty:
        return False
    else:
        return True
    

def query_data_symbols(stn):
    hydat_path = Path('/home/danbot/code/common_data/HYDAT') / 'Hydat_20250715.sqlite3'
    # Connect to the database
    conn = sqlite3.connect(hydat_path)
    # Query all data symbols
    query = "SELECT SYMBOL_ID, SYMBOL_EN FROM DATA_SYMBOLS ORDER BY SYMBOL_ID"
    try:
        df = pd.read_sql_query(query, conn)
        return df.set_index('SYMBOL_ID')['SYMBOL_EN']
    except Exception as ex:
        print(ex)
        return pd.DataFrame()


def reshape_hydat_wide(df):
    # First, ensure all FLOW_SYMBOL columns exist and are named correctly
    id_vars = ["STATION_NUMBER", "YEAR", "MONTH", "NO_DAYS"]
    
    # Melt flows
    flow_df = df.melt(id_vars=id_vars, 
                      value_vars=[f"FLOW{i}" for i in range(1, 32)],
                      var_name="day", value_name="flow")

    # Melt symbols
    sym_df = df.melt(id_vars=id_vars, 
                     value_vars=[f"FLOW_SYMBOL{i}" for i in range(1, 32)],
                     var_name="day", value_name="flow_symbol")

    # Extract day number
    flow_df["day"] = flow_df["day"].str.extract(r"(\d+)$").astype(int)
    sym_df["day"] = sym_df["day"].str.extract(r"(\d+)$").astype(int)

    # Merge on ID columns + day
    merged = pd.merge(flow_df, sym_df, on=id_vars + ["day"])

    # Construct date
    merged["date"] = pd.to_datetime(dict(year=merged["YEAR"], 
                                         month=merged["MONTH"], 
                                         day=merged["day"]), errors='coerce')

    # Filter out invalid days (e.g., day > NO_DAYS)
    merged = merged[merged["day"] <= merged["NO_DAYS"]]
    formatted_df = merged[["STATION_NUMBER", "date", "flow", "flow_symbol"]].dropna(subset=["flow"])
    formatted_df.set_index('date', inplace=True)
    return formatted_df


def query_hydat_database(stn):
    """Query the HYDAT database for a given station and date range."""
    hydat_path = Path('/home/danbot/code/common_data/HYDAT') / 'Hydat_20250715.sqlite3'
    # Connect to the database
    conn = sqlite3.connect(hydat_path)

    station_in_hydat = check_if_station_in_hydat(stn, conn)
    if station_in_hydat is False:
        print(f'Station {stn} not found in HYDAT database.')
        return pd.DataFrame()
    
    base_columns = ["STATION_NUMBER", "YEAR", "MONTH"]
    flow_columns = [f"FLOW{i}, FLOW_SYMBOL{i}" for i in range(1, 32)]
    end_columns = ["NO_DAYS"]

    all_columns = base_columns + flow_columns + end_columns
    column_str = ",\n    ".join(all_columns)

    query = f"""
    SELECT
        {column_str}
    FROM DLY_FLOWS
    WHERE STATION_NUMBER = ?
    ORDER BY YEAR, MONTH;
    """
    df = pd.read_sql_query(query, conn, params=(stn,))

    if df.empty:
        print(f'No data found for {stn} in HYDAT.')
        return pd.DataFrame()
    df = reshape_hydat_wide(df)
    return df


def find_symbol_segments(symbol_df, target_symbol):
    """Return (start, end) date pairs for each continuous period of target_symbol."""
    # Filter for matching symbol only
    mask = (symbol_df['flow_symbol'] == target_symbol)
    dates = symbol_df['flow_symbol'].index[mask]

    if dates.empty:
        return []

    # Compute gaps in days between successive dates
    gaps = dates.to_series().diff().gt(pd.Timedelta(days=1)).fillna(True)

    # Group by contiguous regions (cumsum creates a new group after each gap)
    group_ids = gaps.cumsum()

    # Group by group ID and extract start and end of each contiguous block
    segments = [(group.min(), group.max()) for _, group in dates.to_series().groupby(group_ids)]

    return segments

In [None]:
def plot_quality_flag_periods(stn, df, hydat_df, quality_symbols, runoff_plot, obs_col):
    symbol_dict = quality_symbols.to_dict()
    symbol_colors = {
        'B': 'dodgerblue',  # Baseflow
        'D': 'firebrick',  # Dry weather flow
        'E': 'orange',  # Estimated
    }
    df['flow_symbol'] = hydat_df['flow_symbol'].reindex(df.index, method=None)
    uar_cols = obs_col + [c for c in df.columns if c.startswith('streamflow_sim')]

    for symbol in ['B', 'D', 'E']:
        description = symbol_dict.get(symbol, {})
        color = symbol_colors.get(symbol, 'gray')
        n_symbols = df['flow_symbol'].eq(symbol).sum()
        if n_symbols == 0:
            continue

        segments = find_symbol_segments(df[['flow_symbol']].copy(), symbol)

        for start, end in segments:
            runoff_plot.varea(
                x=pd.date_range(start, end),
                y1=0.98 * df[uar_cols].min().min(),  # get the min of the dataframe for the lower bound
                y2=1.02 * df[uar_cols].max().max(),  # a bit above max for visibility
                fill_color=color, fill_alpha=0.3,
                legend_label=f"{description} ({symbol})"
            )
    
    df['flow'] = hydat_df['flow'].reindex(df.index, method=None)
    df['hydat_uar'] = 1000 * df['flow'] / da_dict[stn]  # convert to unit area runoff (L/s/km2)
    runoff_plot.line(df.index, df['hydat_uar'],
                     color='dodgerblue', legend_label='HYDAT UAR', 
                     line_width=2, line_dash='dotted')
    return runoff_plot

In [None]:
def plot_zero_flow_periods(stn, df, runoff_plot, obs_col):

    color = 'salmon'
    n_symbols = (df['zero_flow_flag'] == True).sum()
    if n_symbols == 0:
        return runoff_plot

    df.rename({'zero_flow_flag': 'flow_symbol'}, inplace=True, axis=1)

    segments = find_symbol_segments(df[['flow_symbol']].copy(), True)
    for start, end in segments:
        runoff_plot.varea(
            x=pd.date_range(start, end),
            y1=0.98 * df[f'{stn}_uar'].min(),  # get the min of the dataframe for the lower bound
            y2=1.02 * df[f'{stn}_uar'].max(),  # a bit above max for visibility
            fill_color=color, fill_alpha=0.3,
            legend_label=f"Q=0 replaced"
        )
    runoff_plot.line(df.index, df[f'{stn}_uar'],
                     color='purple', legend_label='UAR', 
                     line_width=2, line_dash='dotted')
    return runoff_plot

In [None]:
def process_FDCs(stn, og_df, lstm_df, result_folder, station_data):
    pmf_dfs, por_metrics, other_metrics = {}, {}, {}
    # load the parametric result
    parametric_fpath = result_folder / 'parametric' / f'{stn}_fdc_results.json'
    # if not os.path.exists(parametric_fpath):
    #     raise xception(f'Parametric results for {stn} not found at {parametric_fpath}')
    with open(parametric_fpath, 'r') as f:
        parametric_results = json.load(f)
    other_metrics[f'LN_MLE'] = parametric_results['MLE']['eval']
    other_metrics[f'LN_PredictedLog'] = parametric_results['PredictedLog']['eval']
    other_metrics[f'LN_PredictedMOM'] = parametric_results['PredictedMOM']['eval']
    pmf_dfs['LN_MLE'] = parametric_results['MLE']['pmf']
    pmf_dfs['LN_PredictedLog'] = parametric_results['PredictedLog']['pmf']
    pmf_dfs['LN_PredictedMOM'] = parametric_results['PredictedMOM']['pmf']

    knn_fpath = result_folder / 'knn' / f'{stn}_fdc_results.json'
    with open(knn_fpath, 'r') as f:
        knn_results = json.load(f)
    knn_cols = list(knn_results.keys())
    nn2 = [c for c in knn_cols if '2_NN_attribute_dist_ID2_freqEnsemble' in c]
    nn4 = [c for c in knn_cols if '4_NN_attribute_dist_ID2_freqEnsemble' in c]
    nn8 = [c for c in knn_cols if '8_NN_attribute_dist_ID2_freqEnsemble' in c]

    nn2_pmf = knn_results[nn2[0]]['pmf']
    nn4_pmf = knn_results[nn4[0]]['pmf']
    nn8_pmf = knn_results[nn8[0]]['pmf']
    pmf_dfs['KNN_2_NN'] = nn2_pmf
    pmf_dfs['KNN_4_NN'] = nn4_pmf
    pmf_dfs['KNN_8_NN'] = nn8_pmf

    other_metrics[f'KNN_2_NN'] = knn_results[nn2[0]]['eval']
    other_metrics[f'KNN_4_NN'] = knn_results[nn4[0]]['eval']
    other_metrics[f'KNN_8_NN'] = knn_results[nn8[0]]['eval']

    # kde = KDEEstimator(baseline_log_edges)
    # zero_equiv_uar = 1000 * 0.01 / da_dict[stn]  # 0.01 m3/s to L/s/km2
    eval_object = EvaluationMetrics(log_x=station_data.log_x_extended, bitrate=bitrate, min_measurable_log_uar=station_data.min_measurable_log_uar,)
    print(f'    Processing FDCs for {stn}')
    
    years = []
    for date in ['20250514', '20250627']:
        pmf_columns = []
        years = lstm_df.index.year.unique()
        da = da_dict[stn]
        pmf_folder = Path('data/results/additional_results/lstm_pmfs/')
        os.makedirs(pmf_folder) if not os.path.exists(pmf_folder) else None
        pmf_file = f'POR_{stn}_pmfs_{len(years)}_years_{date}.csv'
        pmf_fpath = pmf_folder / pmf_file

        # compute observed POR PMFs 
        obs_cols = [c for c in lstm_df.columns if c.startswith('streamflow_obs') and c.endswith(date)]
        assert len(obs_cols) == 1, f'Expected exactly one observed column for {stn} on {date}, found {len(obs_cols)}'
        por_obs_vals = lstm_df[obs_cols[0]].dropna().values
        # print(f'{stn} POR obs: {min(por_obs_vals):.2f} - {max(por_obs_vals):.2f}')
        # por_obs_pmf, _ = station_data.kde.compute(por_obs_vals, da=da)
        por_obs_pmf = station_data.build_pmf_from_timeseries(por_obs_vals, station_data.min_measurable_log_uar, da)
        assert len(por_obs_pmf) == len(baseline_log_grid), f'PMF length mismatch for {stn} on {date} ({len(por_obs_pmf)} vs {len(baseline_log_grid)})'
        pmf_columns.append(pd.Series(por_obs_pmf, index=baseline_log_grid, name=f'POR_obs_{date}'))

        # compute simulated POR PMFs
        sim_cols = [c for c in lstm_df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
        # temporal mean of the simulated ensemble
        sim_vals = lstm_df[sim_cols].mean(axis=1).dropna().values
        # sim_pmf, _ = station_data.kde.compute(sim_vals, da=da) # temporal mean ensemble PMF
        sim_pmf = station_data.build_pmf_from_timeseries(sim_vals, station_data.min_measurable_log_uar, da)
        assert len(sim_pmf) == len(baseline_log_grid), f'PMF length mismatch for {stn} on {date}'
        pmf_columns.append(pd.Series(sim_pmf, index=baseline_log_grid, name=f'POR_sim_timeEnsemble_{date}'))

        frequency_sim_pmfs = compute_ensemble_pmfs(lstm_df, sim_cols, station_data)
        mean_pmf = frequency_sim_pmfs.mean(axis=1).rename('sim_distEnsemble_mean')
        mean_pmf /= mean_pmf.sum()  # renormalize the PMF
        pmf_columns.append(pd.Series(mean_pmf, index=baseline_log_grid, name=f'POR_sim_distEnsemble_{date}'))

        og_vals = og_df[f'{stn}_uar'].dropna().values
        # print(f'{stn}_uar: {min(og_vals):.2f} - {max(og_vals):.2f}')
        # og_pmf, _ = station_data.kde.compute(og_vals, da=da)
        og_pmf = station_data.build_pmf_from_timeseries(og_vals, station_data.min_measurable_log_uar, da)
        pmf_columns.append(pd.Series(og_pmf, index=baseline_log_grid, name=f'{stn}_uar'))
        por_pmfs = pd.concat(pmf_columns, axis=1)    # do the same for the original timeseries
        pmfs = pd.concat([por_pmfs, frequency_sim_pmfs], axis=1)  # initialize PMFs DataFrame with observed PMFs
        pmfs.index.name = 'log_uar'  # set the index name for clarity
        # save the PMFs to a CSV file
        pmfs.to_csv(pmf_fpath)
        pmf_dfs[date] = pmfs

        # evaluate metrics
        for col in [f'POR_sim_timeEnsemble_{date}', f'POR_sim_distEnsemble_{date}']:
            pmf = pmfs[col].values
            por_metrics[col] = eval_object._evaluate_fdc_metrics_from_pmf(pmf, por_obs_pmf, min_measurable_log_uar=station_data.min_measurable_log_uar)

        for col in frequency_sim_pmfs.columns:
            pmf = frequency_sim_pmfs[col].values
            por_metrics[col] = eval_object._evaluate_fdc_metrics_from_pmf(pmf, por_obs_pmf, min_measurable_log_uar=station_data.min_measurable_log_uar)

    mdf = pd.DataFrame(por_metrics).T
    odf = pd.DataFrame(other_metrics).T
    # save to csv
    metric_fpath = f'data/results/lstm_metrics/{stn}_{len(years)}_years_metrics.csv'
    mdf.to_csv(metric_fpath)
    return mdf, odf, pmf_dfs


### Streamflow monitoring stations that we found to be regulated or influenced by regulation by looking at results plots

A number of stations in the large sample were found to be regulated or influenced by regulation by looking at results plots.  These are listed below with notes.  These stations should be excluded from the sample for future analysis unless the particular application calls for time series of regulated streams.

* 12398000 - Sullivan Lake upstream is a reservoir
* 12058800 - Dam!

### Other anomalous conditions

* 12143700 -  *Small catchment adjacent to dam* (USGS.gov) No regulation or diversion upstream from station. **Flow is mostly seepage from Chester Morse Lake.** U.S. Geological Survey satellite telemeter at station.  This represnts a very specific scenario that depends upon hydraulic conditions in neighbouring catchments but also the operation policy if the neighbouring catchment contains a large reservoir. 
* 12143900 -  *Just downstream from 12143700* Heavily affected by the same condition.


In [None]:
model_label_styles = {
    "timeEnsemble":       ("#2ca25f", "white"),  # green
    "distEnsemble":       ("#2ca25f", "white"),
    "LN MLE":             ("#fdae6b", "black"),  # orange
    "LN_PredictedLog":    ("#fdae6b", "black"),
    "LN_PredictedMOM":    ("#fdae6b", "black"),
    "KNN_2_NN":           ("#756bb1", "white"),  # purple
    "KNN_4_NN":           ("#756bb1", "white"),  # purple
    "KNN_8_NN":           ("#756bb1", "white"),
}


def key_mapper(mod,met):
    
    if met == 'pct_vol_bias':
        met = 'PB'
    met = met.upper()
    if mod.startswith('time'):
        key = f'LSTM time {met}'
    elif mod.startswith('dist'):
        key = f'LSTM dist. {met}'
    elif 'MOM' in mod:
        key = f'LN MoM {met}'
    elif 'PredictedLog' in mod:
        key = f'LN Direct {met}'
    elif 'MLE' in mod:
        key = f'MLE {met}'
    elif mod.endswith('_NN'):
        k = mod.split('_')[1]
        key = f'{k} kNN {met}'
    else:
        raise Exception(f'{mod} {met} combo not found')
    return key


def plot_runoff_timeseries(stn, lstm_df, og_df, date):

    obs_col = [c for c in lstm_df.columns if '_obs_' in c and c.endswith(date)]
    assert len(obs_col) == 1, f'Expected exactly one observed column for {stn}, found {len(obs_col)} {obs_col}'
    sim_cols = [c for c in lstm_df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
    assert len(sim_cols) == 10, f'Expected ten simulation columns for {stn}, found {len(sim_cols)}'
    # plot th time series of the observed values
    runoff_plot = figure(title=f'{stn} Observed Unit Area Runoff', x_axis_type='datetime',
                         width=1000, height=300, y_axis_type='log', toolbar_location='above')
    
    hydat_df = query_hydat_database(stn)
    quality_symbols = query_data_symbols(stn)

    # reindex to daily frequency and keep nans
    df = lstm_df.copy().reindex(pd.date_range(start=lstm_df.index.min(), end=lstm_df.index.max(), freq='D'))
    og_df['zero_flow_flag'] = False

    if (not hydat_df.empty) & (not quality_symbols.empty):
        runoff_plot = plot_quality_flag_periods(stn, df, hydat_df, quality_symbols, runoff_plot, obs_col)

    # plot zero flow segments on the runoff_plot
    runoff_plot = plot_zero_flow_periods(stn, og_df, runoff_plot, obs_col)
    # plot the adjusted line
    runoff_plot.line(df.index, df[obs_col[0]], color='dodgerblue',
                     legend_label='Observed UAR', line_width=2.)
    for col in sim_cols:
        runoff_plot.line(df.index, df[col], color='grey', alpha=0.5,
                         legend_label=f'LSTM ensemble')
    # compute the temporal mean of the simulated ensemble
    mean_sim = df[sim_cols].mean(axis=1)
    runoff_plot.line(df.index, mean_sim, color='green', legend_label='Ensemble Mean', line_width=3, line_dash='dashed')
    runoff_plot.xaxis.axis_label = 'Date'
    runoff_plot.yaxis.axis_label = 'Unit Area Runoff (L/s/km2)'
    runoff_plot.legend.location = 'top_left'
    runoff_plot.legend.click_policy = 'hide'
    runoff_plot.add_layout(runoff_plot.legend[0], 'right')
    runoff_plot.legend.background_fill_alpha = 0.65
    return runoff_plot
    

def format_metrics_table(metric_df, odf, stn, date, fdc_metrics):
    # Prep LSTM metrics
    metric_df.index.name = 'series'
    metric_df = metric_df.reset_index()
    
    lstm_df = metric_df[metric_df['series'].str.endswith(date)].copy()
    metric_cols = [c for c in lstm_df.columns if c not in ['series', 'date']]    
    time_df = lstm_df[lstm_df['series'].str.contains('timeEnsemble')][metric_cols]
    freq_df = lstm_df[lstm_df['series'].str.contains('distEnsemble')][metric_cols]
    # sim_df = lstm_df[lstm_df['series'].str.startswith('streamflow_sim_')][metric_cols]
    # sim_perc = pd.DataFrame(np.percentile(sim_df.values, [5, 95], axis=0), index=['5%', '95%'], columns=metric_cols)
    
    table = pd.concat([time_df, freq_df])#, sim_perc])
    table.index = ['timeEnsemble', 'distEnsemble']#, '5%', '95%']

    # "pct_vol_bias": float(pct_vol_bias), # this is PBIAS (labeled RB in some notebooks)
    # "mean_abs_pct_error": float(mean_abs_pct_error), # this is MAPE
    # "norm_abs_error": float(np.mean(norm_abs_err)),
    # "rmse": float(rmse), 
    # "nse": float(nse), 
    # "kge": float(kge),
    # "ve": float(ve),
    # "pb_50": float(pinball_loss_50),
    # "kld": float(kld),
    # "emd": float(emd),
    
    # Prep ODF
    mapper = {'pct_vol_bias': 'pb', 'mean_abs_pct_error': 'mape',
              'norm_abs_error': 'nae', }
    odf = odf.rename(columns=mapper)
    table = table.rename(columns=mapper)

    max_cols = ['nse', 'kge', 've']
    min_abs_cols = ['pb', 'mape', 'nae',]

    # Combine and compute percentiles
    df = pd.concat([table, odf], axis=0)

    # df.drop(labels=['median_error'], axis=1, inplace=True)
        
    colors = [
        "#2166ac", "#4393c3", "#92c5de", "#d1e5f0", "#f7f7f7",
        "#fddbc7", "#f4a582", "#d6604d", "#b2182b", "#67001f"
    ]
    # Build HTML table with row labels
    html = '<table style="width:100%; border-collapse:collapse; margin-top:25px;">'
    html += '<thead><tr>'
    table_cols = {'pb': 'PB', 'mape': 'MAPE', 'nae': 'NAE'}
    col_order = ['pb', 'nae', 'mape', 'rmse', 'kld', 'emd', 'nse', 'kge']#, 'pmf', 'fdc']
    # 'rb', 'mb', 'mean_abs_rel_error', 'rmse', 'nse', 'kge', 've', 'pb_50',
    #    'kld', 'emd']
    df = df[col_order]
    for c in df.columns:
        label = c
        if c in table_cols:
            label = table_cols[c]
        html += f'<th>{label}</th>'

    html += '<th style="text-align:left;">Model</th></tr></thead><tbody>'
    
    for model, row in df.iterrows():
        html += f'<tr>'
        for metric in df.columns:
            val = df.at[model, metric]

            key = key_mapper(model, metric)
            assert key in all_results_df.columns, f'{key} not found in allresults df columns: {list(all_results_df.columns)}'
            
            global_vals = all_results_df[key].dropna().values
            assert not len(global_vals) == 0, f'Global vals is empty: key={key}, {list(all_results_df.columns)}' 
            minv, maxv = np.min(global_vals), np.max(global_vals)
            
            if metric in min_abs_cols:
                rank = np.searchsorted(np.sort(np.abs(global_vals)), abs(val), side="right")
            else:
                global_vals = np.sort(global_vals)
                rank = np.searchsorted(global_vals, val, side="right")
            
            pct = rank / len(global_vals)
            if metric in max_cols:
                pct = 1 - pct
            idx = min(math.floor(pct * 10), 9)
            bg = colors[idx]
            font = "white" if idx in {0, 1, 8, 9} else "black"
            
            # html += f'<td style="background:{color};text-align:center;"><span style="color:{font};">{row[c]:.2f}</span></td>'
            html += f'<td style="background:{bg}; text-align:center; padding:4px 6px;"><span style="color:{font};">{row[metric]:.2f}</span></td>'
        
        row_bg, row_font = model_label_styles.get(model, ("#ffffff", "black"))
        html += f'<td style="background:{row_bg}; color:{row_font}; text-align:left; font-weight:bold;padding:4px 4px;">{model}</td>'
        # html += f'<td style="text-align:left; font-weight:bold; padding:4px 4px;">{model}</td>'
    html += '</tbody></table>'
    # Add color legend
    legend_html = '<table style="border-collapse:collapse; margin-bottom:4px;"><caption>Rank percentile colour map (N=720):</caption><tr>'
    for i, color in enumerate(colors):
        pct_label = f"{(i+1)*10}%"
        font = "white" if i in {0, 1, 8, 9} else "black"
        legend_html += f'<td style="background:{color}; padding:4px 4px; text-align:center;"><span style="color:{font}">{pct_label}</span></td>'
    legend_html += '</tr></table>'

    return Div(text=html), Div(text=legend_html)

In [None]:
# stn = '12143900'
# og_df = get_original_timeseries(stn, ds)
# lstm_ensemble_df = filter_by_complete_years(stn, lstm_result_base_folder)
# # boxley_df = retrieve_timeseries_discharge('12143700', ds)
# boxley_df = retrieve_timeseries_discharge('12143900', ds)
# # fig = figure(width=800, height=350, x_axis_type='datetime')
# # fig.line(boxley_df.index, boxley_df['12143700_uar'], line_width=2, color='blue', legend_label='Boxley Creek UAR')
# date = '20250514'
# ts = plot_runoff_timeseries(stn, lstm_ensemble_df, og_df, date)
# ts.line(boxley_df.index, boxley_df['12143900_uar'], line_width=2, color='purple',
#         line_dash='dashed', legend_label='12143900 UAR')
# # change the legend labels
# # Rename a legend label
# for legend in ts.legend:
#     for item in legend.items:
#         if item.label['value'] == "Observed UAR":
#             print(item.label)
#             item.label.value = "12143700 UAR"
#             # item.label['value'] = "12143700 UAR"
#             # break # Exit inner loop once found
#         elif item.label['value'] == "Observed UAR":
#             item.label.value = "12143900 UAR"
#             # item.label['value'] = "12143900 UAR"
#             # break # Exit inner loop once found

# show(ts)

Above we can see the simulated flows are dramatically different from observed. As it turns out, this catchment is heavily influenced by seepage from Chester Morse Lake, which is a reservoir on the adjacent catchment.  The flow here is not natural and depends upon hydraulic conditions in the neighbouring catchment but also the reservoir operation policy.  This represents a very specific scenario that is not representative of most catchments and should be excluded from the sample.

In [None]:
def process_diagnostic_plot_layout(stn_data):
    output_folder = BASE_DIR / 'data' / 'results' /  'lstm_plots'
    result_folder = BASE_DIR / 'data' / 'results' / f'fdc_estimation_results_{bitrate:02d}_bits'
    if regularization_type == 'kde':
        result_folder = BASE_DIR / 'data' / 'results' /'fdc_estimation_results_kde'

    stn = stn_data.target_stn
    og_df = stn_data.hyd_df.copy()
    lstm_ensemble_df = filter_LSTM_outputs(stn_data, lstm_result_base_folder)

    if lstm_ensemble_df.empty:
        print(f'No complete years found for {stn}. Skipping.')
        return Div(text=f'<h2>Insufficient records found for {stn}. Skipping.</h2>', width=800)
    
    og_df = og_df[og_df.index.isin(lstm_ensemble_df.index)]
    mdf, odf, pmf_dfs = process_FDCs(stn, og_df, lstm_ensemble_df, result_folder, stn_data)

    dates = list(pmf_dfs.keys())
    pdf_plots, fdc_plots, metric_tables, other_tables = [], [], [], []
    date = '20250514'
    
    pdf_plot, baseline_log_grid, baseline_lin_grid, log_w = plot_observed_and_simulated_pdf(stn, pmf_dfs, og_df, date, stn_data, pdf_plots=pdf_plots)
    pdf_plots.append(pdf_plot)

    min_measurable = stn_data.min_measurable_uar
    assert min_measurable < 1e3, f'Min measurable UAR too high: {min_measurable}'
    
    fdc_plot, fdc_metrics = plot_observed_and_simulated_fdc(stn, pmf_dfs, baseline_lin_grid, lstm_ensemble_df, og_df, date, stn_data, fdc_plots=fdc_plots)
    fdc_plots.append(fdc_plot)
    
    ts_plot = plot_runoff_timeseries(stn, lstm_ensemble_df, og_df, dates[-1])
    
    metric_table, legend_html = format_metrics_table(mdf, odf, stn, date, fdc_metrics)

    table_div = column(
        metric_table,
        legend_html,
    )
    notes_div = Div(text=notes_html, width=1000)
    layout = column(
        row(fdc_plots[0], table_div),
        row(ts_plot), 
        row(pdf_plots[0]),
        row(notes_div),
        )
    return layout

In [None]:
global_min_uar = 5e-5   # see Notebook 1: data
global_max_uar = 1e4    # see Notebook 1: data


np.random.seed(42)

target_cols = [
    'mean_uar', 'sd_uar', 
    'mean_logx', 'sd_logx', 
]

regularization_type = 'discrete'#'kde'  # 'discrete' or 'kde'
# bitrate = 10 if regularization_type == 'discrete' else 10

# set paths for catchment attributes and meteorological forcing data
attr_df_fpath = os.path.join('data', f'catchment_attributes_with_runoff_stats.csv')
LSTM_forcings_folder = '/home/danbot/neuralhydrology/data/BCUB_catchment_mean_met_forcings_20250320'
baseline_distribution_folder = os.path.join('data', 'baseline_distributions', f'{bitrate:02d}_bits')
# parameter_prediction_results_folder = os.path.join('data', 'parameter_prediction_results')

parameter_prediction_results_folder = os.path.join('data', 'results', 'parameter_prediction_results', )
predicted_params_fpath   = os.path.join(parameter_prediction_results_folder, 'OOS_parameter_predictions.csv')
rdf = pd.read_csv(predicted_params_fpath, index_col=['official_id'], dtype={'official_id': str})
predicted_param_dict = rdf.to_dict(orient='index')

# load the pre-computed dictionary of complete years of record for each station
complete_year_stats_fpath = os.path.join('data', 'complete_year_stats.npy')
complete_year_stats = np.load(complete_year_stats_fpath, allow_pickle=True).item()

meet_min_hyd_years = []
for stn in complete_year_stats.keys():
    if len(complete_year_stats[stn]['hyd_years']) >= 5:
        meet_min_hyd_years.append(stn)
    else:
        print(f'Station {stn} has {len(complete_year_stats[stn]["hyd_years"])} complete hydrological years of data.')

# set which (subset of) methods to run
methods = ('knn','parametric', 'lstm',)
k_nearest = 10
include_pre_1980_data = True  # use only stations with data 1980-present concurrent with Daymet
daymet_start_date = '1980-01-01'  # default start date for Daymet data
if include_pre_1980_data:
    daymet_start_date = '1950-01-01'


input_data = {
    'attr_df_fpath': attr_df_fpath,
    'LSTM_forcings_folder': LSTM_forcings_folder,
    'LSTM_ensemble_result_folder': LSTM_ensemble_result_folder,
    'include_pre_1980_data': include_pre_1980_data,  # use only stations with data 1980-present concurrent with Daymet
    'predicted_param_dict': predicted_param_dict,
    'eps': 1e-12,
    'min_record_length': 5, # minimum record length (years)
    'minimum_days_per_month': 20, # minimum number of days with valid data per month
    'parametric_target_cols': target_cols,
    'all_station_ids': daymet_concurrent_stations,
    'baseline_distribution_folder': baseline_distribution_folder,
    'delta': 0.001, # maximum uncertainty (by KL divergence) added to the predicted PMF by the uniform mixture ratio
    'regularization_type': regularization_type, # use 'kde' or 'discrete'.  if discrete, bitrate must be specified
    'bitrate': bitrate,
    'complete_year_stats': complete_year_stats,
    'year_type': 'hydrological',  # 'calendar' or 'hydrological'
    'zero_flow_threshold': 1e-4,  # threshold below which flow is indistinguishable from zero
    'global_min_uar': global_min_uar,
    'global_max_uar': global_max_uar,
}

context = FDCEstimationContext(**input_data)

In [None]:
# retrieve LSTM ensemble predictions
# filter for the common stations
common_stations = list(set(station_ids) & set(lstm_result_stns))
stations_to_process = ['08AA008', '12090500', '12102190', '12115700', '12115800',
       '12157025', '12201960', '10CD003', '10CD004', '12447383',
       '08HB075', '10ED009', '08HA026', '12202310', '10CD005', '12202420',
       '12210900', '12193500', '12036650', '07BB003']

# print(f'There are {len(common_stations)} monitored basins with LSTM ensemble results.')
attr_df = attr_df[attr_df['official_id'].isin(stations_to_process)]

output_folder = BASE_DIR / 'data' / 'results' /  'lstm_plots'
# result_folder = BASE_DIR / 'data' / 'results' / f'fdc_estimation_results_{bitrate:02d}_bits'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

processed_kde_folder = 'data/results/fdc_estimation_results_kde/knn'
processed_kde_result_files = [e.split('_')[0] for e in os.listdir(processed_kde_folder)]


# plots = []
excluded = ['12137800'] 
dam_sites = ['12398000', '12058800', '12143700', '12323760'] 
to_check = ['05BF018']
process_plots = True
if process_plots:
    for stn in processed_kde_result_files:
        # if os.path.exists(output_fname):
        #     print(f'Plot for {stn} already exists at {output_fname}. Skipping.')
        #     continue
        # save the plot to an HTML file
        if regularization_type == 'kde':
            output_fname = output_folder / f'{stn}_fdc_kde.html'
        else:
            output_fname = output_folder / f'{stn}_fdc_{bitrate:02d}_bits.html'
        if stn not in context.attr_gdf['official_id'].values:
            continue
        station_data = StationData(context, stn)
        station_data.kde = KDEEstimator(station_data.log_edges_extended)
        if stn in dam_sites:
            continue
        if stn in excluded:
            continue
        # if os.path.exists(output_fname):
        #     print(output_fname)
        #     continue
        layout = process_diagnostic_plot_layout(station_data)

        output_file(output_fname, title=f'{stn} FDCs')
        save(layout)
        print(f'    Saved plot for {stn} to {output_fname}')
    