In [1]:
import os
import pandas as pd
import numpy as np
from pathlib import Path
import xarray as xr

import json
from multiprocessing import Pool

import geopandas as gpd
from shapely.geometry import Point
import xyzservices.providers as xyz
from scipy.stats import linregress

from bokeh.plotting import figure, output_file, save, show
from bokeh.layouts import gridplot, row, column

# from bokeh.models import ColumnDataSource, LinearAxis, Range1d, HoverTool, Div
from bokeh.io import output_notebook
from bokeh.models import Div
# from bokeh.palettes import Sunset10, Vibrant7, Category20, Bokeh6, Bokeh7, Bokeh8, Greys256, Blues256

# from shapely.geometry import Polygon, Point
# from shapely.ops import unary_union
# from scipy.spatial import Voronoi

# from kde_estimator import KDEEstimator
from fdc_estimator_context import FDCEstimationContext
from fdc_data import StationData

import data_processing_functions as dpf

import xyzservices.providers as xyz
tiles = xyz['USGS']['USTopo']

output_notebook()

In [2]:
BASE_DIR = Path(os.getcwd())

attr_fpath = 'data/BCUB_watershed_attributes_updated_20250227.csv'
attr_df = pd.read_csv(attr_fpath, dtype={'Official_ID': str})
station_ids = sorted(attr_df['official_id'].unique().tolist())

# streamflow folder from (updated) HYSETS
HYSETS_DIR = Path('/home/danbot/code/common_data/HYSETS')
hs_df = pd.read_csv('data/HYSETS_watershed_properties.txt', sep=';')
hs_df = hs_df[hs_df['Official_ID'].isin(station_ids)]
hs_df.head(2)

Unnamed: 0,Watershed_ID,Source,Name,Official_ID,Centroid_Lat_deg_N,Centroid_Lon_deg_E,Drainage_Area_km2,Drainage_Area_GSIM_km2,Flag_GSIM_boundaries,Flag_Artificial_Boundaries,...,Land_Use_Wetland_frac,Land_Use_Water_frac,Land_Use_Urban_frac,Land_Use_Shrubs_frac,Land_Use_Crops_frac,Land_Use_Snow_Ice_frac,Flag_Land_Use_Extraction,Permeability_logk_m2,Porosity_frac,Flag_Subsoil_Extraction
846,847,HYDAT,CROWSNEST RIVER AT FRANK,05AA008,49.59732,-114.4106,402.6522,,0,0,...,0.0103,0.0065,0.0328,0.0785,0.0015,0.0002,1,-15.543306,0.170479,1
849,850,HYDAT,CASTLE RIVER NEAR BEAVER MINES,05AA022,49.48866,-114.1444,820.651,,0,0,...,0.0058,0.0023,0.0105,0.1156,0.0246,0.0,1,-15.929747,0.150196,1


In [3]:
# retrieve LSTM ensemble predictions
lstm_result_base_folder = Path('/home/danbot/code/neuralhydrology/data/')
results_revisions = ['20250514', '20250627']
lstm_result_files = os.listdir(lstm_result_base_folder / f'ensemble_results_{results_revisions[0]}')
lstm_result_stns = [e.split('_')[0] for e in lstm_result_files]

# filter for the common stations between BCUB region and LSTM-compatible (i.e. 1980-)
daymet_concurrent_stations = list(set(station_ids) & set(lstm_result_stns))
# assert '012414900' in daymet_concurrent_stations
print(f'There are {len(daymet_concurrent_stations)} monitored basins concurrent with LSTM ensemble results.')


There are 723 monitored basins concurrent with LSTM ensemble results.


In [4]:

# create dicts for easier access to 'official_id': 'drainage area', geometry, regulation status
da_dict = attr_df[['official_id', 'drainage_area_km2']].copy().set_index('official_id').to_dict()['drainage_area_km2']

In [5]:
watershed_id_dict = {row['Watershed_ID']: row['Official_ID'] for _, row in hs_df.iterrows()}
# and the inverse
official_id_dict = {row['Official_ID']: row['Watershed_ID'] for _, row in hs_df.iterrows()}
# also for drainage areas
da_dict = {row['Official_ID']: row['Drainage_Area_km2'] for _, row in hs_df.iterrows()}

In [6]:
def load_and_filter_hysets_data(station_ids, hs_df):
    hs_df = hs_df[hs_df['Official_ID'].isin(station_ids)]

    # load the updated HYSETS data
    updated_filename = 'HYSETS_2023_update_QC_stations.nc'
    ds = xr.open_dataset(HYSETS_DIR / updated_filename)

    # Get valid IDs as a NumPy array
    selected_ids = hs_df['Watershed_ID'].values

    # Get boolean index where watershedID in selected_set
    # safely access watershedID as a variable first
    ws_ids = ds['watershedID'].data  # or .values if you prefer
    mask = np.isin(ws_ids, selected_ids)

    # Apply mask to data
    ds = ds.sel(watershed=mask)
    # Step 1: Promote 'watershedID' to a coordinate on the 'watershed' dimension
    ds = ds.assign_coords(watershedID=("watershed", ds["watershedID"].data))

    # Step 2: Set 'watershedID' as the index for the 'watershed' dimension
    return ds.set_index(watershed="watershedID")


ds = load_and_filter_hysets_data(station_ids, hs_df)

In [7]:
def set_grid(global_min, global_max, n_grid_points=2**12):
    # self.baseline_log_grid = np.linspace(np.log(adjusted_min_uar), np.log(max_uar), self.n_grid_points)
    baseline_log_grid = np.linspace(
        global_min, global_max, n_grid_points
    )
    baseline_lin_grid = np.exp(baseline_log_grid)
    log_dx = np.gradient(baseline_log_grid)
    max_step_size = baseline_lin_grid[-1] - baseline_lin_grid[-2]
    # print(f'    max step size = {max_step_size:.1f} L/s/km^2 for n={n_grid_points:.1e} grid points')
    min_step_size = baseline_lin_grid[1] - baseline_lin_grid[0]
    # print(f'    min step size = {min_step_size:.2e} L/s/km^2 for n={n_grid_points:.1e} grid points')
    return baseline_lin_grid, baseline_log_grid, log_dx

_, baseline_log_grid, log_dx = set_grid(np.log(6e-6), np.log(5e4), n_grid_points=2**12)
print(baseline_log_grid[0], baseline_log_grid[-1], log_dx[0], log_dx[-1])

-12.02375108873622 10.819778284410283 0.005578395451317775 0.005578395451315998


In [17]:
# load the complete years previously processed
import json
with open('data/complete_years.json', 'r') as f:
    complete_year_dict = json.load(f)

In [37]:
from collections import defaultdict
from kde_estimator import KDEEstimator
from scipy.stats import wasserstein_distance
from jax import numpy as jnp

def compute_kl_divergence(p, q):
    """Compute the KL divergence between two probability distributions."""
    mask = (p > 0) & (q > 0)  # Avoid log(0)
    return jnp.sum(jnp.where(mask, p * jnp.log(p / q), 0.0))


def compute_emd(p, q, baseline_log_grid):
    assert np.isclose(np.sum(p), 1, atol=1e-3), f'sum P = {np.sum(p)}'
    assert np.all(q >= 0), f'min q_i < 0: {np.min(q)}'
    linear_grid = np.exp(baseline_log_grid)
    emd = wasserstein_distance(linear_grid, linear_grid, p, q)
    return float(round(emd, 4))#, {'bias': None, 'unsupported_mass': None, 'pct_of_signal': None}


def compute_nse(obs, sim):
    """Compute the Nash-Sutcliffe Efficiency (NSE) between observed and simulated values."""
    assert not np.isnan(obs).any(), f'NaN values in obs: {obs}'
    assert not np.isnan(sim).any(), f'NaN values in sim: {sim}'
    assert (obs >= 0).all(), f'Negative values in obs: {obs}'
    assert (sim >= 0).all(), f'Negative values in sim: {sim}'
    # Compute the NSE
    numerator = np.sum((obs - sim) ** 2)
    denominator = np.sum((obs - obs.mean()) ** 2)
    nse = 1 - (numerator / denominator)
    return nse


def compute_relative_error(obs, sim):
    """Compute the relative error between observed and simulated values."""
    assert not np.isnan(obs).any(), f'NaN values in obs: {obs}'
    assert not np.isnan(sim).any(), f'NaN values in sim: {sim}'
    assert (obs >= 0).all(), f'Negative values in obs: {obs}'
    assert (sim >= 0).all(), f'Negative values in sim: {sim}'
    # Compute the relative error
    return (obs - sim) / obs


def compute_RMSE(obs, sim):
    """Compute the Root Mean Square Error (RMSE) between observed and simulated values."""
    assert not np.isnan(obs).any(), f'NaN values in obs: {obs}'
    assert not np.isnan(sim).any(), f'NaN values in sim: {sim}'
    assert (obs >= 0).all(), f'Negative values in obs: {obs}'
    assert (sim >= 0).all(), f'Negative values in sim: {sim}'
    # Compute the RMSE
    return np.sqrt(np.mean((obs - sim) ** 2))


def compute_KGE(obs, sim):
    """Compute the Kling-Gupta Efficiency (KGE) between observed and simulated values."""
    assert not np.isnan(obs).any(), f"NaN values in obs: {obs}"
    assert not np.isnan(sim).any(), f"NaN values in sim: {sim}"
    assert (obs >= 0).all(), f"Negative values in obs: {obs}"
    assert (sim >= 0).all(), f"Negative values in sim: {sim}"
    if obs.size == 0:
        return np.nan
    # Compute the KGE
    r = np.corrcoef(obs, sim)[0, 1]
    alpha = sim.mean() / obs.mean()
    beta = sim.std() / obs.std()
    kge = 1 - np.sqrt((r - 1) ** 2 + (alpha - 1) ** 2 + (beta - 1) ** 2)
    return kge


def evaluate_fdc_metrics_from_pmf(pmf_est, baseline_pmf, baseline_log_grid):
    """
    Evaluate RMSE, relative error, NSE, and KGE between two FDCs represented by discrete PMFs.
    Note these are evaluated over the log_grid which is set in the context.

    Parameters
    ----------
    pmf_est : np.ndarray
        Discrete PMF representing the estimated FDC, over `log_grid`.
    log_grid : np.ndarray
        Grid of log-transformed flow values corresponding to PMF bins.

    Returns
    -------
    dict
        Dictionary of RMSE, RelativeError, NSE, and KGE computed over p=1,...,99 quantiles.
    """
    assert (
        len(baseline_pmf) == len(pmf_est) == len(baseline_log_grid)
    ), "Array length mismatch"

    # Convert log flow grid back to linear runoff space
    linear_grid = np.exp(baseline_log_grid)

    # Compute CDFs
    cdf_true = np.cumsum(baseline_pmf)
    cdf_true /= cdf_true[-1]
    cdf_est = np.cumsum(pmf_est)
    cdf_est /= cdf_est[-1]
    
    assert np.isfinite(cdf_true).all(), "Non-finite values in cdf_true"
    assert np.diff(cdf_true).sum() > 0, "cdf_true has no spread"

    # Percentiles (1 to 99)
    probs = np.linspace(0.01, 0.99, 99)

    # Interpolate inverse CDF (flow values at given probabilities)
    q_true = np.interp(
        probs, cdf_true, linear_grid, left=linear_grid[0], right=linear_grid[-1]
    )
    q_est = np.interp(
        probs, cdf_est, linear_grid, left=linear_grid[0], right=linear_grid[-1]
    )
    assert np.all(q_true > 0), "Zero or negative values in q_true — invalid for relative error"
    assert np.all(q_est > 0), "Zero or negative values in q_est — unexpected for flow"

    # Metrics
    rmse = np.sqrt(np.mean((q_true - q_est) ** 2))
    rel_error = np.mean(np.abs((q_est - q_true) / q_true))
    nse = compute_nse(q_true, q_est)
    kge = compute_KGE(q_true, q_est)

    kld = compute_kl_divergence(baseline_pmf, pmf_est)
    emd = compute_emd(baseline_pmf, pmf_est, baseline_log_grid)

    return {
        "FDC_RMSE": float(rmse), 
        "FDC_RelativeError": float(rel_error), 
        "FDC_NSE": float(nse), 
        "FDC_KGE": float(kge),
        "FDC_KLD": float(kld),
        "FDC_EMD": float(emd),
    }


def retrieve_timeseries_discharge(stn, ds):
    watershed_id = official_id_dict[stn]
    df = ds['discharge'].sel(watershed=str(watershed_id)).to_dataframe(name='discharge').reset_index()
    df = df.set_index('time')[['discharge']]
    df.dropna(inplace=True)
    # clip minimum flow to 1e-4
    df['discharge'] = np.clip(df['discharge'], 1e-4, None)
    df.rename(columns={'discharge': stn}, inplace=True)
    df[f'{stn}_uar'] = 1000 * df[stn] / da_dict[stn]
    return df




def compare_results_and_input(stn, sim_df, ds):
    """Compare the input streamflow timeseries with the observed streamflow timeseries.
    Check that the dates in the output match the common dates between Daymet and the input data.
    """
    input_df = retrieve_timeseries_discharge(stn, ds)
    # clip the 'discharge' column to 1e-4, convert to unit area runoff (L/s/km2), and take the log
    input_df = input_df[input_df.index >= '1980-01-01']

    df = pd.concat([input_df, sim_df], axis=1, join='inner')
    df.dropna(inplace=True)

    df['streamflow_obs'] = np.exp(df['streamflow_obs'])
    sim_cols = [c for c in sim_df.columns if c.startswith('streamflow_sim')]
    df[sim_cols] = np.exp(df[sim_cols])
    # assert that the 'log_obs' and the 'streamflow_obs' columns are approximately equal

    # set tolerance in the order of 1 L/s/km2
    if not np.allclose(df[f'{stn}_uar'], df['streamflow_obs'], atol=1): 
        max_diff = np.abs(df[f'{stn}_uar'] - df['streamflow_obs']).max()
        print(f'    Warning: {stn} has a max difference of {max_diff:.2f} between the input and output streamflow timeseries.')
    return df


def plot_ensemble_results(stn, folder):
    """Plot the ensemble results for a given station."""
    p = figure(title=f'Ensemble results for {stn}', x_axis_type='datetime', 
               y_axis_type='log', width=800, height=400)

    for date, clr in zip(['20250514', '20250627'], ['black', 'red']):
        fpath = folder / f'ensemble_results_{date}' / f'{stn}_ensemble.csv'
        df = pd.read_csv(fpath)
        df.rename(columns={'Unnamed: 0': 'time'}, inplace=True)
        df['time'] = pd.to_datetime(df['time'])
        df.set_index('time', inplace=True)
        df = np.exp(df)
        if 'streamflow_obs' in df.columns:
            p.line(df.index, df['streamflow_obs'], color=clr, legend_label=f'{date} Obs', line_width=2)
        
        sim_cols = [c for c in df.columns if c.startswith('streamflow_sim')]
        mean_sim = df[sim_cols].mean(axis=1)
        # compute the 5% quantiles on the simulation columns
        lb = df[sim_cols].quantile(0.05, axis=1)
        ub = df[sim_cols].quantile(0.95, axis=1)

        p.varea(df.index, lb, ub, color=clr, alpha=0.2, legend_label=f'{date} 90% CI')
        p.line(df.index, mean_sim, color=clr, legend_label=f'{date} Mean', line_dash='dashed', line_width=2)

    p.legend.location = 'top_left'
    p.xaxis.axis_label = 'Time'
    p.yaxis.axis_label = 'Streamflow (L/s/km2)'
    p.legend.click_policy= 'hide'
    return p


def filter_by_complete_years(stn, folder):
    all_dfs = []
    for date, clr in zip(['20250514', '20250627'], ['black', 'red']):
        fpath = folder / f'ensemble_results_{date}' / f'{stn}_ensemble.csv'
        if not os.path.exists(fpath):
            return pd.DataFrame()
        df = pd.read_csv(fpath)
        df.rename(columns={'Unnamed: 0': 'time'}, inplace=True)
        df['time'] = pd.to_datetime(df['time'])
        df.set_index('time', inplace=True)
        df.columns = [f'{c}_{date}' for c in df.columns]
        df = np.exp(df)
        all_dfs.append(df)
    result = pd.concat(all_dfs, axis=1, join='inner')
    result = result.dropna(how='any', axis=0)
    complete_years = complete_year_dict.get(stn, None).get('complete_years', [])
    print(f'    Found {len(complete_years)} complete years for {stn}: {complete_years}')
    return result[result.index.year.isin(complete_years)]


def get_original_timeseries(stn, ds):
    """Retrieve the original timeseries for a given station."""
    watershed_id = official_id_dict[stn]
    df = ds['discharge'].sel(watershed=str(watershed_id)).to_dataframe(name='discharge').reset_index()
    df = df.set_index('time')[['discharge']]
    df.dropna(inplace=True)
    # clip minimum flow to 1e-4
    df['discharge'] = np.clip(df['discharge'], 1e-4, None)
    df.rename(columns={'discharge': stn}, inplace=True)
    df[f'{stn}_uar'] = 1000 * df[stn] / da_dict[stn]
    return df


def compute_afdcs_from_sorted_flows(df, years, da, date):
    """Compute AFDCs from sorted daily flows, rather than PDFs."""
    
    obs_cols = [c for c in df.columns if c.startswith('streamflow_obs') and c.endswith(date)]
    assert len(obs_cols) == 1, f'Expected one observed column, found {len(obs_cols)}'
    sim_cols = [c for c in df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
    assert len(sim_cols) == 10, f'Expected 10 simulated columns, found {len(sim_cols)}'

    afdc_obs, afdc_sim = [], []

    for year in years:
        year_df = df[df.index.year == year]

        # Observed
        obs_values = year_df[obs_cols[0]].dropna().values
        if len(obs_values) > 0:
            sorted_obs = np.sort(obs_values)[::-1]  # descending
            afdc_obs.append(pd.Series(sorted_obs, name=f"{year}_obs"))

        # Simulated ensemble mean
        sim_ensemble = year_df[sim_cols].dropna(how='all')  # drop rows with all NaNs
        if not sim_ensemble.empty:
            sim_mean = sim_ensemble.mean(axis=1).dropna().values
            sorted_sim = np.sort(sim_mean)[::-1]
            afdc_sim.append(pd.Series(sorted_sim, name=f"{year}_sim"))

    # Align lengths: trim to shortest year
    min_len = min(len(s) for s in afdc_obs + afdc_sim)
    afdc_obs_trimmed = [s.iloc[:min_len].reset_index(drop=True) for s in afdc_obs]
    afdc_sim_trimmed = [s.iloc[:min_len].reset_index(drop=True) for s in afdc_sim]

    # Combine into DataFrames
    obs_df = pd.concat(afdc_obs_trimmed, axis=1)
    sim_df = pd.concat(afdc_sim_trimmed, axis=1)

    # Compute percentile summary
    afdc_summary = pd.DataFrame(index=np.arange(1, min_len + 1))
    afdc_summary[f'AFDC50_obs_{date}'] = obs_df.median(axis=1)
    afdc_summary[f'AFDC10_obs_{date}'] = obs_df.quantile(0.10, axis=1)
    afdc_summary[f'AFDC90_obs_{date}'] = obs_df.quantile(0.90, axis=1)

    afdc_summary[f'AFDC50_sim_{date}'] = sim_df.median(axis=1)
    afdc_summary[f'AFDC10_sim_{date}'] = sim_df.quantile(0.10, axis=1)
    afdc_summary[f'AFDC90_sim_{date}'] = sim_df.quantile(0.90, axis=1)

    afdc_summary.index.name = 'Rank'
    return afdc_summary


def compute_ensemble_pmfs(df, sim_cols, kde, da):
    """Compute the frequency mean PMF for the simulated ensemble."""
    sim_ensemble_pmfs = []
    for sim_col in sim_cols:
        sim_vals = df[sim_col].dropna().values
        assert len(sim_vals) > 0, f'No valid values found for {sim_col}'
        sim_pmf, _ = kde.compute(sim_vals, da=da)
        sim_ensemble_pmfs.append(pd.Series(sim_pmf, index=baseline_log_grid, name=sim_col))
    # concatenate all PMFs and compute the mean
    return pd.concat(sim_ensemble_pmfs, axis=1)


In [38]:
# import data_processing_functions as dpf
# from concurrent.futures import ThreadPoolExecutor
import jax
import jax.numpy as jnp
import numpy as np

# ---------- Kernel Functions ----------

@jax.jit
def gaussian_kernel(u):
    return jnp.exp(-0.5 * u**2) / jnp.sqrt(2 * jnp.pi)

@jax.jit
def epanechnikov_kernel(u):
    return jnp.where(jnp.abs(u) <= 1, 0.75 * (1 - u**2), 0.0)

@jax.jit
def top_hat_kernel(u):
    return jnp.where(jnp.abs(u) <= 1, 0.5, 0.0)

# ---------- Bandwidth Strategies ----------
def silverman_bandwidth(log_data: jnp.ndarray) -> float:
    q75, q25 = jnp.percentile(log_data, jnp.array((75, 25)))
    stdev = jnp.std(log_data)
    A = jnp.min(jnp.array([stdev, (q75 - q25) / 1.34]))
    return 1.06 * A / log_data.shape[0] ** 0.2


def measurement_error_bandwidth_function(x: jnp.ndarray) -> jnp.ndarray:
    error_points = jnp.array([1e-4, 1e-3, 1e-2, 1e-1, 1., 1e1, 1e2, 1e3, 1e4, 1e5])
    error_values = jnp.array([1.0, 0.5, 0.2, 0.1, 0.1, 0.1, 0.1, 0.15, 0.2, 0.25])
    return jnp.interp(x, error_points, error_values, left=1.0, right=0.25)


def adaptive_bandwidths(uar: jnp.ndarray, da: float) -> jnp.ndarray:
    flow_data = uar * da / 1000
    unique_q = jnp.unique(flow_data)
    
    # compute the measurement error informed bandwidth
    # units must be volumetric flow
    error_model = measurement_error_bandwidth_function(unique_q)
    unique_UAR = (1000 / da) * unique_q
    upper_err_UAR = unique_UAR * (1 + error_model)
    err_widths_UAR = jnp.log(upper_err_UAR) - jnp.log(unique_UAR)

    # compute the basic Silverman bandwidth
    # silverman_bw = silverman_bandwidth(jnp.log(unique_UAR))

    # if there are not enough unique values, add a small amount of noise to the data
    if len(unique_UAR) < 2:
        print(f'    not enough unique values in runoff data ({len(unique_UAR)}), adding noise to the data according to the measurement error model.')
        noise_bounds = (unique_UAR * (1 - error_model), unique_UAR * (1 + error_model))
        flow_data += np.random.uniform(*noise_bounds, size=flow_data.shape)
        unique_q = jnp.unique(flow_data)
        unique_UAR = (1000 / da) * unique_q

    # compute the log midpoints and bandwidths to address the issue
    # of sparse data points in the log space
    log_midpoints = jnp.log((unique_UAR[:-1] + unique_UAR[1:]) / 2)
    left_mirror = unique_UAR[0] - (log_midpoints[0] - unique_UAR[0])
    right_mirror = unique_UAR[-1] + (unique_UAR[-1] - log_midpoints[-1])
    log_midpoints = jnp.concatenate((jnp.array([left_mirror]), log_midpoints, jnp.array([right_mirror])))
    log_diffs = jnp.diff(log_midpoints)  / 2 / 1.15

    bw_vals = jnp.where(log_diffs > err_widths_UAR, log_diffs, err_widths_UAR)
    idx = jnp.searchsorted(unique_UAR, uar)
    return bw_vals[idx]


def kde_kernel(log_data, bw_values, log_grid):
    H = bw_values[:, None]  # (N, 1)
    U = (log_grid[None, :] - log_data) / H  # (N, M)
    K = jnp.exp(-0.5 * U**2) / (H * jnp.sqrt(2 * jnp.pi))
    return K.sum(axis=0) / log_data.shape[0]


class KDEEstimator:
    """
    Adaptive kernel density estimator using a measurement-error-informed bandwidth.

    Attributes
    ----------
    log_grid : jnp.ndarray
        Grid in log space over which to evaluate the KDE.
    dx : jnp.ndarray
        Spacing between grid points (gradient of log_grid).
    cache : dict
        Optional cache to store previously computed KDE results.
    """
    def __init__(self, log_grid, dx, cache=None):
        self.log_grid = jnp.asarray(log_grid, dtype=jnp.float32)
        self.dx = jnp.asarray(dx, dtype=jnp.float32)


    def compute(self, uar_data, da):
        """Compute the adaptive KDE and PMF for given unit area runoff data."""
        uar_data = jnp.asarray(uar_data)
        da = float(da)

        bw_values = adaptive_bandwidths(uar_data, da)
        log_data = jnp.log(uar_data)[:, None]
        pdf = kde_kernel(log_data, bw_values, self.log_grid)

        # Normalize PDF
        pdf /= jnp.trapezoid(pdf, x=self.log_grid)

        # Convert to PMF
        pmf = pdf * self.dx
        pmf /= jnp.sum(pmf)
        
        return pmf, pdf

In [39]:
def plot_observed_and_simulated_pdf(stn, pmf_dfs, og_df, date, pdf_plots=[]):
    """Plot the observed and simulated PDFs for a given station."""

    baseline_lin_grid = np.exp(baseline_log_grid)
    if date == '20250514':
        title = f'{stn}: LSTM PDFs Mean NSE Objective'
    elif date == '20250627':
        title = f'{stn}: LSTM PDFs 95% Quantile Objective'
    if len(pdf_plots) > 0:
        p = figure(title=title, x_axis_type='log',
            width=800, height=350, x_range=pdf_plots[0].x_range,
            y_range=pdf_plots[0].y_range)
    else:
        p = figure(title=title, x_axis_type='log',
            width=800, height=350)

    # plot the observed values as quad glyphs
    observed_vals = og_df[f'{stn}_uar'].dropna().values
    observed_log_vals = np.log(observed_vals)
    min_q, max_q = observed_log_vals.min(), observed_log_vals.max()
    obs_log_dx = np.linspace(min_q - 0.1, max_q + 0.1, num=128)
    hist, edges = np.histogram(observed_log_vals, bins=obs_log_dx, density=True)
    edges = np.exp(edges)  # convert edges back to linear space
    # convert to probbility mass function (PMF)
    p.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:],
            fill_color='dodgerblue', alpha=0.5, legend_label='Observed')
    
    df = pmf_dfs[date].copy()
    for i in range(10):
        sim_col = f'streamflow_sim_{i}_{date}'
        if sim_col in df.columns:
            p.line(baseline_lin_grid, df[sim_col].values / log_dx, color='grey', alpha=0.5, 
                    legend_label=f'LSTM Ensemble')

    # convert pmfs to pdfs
    p.line(baseline_lin_grid, df[f'POR_obs_{date}'].values / log_dx, 
           color='black', line_width=2.5, legend_label=f'POR Observed', line_dash='dotted')
    p.line(baseline_lin_grid, df[f'POR_sim_timeEnsemble_{date}'].values / log_dx, 
           line_width=2.5, color='green', legend_label=f'timeEnsemble')
    p.line(baseline_lin_grid, df[f'POR_sim_freqEnsemble_{date}'].values / log_dx, 
           line_width=2.5, color='red', legend_label=f'freqEnsemble')

    p.xaxis.axis_label = 'Log Unit Area Runoff (L/s/km2)'
    p.xaxis.axis_label = 'Unit Area Runoff (L/s/km2)'
    p.yaxis.axis_label = 'Probability Density'
    p.legend.location = 'top_left'
    p.legend.click_policy = 'hide'
    p.add_layout(p.legend[0], 'right')
    return p


def plot_observed_and_simulated_fdc(stn, lstm_df, og_df, date, fdc_plots=[]):
    """Plot the observed and simulated FDCs for a given station."""

    # baseline_lin_grid = np.exp(baseline_log_grid)
    if date == '20250514':
        title = f'{stn}: LSTM FDCs Mean NSE Objective'
    elif date == '20250627':
        title = f'{stn}: LSTM FDCs 95% Quantile Objective'

    if len(fdc_plots) > 0:
        fdc_plot = figure(title=title, #x_axis_type='log',
            width=400, height=350, x_range=fdc_plots[0].x_range,
            y_range=fdc_plots[0].y_range)
    else:
        fdc_plot = figure(title=title, #x_axis_type='log', 
                          width=800, height=350)
        
    # plot the observed duration curve
    pcts = np.linspace(0.01, 0.99, 99)
    observed_vals = og_df[f'{stn}_uar'].dropna().values
    obs_fdc = np.percentile(observed_vals, pcts * 100)[::-1]

    fdc_plot.line(pcts, obs_fdc, color='dodgerblue', line_width=2.5, legend_label=f'Observed')

    sim_cols = [c for c in lstm_df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
    assert len(sim_cols) == 10, f'Expected 10 simulated columns, found {len(sim_cols)}'

    sim_fdcs = pd.DataFrame(index=pcts)
    for i, sim_col in enumerate(sim_cols):
        sim_vals = lstm_df[sim_col].dropna().values
        sim_fdc = np.percentile(sim_vals, pcts * 100)[::-1]
        sim_fdcs[f'LSTM Simulation {i+1}'] = sim_fdc
        fdc_plot.line(pcts, sim_fdc, color='grey', alpha=0.5, 
                      legend_label=f'LSTM Simulation')
        
    # compute the temporal ensemble mean FDC
    temporal_mean_fdc = lstm_df[sim_cols].mean(axis=1).dropna().values
    fdc_plot.line(pcts, np.percentile(temporal_mean_fdc, pcts * 100)[::-1], 
                  color='green', line_width=2.5, 
                  legend_label=f'Temporal Ensemble Mean')
    
    # compute the frequency ensemble mean FDC
    freq_ensemble_fdc = sim_fdcs.mean(axis=1).values
    fdc_plot.line(pcts, freq_ensemble_fdc, color='red', line_width=2.5, 
                  legend_label=f'Frequency Ensemble Mean')
    
    fdc_plot.xaxis.axis_label = 'Exceedance Probability (%)'
    fdc_plot.yaxis.axis_label = 'Unit Area Runoff (L/s/km2)'
    fdc_plot.legend.location = 'top_right'
    fdc_plot.legend.click_policy = 'hide'
    
    return fdc_plot

In [40]:
import sqlite3

def check_if_station_in_hydat(stn, conn):
    # Query to check if the station exists
    check_query = """
    SELECT STATION_NUMBER, STATION_NAME
    FROM STATIONS
    WHERE STATION_NUMBER = ?
    """
    # Run the query
    station_check = pd.read_sql_query(check_query, conn, params=(stn,))

    # Test if any result was returned
    if station_check.empty:
        return False
    else:
        return True
    
    
def query_data_symbols(conn):
    # Query all data symbols
    query = "SELECT SYMBOL_ID, SYMBOL_EN FROM DATA_SYMBOLS ORDER BY SYMBOL_ID"
    df = pd.read_sql_query(query, conn)
    return df.set_index('SYMBOL_ID')['SYMBOL_EN']


def reshape_hydat_wide(df):
    # First, ensure all FLOW_SYMBOL columns exist and are named correctly
    id_vars = ["STATION_NUMBER", "YEAR", "MONTH", "NO_DAYS"]
    
    # Melt flows
    flow_df = df.melt(id_vars=id_vars, 
                      value_vars=[f"FLOW{i}" for i in range(1, 32)],
                      var_name="day", value_name="flow")

    # Melt symbols
    sym_df = df.melt(id_vars=id_vars, 
                     value_vars=[f"FLOW_SYMBOL{i}" for i in range(1, 32)],
                     var_name="day", value_name="flow_symbol")

    # Extract day number
    flow_df["day"] = flow_df["day"].str.extract(r"(\d+)$").astype(int)
    sym_df["day"] = sym_df["day"].str.extract(r"(\d+)$").astype(int)

    # Merge on ID columns + day
    merged = pd.merge(flow_df, sym_df, on=id_vars + ["day"])

    # Construct date
    merged["date"] = pd.to_datetime(dict(year=merged["YEAR"], 
                                         month=merged["MONTH"], 
                                         day=merged["day"]), errors='coerce')

    # Filter out invalid days (e.g., day > NO_DAYS)
    merged = merged[merged["day"] <= merged["NO_DAYS"]]
    formatted_df = merged[["STATION_NUMBER", "date", "flow", "flow_symbol"]].dropna(subset=["flow"])
    formatted_df.set_index('date', inplace=True)
    return formatted_df


def query_hydat_database(stn):
    """Query the HYDAT database for a given station and date range."""
    hydat_path = Path('/home/danbot/code/common_data/HYDAT') / 'Hydat_20250415.sqlite3'
    # Connect to the database
    conn = sqlite3.connect(hydat_path)

    quality_symbols = query_data_symbols(conn)

    station_in_hydat = check_if_station_in_hydat(stn, conn)
    if station_in_hydat is False:
        print(f'Station {stn} not found in HYDAT database.')
        return pd.DataFrame(), quality_symbols
    
    base_columns = ["STATION_NUMBER", "YEAR", "MONTH"]
    flow_columns = [f"FLOW{i}, FLOW_SYMBOL{i}" for i in range(1, 32)]
    end_columns = ["NO_DAYS"]

    all_columns = base_columns + flow_columns + end_columns
    column_str = ",\n    ".join(all_columns)

    query = f"""
    SELECT
        {column_str}
    FROM DLY_FLOWS
    WHERE STATION_NUMBER = ?
    ORDER BY YEAR, MONTH;
    """
    df = pd.read_sql_query(query, conn, params=(stn,))

    if df.empty:
        print(f'No data found for {stn} in HYDAT.')
        return pd.DataFrame(), quality_symbols
    df = reshape_hydat_wide(df)
    return df, quality_symbols


def find_symbol_segments(symbol_df, target_symbol):
    """Return (start, end) date pairs for each continuous period of target_symbol."""
    # Filter for matching symbol only
    mask = (symbol_df['flow_symbol'] == target_symbol)
    dates = symbol_df['flow_symbol'].index[mask]

    if dates.empty:
        return []

    # Compute gaps in days between successive dates
    gaps = dates.to_series().diff().gt(pd.Timedelta(days=1)).fillna(True)

    # Group by contiguous regions (cumsum creates a new group after each gap)
    group_ids = gaps.cumsum()

    # Group by group ID and extract start and end of each contiguous block
    segments = [(group.min(), group.max()) for _, group in dates.to_series().groupby(group_ids)]

    return segments

In [48]:
def process_FDCs(df, stn, og_df, output_folder):

    kde = KDEEstimator(baseline_log_grid, log_dx)
    print(f'    Processing FDCs for {stn}')
    por_metrics = {}
    pmf_dfs = {}
    years = []
    for date in ['20250514', '20250627']:
        pmf_columns = []
        years = df.index.year.unique()
        da = da_dict[stn]

        pmf_fpath = f'data/results/lstm_pmfs/POR_{stn}_pmfs_{len(years)}_years_{date}.csv'
        # if os.path.exists(pmf_fpath):
        #     print(f'    PMFs for {stn} already exist, skipping.')
        #     continue

        # compute observed POR PMFs
        obs_cols = [c for c in df.columns if c.startswith('streamflow_obs') and c.endswith(date)]
        assert len(obs_cols) == 1, f'Expected exactly one observed column for {stn} on {date}, found {len(obs_cols)}'
        por_obs_vals = df[obs_cols[0]].dropna().values
        # print(f'{stn} POR obs: {min(por_obs_vals):.2f} - {max(por_obs_vals):.2f}')
        por_obs_pmf, por_obs_pdf = kde.compute(por_obs_vals, da=da)
        pmf_columns.append(pd.Series(por_obs_pmf, index=baseline_log_grid, name=f'POR_obs_{date}'))

        # compute simulated POR PMFs
        sim_cols = [c for c in df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
        # temporal mean of the simulated ensemble
        sim_vals = df[sim_cols].mean(axis=1).dropna().values
        sim_pmf, sim_pdf = kde.compute(sim_vals, da=da) # temporal mean ensemble PMF
        pmf_columns.append(pd.Series(sim_pmf, index=baseline_log_grid, name=f'POR_sim_timeEnsemble_{date}'))

        frequency_sim_pmfs = compute_ensemble_pmfs(df, sim_cols, kde, da)
        mean_pmf = frequency_sim_pmfs.mean(axis=1).rename('sim_freqEnsemble_mean')
        mean_pmf /= mean_pmf.sum()  # renormalize the PMF
        pmf_columns.append(pd.Series(mean_pmf, index=baseline_log_grid, name=f'POR_sim_freqEnsemble_{date}'))

        og_vals = og_df[f'{stn}_uar'].dropna().values
        # print(f'{stn}_uar: {min(og_vals):.2f} - {max(og_vals):.2f}')
        og_pmf, og_pdf = kde.compute(og_vals, da=da)
        pmf_columns.append(pd.Series(og_pmf, index=baseline_log_grid, name=f'{stn}_uar'))
        por_pmfs = pd.concat(pmf_columns, axis=1)    # do the same for the original timeseries
        pmfs = pd.concat([por_pmfs, frequency_sim_pmfs], axis=1)  # initialize PMFs DataFrame with observed PMFs
        pmfs.index.name = 'log_uar'  # set the index name for clarity
        # save the PMFs to a CSV file
        pmfs.to_csv(pmf_fpath)
        pmf_dfs[date] = pmfs

        # evaluate metrics
        for col in [f'POR_sim_timeEnsemble_{date}', f'POR_sim_freqEnsemble_{date}']:
            pmf = pmfs[col].values
            por_metrics[col] = evaluate_fdc_metrics_from_pmf(pmf, por_obs_pmf, baseline_log_grid)

        for col in frequency_sim_pmfs.columns:
            pmf = frequency_sim_pmfs[col].values
            por_metrics[col] = evaluate_fdc_metrics_from_pmf(pmf, por_obs_pmf, baseline_log_grid)

    mdf = pd.DataFrame(por_metrics).T
    # save to csv
    metric_fpath = f'data/results/lstm_metrics/{stn}_{len(years)}_years_metrics.csv'
    mdf.to_csv(metric_fpath)
    return mdf, pmf_dfs


In [49]:
def plot_quality_flag_periods(stn, df, hydat_df, quality_symbols, runoff_plot, obs_col):
    symbol_dict = quality_symbols.to_dict()
    symbol_colors = {
        'B': 'dodgerblue',  # Baseflow
        'D': 'firebrick',  # Dry weather flow
        'E': 'orange'  # Estimated
    }
    df['flow_symbol'] = hydat_df['flow_symbol'].reindex(df.index, method=None)
    uar_cols = obs_col + [c for c in df.columns if c.startswith('streamflow_sim')]

    for symbol in ['B', 'D', 'E']:
        description = symbol_dict.get(symbol, {})
        color = symbol_colors.get(symbol, 'gray')
        n_symbols = df['flow_symbol'].eq(symbol).sum()
        if n_symbols == 0:
            continue

        segments = find_symbol_segments(df[['flow_symbol']].copy(), symbol)

        for start, end in segments:
            runoff_plot.varea(
                x=pd.date_range(start, end),
                y1=0.98 * df[uar_cols].min().min(),  # get the min of the dataframe for the lower bound
                y2=1.02 * df[uar_cols].max().max(),  # a bit above max for visibility
                fill_color=color, fill_alpha=0.3,
                legend_label=f"{description} ({symbol})"
            )
    
    df['flow'] = hydat_df['flow'].reindex(df.index, method=None)
    df['hydat_uar'] = 1000 * df['flow'] / da_dict[stn]  # convert to unit area runoff (L/s/km2)
    runoff_plot.line(df.index, df['hydat_uar'],
                     color='dodgerblue', legend_label='HYDAT UAR', 
                     line_width=2, line_dash='dotted')
    return runoff_plot

In [50]:
def plot_runoff_timeseries(stn, lstm_df, date):

    obs_col = [c for c in lstm_df.columns if '_obs_' in c and c.endswith(date)]
    assert len(obs_col) == 1, f'Expected exactly one observed column for {stn}, found {len(obs_col)} {obs_col}'
    sim_cols = [c for c in lstm_df.columns if c.startswith('streamflow_sim') and c.endswith(date)]
    assert len(sim_cols) == 10, f'Expected ten simulation columns for {stn}, found {len(sim_cols)}'
    # plot th time series of the observed values
    runoff_plot = figure(title=f'{stn} Observed Unit Area Runoff', x_axis_type='datetime',
                         width=800, height=350, y_axis_type='log')
    
    hydat_df, quality_symbols = query_hydat_database(stn)

    # reindex to daily frequency and keep nans
    df = lstm_df.copy().reindex(pd.date_range(start=lstm_df.index.min(), end=lstm_df.index.max(), freq='D'))

    if not hydat_df.empty:
        runoff_plot = plot_quality_flag_periods(stn, df, hydat_df, quality_symbols, runoff_plot, obs_col)
    
    runoff_plot.line(df.index, df[obs_col[0]], color='dodgerblue',
                     legend_label='Observed UAR', line_width=2.)
    for col in sim_cols:
        runoff_plot.line(df.index, df[col], color='grey', alpha=0.5,
                         legend_label=f'LSTM ensemble')
    # compute the temporal mean of the simulated ensemble
    mean_sim = df[sim_cols].mean(axis=1)
    runoff_plot.line(df.index, mean_sim, color='black', legend_label='Ensemble Mean', line_width=2, line_dash='dashed')
    runoff_plot.xaxis.axis_label = 'Date'
    runoff_plot.yaxis.axis_label = 'Unit Area Runoff (L/s/km2)'
    runoff_plot.legend.location = 'top_left'
    runoff_plot.legend.click_policy = 'hide'
    # runoff_plot.add_layout(runoff_plot.legend[0], 'right')
    runoff_plot.legend.background_fill_alpha = 0.65
    return runoff_plot


def format_metrics_table(metric_df, stn, date):
    metric_df.index.name = 'series'
    metric_df = metric_df.reset_index()
    metric_df['date'] = metric_df.apply(lambda row: row['series'].split('_')[-1], axis=1)

    # lines = ['solid', 'dashed']
    metric_cols = [c for c in metric_df.columns if c.startswith('FDC_')]
    filtered_metrics = metric_df[metric_df['date'] == date].copy()
    
    time_df = filtered_metrics[filtered_metrics['series'].str.contains('timeEnsemble')].copy()
    freq_df = filtered_metrics[filtered_metrics['series'].str.contains('freqEnsemble')].copy()
    sim_df = filtered_metrics[filtered_metrics['series'].str.startswith('streamflow_sim_')].copy()
    sim_vals = pd.DataFrame(np.percentile(sim_df[metric_cols].values, [5, 50, 95], axis=0), 
                            index=['5%', '50%', '95%'], columns=metric_cols)   

    metric_table_df = pd.concat([time_df[metric_cols], freq_df[metric_cols], sim_vals], axis=0)
    metric_table_df[date] = ['timeEnsemble', 'freqEnsemble', '5%', '50%', '95%']
    metric_table_df.set_index(date, inplace=True)
    metric_table_df.columns = [c.split('_')[-1] for c in metric_table_df.columns]
    metric_table_df.columns = ['RE' if c == 'RelativeError' else c for c in metric_table_df.columns]
    # format the metrics table for display in a Div
    # metric_table_df = metric_table_df.round(3)
    metric_table_html = metric_table_df.style.format(precision=3).set_properties(**{'padding': '3px'}).to_html(classes='table table-striped', border=0,
                                                justify='center', index=True)
    div = Div(text=metric_table_html, width=1000, height=200)
    return div

In [52]:
# retrieve LSTM ensemble predictions
# filter for the common stations
common_stations = list(set(station_ids) & set(lstm_result_stns))
print(f'There are {len(common_stations)} monitored basins with LSTM ensemble results.')
attr_df = attr_df[attr_df['official_id'].isin(common_stations)]

output_folder = BASE_DIR / 'data' / 'results' /  'lstm_plots'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# plots = []
process_fdcs = True
if process_fdcs:
    for stn in common_stations:
        og_df = get_original_timeseries(stn, ds)
        # print(og_df[og_df.index >= '1982-06-01'].head())
        lstm_ensemble_df = filter_by_complete_years(stn, lstm_result_base_folder)

        if lstm_ensemble_df.empty:
            print(f'No complete years found for {stn}. Skipping.')
            continue
        og_df = og_df[og_df.index.isin(lstm_ensemble_df.index)]
        mdf, pmf_dfs = process_FDCs(lstm_ensemble_df, stn, og_df, output_folder)

        dates = list(pmf_dfs.keys())
        pdf_plots, fdc_plots, metric_tables = [], [], []
        ts_plot = plot_runoff_timeseries(stn, lstm_ensemble_df, dates[-1])
        for date in dates:
            metric_table = format_metrics_table(mdf, stn, date)
            metric_tables.append(metric_table)
            pdf_plot = plot_observed_and_simulated_pdf(stn, pmf_dfs, og_df, date, pdf_plots=pdf_plots)
            pdf_plots.append(pdf_plot)
            fdc_plot = plot_observed_and_simulated_fdc(stn, lstm_ensemble_df, og_df, date, fdc_plots=fdc_plots)
            fdc_plots.append(fdc_plot)

        layout = column(
            row([ts_plot, fdc_plots[1]]), 
            row(pdf_plots[0], metric_tables[0]), 
            row(pdf_plots[1], metric_tables[1]), 
            )
        # save the plot to an HTML file
        # show(layout)
        output_fname = output_folder / f'{stn}_fdc.html'
        output_file(output_fname, title=f'{stn} FDCs')
        save(layout)
        print(f'    Saved plot for {stn} to {output_fname}')

There are 723 monitored basins with LSTM ensemble results.
    Found 36 complete years for 05DC012: [1985, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2020, 2021, 2022]
    Processing FDCs for 05DC012
    Saved plot for 05DC012 to /home/danbot/code/distribution_estimation/docs/notebooks/data/results/lstm_plots/05DC012_fdc.html
    Found 18 complete years for 12323710: [2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023]
    Processing FDCs for 12323710
Station 12323710 not found in HYDAT database.
    Saved plot for 12323710 to /home/danbot/code/distribution_estimation/docs/notebooks/data/results/lstm_plots/12323710_fdc.html
    Found 22 complete years for 08LA008: [1961, 1962, 1963, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1972, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1

KeyboardInterrupt: 

[1950, 1951]

In [None]:
def map_metric_to_dict(metric_table, series):
    return {
        'RMSE': metric_table.loc[series, 'RMSE'],
        'NSE': metric_table.loc[series, 'NSE'],
        'KGE': metric_table.loc[series, 'KGE'],
        'RE': metric_table.loc[series, 'RE'],
        'KLD': metric_table.loc[series, 'KLD'],
        'EMD': metric_table.loc[series, 'EMD'],
    }

In [None]:
# create a plot of CDFs by metric across all stations
rmses, nses, kges, rel_errors, dkls, emds = [], [], [], [], [], []
freq_ensemble_vals, time_ensemble_vals, low_bound_vals, high_bound_vals, median_vals = {}, {}, {}, {}, {}
dates = ['20250514', '20250627']  # 20250514 use the latest date for the metrics
metrics_files = sorted(os.listdir(BASE_DIR / 'data' / 'lstm_metrics'))
metric_dict = defaultdict(dict)
for stn in common_stations:
    # files = [f for f in metrcs_files if f'_{stn}_' in f]
    metric_files = [f for f in metrics_files if f.startswith(stn) and f.endswith('years_metrics.csv')]
    n_years = int(metric_files[0].split('_')[1]) if metric_files else 0
    if len(metric_files) == 0:
        print(f'No metrics file found for {stn}, skipping.')
        continue
    assert len(metric_files) == 1, f'Expected exactly one metrics file for {stn}, found {len(metric_files)}'
    mdf = pd.read_csv(BASE_DIR / 'data' / 'lstm_metrics' / metric_files[0])
    metric_cols = [c for c in mdf.columns if c.startswith('FDC_')]
    mdf.rename({'Unnamed: 0': 'series'}, axis=1, inplace=True)
    mdf['date'] = mdf.apply(lambda row: row['series'].split('_')[-1], axis=1)
    # date = files[0].split('_')[-1].split('.')[0]  # extract date from the file name

    for date in dates:
        filtered_metrics = mdf[mdf['date'] == date].copy()
        time_df = filtered_metrics[filtered_metrics['series'].str.contains('timeEnsemble')].copy()
        freq_df = filtered_metrics[filtered_metrics['series'].str.contains('freqEnsemble')].copy()
        sim_df = filtered_metrics[filtered_metrics['series'].str.startswith('streamflow_sim_')].copy()
        sim_vals = pd.DataFrame(np.percentile(sim_df[metric_cols].values, [5, 50, 95], axis=0), 
                                    index=['5%', '50%', '95%'], columns=metric_cols)     
        
        metric_table_df = pd.concat([time_df[metric_cols], freq_df[metric_cols], sim_vals], axis=0)
        metric_table_df[date] = ['timeEnsemble', 'freqEnsemble', '5%', '50%', '95%']
        metric_table_df.set_index(date, inplace=True)
        metric_table_df.columns = [c.split('_')[-1] for c in metric_table_df.columns]
        metric_table_df.columns = ['RE' if c == 'RelativeError' else c for c in metric_table_df.columns]
        
        for s in metric_table_df.index:
            val_dict = map_metric_to_dict(metric_table_df, s)
            #  append the value to the same metric
            
            for m, value in val_dict.items():
                if m not in metric_dict[date]:
                    metric_dict[date][m] = {}
                if s not in metric_dict[date][m]:
                    metric_dict[date][m][s] = [float(value)]
                else:
                    metric_dict[date][m][s].append(float(value))

No metrics file found for 12102190, skipping.


In [None]:
def compute_empirical_cdf(data):
    """Compute the empirical CDF of the data."""
    sorted_data = np.sort(data)
    n = len(sorted_data)
    cdf = np.arange(1, n + 1) / n
    return sorted_data, cdf

In [None]:
cdict = {'timeEnsemble': 'green', 'freqEnsemble': 'red', '5%': 'blue', '50%': 'black', '95%': 'navy'}

In [None]:
plots = []
for m in ['RMSE', 'NSE', 'KGE', 'RE', 'KLD', 'EMD']:
    mplt = []
    for date in dates:
        data = metric_dict[date][m]
        x_label = m
        if m in ['NSE', 'KGE']:
            x_label = f'1 - {m}'
        if len(mplt) > 0:
            p = figure(title=f'{m}: {date}', x_axis_label=x_label, y_axis_label='Density', 
                        width=800, height=400, x_axis_type='log', 
                        x_range=mplt[0].x_range, y_range=mplt[0].y_range)
        else:
            p = figure(title=f'{m}: {date}', x_axis_label=x_label, y_axis_label='Density', 
                        width=800, height=400, x_axis_type='log')
        for s, vals in data.items():
            line = 'solid'
            legend_label = s
            if s.startswith('95%'):
                line = 'dashed'
                legend_label = f'Sim Ensemble (95%)'
            elif s.startswith('5%'):
                line = 'dotted'
                legend_label = f'Sim Ensemble (5%)'
            elif s.startswith('50%'):
                line='dashdot'
                legend_label = f'Sim Ensemble (50%)'
            if m in ['NSE', 'KGE']:
                vals = [1-v for v in vals]
            x, y = compute_empirical_cdf(np.array(vals))
            print(s, line, cdict[s])
            p.line(x=x, y=y, legend_label=legend_label, color=cdict[s], line_width=2.5, line_dash=line)
        p.legend.location='top_left'
        p.legend.click_policy = 'hide'
        plots.append(p)
        mplt.append(p)

timeEnsemble solid green
freqEnsemble solid red
5% dotted blue
50% dashdot black
95% dashed navy
timeEnsemble solid green
freqEnsemble solid red
5% dotted blue
50% dashdot black
95% dashed navy
timeEnsemble solid green
freqEnsemble solid red
5% dotted blue
50% dashdot black
95% dashed navy
timeEnsemble solid green
freqEnsemble solid red
5% dotted blue
50% dashdot black
95% dashed navy
timeEnsemble solid green
freqEnsemble solid red
5% dotted blue
50% dashdot black
95% dashed navy
timeEnsemble solid green
freqEnsemble solid red
5% dotted blue
50% dashdot black
95% dashed navy
timeEnsemble solid green
freqEnsemble solid red
5% dotted blue
50% dashdot black
95% dashed navy
timeEnsemble solid green
freqEnsemble solid red
5% dotted blue
50% dashdot black
95% dashed navy
timeEnsemble solid green
freqEnsemble solid red
5% dotted blue
50% dashdot black
95% dashed navy
timeEnsemble solid green
freqEnsemble solid red
5% dotted blue
50% dashdot black
95% dashed navy
timeEnsemble solid green
freqE

In [None]:
layout = gridplot(plots, ncols=2, width=700, height=400)
# show(layout)
# same the plot html
output_fname = BASE_DIR / 'data' / 'LSTM_ensemble_metrics.html'
output_file(output_fname, title='LSTM Ensemble Metrics')
save(layout)