# Predicted vs. Empricial Distribution Fitting

## Introduction

### Question: What is the best general approach to estimate the PDF at an ungauged location?

We test several methods of estimating flow duration curves for ungauged locations:

1. k-Nearest neighbours, frequency and time averaging.  Equal weighting, inverse distance weighting, and catchment similarity weighting.

We know that mean annual runoff can be predicted with reasonable accuracy from catchment attributes.  We also know that mean runoff and variance are correlated.  Runoff is generally best approximated by a lognormal distribution among parametric distributions, but the lognormal distribution parameters cannot be computed explicitly from the sample mean and variance.  However, the parameters can be estimated by method of moments, with some penalty over maximum likelihood estimation.  While the mean and variance (and even the log-mean) are predictable from catchment attributes, log-variance is not.  The question this experiment asks is whether the poor predictability of the log-variance results in worse estimation of runoff distribution than the method of moments approximation.

2. Parametric estimation: predicting lognormal distribution parameters from catchment attributes by a) method of moments on predicted mean and variance, and b) predicting maximum likelihood parameters predicted from catchment attributes.
 

We predicted mean and SD based on catchment attributes alone.  We then computed the KL divergence of the predicted (proxy) parametric distributions from KDE fits of target catchment distributions.  We then tested the predictability of these KL divergences from attributes, both individually and in pairs. Individually we tested catchment attributes for the predictability of divergence between the predicted parametric distributions and the (error model) adjusted KDE fits.  Pairwise we tested the predictability of divergence between a predicted parametric distribution of a potential proxy model and an (error model) adjusted KDE fit.

Here we generate distribution estimates from k-nearest neighbours (KNN), where k = 1, ..., 10.  Finally, for each simulated location, we compare all methods of estimating distributions:

1. **Parametric**: log-normal distributions estimated from predicting mean runoff from catchment attributes, plus the linear approximation between mean and standard deviation runoff.
2. **KNN**: For each location we approximate a distribution using 1, ..., 10 nearest neighbours by:   
    * equal weighting (EW)
    * inverse-distance weighting (IDW)
    * catchment similarity weighting (CSW)
3. Find the most similar distribution in the monitoring network:
    * by Kullback-Leibler, Wasserstein, TVD, Hellinger distance
    * record the distance ranks of the neighbouring distributions
    * think about how to address how stable the rankings are under different assumptions, i.e. priors



## Data Import and Model Setup

In [1]:
import os
import pandas as pd
import numpy as np
from time import time

import geopandas as gpd
from shapely.geometry import Point
import xyzservices.providers as xyz
from concurrent.futures import ThreadPoolExecutor

from bokeh.plotting import figure, show
from bokeh.layouts import gridplot, row, column
from bokeh.transform import factor_cmap, linear_cmap
from bokeh.models import ColumnDataSource, LinearAxis, Range1d
from bokeh.io import output_notebook
from bokeh.palettes import Sunset10, Vibrant7, Category20, Bokeh6, Bokeh7, Bokeh8, Greys256

import xgboost as xgb
xgb.config_context(verbosity=2)

from sklearn.cluster import AgglomerativeClustering

from sklearn.metrics import (
    root_mean_squared_error,
    mean_absolute_error,
    roc_auc_score,
    roc_curve, auc,
    accuracy_score,
    confusion_matrix,
)

from scipy.stats import linregress
from scipy.stats import lognorm, norm, rdist
from scipy.special import kl_div


import jax
import jax.numpy as jnp
from jax.scipy.stats import gaussian_kde as jkde
from jax import config as jax_config
jax_config.update("jax_enable_x64", False)

from jax import jit
from jax import vmap

from KDEpy import FFTKDE

import data_processing_functions as dpf

# from sklearn.model_selection import StratifiedKFold
output_notebook()

In [2]:
BASE_DIR = os.getcwd()
tiles = xyz['USGS']['USTopo']


In [3]:
# load the catchment characteristics
fname = 'BCUB_watershed_attributes_updated.csv'
attr_df = pd.read_csv(os.path.join('data', fname))
attr_df['log_drainage_area_km2'] = np.log(attr_df['drainage_area_km2'])
attr_df.columns = [c.lower() for c in attr_df.columns]
attr_df['tmean'] = (attr_df['tmin'] + attr_df['tmax']) / 2.0
station_ids = attr_df['official_id'].values
print(f'There are {len(station_ids)} monitored basins in the attribute set.')

There are 1325 monitored basins in the attribute set.


In [4]:
# load the lognormal fit parameter results
ln_fit_fname = 'LN_fit_method_comparison_20250128.csv'
ln_fit_fpath = os.path.join('data', 'results', ln_fit_fname)
ln_df = pd.read_csv(ln_fit_fpath)

ln_df = ln_df[ln_df['official_id'].isin(attr_df['official_id'].values)]
ln_df['mean_runoff_mm_day'] = ln_df['mean_uar'] * 3.6 / 1000
ln_df['sd_runoff_mm_day'] = ln_df['sd_uar'] * 3.6 / 1000

target_columns = [c for c in ln_df.columns if c not in attr_df.columns]
for tc in target_columns:
    # create a dict of the
    target_dict = ln_df[['official_id', tc]].copy().set_index('official_id').to_dict()[tc]
    if tc not in attr_df.columns:
        attr_df[tc] = attr_df['official_id'].apply(lambda x: target_dict[x])

In [5]:
# open an example pairwise results file
input_folder = os.path.join(
    BASE_DIR, "data", "processed_divergence_inputs",
)
pairs_files = os.listdir(input_folder)
rev_date = '20250119'
n_rows = None
# parametric_df = pd.read_csv(os.path.join(input_folder, f'MEMBAKDE_results_{rev_date}.csv'), nrows=n_rows)
# parametric_df.head()
fname = 'Results_estimated_vs_observed_LN_fits_20250118.csv'
bootstrap_result_fpath = os.path.join(os.getcwd(), 'data', 'parametric_fits', fname)
param_df = pd.read_csv(bootstrap_result_fpath)
param_df.drop('Unnamed: 0', inplace=True, axis=1)
param_df.head()

Unnamed: 0,official_id,obs_mean_mm_day,obs_std,pred_mean_mm_day,pred_sigma,KL_KDE_AKDE_2.5,KL_KDE_AKDE_50,KL_KDE_AKDE_97.5,KL_AKDE_LNobs_2.5,KL_AKDE_LNobs_50,KL_AKDE_LNobs_97.5,KL_AKDE_LNest_2.5,KL_AKDE_LNest_50,KL_AKDE_LNest_97.5
0,05AA023,0.763394,1.286813,0.821331,1.201441,0.004435,0.00669,0.011342,1.026485,1.044057,1.065971,0.700195,0.715953,0.735906
1,05AA035,0.772076,1.449857,0.662123,1.017802,0.021181,0.02906,0.043829,0.687753,0.749125,0.806108,0.678864,0.744101,0.802399
2,05AD033,4.073865,5.079821,3.869882,4.717802,0.039788,0.053163,0.073777,0.698456,0.78305,0.865009,0.741019,0.827571,0.912939
3,05BF017,1.162493,1.957323,1.48197,1.963457,2.081178,2.130511,2.189047,0.864798,0.900978,0.940367,0.476573,0.506467,0.536577
4,05BJ010,0.821697,1.169613,0.695347,1.056124,0.011753,0.021304,0.038843,1.333488,1.355958,1.382492,1.697818,1.721834,1.752463


In [6]:
# create a dict of 'official_id': 'drainage area'
da_dict = attr_df[['official_id', 'drainage_area_km2']].copy().set_index('official_id').to_dict()['drainage_area_km2']

In [7]:
centroids = attr_df.apply(lambda row: Point(row['centroid_lon_deg_e'], row['centroid_lat_deg_n']), axis=1)
attr_gdf = gpd.GeoDataFrame(attr_df, geometry=centroids, crs='EPSG:4326')
attr_gdf.drop('unnamed: 0', inplace=True, axis=1)
attr_gdf.reset_index(inplace=True)
# convert to BC Albers for computing distances
attr_gdf = attr_gdf.to_crs(3005)

In [9]:
# preload the FFT KDE fit results (note these are not "error adapted"
kde_file = 'KL_fft_kde_fits_20241226.csv'
kde_fit_df = pd.read_csv(os.path.join('data', 'parametric_divergence_test', kde_file))
unique_proxies = np.unique(kde_fit_df['proxy'].values)
unique_targets = np.unique(kde_fit_df['target'].values)
print(f'{len(unique_proxies)} unique proxies, {len(unique_targets)} unique targets')

1323 unique proxies, 1324 unique targets


In [10]:
# load the predicted parameter results
predict_result_folder = '/home/danbot2/code_5820/24/divergence_measures/docs/notebooks/data/prediction_results/runoff_prediction_results'
best_result_files = [e for e in os.listdir(predict_result_folder) if e.startswith('best_')]
predicted_params, best_result_dfs = [], []
for f in best_result_files:
    param = '_'.join(f.split('_')[4:-1])
    rdf = pd.read_csv(os.path.join(predict_result_folder, f), index_col='official_id')
    rdf = rdf[[c for c in rdf.columns if not c.startswith('Unnamed:')]]
    rdf.columns = [f'{e}_{param}' for e in rdf.columns]
    best_result_dfs.append(rdf)
    predicted_params.append(param)
    
# predicted_params = ['LN_MMO_mu_hat', 'LN_MMO_sd_hat', 'mean_logx', 'sd_logx']
predicted_param_df = pd.concat(best_result_dfs, join='inner', axis=1)
predicted_param_dict = predicted_param_df.to_dict(orient='index')

## Perform KNN distribution estimation


In [11]:
from scipy.spatial import cKDTree
from sklearn.preprocessing import StandardScaler

coords = np.array([[geom.x, geom.y] for geom in attr_gdf.geometry])
stn_tree = cKDTree(coords)

# Create mapping from official_id to index
id_to_index = {oid: i for i, oid in enumerate(attr_gdf["official_id"])}
index_to_id = {i: oid for oid, i in id_to_index.items()}  # Reverse mapping

scaler = StandardScaler()
# Extract values (excluding 'official_id' since it's categorical)
attribute_columns = ['log_drainage_area_km2', 'elevation_m', 'prcp', 'tmean', 'swe', 
                     'land_use_forest_frac_2010', 'land_use_snow_ice_frac_2010', 'land_use_water_frac_2010', 'land_use_wetland_frac_2010']
attr_gdf['log_drainage_area_km2'] = np.log(attr_gdf['drainage_area_km2'])
attr_values = attr_gdf[attribute_columns].to_numpy()
normalized_attr_values = scaler.fit_transform(attr_values)
# Convert normalized distances back to original units
std_devs_attrs = scaler.scale_  # Standard deviation of each feature

attr_tree = cKDTree(normalized_attr_values)

In [12]:
def compute_lognorm_pmf(log_x, mu, sigma, integral_tol=2e-3):
    # Lognormal parameters
    norm_pdf = norm.pdf(log_x, loc=mu, scale=sigma)
    norm_check = np.trapz(norm_pdf, x=log_x)
    lin_grid = np.exp(log_x)
    norm_lin_check = np.trapz(norm_pdf / lin_grid, x=lin_grid) 
    nc1 = np.isclose(norm_check, 1, atol=integral_tol)#, norm_check
    nc2 = np.isclose(norm_lin_check, 1, atol=integral_tol)#, norm_lin_check
    # print(f'Norm integral: {norm_check:.4f}')
    
    norm_cdf = np.cumsum(norm_pdf)
    norm_cdf /= norm_cdf[-1]
    norm_pmf = np.diff(norm_cdf, prepend=0)
    return norm_pmf, norm_pdf
    

In [13]:
# Define the ranges and associated errors
error_points = np.array([0.01, 0.1, 1.0, 10, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7])  # Magnitude points in L/s/km^2
error_values = np.array([1., 0.5, 0.25, 0.15, 0.1, 0.1, 0.1, 0.15, 0.2, 0.25])    # Associated errors (as proportions)

efig = figure(title="Estimated Measurement Error Model", width=600, height=400, x_axis_type='log')
efig.line(error_points, error_values, line_color='red', line_width=2, legend_label='Measurement Error Model')
efig.xaxis.axis_label = r'$$\text{Unit Area Runoff }L s^{-1} \text{km}^{-2}$$'
efig.yaxis.axis_label = r'$$\text{Error } [\text{x}100\%]$$'
efig.legend.background_fill_alpha = 0.5
efig = dpf.format_fig_fonts(efig, font_size=12)
efig = dpf.format_fig_fonts(efig, font_size=12)

layout = gridplot([efig], ncols=2, width=500, height=350)
show(layout)

In [14]:
def check_support_coverage(baseline_pmf, proxy_distribution):
    # find the total mass of the KDE baseline distribution
    # where the proxy distribution = 0
    mask = np.where(proxy_distribution == 0)
    unsupported_pmf = baseline_pmf[mask].sum()
    return unsupported_pmf


In [15]:
def epanechnikov_kernel(u):
    """ Epanechnikov kernel: finite support in [-1, 1] """
    return jnp.where(jnp.abs(u) <= 1, 0.75 * (1 - u**2), 0)

def top_hat_kernel(u):
    return jnp.where(jnp.abs(u) <= 1, 0.5, 0)

def gaussian_kernel(u):
    """ Gaussian kernel: smooth, infinite support """
    return (1 / jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * u**2)

def compute_silverman_approx(log_data):
    q75, q25 = np.percentile(log_data, (75, 25))
    stdev = np.std(log_data)
    A = np.min([stdev, (q75 - q25) / 1.34])
    return 1.06 * A / len(log_data)**0.2


def measurement_error_bandwidth_function(x):
    error_points = jnp.array([1e-4, 1e-3, 1e-2, 1e-1, 1., 1e1, 1e2, 1e3, 1e4, 1e5])  #  Reference flow points in m^3/s
    error_values = jnp.array([1.0, 0.5, 0.2, 0.1, 0.1, 0.1, 0.1, 0.15, 0.2, 0.25])     # Associated errors (as proportions)
    # scaling_factor = baseline_h / jnp.min(error_values) 
    return jnp.interp(x, error_points, error_values) 

  
def compute_measurement_error_informed_adaptive_bandwidth(uar, da):
    """ Compute midpoints in log-space for kernel support boundaries """

    # get the approximated measurement error associated with each unique FLOW value
    flow_data = uar * da / 1000
    unique_q = jnp.unique(flow_data)
    error_model = measurement_error_bandwidth_function(unique_q)

    # error widths must be in unit area runoff log space 
    # to align with the precision bandwidth correction
    # since the KDE is fit to UAR in log space
    unique_UAR = (1000 / da) * unique_q 
    upper_err_UAR = unique_UAR * (1 + error_model)
    err_widths_UAR = np.log(upper_err_UAR) - jnp.log(unique_UAR)
    
    # define bounds for support based on neighbouring unique values
    # Compute midpoints between adjacent log_values
    log_midpoints = np.log((unique_UAR[:-1] + unique_UAR[1:]) / 2)  # Midpoints of internal values
    # Compute left and right extensions by mirroring end widths
    left_mirror = unique_UAR[0] - (log_midpoints[0] - unique_UAR[0])
    right_mirror = unique_UAR[-1] + (unique_UAR[-1] - log_midpoints[-1])
    
    # Prepend and append mirrored values
    log_midpoints = np.concatenate((np.array([left_mirror]), log_midpoints, np.array([right_mirror])))
    # the error distance is half the interval, 
    # then divide by a z-score to represent the proportion of probability mass
    # falling within the range of the midpoints between unique values
    log_diffs = np.diff(log_midpoints) / 2 / 1.15
    
    # integrate where the precision gaps yield larger values than the assumed error
    # bw_vals = np.where(log_diffs > err_widths, log_diffs, err_widths)
    bw_vals = jnp.where(log_diffs > err_widths_UAR, log_diffs, err_widths_UAR)
    # just for interest's sake, find where the precision dominates the error model
    # foo = np.sum(np.where(log_diffs > err_widths, 1, 0))
    # print(f'    ...precision width dominates: {foo}/{len(values)} cases')
    # broadcast the adaptive bandwidth values to the input timeseries
    idx = jnp.searchsorted(unique_UAR, uar)
    return bw_vals[idx]


def adaptive_kde(uar_data, log_grid, da, estimated_grid=None, pdf_error_tol=1e-3, min_allowable_bandwidth=1e-3):
    
    eval_grid = estimated_grid if estimated_grid is not None else log_grid
    
    # compute bandwidths according to a measurement error model
    # incorporating a test for precision vestiges
    bw_values = compute_measurement_error_informed_adaptive_bandwidth(uar_data, da)  # Compute adaptive bandwidth per point

    # Expand dimensions: Tile data and bandwidths for matrix operation
    n, m = len(uar_data), len(eval_grid)
    log_data = jnp.log(uar_data)
    X_grid = jnp.tile(eval_grid, (n, 1))  # Shape: (N, M)
    X_data = jnp.tile(log_data[:, None], (1, m))  # Shape: (N, M)
    H = jnp.tile(bw_values[:, None], (1, m))  # Shape: (N, M)

    # Compute u matrix: (N, M)
    U = (X_grid - X_data) / H

    # Apply Epanechnikov kernel (element-wise)
    # K = epanechnikov_kernel(U) / H  # Scale kernel contributions
    # K = top_hat_kernel(U) / H
    K = gaussian_kernel(U) / H

    # Sum contributions across observations
    pdf = K.sum(axis=0) / n  # Normalize by sample count
    pdf_check = jnp.trapezoid(pdf, x=eval_grid)
    pdf /= pdf_check
    error = abs(pdf_check - 1)
    assert abs(jnp.trapezoid(pdf, x=eval_grid) - 1) < pdf_error_tol, "PDF does not integrate to 1 in adaptive_kde()"
        
    cdf = jnp.cumsum(pdf)
    cdf /= cdf[-1]
    pmf = jnp.diff(cdf, prepend=0)

    # If estimated_grid is used, interpolate PMF onto log_grid
    if estimated_grid is not None:
        pmf_interp = jnp.interp(log_grid, estimated_grid, pmf, left=0, right=0)#, kind='linear', bounds_error=False, fill_value=(0, 0))
        # pmf_interp = interp_func(log_grid)
        pmf_interp = np.where(pmf_interp > 0, pmf_interp, 0)
        # Do not normalize here, leave unsupported probability mass out
        return pmf_interp, pdf

    pmf = np.where(pmf > 0, pmf, 0)
    return pmf, pdf
    


In [16]:
def single_kde_fit(log_data, log_grid):
    kde_pdf = FFTKDE(bw='ISJ').fit(log_data).evaluate(log_grid)
    # Extract the estimated bandwidth
    pdf_check = np.trapz(kde_pdf, x=log_grid)
    kde_pdf /= pdf_check
    # check that the numerical integration over the KDE pdf is close to 1
    assert np.isclose(pdf_check, 1, atol=1e-5), f'{pdf_check:.4f} {kde_pdf[:5]} - {kde_pdf[-5:]}'
    # pdfs[:, i] = sample_pdf
    kde_cdf = np.cumsum(kde_pdf)
    kde_cdf /= kde_cdf[-1]
    kde_pmf = np.diff(kde_cdf, prepend=0)
    assert np.abs(np.sum(kde_pmf) - 1.0) <= 0.001
    return kde_pmf, kde_pdf


In [17]:
def compute_similarity_weights(stns, attributes=None):
    cols = ['official_id', 'log_drainage_area_km2', 'elevation_m', 'prcp', 'tmean', 'swe']
    if len(stns) == 1:
        return [1.0]
    attrs = attr_df[attr_df['official_id'].isin(stns['official_id'])][cols].copy()
    attrs = attrs.set_index('official_id').loc[stns['official_id'].values].reset_index()
    # make sure the official id order is maintained so the weights are properly assigned
    assert np.array_equal(stns['official_id'].values, attrs['official_id'].values)
    attrs = attrs[[c for c in attrs.columns if c != 'official_id']]
    # normalize weights first within column, then equally along rows
    attrs = 1.0 * (attrs - attrs.min()) / (attrs.max() - attrs.min())
    normalized_similarity = attrs.mean(axis=1).values
    return normalized_similarity

In [18]:
def find_k_nearest_neighbors(target_index, tree_type, k=3):
    # Query the k+1 nearest neighbors because the first neighbor is the target point itself
    if tree_type in ['EW', 'IDW']:
        distances, indices = stn_tree.query(coords[target_index], k=k+1)
    elif tree_type == 'CAS':
        # Example query: Find the 3 nearest neighbors for the first point
        distances_norm, indices = attr_tree.query(normalized_attr_values[target_index], k=k+1)
        distances = distances_norm * np.linalg.norm(std_devs_attrs)
    else:
        raise Exception('tree type not identified, must be one of EW, IDW, or CAS.')

    # Remove the target itself from the results
    neighbor_indices = indices[1:]
    neighbor_distances = distances[1:]
    return neighbor_indices, np.round(neighbor_distances / 1000, 1)

In [19]:
def check_neighbours(stn, distance, data):
    proxy_df = dpf.get_timeseries_data(stn)
    # data.set_index('time', inplace=True)
    proxy_df.set_index('time', inplace=True)
    df = pd.concat([data, proxy_df], axis=1, join='inner')
    pct_covered = len(df) / len(data)
    return (stn, distance, proxy_df) if (pct_covered >= 0.75) & (len(df) > 350) else None


In [20]:
from bokeh.plotting import figure, output_file, save
from bokeh.models import Div
from bokeh.layouts import layout

def output_figure(target_stn, eval_grid, log_data, predicted_param_dist, lognorm_dist, mom_ln_dist, kde_dist, epan_kde_dist, best_knn, fpath, result, support_dict, plot_type='pdf'):
    print('    processing output figure')
    fig = figure(title=f'{target_stn} PDF estimation', width=900, height=450, x_axis_type='log')
    hist, log_edges = np.histogram(log_data, bins=2**6, density=True)
    edges = np.exp(log_edges)
    
    # convert density to mass
    cdf = np.cumsum(hist)
    cdf /= cdf[-1]
    pmf = np.diff(cdf, prepend=cdf[0])
    pmf /= pmf.sum()
    
    if plot_type == 'pmf':
        hist = pmf
    
    if not np.isclose(np.sum(pmf), 1, atol=1e-3):
        print(f'Histogram sum != 1: {pmf.sum():.3f}')

    yrs = len(log_data) / 365
    fig.quad(bottom=0, left=edges[:-1], right=edges[1:], top=hist, color='lightgreen', fill_alpha=0.5, legend_label=f'Data (N={yrs:.1f} yrs)')

    x = np.exp(eval_grid)
    fig.line(x, predicted_param_dist, color='dodgerblue', line_dash='dotted', line_width=2, legend_label='Predicted LN Params')
    fig.line(x, lognorm_dist, color='dodgerblue', line_dash='solid', line_width=2, legend_label='MLE LogNorm')
    fig.line(x, mom_ln_dist, color='dodgerblue', line_dash='dashed', line_width=2, legend_label='MOM LogNorm')
    fig.line(x, kde_dist, color='grey', line_dash='solid', line_width=2, legend_label='KDE')
    fig.line(x, epan_kde_dist, color='black', line_dash='solid', line_width=2, legend_label='EpanKDE')
    fig.line(x, best_knn['EW'][2], color='orange', line_dash='solid', line_width=2, legend_label=best_knn['EW'][0])
    fig.line(x, best_knn['IDW'][2], color='orange', line_dash='dashed', line_width=2, legend_label=best_knn['IDW'][0])
    fig.line(x, best_knn['CAS'][2], color='orange', line_dash='dotted', line_width=2, legend_label=best_knn['CAS'][0])
    fig.line(x, best_knn['CASdist'][2], color='darkorange', line_dash='solid', line_width=2, legend_label=best_knn['CASdist'][0])
    
    fig.legend.click_policy='hide'
    fig.add_layout(fig.legend[0], 'right')
    fig.legend.background_fill_alpha = 0.5
    fig.xaxis.axis_label = r'$$\text{Runoff } [Ls^{-1}\text{km}^{-2}]$$'
    fig.yaxis.axis_label = r'$$\text{Probability Density}$$'
    fig = dpf.format_fig_fonts(fig, font_size=14)
    res_table = pd.DataFrame(result, index=['DKL']).T    
    res_table.index.name = 'Model'
    res_table = res_table.sort_values(by='DKL')
    res_table.reset_index(inplace=True)

    min_res = res_table['DKL'].min()
    res_table['pct_from_top'] = round(100 * (res_table['DKL'] - min_res) / min_res, 0)
    suppt_table = pd.DataFrame(support_dict, index=['unsupported_mass']).T
    suppt_table.index.name = 'Model'
    suppt_table = suppt_table.sort_values(by='unsupported_mass')
    suppt_table.reset_index(inplace=True)
    top_div = Div(text=res_table.head(5).to_html(index=True, border=0))
    divs, sup_tabs, param_tabs = [], [], []
    np_models = ['EW_', 'IDW_', 'CAS_', 'CASdist_']
    for tp in np_models:
        res = res_table[res_table['Model'].str.contains(tp)].copy().round(3)
        tab = Div(text=res.to_html(index=True, border=0))
        divs.append(tab)
        sup = suppt_table[suppt_table['Model'].str.contains(tp)].copy().round(2)
        result2 = Div(text=sup.to_html(index=False, border=0))
        sup_tabs.append(result2)
    p_mods = [c for c in res_table['Model'].values if all(n not in c for n in np_models)]
    res = res_table[res_table['Model'].isin(p_mods)].copy().round(3)
    tab = Div(text=res.to_html(index=True, border=0))
    param_tabs.append(tab)
        
    lt = layout([[fig], [top_div, param_tabs], divs, sup_tabs])
    # show(fig)
    output_file(filename=fpath, title=f"{target_stn}")
    save(lt)

In [1]:
class Experiment:
    def __init__(self, i, row, k_nearest=10, n_neighbours_to_check=150, n_grid_points=2**14, left_log=-3, right_log=3):
        for k, v in row.items():
            setattr(self, k, v)

        self.result, self.support_dict = {}, {}  
        stn = self.official_id # target catchment
        self.figure_fpath = os.path.join(BASE_DIR, 'data', 'knn_comparison_plots', f'{stn}_distribution_prediction.html')
        self.target_stn = stn
        # self.params = {k: v[0] for k, v in ln_df[ln_df['official_id'] == stn].copy().to_dict(orient='list').items()}
        self.params = predicted_param_dict[stn] 
        
        # import the streamflow data and do uar and log conversions
        self.stn_df = dpf.get_timeseries_data(stn)
        self.stn_df.set_index('time', inplace=True)
        self.stn_df[f'{stn}_uar'] = 1000 * self.stn_df[stn] / self.drainage_area_km2
        
        self.uar = self.stn_df[f'{stn}_uar'].dropna().values
        self.log_uar = np.log(self.uar.reshape(-1, 1))
        self.n_obs = len(self.uar)
        self.n_grid_points = n_grid_points
        self.n_neighbours_to_check = n_neighbours_to_check
        self.k_nearest = k_nearest
        self.index_to_id = index_to_id
        self.id_to_index = id_to_index
        self.da_dict = da_dict  # Store drainage areas
        self.target_da = da_dict[stn]
        self.target_tree_index = i
        self.initialize_nearest_neighbour_data()
        self.left_log = left_log
        self.right_log = right_log
        self.set_grid()
        print('    ...completed initialization.')

    def set_nearest_nbr_data(self, tree_type):

        neighbour_idxs, distances = find_k_nearest_neighbors(self.target_tree_index, tree_type, k=n_neighbours_to_check)        
        neighbours = attr_gdf.iloc[neighbour_idxs]['official_id'].tolist()
        # for each proxy, load the MLE and predicted mean and standard deviation from XGBoost model  
        # checked_neighbours = Parallel(n_jobs=-1, backend="loky")(delayed(check_neighbours)(stn, dist, df.copy()) 
                                              # for stn, dist in zip(neighbours, distances))
        with ThreadPoolExecutor(max_workers=22) as executor:  # Adjust max workers
            checked_neighbours = list(executor.map(check_neighbours, neighbours, distances, [self.stn_df.copy()]*len(neighbours)))
    
        good_nbrs = [e for e in checked_neighbours if e is not None]
        if len(good_nbrs) == 0:
            raise Exception('No suitable nearest neighbours found')
        
        good_nbrs = sorted(good_nbrs, key=lambda tup: tup[1])
        nbr_data = pd.DataFrame([e[:2] for e in good_nbrs], columns=['official_id', 'distance'])
        
        nbr_df = pd.concat([e[2] for e in good_nbrs], join='inner', axis=1)
        n_found = len(good_nbrs)
        if n_found < self.k_nearest:
            raise Exception(f'{n_found}/{self.k_nearest} suitable nearest neighbours found')
        return nbr_df, nbr_data


    def initialize_nearest_neighbour_data(self):
        print(f'    ...searching for minimum {self.k_nearest} in {self.n_neighbours_to_check} nearest neighbours with minimum concurrent record.')
        self.nbr_df, self.nbr_data = self.set_nearest_nbr_data('EW')
        self.nbr_df_attr, self.nbr_data_attr = self.set_nearest_nbr_data('CAS')
        # set spatial distance in attribute space dataframe
        self.nbr_data_attr['spatial_dist'] = self.nbr_data_attr['official_id'].apply(lambda x: self.query_distance(stn_tree, x, self.target_stn)/1e3)
        # set attribute distances in the geographic distance dataframe
        self.nbr_data['attr_dist'] = self.nbr_data['official_id'].apply(lambda x: self.query_distance(attr_tree, x, self.target_stn))

        self.nbr_data = self.normalize_dataframe(self.nbr_data)
        self.nbr_data_attr = self.normalize_dataframe(self.nbr_data_attr)
        

    def set_grid(self):
        epsilon = 1e-6 
        minx, maxx = np.min(self.uar) - epsilon, np.max(self.uar) + epsilon
        self.baseline_log_grid = np.linspace(np.log(minx) + self.left_log, np.log(maxx) + self.right_log, self.n_grid_points)
        self.baseline_lin_grid = np.exp(self.baseline_log_grid)

    
    def query_distance(self, tree, id1, id2):
        """Query distance between two points in a tree using official_id."""
        if id1 not in self.id_to_index or id2 not in self.id_to_index:
            raise ValueError(f"One or both IDs ({id1}, {id2}) not found.")
    
        # Get indices from ID mapping
        index1, index2 = self.id_to_index[id1], self.id_to_index[id2]        
        # Query the distance
        distance = np.linalg.norm(tree.data[index1] - tree.data[index2])  # Euclidean distance
        return distance

    def compute_MLE_lognorm(self):
        obs_mean, obs_std = self.params['actual_mean_logx'], self.params['actual_sd_logx']
        self.lognorm_pmf, self.lognorm_pdf = compute_lognorm_pmf(self.baseline_log_grid, obs_mean, obs_std)
        kld, support = self.compute_kld(self.baseline_pmf, self.lognorm_pmf)
        self.result['LN_MLE_DKL'] = kld.item()
        self.support_dict['LN_MLE_DKL'] = support
        

    def ensemble_distribution_estimates(self, knn_df, ensemble_label, distance_weights=None, epsilon=0.5):
    
        pdfs, pmfs = pd.DataFrame(), pd.DataFrame()
        log_knn = np.log(knn_df)
        
        for c in knn_df.columns:
            proxy_stn = c.split('_')[0]
            # print('proxy stn: ', proxy_stn, da_dict[proxy_stn])
            
            est_min, est_max = knn_df[c].min(), knn_df[c].max()
            est_grid = np.linspace(np.log(est_min) - self.left_log, np.log(est_max) + self.right_log, self.n_grid_points)
            k_pmf, k_pdf = adaptive_kde(knn_df[c].values, self.baseline_log_grid, da_dict[proxy_stn], estimated_grid=est_grid)
            pmfs[c], pdfs[c] = k_pmf, k_pdf
        # print('distance weights: ', distance_weights)
        # Normalize distance weights
        if distance_weights is not None:
            distance_weights /= np.sum(distance_weights)
    
        if distance_weights is not None:
            distance_weights = np.array(distance_weights)  # Ensure 1D array
            pdf_est = pdfs.to_numpy() @ distance_weights
        else:
            pdf_est = pdfs.mean(axis=1).to_numpy()
        # Check integral before normalization
        pdf_check = np.trapz(pdf_est, x=self.baseline_log_grid)#.reshape(-1, 1)
        if not np.isclose(pdf_check, 1, atol=1e-3):
            pdf_est /= pdf_check  # Only normalize if necessary
    
        # Compute CDF and PMF
        cdf_est = np.cumsum(pdf_est)
        cdf_est /= cdf_est[-1]
        pmf_est = np.diff(cdf_est, prepend=0)
        
        return pmf_est, pdf_est

    
    def compute_weights(self, distances, power=2):
        """Compute normalized inverse distance weights."""
        distances = np.maximum(distances, 1e-4)  # Prevent division by zero
        weights = 1 / (distances ** power)
        return weights / weights.sum()  # Normalize to sum to 1

    
    def normalize_dataframe(self, df, exclude_col='official_id'):
        """Normalize all columns except `exclude_col` using min-max scaling."""
        cols = [c for c in df.columns if c != exclude_col]
        df[cols] = (df[cols] - df[cols].min()) / (df[cols].max() - df[cols].min())
        return df

            
    def format_knn_pmf_inputs(self, nbrs, knn_df):
        # Compute Unit Area Runoff (UAR) for each neighbor
        cols = []
        for s in nbrs:
            uar_col = f'{s}_uar'    
            da = da_dict[s]
            assert ~np.isnan(da)
            knn_df[uar_col] = 1000 * knn_df[s] / da
            cols.append(uar_col)
        return knn_df, cols
        
    
    def estimate_knn_pmf_pdf(self, df, cols, weights=None):
        """Estimate PMF and PDF using adaptive KDE with optional weights."""
        
        assert ~df[cols].empty, 'dataframe is empty'
        assert df[cols].notna().all().all(), "NaN values found in df[cols] before processing"
        if weights is not None:
            assert ~np.any(np.isnan(weights)), f'nan weight found: {weights}'
            assert np.isclose(np.sum(weights), 1), f'weights do not sum to 1: {weights}'
            assert (weights > 0).all(), f'not all weights > 0, {weights}'
            estimate = df[cols].mul(weights, axis=1).sum(axis=1)
        else:
            estimate = df[cols].mean(axis=1)

        est_min, est_max = df[cols].min().min(), df[cols].max().max()
        est_grid = np.linspace(np.log(est_min) - self.left_log, np.log(est_max) + self.right_log, self.n_grid_points)
    
        assert (estimate >= 0).all(), f"Estimate < 0 detected: {np.min(estimate)}"
        pmf, pdf = adaptive_kde(estimate.values, self.baseline_log_grid, self.target_da, estimated_grid=est_grid)
        return pmf, pdf
        

    def knn_pmf_estimation(self):
        """
        Generate PDF/PMF estimates for the target catchment using kNN.
        """
        knn_pmfs, knn_pdfs = pd.DataFrame(), pd.DataFrame()
        stn = self.official_id
        print(f'    Processing kNN for {stn}')
        
        t0 = time()
        for k in range(1, self.k_nearest+1):
            # Get spatial kNN
            k_nbrs = self.nbr_data.iloc[:k].copy()
            nbr_stns = k_nbrs['official_id'].values
            knn_df, knn_cols = self.format_knn_pmf_inputs(nbr_stns, self.nbr_df.copy()) 

            # Compute Equal Weighting (EW) PMF/PDF
            label = f'{k}NN_EW_{stn}'
            knn_pmfs[label], knn_pdfs[label] = self.estimate_knn_pmf_pdf(knn_df, knn_cols)

            # Compute IDW-based PMF/PDF
            k_nbrs['distance'] = np.maximum(k_nbrs['distance'].values, 1e-5)
            distance_weights = self.compute_weights(k_nbrs['distance'].values)
            
            label = f'{k}NN_IDW_{stn}'
            knn_pmfs[label], knn_pdfs[label] = self.estimate_knn_pmf_pdf(knn_df, knn_cols, distance_weights)
                
            # Get attribute-space kNN
            k_nbrs_attr = self.nbr_data_attr.iloc[:k].copy()
            nbr_stns_attr = k_nbrs_attr['official_id'].values
            knn_df_attr, knn_attr_cols = self.format_knn_pmf_inputs(nbr_stns_attr, self.nbr_df_attr.copy())    
    
            # Compute CAS-based PMF/PDF
            if k == 0:
                cas_weights = np.array([1])
            else:
                k_nbrs_attr['distance'] = np.maximum(k_nbrs_attr['distance'].values, 1e-5)
                cas_weights = self.compute_weights(k_nbrs_attr['distance'].values)
                
            label = f'{k}NN_CAS_{stn}'
            knn_pmfs[label], knn_pdfs[label] = self.estimate_knn_pmf_pdf(knn_df_attr, knn_attr_cols, cas_weights)
            # Compute CAS + Distance PMF/PDF
            k_nbrs_dist_attr = self.nbr_data.iloc[:k].copy()
            nbr_stns_dist_attr = k_nbrs_dist_attr['official_id'].values
            knn_df_dist_attr, knn_dist_attr_cols = self.format_knn_pmf_inputs(nbr_stns_dist_attr, self.nbr_df.copy())  
            
            k_nbrs_dist_attr['sum'] = np.maximum(k_nbrs_dist_attr['distance'] + k_nbrs_dist_attr['attr_dist'], 1e-3)
            cas_dist_weights = self.compute_weights(k_nbrs_dist_attr['sum'].values)
    
            label = f'{k}NN_CASdist_{stn}'
            knn_pmfs[label], knn_pdfs[label] = self.estimate_knn_pmf_pdf(knn_df_dist_attr, knn_dist_attr_cols, cas_dist_weights)
            # Compute ensemble estimates
            knn_sets = [
                ("EW", knn_df, knn_cols, None), 
                ("IDW", knn_df, knn_cols, distance_weights), 
                ("CAS", knn_df_attr, knn_attr_cols, cas_weights), 
                ("CASdist", knn_df_dist_attr, knn_dist_attr_cols, cas_dist_weights)
            ]
            for method, df, cols, weights in knn_sets:
                ensemble_label = f'{k}NN_{method}_ensemble_{stn}'
                pmf_est, pdf_est = self.ensemble_distribution_estimates(df[cols].copy(), ensemble_label, distance_weights=weights)
                if pmf_est is None:
                    return [], [], test_data
                knn_pdfs[ensemble_label] = pdf_est
                knn_pmfs[ensemble_label] = pmf_est
    
        return knn_pmfs, knn_pdfs
        

    def compute_pmf_from_predicted_params(self, mu_hat, sd_hat):
               
        norm_pdf = norm.pdf(self.baseline_log_grid, loc=self.params[mu_hat], scale=self.params[sd_hat])
        pdf_check = np.trapz(norm_pdf, x=self.baseline_log_grid)
        norm_pdf /= pdf_check
    
        if not np.isclose(np.trapz(norm_pdf, x=self.baseline_log_grid), 1, atol=1e-3):
            if norm_pdf[0] > norm_pdf[-1]:
                self.left_log -= 1
            else:
                self.right_log += 1
            msg = f'   Predicted param pdf_check failed: {pdf_check:.5f} {norm_pdf[:5]}, {norm_pdf[-5:]} new log range {self.left_log}-{self.right_log}'
            print(msg)
            raise Exception(msg)
            
        norm_cdf = norm_pdf.cumsum()
        norm_cdf /= norm_cdf[-1]
        norm_pmf = np.diff(norm_cdf, prepend=0)
        return norm_pmf, norm_pdf
        

    def process_best_knn_results(self, knn_pmfs, knn_pdfs):
        print('    ...processing knn result')
        knn_results_dict = {}
        for knn_type in ['EW', 'IDW', 'CAS', 'CASdist']:
            knn_cols = [c for c in knn_pmfs.columns if f'_{knn_type}_' in c]    
            min_knn, best_knn = 1e9, None
            for c in knn_cols: 
                q = knn_pmfs[c].values
                q_pdf = knn_pdfs[c].values               
                knn_kld, support = self.compute_kld(self.baseline_pmf, q)
                self.support_dict[c] = support
                prior_bias = support['bias']
                if prior_bias > 0.1 * knn_kld.item():
                    pct_bias = round(100 * (prior_bias / knn_kld), 1)
                    self.support_dict[c]['pct_of_signal'] = pct_bias
                    # print(f'    {c}: Prior bias {prior_bias:.3f} bits/sample bias {pct_bias:.1f}% of the KLD')
                    
                self.result[c] = knn_kld.item()
                if knn_kld < min_knn:
                    min_knn = knn_kld.item()
                    best_knn = (c, q, q_pdf)
            knn_results_dict[knn_type] = best_knn
        return knn_results_dict

        
    def compute_kld(self, p, q, prior=1):
        # Ensure q is at least 2D for consistent broadcasting
        mask = (p > 0) #& (q > 0)
        kld_array = jnp.zeros_like(p)
        unsupported_mass = check_support_coverage(p, q)
        prior_bias = 0
        if not (q > 0).all():
            q_mod = self.n_obs * q + [prior for _ in q]
            q_mod /= q_mod.sum()
            prior_bias = jnp.sum(jnp.where(mask, p * jnp.log2(p / q_mod), 0))
            q = q_mod
            
        prior_bias_dict = {
            'bias': prior_bias, 
            'prior': prior,
            'unsupported_mass': round(100 * unsupported_mass, 1)
        }
        return jnp.sum(jnp.where(mask, p * jnp.log2(p / q), 0)), prior_bias_dict
        

    def process_target(self):
        stn = self.official_id # target catchment
        
        self.kde_pmf, self.kde_pdf = single_kde_fit(self.log_uar, self.baseline_log_grid)
        print('   ...processed single kde fit')
        
        self.baseline_pmf, self.baseline_pdf = adaptive_kde(self.uar, self.baseline_log_grid, self.drainage_area_km2)        
        print('   ...adaptive (baseline) kde fit')
        
        # compute parametric PMFs (Lognorm MLE and predicted params) for the target
        # compute the pmf from the lognorm parameters predicted from catchment attributes
        mom_param_pmf, mom_param_pdf = self.compute_pmf_from_predicted_params('predicted_LN_MMO_mu_hat', 'predicted_LN_MMO_sd_hat')
        kld, support = self.compute_kld(self.baseline_pmf, mom_param_pmf)
        self.result['LN_MOM_DKL'] = kld.item()
        self.support_dict['LN_MOM_DKL'] = support
        print('   ...processed pmf from predicted MOM parameters')
        
        predicted_param_pmf, predicted_param_pdf = self.compute_pmf_from_predicted_params('predicted_mean_logx', 'predicted_sd_logx')
        kld, support = self.compute_kld(self.baseline_pmf, predicted_param_pmf)
        self.result['LN_predicted_params_DKL'] = kld.item()
        self.support_dict['LN_predicted_params_DKL'] = support
        print('   ...processed pmf from direct predicted LN parameters')
        
        tc = time()
        knn_pmf_estimates, knn_pdf_estimates = self.knn_pmf_estimation()
        td = time()
        print(f'Time to complete knn: {td-tc:.1f}')
        knn_results_dict = self.process_best_knn_results(knn_pmf_estimates, knn_pdf_estimates)
        
        # keep track of whether the lognorm yields incomplete support coverage
        self.compute_MLE_lognorm()
            
        foo = pd.DataFrame(self.result, index=['DKL']).round(4)
        foo.index.name = 'Model'
        print(foo.T.sort_values(by='DKL').head(10))
        
        output_figure(stn, self.baseline_log_grid, self.log_uar, predicted_param_pdf, 
                      self.lognorm_pdf, mom_param_pdf, self.kde_pdf, self.baseline_pdf, 
                      knn_results_dict, self.figure_fpath, self.result, self.support_dict)
        
        return self.result, self.support_dict
   


In [2]:
display_problem_fig = True
n_bootstrap_samples = 100
n_grid_points = 2**12
n_neighbours_to_check = 150
k_nearest = 10
to_check = ['05AD003','05AD031', '05AB022', '05AB030', '05AD031', '12091050']
n_processed = 1

for i, row in attr_gdf.iterrows():
    stn = row['official_id']
    result_fpath = os.path.join(BASE_DIR, 'data', 'temp', f'{stn}_knn_result')
    support_fpath = os.path.join(BASE_DIR, 'data', 'temp', f'{stn}_knn_support_result')
    # if not os.path.exists(fpath): 
    
    experiment = Experiment(i, row, k_nearest=k_nearest, n_neighbours_to_check=n_neighbours_to_check, n_grid_points=n_grid_points)
    left_log, right_log = -2, 2
    result, support_dict = experiment.process_target()                
    res = pd.DataFrame(result, index=list(range(len(result))))
    support_res = pd.DataFrame(support_dict, index=list(range(len(result))))
    res.to_csv(result_fpath)
    support_res.to_csv(support_fpath)

NameError: name 'attr_gdf' is not defined