In [1]:
import os
import pandas as pd
import numpy as np
import json
from time import time
import geopandas as gpd

import xgboost as xgb
xgb.config_context(verbosity=2)
from scipy.stats import norm, laplace, genextreme

from collections import defaultdict
from multiprocessing import Pool, cpu_count

import jax
import jax.numpy as jnp

from kde_estimator import KDEEstimator
from fdc_estimator_context import FDCEstimationContext 
from fdc_data import StationData
from evaluation_metrics import EvaluationMetrics

import data_processing_functions as dpf

from pathlib import Path
BASE_DIR = os.getcwd()

In [2]:
from bokeh.plotting import figure, show, output_notebook
from bokeh.layouts import gridplot

import xyzservices.providers as xyz
tiles = xyz['USGS']['USTopo']
output_notebook()


In [3]:
# load the catchment characteristics
fname = f'catchment_attributes_with_runoff_stats.csv'
attr_df = pd.read_csv(os.path.join('data', fname), dtype={'official_id': str, 'drainage_area_km2': float})
attr_df.columns = [c.lower() for c in attr_df.columns]
attr_df['log_drainage_area_km2'] = np.log(attr_df['drainage_area_km2'])
# attr_df = attr_df[~attr_df['official_id'].isin(exclude)]
# attr_df.columns = [c.lower() for c in attr_df.columns]
attr_df['tmean'] = (attr_df['tmin'] + attr_df['tmax']) / 2.0
station_ids = attr_df['official_id'].values
# assert '12414900' in station_ids

print(f'There are {len(station_ids)} monitored basins in the attribute set.')



There are 1098 monitored basins in the attribute set.


In [4]:
# streamflow folder from (updated) HYSETS
HYSETS_DIR = Path('/home/danbot/code/common_data/HYSETS')
# STREAMFLOW_DIR = HYSETS_DIR / 'streamflow'

hs_df = pd.read_csv('data/HYSETS_watershed_properties.txt', sep=';', dtype={'Official_ID': str})
hs_df = hs_df[hs_df['Official_ID'].isin(station_ids)]
hs_df.head(2)

Unnamed: 0,Watershed_ID,Source,Name,Official_ID,Centroid_Lat_deg_N,Centroid_Lon_deg_E,Drainage_Area_km2,Drainage_Area_GSIM_km2,Flag_GSIM_boundaries,Flag_Artificial_Boundaries,...,Land_Use_Wetland_frac,Land_Use_Water_frac,Land_Use_Urban_frac,Land_Use_Shrubs_frac,Land_Use_Crops_frac,Land_Use_Snow_Ice_frac,Flag_Land_Use_Extraction,Permeability_logk_m2,Porosity_frac,Flag_Subsoil_Extraction
846,847,HYDAT,CROWSNEST RIVER AT FRANK,05AA008,49.59732,-114.4106,402.6522,,0,0,...,0.0103,0.0065,0.0328,0.0785,0.0015,0.0002,1,-15.543306,0.170479,1
849,850,HYDAT,CASTLE RIVER NEAR BEAVER MINES,05AA022,49.48866,-114.1444,820.651,,0,0,...,0.0058,0.0023,0.0105,0.1156,0.0246,0.0,1,-15.929747,0.150196,1


In [5]:
# load the baseline PMFs from the previous notebook
pmf_path = Path(os.getcwd()) / 'data' / 'results' / 'baseline_distributions' / f'bcub_pmfs.csv'
pmf_df = pd.read_csv(pmf_path, index_col=0)
pmf_stations = pmf_df.columns
station_ids = list(set(station_ids).intersection(set(pmf_stations)))
print(len(station_ids))

1097


In [6]:
# retrieve LSTM ensemble predictions
lstm_result_folder = '/home/danbot/code/neuralhydrology/data/ensemble_results_20250514'
lstm_result_files = os.listdir(lstm_result_folder)
lstm_result_stns = [e.split('_')[0] for e in lstm_result_files]
assert '12414900' in lstm_result_stns

# find any non-matching station ids in the lstm result files
for stn in lstm_result_stns:
    if stn not in station_ids:
        # try adding a leading zero
        ending_in = [e for e in station_ids if e.endswith(stn)]
        if len(ending_in) > 0:
            print(stn, 'matches', ending_in)
        modified_stn = stn.zfill(8)
        if modified_stn in station_ids:
            print(f'Found modified station id: {modified_stn} for {stn}')
        else:
            print(f'Warning: {stn} is in LSTM results but not in the station attributes.')

# filter for the common stations between BCUB region and LSTM-compatible (i.e. 1980-)
daymet_concurrent_stations = list(set(station_ids) & set(lstm_result_stns) & set(pmf_stations))
# assert '12414900' in daymet_concurrent_stations
print(f'There are {len(daymet_concurrent_stations)} monitored basins concurrent with LSTM ensemble results.')
print(f'There are {len(pmf_stations)} monitored basins with baseline PMFs.')

There are 723 monitored basins concurrent with LSTM ensemble results.
There are 1097 monitored basins with baseline PMFs.


In [7]:
# import updated catchment polygons
# poly_fpath = os.path.join(os.path.join('data', f'BCUB_watershed_attributes_updated_{rev_date}.csv'))
# catchment_gdf = pd.read_csv(poly_fpath)
# catchment
# catchment_gdf = catchment_gdf[catchment_gdf['Official_ID'].isin(station_ids)]
# print(len(catchment_gdf), 'catchments in the polygon set')

# import the license water extraction points
# dam_gdf = gpd.read_file('data/Dam_Points_20240103.gpkg')
# assert dam_gdf.crs == catchment_gdf.crs, "Catchment and dam geometries must have the same CRS"
# joined = gpd.sjoin(catchment_gdf, dam_gdf, how="inner", predicate="contains")
# Create a new boolean column 'contains_dam' in catchment_gdf.
# If a polygon's index appears in the joined result, it means it contains at least one point.
# regulated = joined['Official_ID'].values
# catchment_gdf["contains_dam"] = catchment_gdf['Official_ID'].apply(lambda x: x in regulated)
# n_regulated = catchment_gdf['contains_dam'].sum()
# print(f'{n_regulated}/{len(catchment_gdf)} catchments contain withdrawal licenses')

# # create dicts for easier access to 'official_id': 'drainage area', geometry, regulation status
# da_dict = attr_df[['official_id', 'drainage_area_km2']].copy().set_index('official_id').to_dict()['drainage_area_km2']
# dam_dict = catchment_gdf[['Official_ID', 'contains_dam']].copy().set_index('Official_ID').to_dict()['contains_dam']
# polygon_dict = catchment_gdf[['Official_ID', 'geometry']].copy().set_index('Official_ID').to_dict()['geometry']

# # add the centroid point geometry to the attributes dataframe
# attr_df = attr_df[attr_df['official_id'].isin(catchment_gdf['Official_ID'].values)].copy()
# centroids = attr_df.apply(lambda x: polygon_dict[x['official_id']].centroid, axis=1)
# attr_gdf = gpd.GeoDataFrame(attr_df, geometry=centroids, crs=catchment_gdf.crs)
# attr_gdf["contains_dam"] = attr_gdf['official_id'].apply(lambda x: dam_dict[x] if x in dam_dict else False)
# add the concurrency status as a boolean column
# attr_df['LSTM_concurrent'] = attr_df['official_id'].apply(lambda x: x in daymet_concurrent_stations)
# attr_df.reset_index(inplace=True, drop=True)
# print(f'N network stations={len(attr_df)}')

## Non-Parametric Simulation

### Time-based ensemble

A probability distribution $\hat p = f(\tilde x(t))$ is estimated for a target (ungauged location) by a weighted mean of runoff time-series from k nearest neighbour stations, $\tilde x(t) = \textbf{X}(t)\cdot w$ where $X(t) \in \mathbb{R}^{N \times k}$ and $w \in \mathbb{R}^{k\times 1}$ is a vector of k weights.  So $\hat p = f(\textbf{X}(t) \cdot w )$  Weights $w$ are computed in three ways, described in the next subsection, and k-nearest neighbours are selected using the criteria defined below.  Each gauged station in the monitoring network is treated as an ungauged location to generate a large sample of simulations across hydrologically diverse catchments, or rather as many catchments as can be tested.

### Frequency-based ensembles

A simulated probability density function is estimated from observations of k nearest neighbour stations.  First, k simulated series are generated by equal unit area runoff , $\hat p = \hat P \cdot w$ where $\hat P = [\hat p_1, \hat p_2, \cdots, \hat p_k]$ and each $\hat p_i = f(X_i(t))$.

In both cases, the function $f \rightarrow \hat p(x)$ represents kernel density estimation, which defines the probability density as $$\hat p(x) = \frac{1}{n \cdot h(x)} \sum_{i=1}^{n}K\left( \frac{x-x_i}{h(x)}\right)$$ 

Where $h(x)$ reflects an adaptive kernel bandwidth that addresses vestiges of precision in the observed data to reflect the nature of streamflow as a continuous variable, and additionally incorporates piecewise linear model to represent overall measurement uncertainty.


## Notes on k-nearest neighbours

Time series streamflow records vary widely in their temporal coverage, and finding k-nearest neighbours presents a tradeoff between selecting nearest neighbours and maximizing the number of observations concurrent with the target.  From the literature, concurrency is assured by pre-selecting a subset of stations with continuous records over a common period of record, or by infilling gaps with k-nearest neighbours simulation.  Some kind of tradeoff must be made, and we aim to use a method that maximizes information content while minimizing the number of assumptions.  The following notes are intended to clarify the implications of using k-nearest neighbours to fill gaps in the time series.

1. **Infilled-by-kNN != Independent Proxy**: If a gap in an observation record is inferred from neighbors, it becomes redundant in the ensemble and increases the weight of the other (k minus n) neighbours.  So at that time step, its influence is non-unique, and including it in the ensemble is functionally equivalent to using the same set of other proxies directly, or just reducing the ensemble size.

2. **Inflated Ensemble Size**: Filling gaps by "nested" k-nearest neighbours inflates the expresed number of independent neighbors.  Comparing the effectiveness of ensemble simulations as a function of k is then misleading because the effective number of independent proxies is *at most* k. 

3. **Information leakage risk**: If you repeatedly use kNN to fill missing data from within the same pool, especially when simulating extreme values, you risk suppressing variability by biasing toward the central tendency of the ensemble.  This defeats one of the core motivations for kNN: to preserve structure and variability from observations at neighboring stations.

To address the nuance above, we propose three time-based methods for selecting k-nearest neighbours beyond strictly nodes in the network.  The problem is related to the set-cover problem where the goal is to select a subset of stations that maximizes the intersection of their data availability over a specified time period.  The following sections outline the three methods for selecting k-nearest neighbours based on availability of concurrent data.

### Summary: Set-Theoretic Foundations of Strict k-NN Concurrency Selection

This problem is closely related to classic combinatorial and set-theoretic optimization problems.

#### Set-Theoretic Definition

Let each column $( S_i \subseteq T )$ represent the set of timestamps where station $( i )$ has valid (non-NaN) data.  
Let $( \mathcal{S} = \{ S_1, S_2, \dots, S_n \} )$ be the collection of all such subsets, sorted by proximity (e.g., distance or attribute similarity).  
The goal is to select a subset $( \mathcal{K} \subset \mathcal{S} )$ such that:
- $( |\mathcal{K}| = k )$
- $( \bigcap_{S \in \mathcal{K}} S )$ satisfies a temporal completeness constraint (e.g., ≥5 years with ≥10 observations in each of 12 months)

This is a constrained subset selection problem on the intersection of sets.

#### Related Concepts

| Concept                                 | Description |
|----------------------------------------|-------------|
| Set Intersection Selection             | Select \( k \) sets whose intersection satisfies a completeness constraint. |
| Maximum Coverage under Cardinality Constraint | Choose \( k \) sets to maximize the coverage (or completeness) of their intersection. |
| Recursive k-Subset Validation          | If the initial \( k \) sets fail, iteratively add more candidates and evaluate all \( \binom{k+1}{k} \) combinations, and so on. |
| NP-Hard Nature                         | This problem is computationally hard and shares structure with the Set Cover and Maximum Coverage problems. |

#### Practical Implication

This formulation justifies using greedy or approximate subset selection strategies when exhaustively testing all combinations becomes computationally infeasible.
## Define a universal parametric prior

In order to fairly test how parametric and non-parametric pdf estimation methods compare to each other, we need a consistent way to deal with indeterminate cases where the simulated distribution does not provide support coverage of the "ground truth" observations.  I feel two ways about this: the KL divergence is the culprit here, and the problem could be avoided by choosing another divergence measure.  However the definintion of KL divergence in information theoretic terms of compression make it seem more foundational than other measures, but ultimately is this true?  Should we look to math statistics to make more direct links between f-divergences and what we use as a discriminant for a particular application?  Should we be more concerned about "Bayesian consistency" of the discriminant (or surrogate loss function) with the choice of divergence measure?


1.  **Quantify the distribution of unsupported mass across all models**.  It is important to describe the extent of the problem across the sample **and** across various methods.  i.e. discrete distributions have the issue of support coverage, but so do all methods!
2.  Even in kNN / ensemble simulation approaches, the problem of incomplete support coverage necessitates assuming some prior probability.  The issue is that setting a uniform prior over the observed range takes advantage of information about the observed range.




### Global Uniform Prior

$$f(x) = \frac{1}{b-a}, \quad x\in (a, b) \text{ and } f(x) = 0 \text{ otherwise.}$$
$$\int_a^b f(x)\text{dx} = 1$$

Given the target range is a sub interval $(c, d) \subseteq (a, b)$, then the **total** prior probability mass over (c, d) is:

$$M_\text{target} = \int_c^d \frac{1}{b-a}\text{dx} = \frac{d-c}{b-a}$$

Over the set of intervals $\Delta x_i$ covering the **target range**, the probability mass associated with each interval (bin) is given by: 

$$\Delta x_i \frac{d-c}{b-a}$$



A desirable property of the prior is that it reflects the strength of belief in the model (data), where a smaller prior reflects stronger belief in the data/model and vice versa.  Dividing by the number of observations has such an effect, however it also makes for very small priors.  The consequence of very small priors is they have negligible effect on models that provide complete support coverage, and they severely penalize models that do not, resulting in a form of instability.  The very small prior creates a heavy tail in the distribution of a large sample of KL divergences, with further downstream effects in optimization.  

A method that uses a prior with negligible effect on a model with complete support coverage and a very big effect on one without can be interpreted in a few ways:  

1.  Incomplete support coverage, or underspecification, is very heavily penalized.  The method does not tolerate a model that cannot predict the full observed range.
2.  A **proper** probability distribution sums (discrete) or integrates (continuous) to 1.  Very small probabilities are in a sense associated with a high degree of certainty since they reflect the expectation of the system being observed in a particular state.
3.  The penalty of underestimating a state frequency is that storing and transmitting information about the state requires (the log ratio) more bandwidth/disk space because it is assigned a longer bit string than the actual frequency calls for under optimal encoding.
4.  Assigning a very small probability to a state ...

In [8]:
# load the predicted parameter results
parameter_prediction_results_folder = os.path.join('data', 'results', 'parameter_prediction_results', )
predicted_params_fpath   = os.path.join(parameter_prediction_results_folder, 'mean_parameter_predictions.csv')
rdf = pd.read_csv(predicted_params_fpath, index_col=['official_id'], dtype={'official_id': str})
predicted_param_dict = rdf.to_dict(orient='index')
predicted_param_dict['0212414900'].keys()

dict_keys(['uar_mean_mean_predicted', 'uar_mean_actual', 'uar_std_mean_predicted', 'uar_std_actual', 'uar_median_mean_predicted', 'uar_median_actual', 'uar_mad_mean_predicted', 'uar_mad_actual', 'log_uar_mean_mean_predicted', 'log_uar_mean_actual', 'log_uar_std_mean_predicted', 'log_uar_std_actual', 'log_uar_median_mean_predicted', 'log_uar_median_actual', 'log_uar_mad_mean_predicted', 'log_uar_mad_actual'])

In [9]:
plots = []
predicted_param_sample = {}
for l, al in zip(['log_uar_mean_mean_predicted', 'log_uar_std_mean_predicted'], [r'$$\text{Log Mean UAR }(L/s/\text{km}^2)$$', r'$$\text{Log SD UAR }(L/s/\text{km}^2)$$']):
    vals = [d[l] for _, d in predicted_param_dict.items()]
    predicted_param_sample[l] = vals
    # plot the histogram of the mean_uar values
    hist, edges = np.histogram(vals, bins=40, density=True)
    # create a scatter plot of the predicted parameter vs the target parameter
    f = figure(title=f'Predicted {l}', width=600, height=400)
    f.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:], fill_color='lightblue', line_color='black', legend_label='')
    f.xaxis.axis_label = al
    f.yaxis.axis_label = r'$$P(x)$$'
    f = dpf.format_fig_fonts(f, font_size=14)
    plots.append(f)
# retrieve all the mean_uar values 

lt = gridplot(plots, ncols=2, width=400, height=400)
show(lt)

In [10]:
class FDCEstimatorRunner:
    def __init__(self, stn_id, ctx, methods, k_nearest, parametric_target_cols, estimator_classes, **kwargs):
        self.stn_id = stn_id
        self.ctx = ctx
        self.methods = methods
        self.k_nearest = k_nearest
        self.parametric_target_cols = parametric_target_cols
        # self._check_min_overlap()
        self._create_results_folders()
        self._create_readme()
        self.ESTIMATOR_CLASSES = estimator_classes
        self.prior_strength = ctx.prior_strength

    def _create_results_folders(self):
        # create a results foder for each method if it doesn't exist
        self.results_folder = os.path.join('data', 'results', f'fdc_estimation_results',)
        for method in self.methods:
            method_folder = os.path.join(self.results_folder, method)
            if not os.path.exists(method_folder):
                os.makedirs(method_folder)

    
    def _create_readme(self):
        # create a readme file in the results folder to list constraints
        readme_file = os.path.join(self.results_folder, 'README.txt')
        
        with open(readme_file, 'w') as file:
            file.write("This folder contains the results of the FDC estimation.\n")
            file.write(f"Methods evaluated: {', '.join(self.methods)}\n")
            # add the concurrency constraint and number of stations represented in the network
            N = len(self.ctx.official_ids)
            if self.ctx.LSTM_concurrent_network == True:
                file.write(f'Uses only stations within Daymet input period of record / LSTM results: N={N} stations in the network.\n')
                file.write(f'Global start date on streamflow data: {self.ctx.global_start_date}\n')
            else:
                file.write(f'Uses all available network stations in the BCUB region (1950-2024): N={N} stationsin the network.')
                

    def _load_reference_distributions(self):
        self.kde = KDEEstimator(self.data.baseline_log_grid, self.data.log_dx)
        self.baseline_pmf, self.baseline_pdf = self.data.baseline_pmf, self.data.baseline_pdf
        self.ctx.baseline_pmf = self.baseline_pmf


    def _save_result(self, result):
        with open(self.result_file, 'w') as file:
            json.dump(result, file, indent=4)

 
    def run_selected(self):
        # check the minimum number of years of overlap for all stations in self.ctx.overlap_dict
        
        for method in self.methods:
            self.result_file = os.path.join(self.results_folder, method, f'{self.stn_id}_fdc_results.json')
            if os.path.exists(self.result_file):
                continue
            else:
                self.data = StationData(self.ctx, self.stn_id)
                self.data.k_nearest = self.k_nearest
                self.data.parametric_target_cols = self.parametric_target_cols
                self._load_reference_distributions()
            try:
                EstimatorClass = self.ESTIMATOR_CLASSES[method]
                estimator = EstimatorClass(
                    self.ctx, self.stn_id, self.data
                )                
                eval_metrics = EvaluationMetrics(self.data.baseline_log_grid, self.data.log_dx)
                result = estimator.run_estimators(
                    eval_metrics=eval_metrics,
                )
                self._save_result(result)
            except Exception as e:
                raise Exception(f"  {method} estimator failed for {self.stn_id}: {str(e)}")
                

In [11]:
class ParametricFDCEstimator:
    def __init__(self, ctx, target_stn, data, *args, **kwargs):
        # super().__init__(*args, **kwargs)
        self.ctx = ctx
        self.target_stn = target_stn
        self.data = data
        # self.data = data
        self.predicted_param_dict = self.ctx.predicted_param_dict
        self.predicted_param_df = pd.DataFrame(self.predicted_param_dict).T


    def _compute_lognorm_pmf(self, mu, sigma):
        pdf = norm.pdf(self.data.baseline_log_grid, loc=mu, scale=sigma)
        pdf /= jnp.trapezoid(pdf, x=self.data.baseline_log_grid)
        pmf = pdf * self.data.log_dx
        pmf /= pmf.sum()
        return pmf, pdf
    

    def _compute_GEV_pmf(self, xi, mu, sigma):
        # assert values are within the valid range for GEV
        xi = max(xi, -0.5 + 1e-12)  # clip xi to avoid numerical issues
        sigma = max(sigma, 1e-12)  # ensure sigma is positive
        pdf = genextreme.pdf(self.data.baseline_log_grid, xi, loc=mu, scale=sigma)
        pdf /= jnp.trapezoid(pdf, x=self.data.baseline_log_grid)
        pmf = pdf * self.data.log_dx
        pmf /= pmf.sum()  # normalize raw PMF
        return pmf, pdf


    def _estimate_from_mle(self):
        log_mu = self.predicted_param_dict[self.target_stn]['log_uar_mean_actual']
        log_sigma = self.predicted_param_dict[self.target_stn]['log_uar_std_actual']
        return self._compute_lognorm_pmf(log_mu, log_sigma)


    # def _estimate_from_observed_lmoments_gev(self):
    #     # compute the GEV parameters from the L-moments
    #     xi = self.data.LN_param_dict['logx_lmom_xi'][self.target_stn]['actual']
    #     loc = self.data.LN_param_dict['logx_lmom_loc'][self.target_stn]['actual']
    #     scale = self.data.LN_param_dict['logx_lmom_scale'][self.target_stn]['actual']
    #     return self._compute_GEV_pmf(xi, loc, scale)
    

    def _estimate_from_predicted_log_params(self):
        mu = self.predicted_param_dict[self.target_stn]['log_uar_mean_mean_predicted']
        sigma = self.predicted_param_dict[self.target_stn]['log_uar_std_mean_predicted']
        return self._compute_lognorm_pmf(mu, sigma)
        
    
    def _estimate_from_predicted_linear_mom(self):
        mean_x = self.predicted_param_dict[self.target_stn]['uar_mean_mean_predicted']
        sd_x = self.predicted_param_dict[self.target_stn]['uar_std_mean_predicted']
        v = np.log(1 + (sd_x / mean_x) ** 2)
        mu = np.log(mean_x) - 0.5 * v
        return self._compute_lognorm_pmf(mu, np.sqrt(v))
    

    def _estimate_LN_from_randomly_drawn_params(self):
        # randomly draw from the predicted parameters
        random_idx = np.random.choice(len(self.predicted_param_df))
        random_stn_idx = self.predicted_param_df.index[random_idx]
        mu_random =self.predicted_param_dict[random_stn_idx]['log_uar_mean_mean_predicted']
        sigma_random = self.predicted_param_dict[random_stn_idx]['log_uar_std_mean_predicted']
        return self._compute_lognorm_pmf(mu_random, sigma_random)


    def run_estimators(self, eval_metrics):
        results = {}
        fns = [
            self._estimate_from_mle, 
            self._estimate_from_predicted_log_params,
            self._estimate_from_predicted_linear_mom, 
            self._estimate_LN_from_randomly_drawn_params,
            # self._estimate_from_observed_lmoments_gev,
            # self._estimate_from_predicted_lmoments_gev, 
            # self._estimate_LMOM_gev_from_randomly_drawn_params
            ]
        labels = ['MLE', 'PredictedLog', 'PredictedMOM', 'RandomDraw', 
                  #'ObsLMomentsGEV', 'PredictedLMomentsGEV', 'LMomentsGEVRandomDraw',
                  ]
        for fn, label in zip(fns, labels):
            pmf, pdf = fn()
            _, pmf_posterior = self.data._compute_posterior_with_laplace_prior(pmf)            
            if 'Moments' in label:
                # assert no nan values in the pmf
                assert not np.any(np.isnan(pmf)), f'PMF contains NaN values for {label}: {pmf[:10]}'

            results[label] = {'pmf_posterior': pmf_posterior.tolist(), 'pmf': pmf.tolist()}

            estimation_metrics = eval_metrics._evaluate_fdc_metrics_from_pmf(pmf_posterior, self.data.baseline_pmf)
            results[label]['eval'] = estimation_metrics

            # compute the bias
            bias_metrics = eval_metrics._evaluate_fdc_metrics_from_pmf(pmf_posterior, pmf)
            results[label]['bias'] = bias_metrics
                
        # compute the bias from the eps
        return results

In [12]:
class LSTMFDCEstimator:
    def __init__(self, ctx, target_stn, data, *args, **kwargs):
        # super().__init__(*args, **kwargs)
        self.ctx = ctx
        self.target_stn = target_stn
        self.data = data
        # self.data = data
        self.LSTM_forcings_folder = self.ctx.LSTM_forcings_folder
        self.LSTM_ensemble_result_folder = self.ctx.LSTM_ensemble_result_folder
        self.df = self._load_ensemble_result()
        self.df = self._filter_for_complete_years()
        self.sim_cols = sorted([c for c in self.df.columns if c.startswith('streamflow_sim_')])
        self.kde = KDEEstimator(self.data.baseline_log_grid, self.data.log_dx)


    def _load_ensemble_result(self):
        fpath = os.path.join(self.LSTM_ensemble_result_folder, f'{self.target_stn}_ensemble.csv')
        df = pd.read_csv(fpath)
        # rename 'Unnamed: 0' to 'time' and set to index
        df.rename(columns={'Unnamed: 0': 'time'}, inplace=True)
        df['time'] = pd.to_datetime(df['time'])
        df.set_index('time', inplace=True)
        return df
    

    def _filter_for_complete_years(self):
        # Convert to datetime only if necessary
        if self.df.empty:
            return pd.DataFrame()
        date_column = 'time'
        self.df.reset_index(inplace=True)
        if not np.issubdtype(self.df[date_column].dtype, np.datetime64):
            self.df = self.df.copy()
            self.df[date_column] = pd.to_datetime(self.df[date_column])

        # Filter out missing values first
        valid_data = self.df.copy().dropna()

        # Extract year and month
        valid_data['year'] = valid_data[date_column].dt.year
        valid_data['month'] = valid_data[date_column].dt.month
        valid_data['day'] = valid_data[date_column].dt.day
        
        # Count total and missing days per year-month group
        month_counts = valid_data.groupby(['year', 'month'])['day'].nunique()
        
        # Identify complete months (at least 20 observations)
        complete_months = (month_counts >= 20)

        # count how many complete months per year
        complete_month_counts = complete_months.groupby(level=0).sum()
        
        complete_years = complete_month_counts[complete_month_counts == 12]
        self.complete_years = list(complete_years.index.values)

        valid_data = valid_data[valid_data['year'].isin(complete_years.index)].copy()
        # drop the year column
        return valid_data.drop(columns=['year', 'month', 'day'])
    
    
    def _load_LSTM_forcing_file(self):
        # retrieve LSTM forcing data
        # read the forcing data from the LSTM forcing file
        # and return a dataframe with the same index as the LSTM results
        ldf = pd.read_csv(os.path.join(self.met_forcings_folder, f'{self.target_stn}_forcing.csv'))
        ldf.rename(columns={'Unnamed: 0': 'time'}, inplace=True)
        ldf['time'] = pd.to_datetime(ldf['time'])
        ldf.set_index('time', inplace=True)
        ldf = ldf.loc[self.stn_df.index]
        # convert to unit area runoff (L/s/km2)
        ldf['uar'] = 1000 * ldf['discharge'] / self.target_da
        return ldf

    
    def _plot_pmfs(self, pmf_time, pmf_freq, line_dash='solid'):
        # plot using bokeh
        f = figure(title=self.target_stn, width=600, height=400)
        f.line(self.data.baseline_log_grid, pmf_time, line_width=2, color='blue', legend_label='Time Ensemble', line_dash=line_dash)
        # f.line(self.data.baseline_log_grid, pmf1, line_width=2, color='red', legend_label='T_MeanLinEns PMF', line_dash=line_dash)
        f.line(self.data.baseline_log_grid, pmf_freq, line_width=2, color='purple', legend_label='Frequency Ensemble', line_dash=line_dash)
        f.line(self.data.baseline_log_grid, self.ctx.baseline_pmf, line_width=2, color='green', legend_label='Observed', line_dash=line_dash)
        f.xaxis.axis_label = 'Log UAR (L/s/km2)'
        f.yaxis.axis_label = 'PMF'
        f.legend.location = 'top_left'
        f.legend.background_fill_alpha = 0.25
        f.legend.click_policy = 'hide'
        f = dpf.format_fig_fonts(f, font_size=14)
        show(f)


    def _compute_time_ensemble_pmf(self):
        data = self.df[self.sim_cols].copy()
        temporal_ensemble_log = data.mean(axis=1) # this is still in log space
        self.temporal_ensemble = np.exp(temporal_ensemble_log.values)
        pmf, _ = self.kde.compute(self.temporal_ensemble, self.data.target_da)
        _, pmf_posterior = self.data._compute_posterior_with_laplace_prior(pmf)
        return (pmf, pmf_posterior)


    def _compute_frequency_ensemble_pmf(self):
        data = self.df[self.sim_cols].copy()
        data.dropna(inplace=True)
        # compute the frequency ensemble PMF
        # initialize a len(data) x n_sim_cols array
        pmfs = np.column_stack([
            self.kde.compute(np.exp(data[c].values), self.data.target_da)[0]
            for c in self.sim_cols
        ])
        # average the pmfs over the ensemble 
        pmf = pmfs.mean(axis=1)
        assert len(pmf) == len(self.data.baseline_log_grid), f'len(pmfs) = {len(pmfs)} != len(baseline_log_grid) = {len(self.data.baseline_log_grid)}' 
        _, pmf_posterior = self.data._compute_posterior_with_laplace_prior(pmf)
        return (pmf, pmf_posterior)


    def _compute_ensemble_distribution_estimate(self, ensemble_type, eval_metrics):
        if ensemble_type == 'time':
            pmf, pmf_posterior = self._compute_time_ensemble_pmf()
        elif ensemble_type == 'frequency':
            pmf, pmf_posterior = self._compute_frequency_ensemble_pmf()
        else:
            raise ValueError(f'Unknown ensemble type: {ensemble_type}')
        
        # compute the divergence measures
        result = {}
        result['eval'] = eval_metrics._evaluate_fdc_metrics_from_pmf(pmf_posterior, self.data.baseline_pmf)
        result['bias'] = eval_metrics._evaluate_fdc_metrics_from_pmf(pmf_posterior, pmf)
        return result


    def run_estimators(self, eval_metrics):
        # met_forcing = self._load_LSTM_forcing_file()  # Load LSTM forcing data
        results = {}
        for ensemble_type in ['time', 'frequency']:
            print(f'     Processing {ensemble_type} ensemble for {self.target_stn}')
            result = self._compute_ensemble_distribution_estimate(ensemble_type, eval_metrics)
            results[ensemble_type] = result
        return results

In [13]:
class kNNFDCEstimator:
    def __init__(self, ctx, target_stn, data, *args, **kwargs):
        self.ctx = ctx
        self.target_stn = target_stn
        self.data = data
        self.k_nearest = data.k_nearest
        # self.max_to_check_start = data.max_to_check
        # self.max_to_check = data.max_to_check
        self.weight_schemes = [1, 2] #inverse distance and inverse square distance
        self.knn_simulation_data = {}
        self.knn_pdfs = pd.DataFrame()
        self.knn_pmfs = pd.DataFrame()


    def _find_k_nearest_neighbors(self, tree_type, max_to_check):
        # Query the k+1 nearest neighbors because the first neighbor is the target point itself
        target_idx = self.ctx.id_to_idx[self.target_stn]
        if tree_type == 'spatial_dist':
            distances, indices = self.ctx.spatial_tree.query(self.ctx.coords[target_idx], k=max_to_check)
            distances *= 1 / 1000
        elif tree_type == 'attribute_dist':
            # Example query: Find the nearest neighbors for the first point
            distances, indices = self.ctx.attribute_tree.query(self.ctx.normalized_attr_values[target_idx], k=max_to_check)
        else:
            raise Exception('tree type not identified, must be one of spatial_dist, or attribute_dist.')
        
        # Remove target (self) from the results
        self_index = target_idx
        keep = indices != self_index
        indices = indices[keep]
        distances = distances[keep]

        return indices, np.round(distances, 3)
    

    def _compute_effective_k(self, df, max_k=None):
        arr = df.to_numpy()
        T, K = arr.shape
        max_k = max_k or K

        nan_mask = np.isnan(arr)
        sorted_idx = np.argsort(nan_mask, axis=1)
        row_idx = np.arange(T)[:, None]

        ks = np.arange(1, max_k + 1)
        effective_k = []
        mean_furthest = []

        for k in ks:
            idx = sorted_idx[:, :k]
            valid = ~nan_mask[row_idx, idx]
            valid_count = valid.sum(axis=1)

            effective_k.append(valid_count.mean())
            furthest_idx = np.where(valid, idx, -1).max(axis=1)
            mean_furthest.append(furthest_idx[valid_count >= k].mean() if np.any(valid_count >= k) else np.nan)

        return pd.DataFrame({
            'effective_k': np.round(effective_k, 1),
            'mean_furthest_idx': np.round(mean_furthest, 2)
        }, index=ks)
    

    def _query_distance(self, tree, id1, id2):
        """Query distance between two points in a tree using official_id."""
        if id1 not in self.ctx.id_to_idx or id2 not in self.ctx.id_to_idx:
            raise ValueError(f"One or both IDs ({id1}, {id2}) not found.")
    
        # Get indices from ID mapping
        index1, index2 = self.ctx.id_to_idx[id1], self.ctx.id_to_idx[id2]
        # Query the distance
        distance = np.linalg.norm(tree.data[index1] - tree.data[index2])  # Euclidean distance
        return distance
    

    def _process_neighbor(args):
        """
        Process a single neighbor to retrieve its data and compute the number of complete years.
        This function is designed to be used with multiprocessing.
        """
        nbr_id, dist, retrieve_fn, find_complete_fn = args

        try:
            df = retrieve_fn(nbr_id)
            if not isinstance(df, pd.DataFrame) or df.empty:
                return None

            proxy_df = df[[f'{nbr_id}_uar']]
            complete_years = list(find_complete_fn(proxy_df))
            n_years = len(complete_years)
            return (nbr_id, dist, n_years, proxy_df)
        except Exception as e:
            print(f"Failed to process {nbr_id}: {e}")
            return None
        
    
    def _retrieve_nearest_nbr_data(self, tree_type):
        MAX_CHECK = 700
        REQUIRED_GOOD = 10
        # Get the index of the target station
        
        # Query once for all potential neighbors
        nbr_idxs, dists = self._find_k_nearest_neighbors(tree_type, MAX_CHECK)
        nbr_ids = [self.ctx.idx_to_id[i] for i in nbr_idxs if self.ctx.idx_to_id[i] != self.target_stn]
        distances = [d for i, d in zip(nbr_idxs, dists) if self.ctx.idx_to_id[i] != self.target_stn]

        good_nbrs = []
        sorted_nbrs = sorted(zip(nbr_ids, distances), key=lambda x: x[1])        

        for (nbr_id, dist) in sorted_nbrs:
            df = self.data.retrieve_timeseries_discharge(nbr_id)
            if not isinstance(df, pd.DataFrame) or df.empty:
                continue  # Skip bad or empty DataFrames
            proxy_df = df[[f'{nbr_id}_uar']]
            complete_years = self.data.complete_year_dict[nbr_id]
            n_years = len(complete_years)
            good_nbrs.append((nbr_id, dist, n_years, proxy_df))
            if len(good_nbrs) == REQUIRED_GOOD:
                break

        # Concatenate the timeseries
        nbr_df = pd.concat([r[3] for r in good_nbrs], axis=1)
        # Build metadata DataFrame
        complement_type = 'attribute_dist' if tree_type == 'spatial_dist' else 'spatial_dist'
        complement_tree = getattr(self.ctx, f"{complement_type.split('_')[0]}_tree")
        scale = 1 / 1000 if complement_type == 'spatial_dist' else 1

        nbr_data = pd.DataFrame(
            [r[:3] for r in good_nbrs],
            columns=['official_id', 'distance', 'n_years']
        )
        nbr_data[complement_type] = nbr_data['official_id'].apply(
            lambda x: round(scale * self._query_distance(complement_tree, x, self.target_stn), 3)
        )

        return nbr_df, nbr_data


    def _initialize_nearest_neighbour_data(self):
        """
        Generate nearest neighbours for spatial and attribute selected k-nearest neighbours for both concurrent and asynchronous records.
        """
        print(f'    ...initializing nearest neighbours with minimum concurrent record.')
        self.nbr_dfs = defaultdict(lambda: defaultdict(dict))
        
        for tree_type in ['spatial_dist', 'attribute_dist']:
            nbr_df, nbr_data = self._retrieve_nearest_nbr_data(tree_type)
            effective_k = self._compute_effective_k(nbr_df, max_k=self.k_nearest)
            self.nbr_dfs[tree_type] = {
                'nbr_df': nbr_df,
                'nbr_data': nbr_data,
                'effective_k': effective_k,
            }
    

    def _compute_weights(self, m, k, distances, epsilon=1e-3):
        """Compute normalized inverse (square) distance weights to a given power."""

        distances = jnp.where(distances == 0, epsilon, distances)

        if k == 1:
            return jnp.array([1])
        else:
            inv_weights = 1 / (jnp.abs(distances) ** m)
            return inv_weights / jnp.sum(inv_weights)
    
      
    def _compute_frequency_ensemble_mean(self, pdfs, weights):
        """
        This function computes the weighted ensemble distribution estimates.
        """
        # Normalize distance weights
        if weights is not None:
            weights /= jnp.sum(weights).astype(jnp.float32)
            weights = jnp.array(weights)  # Ensure 1D array
            pdf_est = jnp.asarray(pdfs.to_numpy() @ weights)
        else:
            pdf_est = jnp.asarray(pdfs.mean(axis=1).to_numpy())


        # Check integral before normalization
        pdf_check = jnp.trapezoid(pdf_est, x=self.data.baseline_log_grid)
        normalized_pdf = pdf_est / pdf_check
        assert jnp.isclose(jnp.trapezoid(normalized_pdf, x=self.data.baseline_log_grid), 1), f'ensemble pdf does not integrate to 1: {pdf_check:.4f}'
                
        # Compute PMF
        pmf_est = normalized_pdf * self.data.log_dx
        pmf_est /= jnp.sum(pmf_est)

        return pmf_est, pdf_est


    def _compute_ensemble_member_distribution_estimates(self, df):
        """
        Compute the ensemble distribution estimates based on the KNN dataframe.
        """    
        pdfs, prior_biases = pd.DataFrame(), {}
        # initialize a kde estimator object
        kde = KDEEstimator(self.data.baseline_log_grid, self.data.log_dx)
        for c in df.columns: 
            # evaluate the laplace on the prediction as a prior
            # drop the nan values
            values = df[c].dropna().values
            obs_count = len(values)
            assert len(values) > 0, f'0 values for {c}'

            # compute the pdf and pmf using kde
            assert sum(np.isnan(values)) == 0, f'NaN values in {c} {values[:5]}'

            kde_pmf, _ = kde.compute(
                values, self.data.target_da
            )

            prior = self.data._compute_prior_from_laplace_fit(values, n_cols=1) # priors are expressed in pseudo-counts
            # convert the pdf to counts and apply the prior
            counts = kde_pmf * obs_count + prior

            # re-normalize the pmf
            pmf = counts / jnp.sum(counts)
            pdf = pmf / self.data.log_dx

            pdf_check = jnp.trapezoid(pdf, x=self.data.baseline_log_grid)
            pdf /= pdf_check
            # pdf /= pdf_check
            assert jnp.isclose(jnp.trapezoid(pdf, x=self.data.baseline_log_grid), 1.0, atol=0.001), f'pdf does not integrate to 1 in compute_ensemble_member_distribution_estimates: {pdf_check:.4f}'
            pdfs[c] = pdf

            # convert the pdf to pmf
            pmf = pdf * self.data.log_dx
            pmf /= jnp.sum(pmf)
            # assert np.isclose(np.sum(pmf), 1, atol=1e-4), f'pmf does not sum to 1 in compute_ensemble_member_distribution_estimates: {np.sum(pmf):.5f}'
            
            # compute the bias added by the prior
            prior_biases[c.split('_')[0]] = {'DKL': self.data._compute_kld(kde_pmf, pmf), 'EMD': self.data._compute_emd(kde_pmf, pmf)}
        return pdfs, prior_biases
    
    
    def _compute_frequency_ensemble_distributions(self, nbr_df, nbr_data, distance_type):
        """
        For asynchronous comparisons, we estimate pdfs for ensemble members, then compute the mean in the time domain
        to represent the FDC simulation.  We do not do temporal averaging in this case.
        """
        knn_df_all = nbr_df.iloc[:, :self.k_nearest].copy()
        knn_data_all = nbr_data.iloc[:, :self.k_nearest].copy()
        proxy_ids = [c.split('_')[0] for c in knn_df_all.columns.tolist()]
        frequency_ensemble_pdfs = self.ctx.baseline_pmf_df[proxy_ids].copy()

        labels, pdfs, pmfs = [], [], []
        all_distances = knn_data_all['distance'].values
        all_ids = knn_data_all['official_id'].values
        # prior_bias_df = pd.DataFrame(prior_bias_dict)
        for wm in self.weight_schemes:
            for k in range(1, self.k_nearest + 1):
                distances = all_distances[:k]
                nbr_ids = all_ids[:k]
                knn_pdfs = frequency_ensemble_pdfs.iloc[:, :k].copy()

                label = f'{self.target_stn}_{k}_NN_{distance_type}_ID{wm}_freqEnsemble'
                weights = self._compute_weights(wm, k, distances)
                pmf_est, pdf_est = self._compute_frequency_ensemble_mean(knn_pdfs, weights)
                assert pmf_est is not None, f'pmf_est is None for {label}'
            
                # compute the mean number of observations (non-nan values) per row
                mean_obs_per_timestep = knn_df_all.iloc[:, :k].notna().sum(axis=1).mean()
                mean_obs_per_proxy = knn_df_all.iloc[:, :k].notna().sum(axis=0).mean()

                _, pmf_posterior = self.data._compute_posterior_with_laplace_prior(pmf_est)
                eval = self.data.eval_metrics._evaluate_fdc_metrics_from_pmf(pmf_posterior, self.data.baseline_pmf)
                bias = self.data.eval_metrics._evaluate_fdc_metrics_from_pmf(pmf_posterior, pmf_est)
      
                # compute the frequency-based ensemble pdf estimate
                self.knn_simulation_data[label] = {
                    'k': k, 'n_obs': mean_obs_per_proxy,
                    'mean_obs_per_timestep': mean_obs_per_timestep,
                    'nbrs': ','.join(nbr_ids),
                    'eval': eval,
                    'bias': bias,
                    }
                
                pdfs.append(np.asarray(pdf_est))
                pmfs.append(np.asarray(pmf_est))
                labels.append(label)

        # create a dataframe of labels(columns) for each pdf
        knn_pdfs = pd.DataFrame(pdfs, index=labels).T
        knn_pmfs = pd.DataFrame(pmfs, index=labels).T
        # Filter out already existing columns to avoid duplication
        new_pdf_cols = knn_pdfs.columns.difference(self.knn_pdfs.columns)
        new_pmf_cols = knn_pmfs.columns.difference(self.knn_pmfs.columns)
        # Concat only new columns
        self.knn_pdfs = pd.concat([self.knn_pdfs, knn_pdfs[new_pdf_cols]], axis=1)
        self.knn_pmfs = pd.concat([self.knn_pmfs, knn_pmfs[new_pmf_cols]], axis=1)
    
    
    def _delta_spike_pmf_pdf(self, single_val, log_grid):
        """
        Return a spike PMF and compatible PDF centered at the only value in the input.
        The spike is placed at the nearest log_grid point to log(single_val).
        """
        log_val = jnp.log(single_val)
        spike_idx = jnp.argmin(jnp.abs(log_grid - log_val))
        
        pmf = jnp.zeros_like(log_grid)
        pmf = pmf.at[spike_idx].set(1.0)

        dx = jnp.gradient(log_grid)
        pdf = pmf / dx  # assign all mass to one bin

        return pmf, pdf

    
    def _compute_ensemble_contribution_metrics(self, df: pd.DataFrame, weights: np.ndarray):
        mask = ~df.isna()
        
        # Mean number of valid values per row
        mean_valid_per_row = mask.sum(axis=1).mean()

        # Normalized weights per row, masking NaNs
        X = df.to_numpy()
        W = np.broadcast_to(weights, X.shape)
        masked_weights = np.where(mask, W, 0.0)
        weight_sums = masked_weights.sum(axis=1)
        weight_sums[weight_sums == 0] = np.nan
        normalized_weights = masked_weights / weight_sums[:, None]

        # Average contribution per column across all rows
        mean_w = np.nanmean(normalized_weights, axis=0)
        effective_n = 1.0 / np.nansum(mean_w ** 2)

        return mean_valid_per_row, effective_n
    

    def _weighted_row_mean_ignore_nan(self, df: pd.DataFrame, weights: np.ndarray):
        """
        Computes the weighted mean for each row, accounting for NaNs and ensuring that
        weights are re-normalized based on valid values only. Returns a Series aligned
        with df.index, as well as ensemble stats.
        """
        X = df.to_numpy()
        mask = ~np.isnan(X)

        W = np.broadcast_to(weights, X.shape)
        masked_weights = np.where(mask, W, 0.0)

        row_weight_sums = masked_weights.sum(axis=1)
        row_weight_sums[row_weight_sums == 0] = np.nan

        normalized_weights = masked_weights / row_weight_sums[:, None]
        estimated = np.nansum(X * normalized_weights, axis=1)

        # Return as Series aligned with index
        estimated_series = pd.Series(estimated, index=df.index)

        # Also compute stats on weights
        mean_valid_per_row = mask.sum(axis=1).mean()
        mean_weight_per_col = np.nanmean(normalized_weights, axis=0)
        effective_k = 1.0 / np.nansum(mean_weight_per_col ** 2)

        return estimated_series, mean_valid_per_row, effective_k


    def _finalize_temporal_ensemble(
            self, k, label, temporal_ensemble_mean, nbrs_used,
            effective_k, mean_valid_per_row
            ):

        # Clip to prevent zero runoff issues
        temporal_ensemble_mean = np.clip(
            temporal_ensemble_mean, 1000 * 1e-4 / self.data.target_da, None
        )

        # Estimate PDF/PMF using KDE or 
        # add small amount of random noise if there is no variance
        if len(jnp.unique(temporal_ensemble_mean.values)) == 1:
            est_pmf, est_pdf = self._delta_spike_pmf_pdf(
                temporal_ensemble_mean.values[0], self.data.baseline_log_grid
            )
        else:
            est_pmf, est_pdf = self.target_kde.compute(
                temporal_ensemble_mean.values, self.data.target_da
            )

        assert est_pmf is not None, f'pmf is None for {label}'

        _, pmf_posterior = self.data._compute_posterior_with_laplace_prior(est_pmf)

        eval = self.eval_metrics._evaluate_fdc_metrics_from_pmf(pmf_posterior, self.data.baseline_pmf)
        bias = self.eval_metrics._evaluate_fdc_metrics_from_pmf(pmf_posterior, est_pmf)

        # Store simulation outputs and metadata
        self.knn_pdfs[label] = est_pdf
        self.knn_pmfs[label] = est_pmf
        self.knn_simulation_data[label] = {
            'nbrs': nbrs_used,
            'k': k,
            'n_obs': len(temporal_ensemble_mean),
            'mean_': mean_valid_per_row,
            'mean_nbrs_per_timestep': effective_k,  # rename if clearer
            'effective_k': effective_k,
            'eval': eval,
            'bias': bias,
        }


    def _compute_temporal_ensemble_distributions(self, distance_type, wm, nbr_df, nbr_data):
        distances = nbr_data['distance'].values
        for k in range(1, self.k_nearest + 1):
            knn_df = nbr_df.iloc[:, :k].copy()
            label = f'{self.target_stn}_{k}_NN_{distance_type}_ID{wm}_timeEnsemble'            
            weights = self._compute_weights(wm, k, distances[:k])
            temporal_ensemble_mean, mean_valid_per_row, effective_k = self._weighted_row_mean_ignore_nan(knn_df, weights)
            nbrs_used = [c.split('_')[0] for c in knn_df.columns]
            self._finalize_temporal_ensemble(
                k, label, temporal_ensemble_mean, nbrs_used,
                effective_k, mean_valid_per_row
            )
    
    
    def _compute_distribution_estimates(self, distance_type):

        nbr_df = self.nbr_dfs[distance_type]['nbr_df'].copy()
        nbr_data = self.nbr_dfs[distance_type]['nbr_data'].copy()

        for wm in self.weight_schemes:
            # compute the FDC estimate by temporal ensemble mean
            t0 = time()
            self._compute_temporal_ensemble_distributions(distance_type, wm, nbr_df, nbr_data,)
            t1 = time()
            # compute the frequency average ensemble pdfs
            self._compute_frequency_ensemble_distributions(nbr_df, nbr_data, distance_type)
            t2 = time()
            print(f'    ...{distance_type} ID{wm} took {t1 - t0:.2f}s for temporal ensemble, {t2 - t1:.2f}s for frequency ensemble.')

        # Validation
        sim_labels = list(self.knn_simulation_data.keys())
        pdf_labels = list(self.knn_pdfs.columns)
        assert set(sim_labels) == set(pdf_labels)
        
    
    def run_estimators(self, eval_metrics):
        self.eval_metrics = eval_metrics                  
        self._initialize_nearest_neighbour_data()
        # set the baseline pdf by kde
        self.target_kde = KDEEstimator(self.data.baseline_log_grid, self.data.log_dx)
        for dist in ['spatial_dist', 'attribute_dist']:        
            self._compute_distribution_estimates(dist)
        return self._format_results()
    
    
    def _make_json_serializable(self, d):
        output = {}
        for k, v in d.items():
            if isinstance(v, (np.ndarray, jnp.ndarray)):
                output[k] = v.tolist()
            elif hasattr(v, "tolist"):
                output[k] = v.tolist()
            else:
                output[k] = v
        return output
    
    
    def _format_results(self):
        pmf_labels, pdf_labels, sim_labels = list(self.knn_pmfs.columns), list(self.knn_pdfs.columns), list(self.knn_simulation_data.keys())
        # assert label sets are the same
        assert set(pmf_labels) == set(pdf_labels), f'pmf_labels {pmf_labels} != pdf_labels {pdf_labels}'
        assert set(pmf_labels) == set(sim_labels), f'pmf_labels {pmf_labels} != sim_labels {sim_labels}'
        results = self.knn_simulation_data
        for label in pmf_labels:
            # add the pmf and pdf in a json serializable format
            results[label]['pmf'] = self.knn_pmfs[label].tolist()
            results[label]['pdf'] = self.knn_pdfs[label].tolist()
            results[label] = self._make_json_serializable(results[label])
        return results
        

In [14]:
np.random.seed(42)

target_cols = [
    'mean_uar', 'sd_uar', 
    'mean_logx', 'sd_logx', 
]

# from utils import FDCEstimationContext
attr_df_fpath = os.path.join('data', f'catchment_attributes_with_runoff_stats.csv')
LSTM_forcings_folder = '/home/danbot/neuralhydrology/data/BCUB_catchment_mean_met_forcings_20250320'
LSTM_ensemble_result_folder = '/home/danbot/code/neuralhydrology/data/ensemble_results_20250514'
baseline_distribution_folder = os.path.join('data', 'results', 'baseline_distributions')
# parameter_prediction_results_folder = os.path.join('data', 'parameter_prediction_results')

methods = ('parametric', 'lstm', 'knn',)
# methods = ('knn',)
exclude_pre_1980_data = False  # use only stations with data 1980-present concurrent with Daymet
daymet_start_date = '1950-01-01'  # default start date for Daymet data
k_nearest = 10
if exclude_pre_1980_data:
    daymet_start_date = '1980-01-01'

processed = []
ESTIMATOR_CLASSES = {
    'parametric': ParametricFDCEstimator,
    'lstm': LSTMFDCEstimator,
    'knn': kNNFDCEstimator,
    # add others here
}
input_data = {
    'attr_df_fpath': attr_df_fpath,
    'LSTM_forcings_folder': LSTM_forcings_folder,
    'LSTM_ensemble_result_folder': LSTM_ensemble_result_folder,
    'LSTM_concurrent_network': exclude_pre_1980_data,  # use only stations with data 1980-present concurrent with Daymet
    'daymet_start_date': daymet_start_date,
    # 'parameter_prediction_results_folder': parameter_prediction_results_folder,
    'predicted_param_dict': predicted_param_dict,
    'divergence_measures': ['DKL', 'EMD'],
    'baseline_pmf_stations': pmf_stations,
    'eps': 1e-12,
    'min_flow': 1e-4,
    'n_grid_points': 2**12,
    'min_record_length': 5,
    'minimum_days_per_month': 20,
    'parametric_target_cols': target_cols,
    'all_official_ids': station_ids,
    'daymet_concurrent_stations': daymet_concurrent_stations,
    'baseline_distribution_folder': baseline_distribution_folder,
    'prior_strength': 1e-2,  # prior strength for the Laplace fit
}

context = FDCEstimationContext(**input_data)

    Using all stations in the catchment data with a baseline PMF (validated): 1097
    ...overlap dict loaded from data/record_overlap_dict.json


In [15]:
import warnings
import sys
import traceback

def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
    log = file if hasattr(file, 'write') else sys.stderr
    traceback.print_stack(file=log)
    log.write(warnings.formatwarning(message, category, filename, lineno, line))

warnings.showwarning = warn_with_traceback

In [None]:
processed = []
t0 = time()
for stn in [s for s in context.official_ids if s in daymet_concurrent_stations]:
    if stn == '12414900': # this station has no data in the LSTM ensemble results
        print(f'    ...skipping {stn} due to naming issue.')
        continue
    print(f'Estimating FDC for {stn}...')
    runner = FDCEstimatorRunner(stn, context, methods, k_nearest, target_cols, ESTIMATOR_CLASSES)
    runner.run_selected()
    processed.append(stn)
    if len(processed) % 10 == 0:
        t1 = time()
        elapsed = t1 - t0
        unit_time = elapsed / len(processed)
        print(f'Processed {len(processed)}/{len(context.official_ids)} stations in {unit_time:.2f} seconds per station')

Estimating FDC for 05AA008...
     Processing time ensemble for 05AA008
     Processing frequency ensemble for 05AA008
    ...initializing nearest neighbours with minimum concurrent record.
    ...spatial_dist ID1 took 10.27s for temporal ensemble, 1.01s for frequency ensemble.
    ...spatial_dist ID2 took 5.95s for temporal ensemble, 0.32s for frequency ensemble.
    ...attribute_dist ID1 took 10.33s for temporal ensemble, 0.48s for frequency ensemble.
    ...attribute_dist ID2 took 8.01s for temporal ensemble, 0.31s for frequency ensemble.
Estimating FDC for 05AA022...
     Processing time ensemble for 05AA022
     Processing frequency ensemble for 05AA022
    ...initializing nearest neighbours with minimum concurrent record.
    ...spatial_dist ID1 took 10.21s for temporal ensemble, 0.28s for frequency ensemble.
    ...spatial_dist ID2 took 6.41s for temporal ensemble, 0.29s for frequency ensemble.
    ...attribute_dist ID1 took 9.31s for temporal ensemble, 0.32s for frequency ensem