In [10]:
import os
import pandas as pd
import numpy as np
import json
from time import time
import xarray as xr
import geopandas as gpd

from scipy.spatial import cKDTree
from sklearn.preprocessing import StandardScaler
from collections import defaultdict

import xgboost as xgb
xgb.config_context(verbosity=2)
from scipy.stats import wasserstein_distance
from scipy.stats import norm, laplace

from bokeh.plotting import figure, show, save, output_file
# from bokeh.io import output_notebook
from bokeh.models import LinearColorMapper

import jax
import jax.numpy as jnp
from jax import jit

# from KDEpy import FFTKDE

import data_processing_functions as dpf

from concurrent.futures import ThreadPoolExecutor

from kde_estimator import KDEEstimator

from pathlib import Path
BASE_DIR = os.getcwd()



In [2]:
from bokeh.plotting import figure, show, output_notebook
from bokeh.layouts import gridplot

import xyzservices.providers as xyz
tiles = xyz['USGS']['USTopo']
output_notebook()


In [12]:
# load the catchment characteristics
rev_date = '20250227'
fname = f'BCUB_watershed_attributes_updated_{rev_date}.csv'
attr_df = pd.read_csv(os.path.join('data', fname))
attr_df.columns = [c.lower() for c in attr_df.columns]
attr_df['log_drainage_area_km2'] = np.log(attr_df['drainage_area_km2'])
# attr_df = attr_df[~attr_df['official_id'].isin(exclude)]
# attr_df.columns = [c.lower() for c in attr_df.columns]
attr_df['tmean'] = (attr_df['tmin'] + attr_df['tmax']) / 2.0
station_ids = attr_df['official_id'].values

print(f'There are {len(station_ids)} monitored basins in the attribute set.')



There are 1308 monitored basins in the attribute set.


In [13]:
# streamflow folder from (updated) HYSETS
HYSETS_DIR = Path('/home/danbot2/code_5820/large_sample_hydrology/common_data/HYSETS_data')
STREAMFLOW_DIR = HYSETS_DIR / 'streamflow'

hs_df = pd.read_csv(HYSETS_DIR / 'HYSETS_watershed_properties.txt', sep=';')
hs_df = hs_df[hs_df['Official_ID'].isin(station_ids)]
hs_df.head(2)


Unnamed: 0,Watershed_ID,Source,Name,Official_ID,Centroid_Lat_deg_N,Centroid_Lon_deg_E,Drainage_Area_km2,Drainage_Area_GSIM_km2,Flag_GSIM_boundaries,Flag_Artificial_Boundaries,...,Land_Use_Wetland_frac,Land_Use_Water_frac,Land_Use_Urban_frac,Land_Use_Shrubs_frac,Land_Use_Crops_frac,Land_Use_Snow_Ice_frac,Flag_Land_Use_Extraction,Permeability_logk_m2,Porosity_frac,Flag_Subsoil_Extraction
846,847,HYDAT,CROWSNEST RIVER AT FRANK,05AA008,49.59732,-114.4106,402.6522,,0,0,...,0.0103,0.0065,0.0328,0.0785,0.0015,0.0002,1,-15.543306,0.170479,1
849,850,HYDAT,CASTLE RIVER NEAR BEAVER MINES,05AA022,49.48866,-114.1444,820.651,,0,0,...,0.0058,0.0023,0.0105,0.1156,0.0246,0.0,1,-15.929747,0.150196,1


### Global Uniform Prior

$$f(x) = \frac{1}{b-a}, \quad x\in (a, b) \text{ and } f(x) = 0 \text{ otherwise.}$$
$$\int_a^b f(x)\text{dx} = 1$$

Given the target range is a sub interval $(c, d) \subseteq (a, b)$, then the **total** prior probability mass over (c, d) is:

$$M_\text{target} = \int_c^d \frac{1}{b-a}\text{dx} = \frac{d-c}{b-a}$$

Over the set of intervals $\Delta x_i$ covering the **target range**, the probability mass associated with each interval (bin) is given by: 

$$\Delta x_i \frac{d-c}{b-a}$$



A desirable property of the prior is that it reflects the strength of belief in the model (data), where a smaller prior reflects stronger belief in the data/model and vice versa.  Dividing by the number of observations has such an effect, however it also makes for very small priors.  The consequence of very small priors is they have negligible effect on models that provide complete support coverage, and they severely penalize models that do not, resulting in a form of instability.  The very small prior creates a heavy tail in the distribution of a large sample of KL divergences, with further downstream effects in optimization.  

A method that uses a prior with negligible effect on a model with complete support coverage and a very big effect on one without can be interpreted in a few ways:  

1.  Incomplete support coverage, or underspecification, is very heavily penalized.  The method does not tolerate a model that cannot predict the full observed range.
2.  A **proper** probability distribution sums (discrete) or integrates (continuous) to 1.  Very small probabilities are in a sense associated with a high degree of certainty since they reflect the expectation of the system being observed in a particular state.
3.  The penalty of underestimating a state frequency is that storing and transmitting information about the state requires (the log ratio) more bandwidth/disk space because it is assigned a longer bit string than the actual frequency calls for under optimal encoding.
4.  Assigning a very small probability to a state ...

In [14]:
@jit
def compute_overlap_matrix(mask_jax):
    return jnp.matmul(mask_jax, mask_jax.T)

class FDCEstimationContext:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

        self._load_catchment_data()
        self._load_and_filter_hysets_data()
        self.LN_param_dict = self._get_ln_params()
        self.id_to_idx, self.idx_to_id = self._set_tree_idx_mappers()
        self._build_network_trees()
        self._set_attribute_indexers()
        self.overlap_dict = self._compute_concurrent_overlap_dict()
        
    
    def _load_and_filter_hysets_data(self):
        hs_df = pd.read_csv(HYSETS_DIR / 'HYSETS_watershed_properties.txt', sep=';')
        
        hs_df = hs_df[hs_df['Official_ID'].isin(self.study_ids)]
        self.global_start_date = pd.to_datetime('1950-01-01') # Hysets streamflow starts 1950-01-01
        self.hs_df = hs_df
        # load the updated HYSETS data

        updated_filename = 'HYSETS_2023_update_QC_stations.nc'
        ds = xr.open_dataset(HYSETS_DIR / updated_filename)

        # Get valid IDs as a NumPy array
        selected_ids = hs_df['Watershed_ID'].values

        # Get boolean index where watershedID in selected_set
        # safely access watershedID as a variable first
        ws_ids = ds['watershedID'].data  # or .values if you prefer
        mask = np.isin(ws_ids, selected_ids)

        # Apply mask to data
        ds = ds.sel(watershed=mask)
        # Step 1: Promote 'watershedID' to a coordinate on the 'watershed' dimension
        ds = ds.assign_coords(watershedID=("watershed", ds["watershedID"].data))

        # filter the timeseries by the global start date
        ds = ds.sel(time=slice(self.global_start_date, None))

        # Step 2: Set 'watershedID' as the index for the 'watershed' dimension
        self.ds = ds.set_index(watershed="watershedID")
    
    
    def _load_catchment_data(self):
        gdf = gpd.read_file(self.attr_gdf_fpath)
        gdf.columns = [c.lower() for c in gdf.columns]

        self.study_ids = gdf['official_id'].values
        
        # import the license water extraction points
        dam_gdf = gpd.read_file('data/Dam_Points_20240103.gpkg')
        assert gdf.crs == dam_gdf.crs, "CRS of catchment and dam dataframes do not match"
        joined = gpd.sjoin(gdf, dam_gdf, how="inner", predicate="contains")
        joined = joined[joined['official_id'].isin(self.study_ids)]
        # Create a new boolean column 'contains_dam' in catchment_gdf.
        # If a polygon's index appears in the joined result, it means it contains at least one point.
        regulated = joined['official_id'].values
        N = len(gdf)
        print(f'{len(regulated)}/{N} catchments contain withdrawal licenses')
                
        # create dict structures for easier access of attributes and geometries
        self.da_dict = gdf[['official_id', 'drainage_area_km2']].set_index('official_id').to_dict()['drainage_area_km2']
        self.polygon_dict = gdf[['official_id', 'geometry']].set_index('official_id').to_dict()['geometry']
        
        centroids = gdf.apply(lambda x: self.polygon_dict[x['official_id']].centroid, axis=1)
        attr_gdf = gpd.GeoDataFrame(gdf, geometry=centroids, crs=gdf.crs)
        attr_gdf["contains_dam"] = attr_gdf['official_id'].apply(lambda x: x in regulated)
        attr_gdf.reset_index(inplace=True, drop=True)
        self.attr_gdf = attr_gdf


    def _build_network_trees(self, attribute_cols=['log_drainage_area_km2', 'elevation_m', 'prcp', 'tmean', 'swe', 'srad',
                            'centroid_lon_deg_e', 'centroid_lat_deg_n', 'land_use_forest_frac_2010', 'land_use_snow_ice_frac_2010',
                            #  , 'land_use_wetland_frac_2010', 'land_use_water_frac_2010', 
                            ]):
        self.coords = np.array([[geom.x, geom.y] for geom in self.attr_gdf.geometry.values])
        self.spatial_tree = cKDTree(self.coords)

        # Create mapping from official_id to index
        self.id_to_index = {oid: i for i, oid in enumerate(self.attr_gdf["official_id"])}
        self.index_to_id = {i: oid for oid, i in self.id_to_index.items()}  # Reverse mapping

        # Extract values (excluding 'official_id' since it's categorical)
        self.attr_gdf['tmean'] = (self.attr_gdf['tmin'] + self.attr_gdf['tmax']) / 2.0
        self.attr_gdf['log_drainage_area_km2'] = np.log(self.attr_gdf['drainage_area_km2'])
        attr_values = self.attr_gdf[attribute_cols].to_numpy()
        scaler = StandardScaler()
        self.normalized_attr_values = scaler.fit_transform(attr_values)
        # Construct the KDTree for the attribute space
        self.attribute_tree = cKDTree(self.normalized_attr_values)


    def _set_tree_idx_mappers(self):
        """Set the index mappers for the official_id to index and vice versa.
        This is for the network TREES"""
        id_to_idx = {id: i for i, id in enumerate(self.attr_gdf['official_id'].values)}
        idx_to_id = {i: id for i, id in enumerate(self.attr_gdf['official_id'].values)}

        return id_to_idx, idx_to_id
    

    def _set_attribute_indexers(self):
        # map keys to their 
        # overlap_dict[1].keys()
        # create a dictionary where the keys are Watershed_ID and the values are Official_ID
        self.watershed_id_dict = {row['Watershed_ID']: row['Official_ID'] for _, row in self.hs_df.iterrows()}
        # and the inverse
        self.official_id_dict = {row['Official_ID']: row['Watershed_ID'] for _, row in self.hs_df.iterrows()}
        # also for drainage areas
        self.da_dict = {row['Official_ID']: row['Drainage_Area_km2'] for _, row in self.hs_df.iterrows()}
        

    def _get_ln_params(self):
        """Retrieve log-normal parameters for a station."""
        target_columns = ['mean_uar', 'sd_uar', 'mean_logx', 'sd_logx']

        predicted_param_dict = {}
        for t in target_columns:
            fpath = os.path.join(self.parameter_prediction_results_folder, f'best_out_of_sample_{t}_predictions.csv')
            rdf = pd.read_csv(fpath, index_col='official_id')
            rdf = rdf[[c for c in rdf.columns if not c.startswith('Unnamed:')]].sort_values('official_id')
            predicted_param_dict[t] = rdf.to_dict(orient='index')
        return predicted_param_dict
    

    def _generate_12_month_windows(self, index):
        months = pd.date_range(index.min(), index.max(), freq='MS')
        windows = [(start, start + pd.DateOffset(months=12) - pd.Timedelta(days=1)) for start in months]
        return [w for w in windows if w[1] <= index.max()]
    
    
    def _is_window_valid(self, ts, start, end):
        window = ts.loc[start:end]
        if window.empty:
            return False
        grouped = window.groupby(window.index.month).size()
        if set(grouped.index) != set(range(1, 13)):
            return False
        if grouped.min() < 10:
            return False
        return True
    
    
    def _compute_station_valid_windows(self, ts, windows):
        return [self._is_window_valid(ts, start, end) for (start, end) in windows]
    
    
    def _count_valid_shared_windows(self, valid_i, valid_j):
        return sum(np.logical_and(valid_i, valid_j))
    
    
    def _compute_concurrent_overlap_dict(self, variable='discharge'):
        """
        Compute the concurrent overlap of monitored watersheds in the dataset.
        Threshold years represent the minimum number of days of overlap 
        (ignoring continuity) for a watershed to be considered concurrent.
        """
        overlap_dict_file = os.path.join('data', 'record_overlap_dict.json')
        if os.path.exists(overlap_dict_file):
            with open(overlap_dict_file, 'r') as f:
                overlap_dict = json.load(f)
            print(f'    ...overlap dict loaded from {overlap_dict_file}')
            return overlap_dict
        
        watershed_ids = self.ds['watershed'].values
        data = self.ds[variable].load().values  # (N, T)
        threshold_years = self.minimum_concurrent_years

        # Compute mask on GPU
        M = jnp.asarray(~np.isnan(data), dtype=jnp.uint16) # 
        O = compute_overlap_matrix(M)

        N = M.shape[0] # number of sites
        T = M.shape[1] # number of time steps
        connectivity_factor = np.sum(O) / float(N**2 * T)
        print(f'Connectivity factor: {connectivity_factor:.4f}')

        # Build output
        N = len(watershed_ids)
        print(f'    ...building overlap dict for N={N} monitored watersheds in the network.')
        thresholds_days = 365 * np.array(threshold_years) 
        overlap_dict = {t: {} for t in threshold_years}

        for t_years, t_days in zip(threshold_years, thresholds_days):
            over_thresh = O >= t_days
            over_thresh = over_thresh.at[jnp.diag_indices(N)].set(False)

            over_thresh_np = np.array(over_thresh)
            for i in range(N):
                ws = int(watershed_ids[i])
                overlap_dict[t_years][ws] = list(watershed_ids[over_thresh_np[i]])

        # Save the overlap dictionary to a JSON file
        with open(overlap_dict_file, 'w') as f:
            json.dump(overlap_dict, f)
        print(f'    ...overlap dict saved to {overlap_dict_file}')

        return overlap_dict

In [15]:
from dataclasses import dataclass

@dataclass
class StationData:
    def __init__(self, context, stn):
        self.ctx = context
        self.target_stn = stn
        self.attr_gdf = context.attr_gdf
        self.LN_param_dict = context.LN_param_dict
        self.n_grid_points = context.n_grid_points
        # self.catchments = catchments
        self.min_flow = context.min_flow # don't allow flows below this value
        self.divergence_measures = context.divergence_measures
        self.met_forcings_folder = context.LSTM_forcings_folder
        self.LSTM_ensemble_result_folder = context.LSTM_ensemble_result_folder
        
        self.target_da = self.attr_gdf[self.attr_gdf['official_id'] == stn]['drainage_area_km2'].values[0]
        self._initialize_target_streamflow_data()
        self.global_max, self.global_min = self._set_global_range(epsilon=1e-12)
        self._set_grid()
        self._set_divergence_measure_functions()
    

    def retrieve_timeseries_discharge(self, stn):
        watershed_id = self.ctx.official_id_dict[stn]
        df = self.ctx.ds['discharge'].sel(watershed=str(watershed_id)).to_dataframe(name='discharge').reset_index()
        df = df.set_index('time')[['discharge']]
        df.dropna(inplace=True)
        # clip minimum flow to 1e-4
        df['discharge'] = np.clip(df['discharge'], self.ctx.min_flow, None)
        df.rename(columns={'discharge': stn}, inplace=True)
        df[f'{stn}_uar'] = 1000 * df[stn] / self.ctx.da_dict[stn]      
        return df
        

    def _initialize_target_streamflow_data(self):
        # self.stn_df = dpf.get_timeseries_data(self.target_stn)
        self.stn_df = self.retrieve_timeseries_discharge(self.target_stn)
        self.uar_label = f'{self.target_stn}_uar'
        self.n_observations = len(self.stn_df[self.uar_label].dropna())  
    

    def _set_grid(self):        
        # self.baseline_log_grid = np.linspace(np.log(adjusted_min_uar), np.log(max_uar), self.n_grid_points)
        self.baseline_log_grid = np.linspace(self.global_min, self.global_max, self.n_grid_points)
        self.baseline_lin_grid = np.exp(self.baseline_log_grid)
        self.log_dx = np.gradient(self.baseline_log_grid)
        # max_step_size = self.baseline_lin_grid[-1] - self.baseline_lin_grid[-2]
        # print(f'    max step size = {max_step_size:.1f} L/s/km^2 for n={self.n_grid_points:.1e} grid points')        
        
        
    def _set_global_range(self, xminglobal=1e-1, xmaxglobal=1e5, epsilon=1e-12):
        """
        xminglobal should be expressed in terms that equate to 
        a minimum flow of 0.1 L/s, or 1e-4 m^3/s must be divided by area to make UAR
        NOTE: xmaxglobal is already expressed in unit area (100m^3/s/km^2 = 1e5 L/s/km^2
        """
        # shift the minimum flow by a small amount to force the left interval bound
        # to include the assumed global minimum flow (1e-4 m^/3) or 0.1 L/s) 
        # expressed in unit area terms on a site-specific basis
        gmin_uar = (xminglobal - epsilon) / self.target_da
        # w_global = np.log(xmaxglobal) - np.log(gmin_uar)
        # check to make sure that the maximum UAR contains the max observed value
        max_uar = self.stn_df[self.uar_label].max()
        assert max_uar <= xmaxglobal, f'max UAR > max global assumption {max_uar:.1e} > {xmaxglobal:.1e}'
        return np.log(xmaxglobal), np.log(gmin_uar)
    

    def _adjust_Q_pdf_with_prior(self, Q, label):
        """
        Adjusts the simulated PDF Q(x), originally defined on x_sim, by incorporating a Dirichlet prior,
        to produce an adjusted (posterior) PDF Q_mod on x_target.

        """
        # Ensure the target grid lies within the global range.
        # Compute the global log-width.
                        
        # Convert Q_interp (a PDF) into effective "counts" using the number of observations.
        # counts_Q = N_obs * Q
        
        # Convert the prior density into pseudo-counts.
        # years_equiv = n_series * (N_obs / 365.25)
        # counts_prior = (self.pdf_prior / years_equiv) * dx
        # prior_pdf_interp = jnp.interp(self.baseline_log_grid, self.ba, prior_pdf, left=0, right=0)
        prior_pseudo_counts = self.knn_simulation_data[label]['prior']
        n_observations = self.knn_simulation_data[label]['n_obs']
        
        # Combine the counts from Q and the prior.
        Q_mod = Q * n_observations + prior_pseudo_counts
        # Renormalize to obtain the adjusted PDF (discrete PMF) on x_target.
        Q_mod /= np.sum(Q_mod)
        assert np.all(np.isfinite(Q_mod)), 'Q_mod is messed up'
        if not np.min(Q_mod) > 0:
            print('Q_mod min:', np.min(Q_mod))
            print('Q_mod sum:', np.sum(Q_mod))
            print('Q_mod:', Q_mod)
            print('Q:', Q)
            print('prior_pseudo_counts:', prior_pseudo_counts)
            print('n_observations:', n_observations)
            Q_mod += 1e-18
            Q_mod /= np.sum(Q_mod)
            # raise ValueError(f'Q_mod min < 0: {np.min(Q_mod)}')
        assert np.min(Q_mod) > 0, f'qmod_i < 0 ({np.min(Q_mod)})'
        assert np.isclose(np.sum(Q_mod), 1), f"Q_mod doesn't sum to 1: {np.sum(Q_mod):.5f}"

        q_mask = (Q > 0)
        prior_bias = jnp.sum(jnp.where(q_mask, Q * jnp.log2(Q / Q_mod), 0))
        if prior_bias < -0.0001:
            prior_pdf = self.knn_simulation_data[label]['prior']
            print('prior pmf sum =', np.sum(prior_pdf))
            print('Q_sum = ', np.sum(Q))
            print('Q_mod sum = ', np.sum(Q_mod))
            msg = f'    Prior bias {prior_bias:.3f} bits/sample bias'
            raise ValueError(msg)
        
        return Q_mod
    

    def _compute_emd(self, p, q, label=None):
        assert np.isclose(np.sum(p), 1, atol=1e-3), f'sum P = {np.sum(p)}'
        assert np.all(q >= 0), f'min q_i < 0: {np.min(q)}'
        emd = wasserstein_distance(self.baseline_log_grid, self.baseline_log_grid, p, q)
        return float(round(emd, 3))#, {'bias': None, 'unsupported_mass': None, 'pct_of_signal': None}

    
    def _compute_kld(self, p, q, label=None):
        # assert p and q are 1d
        assert jnp.ndim(p) >= 1, f'p is not 1D: {jnp.ndim(p)}'
        assert jnp.ndim(q) >= 1, f'q is not 1D: {jnp.ndim(q)}'
        # Ensure q is at least 2D for consistent broadcasting
        assert jnp.isclose(np.sum(p), 1, atol=1e-3), f'sum P = {np.sum(p)}'
        assert jnp.isclose(np.sum(q), 1, atol=1e-3), f'sum Q = {np.sum(q)}'
        assert jnp.all(q >= 0), f'min q_i < 0: {np.min(q)}'
        assert jnp.all(p >= 0), f'min p_i < 0: {np.min(p)}'
        p_mask = (p > 0)
        dkl = jnp.sum(jnp.where(p_mask, p * jnp.log2(p / q), 0)).item()
        
        return round(dkl, 3)


    def _set_divergence_measure_functions(self):
        self.divergence_functions = {k: None for k in self.divergence_measures}
        for dm in self.divergence_measures:
            # set the divergence measure functions
            if dm == 'DKL':
                self.divergence_functions[dm] = self._compute_kld
            elif dm == 'EMD':
                self.divergence_functions[dm] = self._compute_emd
            else:
                raise Exception("only KL divergence (DKL) and Earth Mover's Distance (EMD) are implemented")
            

    def _compute_bias_from_eps(self, pmf: jnp.ndarray, divergence_measure: str, eps: float = 1e-12) -> float:
        """Compute KL divergence between original PMF and PMF + eps.

        Parameters
        ----------
        pmf : jnp.ndarray
            The original PMF (should sum to 1).
        eps : float
            Small value added to avoid zero bins.

        Returns
        -------
        float
            D_KL(pmf || pmf_eps) representing the bias introduced by smoothing.
        """
        pmf_eps = pmf + eps
        pmf_eps /= pmf_eps.sum()
        return self.divergence_functions[divergence_measure](pmf, pmf_eps)


In [16]:
class kNNFDCEstimator:
    def __init__(self, ctx, target_stn, data, *args, **kwargs):
        self.ctx = ctx
        self.target_stn = target_stn
        self.data = data
        self.k_nearest = data.k_nearest
        self.weight_schemes = ['ID1', 'ID2'] #inverse distance and inverse square distance
        self.knn_simulation_data = {}
        self.knn_pdfs = pd.DataFrame()
        self.knn_pmfs = pd.DataFrame()


    def _find_k_nearest_neighbors(self, target_index, tree_type, overlapping_tree_idxs):
        # Query the k+1 nearest neighbors because the first neighbor is the target point itself

        if len(overlapping_tree_idxs) >= 10: 
            max_search = len(overlapping_tree_idxs)
        else:
            max_search = self.data.max_to_check

        if tree_type == 'spatial_dist':
            distances, indices = self.ctx.spatial_tree.query(self.ctx.coords[target_index], k=max_search+1)
            distances *= 1 / 1000
        elif tree_type == 'attribute_dist':
            # Example query: Find the nearest neighbors for the first point
            distances, indices = self.ctx.attribute_tree.query(self.ctx.normalized_attr_values[target_index], k=max_search+1)
        else:
            raise Exception('tree type not identified, must be one of spatial_dist, or attribute_dist.')
        
        # Remove target (self) from the results
        self_index = target_index
        keep = indices != self_index
        indices = indices[keep]
        distances = distances[keep]

        # Filter by the pre-processed overlapping tree indices
        if len(overlapping_tree_idxs) >= 10:
            overlap_set = set(overlapping_tree_idxs)
            filtered = [(i, d) for i, d in zip(indices, distances) if i in overlap_set]
        else:
            filtered = list(zip(indices, distances))

        neighbour_indices, neighbour_distances = zip(*filtered)
        return np.array(neighbour_indices), np.round(np.array(neighbour_distances), 3)
    

    def _check_time_series_completeness(self, timeseries):
        df = timeseries.copy().dropna(subset=[c for c in timeseries.columns if c.endswith('_uar')])
        # print(f'           After dropping len={len(df)}, {df.index.min()} - {df.index.max()}')
        # compute the total number of observations per month
        # we want to ensure the simulation period is not seasonal
        if df.empty:
            return False
        df['month'] = df.index.month
        days_by_month = df.groupby(['month']).size()
        # Check that all 12 months are present
        if set(days_by_month.index) != set(range(1, 13)):
            print('    Not all months represented year.  Skipping.')
            return False
        min_days = days_by_month.min()
        # Check that every month has at least 10 unique days
        if np.any(min_days < 10):
            print('    At least one month is not represented by at least ten days, skipping comparison.')
            return False
        # count the number of complete years of record
        return True


    def _compute_effective_k(self, df, max_k=None):
        arr = df.to_numpy()
        T, K = arr.shape
        max_k = max_k or K

        nan_mask = np.isnan(arr)
        sorted_idx = np.argsort(nan_mask, axis=1)
        row_idx = np.arange(T)[:, None]

        ks = np.arange(1, max_k + 1)
        effective_k = []
        mean_furthest = []

        for k in ks:
            idx = sorted_idx[:, :k]
            valid = ~nan_mask[row_idx, idx]
            valid_count = valid.sum(axis=1)

            effective_k.append(valid_count.mean())
            furthest_idx = np.where(valid, idx, -1).max(axis=1)
            mean_furthest.append(furthest_idx[valid_count >= k].mean() if np.any(valid_count >= k) else np.nan)

        return pd.DataFrame({
            'effective_k': np.round(effective_k, 1),
            'mean_furthest_idx': np.round(mean_furthest, 2)
        }, index=ks)

        
    def _check_neighbours(self, stn, distance, concurrent, target_df, min_years=5):
        # proxy_df = dpf.get_timeseries_data(stn)
        proxy_df = self.data.retrieve_timeseries_discharge(stn)[[f'{stn}_uar']]
        if concurrent == 'concurrent':
            proxy_df = proxy_df.reindex(target_df.index)

        full_year_represented = self._check_time_series_completeness(proxy_df)
        n_years = proxy_df.dropna().shape[0] // 365
        if n_years < min_years or not full_year_represented:
            print(f'    Skipping {stn}: <1 year or incomplete.')
            return None
        return [stn, round(distance, 3), n_years, proxy_df]
    
    
    def _query_distance(self, tree, id1, id2):
        """Query distance between two points in a tree using official_id."""
        if id1 not in self.ctx.id_to_index or id2 not in self.ctx.id_to_index:
            raise ValueError(f"One or both IDs ({id1}, {id2}) not found.")
    
        # Get indices from ID mapping
        index1, index2 = self.ctx.id_to_index[id1], self.ctx.id_to_index[id2]        
        # Query the distance
        distance = np.linalg.norm(tree.data[index1] - tree.data[index2])  # Euclidean distance
        return distance
    

    def _retrieve_nearest_nbr_data(self, tree_type, concurrent, min_concurrent_years):
        tree_idx = self.ctx.id_to_index[self.target_stn]
        # get the pre-screened concurrent stations
        target_ws_id = self.ctx.official_id_dict[self.target_stn]

        max_nbrs = self.data.max_to_check

        # use the pre-computed overlap dictionary to find the concurrent stations
        overlapping_tree_idxs = []
        if concurrent == 'concurrent':
            assert min_concurrent_years > 0, f'concurrent min years must be > 0, not {min_concurrent_years}'
            overlapping_ws_ids = self.ctx.overlap_dict[min_concurrent_years].get(target_ws_id, [])
            overlapping_stn_official_ids = [self.ctx.watershed_id_dict[e] for e in overlapping_ws_ids]

            existing_keys = [e for e in overlapping_stn_official_ids if e in self.ctx.id_to_index]
            overlapping_tree_idxs = [self.ctx.id_to_index[e] for e in existing_keys]
            overlapping_tree_idxs = [e for e in overlapping_tree_idxs if e is not None]

        neighbour_idxs, distances = self._find_k_nearest_neighbors(tree_idx, tree_type, overlapping_tree_idxs)
        neighbours = [self.ctx.index_to_id[e] for e in neighbour_idxs.tolist()][:max_nbrs]

        assert self.target_stn not in neighbours, f'{self.target_stn} is in the list of neighbours: {neighbours}'
        concurrent_inputs = [concurrent] * len(neighbours)
        df_inputs = [self.data.stn_df.copy()]*len(neighbours)

        # don't need to recheck neighbours
        with ThreadPoolExecutor(max_workers=20) as executor:  # Adjust max workers
            checked_neighbours = list(executor.map(self._check_neighbours, neighbours, distances, concurrent_inputs, df_inputs))
        
        good_nbrs = [e for e in checked_neighbours if e is not None]
        print(f'     ...found {len(good_nbrs)} ({tree_type} tree) neighbours ({concurrent}) for {self.target_stn} from {max_nbrs} nearest neighbours')
        if len(good_nbrs) < 10:
            print('   need', 10 - len(good_nbrs), '  more good records')
            self.data.max_to_check += 50
            max_nbrs = self.data.max_to_check
            if max_nbrs <= 600:
                return self._retrieve_nearest_nbr_data(tree_type, concurrent, min_concurrent_years)
            else:
                raise Exception(f'{len(good_nbrs)}/10 suitable nearest neighbours found ({max_nbrs} searched)')

        complement_dist = 'attribute_dist' if tree_type == 'spatial_dist' else 'spatial_dist' 
        # set the OPPOSITE tree to the current tree type
        complement_tree = self.ctx.attribute_tree if tree_type == 'spatial_dist' else self.ctx.spatial_tree 
        # the multiplier should reflect conversion to km from m for spatial distance 
        multiplier = 1 / 1000 if complement_dist == 'spatial_dist' else 1 

        # sort the neighbours by distance
        # output is ['official_id', 'distance', 'n_years', 'timeseries df']
        good_nbrs = sorted(good_nbrs, key=lambda tup: tup[1])        
        nbr_df = pd.concat([e[3] for e in good_nbrs], join='outer', axis=1)

        # Align with target index if concurrent, leave as is if async.
        if concurrent == 'concurrent':
            target_series = self.data.stn_df.dropna(subset=[f'{self.target_stn}_uar'])
            nbr_df = nbr_df.reindex(index=target_series.index)

        nbr_data = pd.DataFrame([e[:3] for e in good_nbrs], columns=['official_id', 'distance', 'n_years'])  
        nbr_data[complement_dist] = (nbr_data['official_id'].apply(lambda x: self._query_distance(complement_tree, x, self.target_stn)) * multiplier).round(3)

        effective_k = self._compute_effective_k(nbr_df)
        return nbr_df, nbr_data, effective_k


    def _contributing_ensemble_check(self, nbr_data, min_years, previous_ids, min_years_prev):
        current_ids = set([c.split('_')[0] for c in nbr_data.columns])
        if previous_ids is not None:
            if current_ids == previous_ids:
                print(f"    No change in neighbor set between {min_years_prev} and {min_years} years.")
            else:
                added = current_ids - previous_ids
                removed = previous_ids - current_ids
                print(f"    Change detected for min_concurrent_years {min_years_prev} → {min_years}:")
                if added:
                    print(f"    + Added:   {sorted(added)}")
                if removed:
                    print(f"    - Removed: {sorted(removed)}")

        previous_ids = current_ids
        min_years_prev = min_years
        return previous_ids, min_years_prev
    
    
    def _initialize_nearest_neighbour_data(self):
        """
        Generate nearest neighbours for spatial and attribute selected k-nearest neighbours for both concurrent and asynchronous records.
        """
        print(f'    ...initializing {self.data.max_to_check} nearest neighbours with minimum concurrent record.')
        self.nbr_dfs = defaultdict(lambda: defaultdict(dict))
        
        for tree_type in ['spatial_dist', 'attribute_dist']:
            for concurrent in ['concurrent', 'async']:
                if (self.ctx.LSTM_concurrent_network is False) & (concurrent == 'concurrent'):
                    print('    Skipping concurrent check for non-LSTM concurrent network.')
                    continue

                if concurrent == 'async':
                    min_concurrent_years = [0]
                else:
                    min_concurrent_years = self.ctx.minimum_concurrent_years
                previous_ids, min_years_prev = None, None
                for min_overlap in min_concurrent_years:
                    nbr_df, nbr_data, effective_k = self._retrieve_nearest_nbr_data(tree_type, concurrent, min_overlap)
                    previous_ids, min_years_prev = self._contributing_ensemble_check(nbr_df, min_overlap, previous_ids, min_years_prev)
                    assert not nbr_df.empty, f'{tree_type} attr concurrent={concurrent} nbr df empty'
                    print('       ', tree_type, concurrent, min_overlap, len(nbr_df))
                    self.nbr_dfs[tree_type][concurrent][min_overlap] = {
                        'nbr_df': nbr_df,
                        'nbr_data': nbr_data,
                        'effective_k': effective_k
                    }
    
    
    def _run_spatial_knn(self, num_neighbours, time_type, weight_scheme):
        target_idx = self.ctx.id_to_idx[self.target_stn]
        
        for k in range(1, num_neighbours+1):
            sim_label = f'{k}NN_{time_type}_{weight_scheme}'

    
    def _compute_weights(self, m, k, distances, epsilon=1e-3):
        """Compute normalized inverse (square) distance weights to a given power."""

        distances = jnp.where(distances == 0, epsilon, distances)

        if m == 'ID1':
            power = 1.0
        elif m == 'ID2':
            power = 2.0
        else:
            raise ValueError(f"Unknown weighting method: {m}.  Only ID1 (inverse distance) or ID2 (inverse square distance) are implemented.")
        
        if k == 1:
            return jnp.array([1])
        else:
            inv_weights = 1 / (jnp.abs(distances) ** power)
            return inv_weights / jnp.sum(inv_weights)
    
    
    def _compute_prior_from_laplace_fit(self, predicted_uar, n_cols=1, min_prior=1e-10, scale_factor=1.05, recursion_depth=0, max_depth=100):
        """
        Fit a Laplace distribution to the simulation and define a 
        pdf across a pre-determined "global" range to avoid data
        leakage.  "Normalize" by setting the total prior mass to
        integrate to a factor related to the number of observations.
        """
        # assert no nan values
        assert np.isfinite(predicted_uar).all(), f'NaN values in predicted_uar: {predicted_uar}'
        # assert all positive values
        # assert np.all(predicted_uar > 0), f'Negative values in predicted_uar: {np.min(predicted_uar)}'
        # replace anything <= 0 with 1e-4 scaled by the drainage area
        predicted_uar = np.where(predicted_uar <= 0, 1000 * 1e-4 / self.data.target_da, predicted_uar)
        assert np.all(predicted_uar > 0), f'Negative values in predicted_uar: {np.min(predicted_uar)}'
        # print('min/max: ', np.min(predicted_uar), np.max(predicted_uar))
        loc, scale = laplace.fit(np.log(predicted_uar))       

        # Apply scale factor in case of recursion
        if scale <= 0:
            original_scale = scale
            scale = scale_factor ** recursion_depth
            print(f'   Adjusting scale from {original_scale:.3f} to {scale:.3f} for recursion depth {recursion_depth}')

        prior_pdf = laplace.pdf(self.data.baseline_log_grid, loc=loc, scale=scale)
        prior_check = jnp.trapezoid(prior_pdf, x=self.data.baseline_log_grid)
        prior_pdf /= prior_check

        # Check for zeros
        if np.any(prior_pdf == 0) | np.any(np.isnan(prior_pdf)):
            # Prevent scale from being too small
            if recursion_depth >= max_depth:
                # set a very small prior
                prior_pdf = np.ones_like(self.data.baseline_log_grid)
                err_msg = f"Recursion limit reached. Scale={scale}, setting default prior to 1 pseudo-count uniform distribution."
                print(err_msg)
                return prior_pdf
                # raise ValueError(err_msg)
            # print(f"Recursion {recursion_depth}: Zero values detected. Increasing scale to {scale:.6f}")
            return self._compute_prior_from_laplace_fit(predicted_uar, n_cols=n_cols, recursion_depth=recursion_depth + 1)
        
        second_check = jnp.trapezoid(prior_pdf, x=self.data.baseline_log_grid)
        assert np.isclose(second_check, 1, atol=2e-4), f'prior check != 1, {second_check:.6f} N={len(predicted_uar)} {predicted_uar}'
        assert np.min(prior_pdf) > 0, f'min prior == 0, scale={scale:.5f}'

        # convert prior PDF to PMF (pseudo-count mass function)
        prior_pmf = prior_pdf * self.data.log_dx

        # scale the number of pseudo-counts based on years of record  (365 / n_observations)
        # and number of models in the ensemble (given by n_cols)
        prior_pseudo_counts = prior_pmf * (365 / (len(predicted_uar) * n_cols))
        
        # return weighted_prior_pdf
        return prior_pseudo_counts
    

    def _compute_frequency_ensemble_mean(self, pdfs, weights):
        """
        This function computes the weighted ensemble distribution estimates.
        """
        # Normalize distance weights
        if weights is not None:
            weights /= jnp.sum(weights).astype(float)
            weights = jnp.array(weights)  # Ensure 1D array
            pdf_est = jnp.asarray(pdfs.to_numpy() @ weights)
        else:
            pdf_est = jnp.asarray(pdfs.mean(axis=1).to_numpy())


        # Check integral before normalization
        pdf_check = jnp.trapezoid(pdf_est, x=self.data.baseline_log_grid)
        normalized_pdf = pdf_est / pdf_check
        assert jnp.isclose(jnp.trapezoid(normalized_pdf, x=self.data.baseline_log_grid), 1), f'ensemble pdf does not integrate to 1: {pdf_check:.4f}'
                
        # Compute PMF
        pmf_est = normalized_pdf * self.data.log_dx
        pmf_est /= jnp.sum(pmf_est)

        return pmf_est, pdf_est
    

    def _compute_ensemble_member_distribution_estimates(self, df):
        """
        Compute the ensemble distribution estimates based on the KNN dataframe.
        """    
        pdfs, prior_biases = pd.DataFrame(), {}
        # initialize a kde estimator object
        kde = KDEEstimator(self.data.baseline_log_grid, self.data.log_dx)
        for c in df.columns: 
            # evaluate the laplace on the prediction as a prior
            # drop the nan values
            values = df[c].dropna().values
            obs_count = len(values)
            assert len(values) > 0, f'0 values for {c}'

            # compute the pdf and pmf using kde
            assert sum(np.isnan(values)) == 0, f'NaN values in {c} {values[:5]}'

            kde_pmf, _ = kde.compute(
                values, self.data.target_da
            )

            prior = self._compute_prior_from_laplace_fit(values, n_cols=1) # priors are expressed in pseudo-counts
            # convert the pdf to counts and apply the prior
            counts = kde_pmf * obs_count + prior

            # re-normalize the pmf
            pmf = counts / jnp.sum(counts)
            pdf = pmf / self.data.log_dx

            pdf_check = jnp.trapezoid(pdf, x=self.data.baseline_log_grid)
            pdf /= pdf_check
            # pdf /= pdf_check
            assert jnp.isclose(jnp.trapezoid(pdf, x=self.data.baseline_log_grid), 1.0, atol=0.001), f'pdf does not integrate to 1 in compute_ensemble_member_distribution_estimates: {pdf_check:.4f}'
            pdfs[c] = pdf

            # convert the pdf to pmf
            pmf = pdf * self.data.log_dx
            pmf /= jnp.sum(pmf)
            # assert np.isclose(np.sum(pmf), 1, atol=1e-4), f'pmf does not sum to 1 in compute_ensemble_member_distribution_estimates: {np.sum(pmf):.5f}'
            
            # compute the bias added by the prior
            prior_biases[c.split('_')[0]] = {'DKL': self.data._compute_kld(kde_pmf, pmf), 'EMD': self.data._compute_emd(kde_pmf, pmf)}
        return pdfs, prior_biases
    
    
    def _compute_frequency_ensemble_distribution(self, knn_df_all, knn_data_all, distance_type, concurrent, ensemble_pdfs, min_overlap=0):
        """
        For asynchronous comparisons, we estimate pdfs for ensemble members, then compute the mean in the time domain
        to represent the FDC simulation.  We do not do temporal averaging in this case.
        Default min_overlap is 0 for asynchronous comparisons.
        """
        # distances_all = knn_data_all['distance'].values[:self.k_nearest]
        # nbr_ids_all = knn_data_all['official_id'].values[:self.k_nearest]
        
        # distances = jnp.array(nbr_data['distance'].astype(float).values)
        labels, pdfs, pmfs = [], [], []
        all_distances = knn_data_all['distance'].values
        all_ids = knn_data_all['official_id'].values
        # prior_bias_df = pd.DataFrame(prior_bias_dict)
        for wm in self.weight_schemes:
            for k in range(1, self.k_nearest + 1):
                distances = all_distances[:k]
                nbr_ids = all_ids[:k]
                knn_pdfs = ensemble_pdfs.iloc[:, :k].copy()

                label = f'{self.target_stn}_{k}_NN_{min_overlap}_minYears_{concurrent}_{distance_type}_{wm}_freqEnsemble'
                weights = self._compute_weights(wm, k, distances)
                pmf_est, pdf_est = self._compute_frequency_ensemble_mean(knn_pdfs, weights)
                assert pmf_est is not None, f'pmf_est is None for {label}'
            
                # compute the mean number of observations (non-nan values) per row
                mean_obs_per_timestep = knn_df_all.iloc[:, :k].notna().sum(axis=1).mean()
                mean_obs_per_proxy = knn_df_all.iloc[:, :k].notna().sum(axis=0).mean()

                # max_prior_bias = prior_bias_df.iloc[:k].max(axis=0)                
                # compute the frequency-based ensemble pdf estimate
                self.knn_simulation_data[label] = {'k': k, 'n_obs': mean_obs_per_proxy,
                                                'mean_obs_per_timestep': mean_obs_per_timestep,
                                                'nbrs': ','.join(nbr_ids)}

                self.knn_simulation_data[label]['DKL'] = self.data._compute_kld(self.ctx.baseline_pmf, pmf_est)
                self.knn_simulation_data[label]['EMD'] = self.data._compute_emd(self.ctx.baseline_pmf, pmf_est)

                # print(k, wm, label, self.knn_simulation_data[label]['DKL'])
                
                pdfs.append(np.asarray(pdf_est))
                pmfs.append(np.asarray(pmf_est))
                labels.append(label)

        # create a dataframe of labels(columns) for each pdf
        knn_pdfs = pd.DataFrame(pdfs, index=labels).T
        knn_pmfs = pd.DataFrame(pmfs, index=labels).T
        # Filter out already existing columns to avoid duplication
        new_pdf_cols = knn_pdfs.columns.difference(self.knn_pdfs.columns)
        new_pmf_cols = knn_pmfs.columns.difference(self.knn_pmfs.columns)
        # Concat only new columns
        self.knn_pdfs = pd.concat([self.knn_pdfs, knn_pdfs[new_pdf_cols]], axis=1)
        self.knn_pmfs = pd.concat([self.knn_pmfs, knn_pmfs[new_pmf_cols]], axis=1)
    
    
    def _run_async_knn(self, dist):
        min_overlap = 0
        nbr_df = self.nbr_dfs[dist]['async'][min_overlap]['nbr_df'].copy()
        nbr_data = self.nbr_dfs[dist]['async'][min_overlap]['nbr_data'].copy()
        knn_df_all = nbr_df.iloc[:, :self.k_nearest].copy()
        knn_data_all = nbr_data.iloc[:, :self.k_nearest].copy()
        frequency_ensemble_pdfs, _ = self._compute_ensemble_member_distribution_estimates(knn_df_all)
        self._compute_frequency_ensemble_distribution(knn_df_all, knn_data_all, dist, 'async', frequency_ensemble_pdfs)

    
    def _delta_spike_pmf_pdf(self, single_val, log_grid):
        """
        Return a spike PMF and compatible PDF centered at the only value in the input.
        The spike is placed at the nearest log_grid point to log(single_val).
        """
        log_val = jnp.log(single_val)
        spike_idx = jnp.argmin(jnp.abs(log_grid - log_val))
        
        pmf = jnp.zeros_like(log_grid)
        pmf = pmf.at[spike_idx].set(1.0)

        dx = jnp.gradient(log_grid)
        pdf = pmf / dx  # assign all mass to one bin

        return pmf, pdf

    
    def _compute_nse(self, obs, sim):
        """Compute the Nash-Sutcliffe Efficiency (NSE) between observed and simulated values."""
        assert not np.isnan(obs).any(), f'NaN values in obs: {obs}'
        assert not np.isnan(sim).any(), f'NaN values in sim: {sim}'
        assert (obs >= 0).all(), f'Negative values in obs: {obs}'
        assert (sim >= 0).all(), f'Negative values in sim: {sim}'
        # Compute the NSE
        numerator = jnp.sum((obs - sim) ** 2)
        denominator = jnp.sum((obs - obs.mean()) ** 2)
        nse = 1 - (numerator / denominator)
        return nse


    def _compute_KGE(self, obs, sim):
        """Compute the Kling-Gupta Efficiency (KGE) between observed and simulated values."""
        assert not np.isnan(obs).any(), f'NaN values in obs: {obs}'
        assert not np.isnan(sim).any(), f'NaN values in sim: {sim}'
        assert (obs >= 0).all(), f'Negative values in obs: {obs}'
        assert (sim >= 0).all(), f'Negative values in sim: {sim}'
        # Compute the KGE
        r = jnp.corrcoef(obs, sim)[0, 1]
        alpha = sim.mean() / obs.mean()
        beta = sim.std() / obs.std()
        kge = 1 - jnp.sqrt((r - 1) ** 2 + (alpha - 1) ** 2 + (beta - 1) ** 2)
        return kge
    
    
    
    def _compute_ensemble_contribution_metrics(self, df: pd.DataFrame, weights: np.ndarray):
        mask = ~df.isna()
        
        # Mean number of valid values per row
        mean_valid_per_row = mask.sum(axis=1).mean()

        # Normalized weights per row, masking NaNs
        X = df.to_numpy()
        W = np.broadcast_to(weights, X.shape)
        masked_weights = np.where(mask, W, 0.0)
        weight_sums = masked_weights.sum(axis=1)
        weight_sums[weight_sums == 0] = np.nan
        normalized_weights = masked_weights / weight_sums[:, None]

        # Average contribution per column across all rows
        mean_w = np.nanmean(normalized_weights, axis=0)
        effective_n = 1.0 / np.nansum(mean_w ** 2)

        return mean_valid_per_row, effective_n
    
    
    
    
   
    def run_estimators(self, divergence_measures, eps, baseline_pmf):
        self._initialize_nearest_neighbour_data()
        
        for dist in ['spatial_dist', 'attribute_dist']:            
            self._run_async_knn(dist) 
            if self.ctx.LSTM_concurrent_network is False:
                print('    Skipping concurrent check for non-LSTM concurrent network.')
                continue 
            for min_concurrent_years in self.ctx.minimum_concurrent_years[::-1]: # go from most to least minimum required concurrent years 
                self._initialize_concurrent_ensemble_inputs(dist, min_concurrent_years)
        return self._process_results()

    
    def _make_json_serializable(self, d):
        output = {}
        for k, v in d.items():
            if isinstance(v, (np.ndarray, jnp.ndarray)):
                output[k] = v.tolist()
            elif hasattr(v, "tolist"):
                output[k] = v.tolist()
            else:
                output[k] = v
        return output
    
    
    def _process_results(self):        
        pmf_labels, pdf_labels, sim_labels = list(self.knn_pmfs.columns), list(self.knn_pdfs.columns), list(self.knn_simulation_data.keys())
        # assert label sets are the same
        assert set(pmf_labels) == set(pdf_labels), f'pmf_labels {pmf_labels} != pdf_labels {pdf_labels}'
        assert set(pmf_labels) == set(sim_labels), f'pmf_labels {pmf_labels} != sim_labels {sim_labels}'
        results = self.knn_simulation_data
        for label in pmf_labels:
            # add the pmf and pdf in a json serializable format
            results[label]['pmf'] = self.knn_pmfs[label].tolist()
            results[label]['pdf'] = self.knn_pdfs[label].tolist()
            results[label] = self._make_json_serializable(results[label])
        return results
        

In [17]:
class PDFEstimatorRunner:
    def __init__(self, stn_id, ctx, methods, k_nearest, max_to_check, **kwargs):
        self.stn_id = stn_id
        self.ctx = ctx
        self.methods = methods
        self.k_nearest = k_nearest
        self.max_to_check = max_to_check
        self._check_min_overlap()
        self._create_results_folders()


    def _create_results_folders(self):
        # create a results foder for each method if it doesn't exist
        self.results_folder = os.path.join('data', f'pdf_estimates',)
        os.makedirs(self.results_folder)


    def _process_ground_truth(self):
        self.kde = KDEEstimator(self.data.baseline_log_grid, self.data.log_dx)
        self.baseline_pmf, self.baseline_pdf = self.kde.compute(
            self.data.stn_df[self.data.uar_label].values, self.data.target_da
        )
        self.ctx.baseline_pmf = self.baseline_pmf


    def _save_result(self, result):
        with open(self.result_file, 'w') as file:
            json.dump(result, file, indent=4)


    def _check_min_overlap(self):
        for min_years, cdict in self.ctx.overlap_dict.items():
            min_concurrent_stns = 1e6
            n_less_than_ten, lonely_stns = 0, []
            print(f'Processing {min_years} years of concurrent record')
            for stn, concurrent_ids in cdict.items():
                n_stns = len(concurrent_ids)
                if n_stns < 10:
                    n_less_than_ten += 1
                    lonely_stns.append(stn)
                if n_stns < min_concurrent_stns:
                    min_concurrent_stns = n_stns
            print(f"    {n_less_than_ten} stations do not have at least 10 viable sensors in the network with at least {min_years} years of concurrent record.")

 
    def run_selected(self):
        # check the minimum number of years of overlap for all stations in self.ctx.overlap_dict
        
        self.result_file = os.path.join(self.results_folder, f'{self.stn_id}_estimated_pdfs.json')
        if os.path.exists(self.result_file):
            return None
        else:
            self.data = StationData(self.ctx, self.stn_id)
            # self.data.k_nearest = self.k_nearest
            # self.data.max_to_check = self.max_to_check
            self._process_ground_truth()
        try:
            EstimatorClass = ESTIMATOR_CLASSES[method]
            estimator = EstimatorClass(
                self.ctx, self.stn_id, self.data
            )
            result = estimator.run_estimators(
                divergence_measures=self.ctx.divergence_measures, 
                eps=self.ctx.eps,
                baseline_pmf=self.baseline_pmf,
                )
            self._save_result(result)
        except Exception as e:
            raise Exception(f"  {method} estimator failed for {self.stn_id}: {str(e)}")
                

In [18]:
np.random.seed(42)

# from utils import FDCEstimationContext
rev_date = '20250227'
attr_gdf_fpath = os.path.join('data', f'BCUB_watershed_attributes_updated_{rev_date}.geojson')
LSTM_forcings_folder = '/home/danbot2/code_5820/neuralhydrology/data/BCUB_catchment_mean_met_forcings_20250320'
LSTM_ensemble_result_folder = '/home/danbot2/code_5820/neuralhydrology/data/ensemble_results'
parameter_prediction_results_folder = os.path.join('data', 'parameter_prediction_results')

# methods = ('parametric', 'lstm',)
methods = ('parametric', 'lstm', 'knn',)
methods = ('knn',)

processed = []
ESTIMATOR_CLASSES = {
}
input_data = {
    'attr_gdf_fpath': attr_gdf_fpath,
    'LSTM_forcings_folder': LSTM_forcings_folder,
    'LSTM_ensemble_result_folder': LSTM_ensemble_result_folder,
    'LSTM_concurrent_network': False,  # use only stations with data 1980-present concurrent with Daymet
    'parameter_prediction_results_folder': parameter_prediction_results_folder,
    'streamflow_dir': STREAMFLOW_DIR,
    # 'predicted_param_sample': predicted_param_sample,
    'divergence_measures': ['DKL', 'EMD'],
    'minimum_concurrent_years': [0, 1, 2, 5, 10, 20],
    'eps': 1e-12,
    'min_flow': 1e-4,
    'n_grid_points': 2**12,
    'min_overlap_years': 5,
}

context = FDCEstimationContext(**input_data)

2889/1324 catchments contain withdrawal licenses
    ...overlap dict loaded from data/record_overlap_dict.json


In [20]:
plot_network_records = False
if plot_network_records:
    discharge = context.ds['discharge']
    # mask = (~np.isnan(discharge)).astype(int)
    output_file("images/data_availability_matrix.html")
    # Step 1: Convert to DataArray and group by week
    # Convert time coordinate to Pandas Index first
    time_index = pd.DatetimeIndex(discharge.time.values)

    # Convert to weekly periods and back to timestamps (start of week)
    week = time_index.to_period("W").to_timestamp()

    # Assign 'week' as a new coordinate aligned with time
    d = context.ds['discharge']  # shape: (w, time)

    d.coords["week"] = ("time", week)
    # Count non-NaNs per watershed/week
    weekly_counts = d.groupby("week").map(lambda x: np.isfinite(x).sum(dim="time"))

    # Boolean: weeks with ≥3 days of data
    weekly_mask = (weekly_counts >= 3).transpose("watershed", "week")
    week_dt = pd.to_datetime(weekly_mask['week'].values) 
    # Convert to float and flip for plotting
    img = weekly_mask.astype(float).values.T#[::-1, :]

    watershed_ids = d.watershed.values
    x_start, x_end = 0, len(watershed_ids)
    y_start, y_end = week_dt.min(), week_dt.max()
    output_file("images/weekly_data_availability_matrix.html")
    # Assume weekly_mask is a DataArray: dimensions ('watershed', 'week')
    img = weekly_mask.astype(float).values.T[::-1, :]  # shape: (ny, nx)

    # Get the datetime values for the weekly bins
    week_dt = pd.to_datetime(weekly_mask['week'].values)  # len = nx
    watersheds = weekly_mask['watershed'].values          # len = ny


    p = figure(
        width=1200,
        height=300,
        y_axis_type="datetime",
        y_range=(week_dt[0], week_dt[-1]),
        x_range=(0, len(watersheds)),
        title='',
        # toolbar_location=None,
        # title="Weekly Data Availability Matrix (≥3 days/week)"
    )

    # Define color mapper
    mapper = LinearColorMapper(palette=["#ffffff", "#444444"], low=0, high=1)

    # Use image glyph; Bokeh handles datetime x when x is a datetime64
    p.image(
        image=[img],
        y=week_dt[0],
        x=0,
        dh=(week_dt[-1] - week_dt[0]),#.astype('timedelta64[D]').astype(int),  # width in days
        dw=len(watersheds),
        color_mapper=mapper
    )

    # Format ticks to show only years
    # p.xaxis.formatter = DatetimeTickFormatter(years="%Y")

    # Minimalist style
    p.xaxis.axis_label = 'Watershed ID'
    p.yaxis.axis_label = 'Date'
    p.grid.visible = False
    p.outline_line_color = None

    save(p)


In [94]:
processed = []
t0 = time()
for stn in context.official_ids:
    print(stn)
    runner = PDFEstimatorRunner(stn, context, methods, k_nearest=10, max_to_check=20)
    runner.run_selected()
    processed.append(stn)
    if len(processed) % 10 == 0:
        t1 = time()
        elapsed = t1 - t0
        unit_time = elapsed / len(processed)
        print(f'Processed {len(processed)}/{len(context.official_ids)} stations in {unit_time:.2f} seconds per station')

AttributeError: 'FDCEstimationContext' object has no attribute 'official_ids'

#### Kernel Density Estimator in 1D

$$\hat g (x) = \frac{1}{N} \sum_{i=1}^{N} \frac{1}{h(x_i)} K \left[ \frac{x-x_i}{h(x_i)} \right]$$

Where: 
* $N$ is the number of data points
* $x_i$ is the $i^\text{th}$ data point
* $K(x)$ is the kernel function, normalized to 1.
* $h$ is the bandwidth, in this case it is a function of x.

In [None]:
# plot the measurement error:

efig = figure(title="Estimated Measurement Error Model", width=600, height=400, x_axis_type='log')
error_points = jnp.array([1e-4, 1e-3, 1e-2, 1e-1, 1., 1e1, 1e2, 1e3, 1e4, 1e5])  #  Reference flow points in m^3/s
error_values = jnp.array([1.0, 0.5, 0.2, 0.1, 0.1, 0.1, 0.1, 0.15, 0.2, 0.25])
efig.line(error_points, error_values, line_color='red', line_width=2, legend_label='Measurement Error Model')
efig.xaxis.axis_label = r'$$\text{Flow } m^3/s$$'
efig.yaxis.axis_label = r'$$\text{Error } [\text{x}100\%]$$'
efig.legend.background_fill_alpha = 0.5
efig = dpf.format_fig_fonts(efig, font_size=14)

layout = gridplot([efig], ncols=2, width=500, height=350)
show(layout)