In [None]:
#%% Setting Up
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import xarray as xr
# import xesmf as xe
import networkx as nx
# import rioxarray as rxr

import geopandas as gpd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from shapely.geometry import Point
from shapely.geometry import Polygon

import glob
import os
import itertools
import tqdm
import gc
import time
import pickle

from joblib import Parallel, delayed

import torch

import configparser
cfg = configparser.ConfigParser()
cfg.optionxform = str
cfg.read('/home/sarth/rootdir/datadir/assets/defaults.ini')
cfg = {s: dict(cfg.items(s)) for s in cfg.sections()}
PATHS = cfg['PATHS']

print("Setting up...")

In [None]:
#%% Region-Specific: CAMELS-US
DIRNAME = '03min_GloFAS_CAMELS-US'
SAVE_PATH = os.path.join(PATHS['devp_datasets'], DIRNAME)
resolution = 0.05
lon_360_180 = lambda x: (x + 180) % 360 - 180 # convert 0-360 to -180-180
lon_180_360 = lambda x: x % 360 # convert -180-180 to 0-360
region_bounds = {
    'minx': -130,
    'miny': 20,
    'maxx': -65,
    'maxy': 50
}
camels_attributes_graph = pd.read_csv(os.path.join(SAVE_PATH, 'graph_attributes.csv'), index_col=0)
camels_attributes_graph.index = camels_attributes_graph.index.map(lambda x: str(x).zfill(8))
camels_attributes_graph['huc_02'] = camels_attributes_graph['huc_02'].map(lambda x: str(x).zfill(2))
camels_graph = camels_attributes_graph.copy()
camels_graph = camels_graph[camels_graph['area_percent_difference'] < 10]
camels_graph = camels_graph[camels_graph['num_nodes'] > 1]
print(f"Number of CAMELS-US catmt's: {len(camels_graph)}")
del camels_attributes_graph

region_shp = gpd.read_file(os.path.join(PATHS['watershed-boundary-dataset'], 'huc02', 'shapefile.shp'), crs = 'epsg:4326')
all_watersheds = region_shp.copy()
all_watersheds = all_watersheds.rename(columns={'huc2': 'watershed'})
all_watersheds['huc_02'] = all_watersheds['watershed'].map(lambda x: x.split('_')[0])

temp = gpd.read_file(os.path.join(PATHS['CAMELS'], 'CAMELS-US', 'HCDN_nhru_final_671.shp'), crs = 'epsg:4326')
temp = temp[['hru_id', 'geometry']]
temp['hru_id'] = temp['hru_id'].map(lambda x: str(x).zfill(8))
temp = temp.set_index('hru_id')

all_catchments = camels_graph.merge(temp, left_index=True, right_index=True, how='left')
all_catchments = all_catchments[['huc_02', 'gauge_lon', 'gauge_lat', 'area_geospa_fabric', 'geometry', 'snapped_lon', 'snapped_lat']]
all_catchments = gpd.GeoDataFrame(all_catchments, crs='epsg:4326', geometry='geometry')
all_catchments = all_catchments.reset_index()
del temp

In [None]:
from scipy import stats, signal

def _mask_valid(pred, true):
    mask = ~np.isnan(true) & ~np.isnan(pred)
    pred = pred[mask]
    true = true[mask]
    pred[pred < 0] = 0
    true[true < 0] = 0
    return pred, true

def RMSE(pred, true):
    pred, true = _mask_valid(pred, true)
    return np.sqrt(np.mean((true - pred)**2))

def pearsonr(pred, true):
    pred, true = _mask_valid(pred, true)
    r, _ = stats.pearsonr(true, pred)
    return r

def NSE(pred, true):
    pred, true = _mask_valid(pred, true)
    return 1 - np.sum((true - pred)**2) / np.sum((true - np.mean(true))**2)

def KGE(pred, true):
    pred, true = _mask_valid(pred, true)
    r = pearsonr(pred, true)
    alpha = np.std(pred) / np.std(true)
    beta = np.mean(pred) / np.mean(true)
    return 1 - np.sqrt((r - 1)**2 + (alpha - 1)**2 + (beta - 1)**2)

def PBIAS(pred, true):
    pred, true = _mask_valid(pred, true)
    return np.sum(true - pred) / np.sum(true) * 100

def alpha_NSE(pred, true):
    pred, true = _mask_valid(pred, true)
    return np.std(pred) / np.std(true)

def beta_NSE(pred, true):
    pred, true = _mask_valid(pred, true)
    return (np.mean(pred) - np.mean(true)) / np.std(true)

def _get_fdc(data):
    data = np.sort(data)[::-1]
    return data

def fdc_fms(pred, true, lower = 0.2, upper = 0.7):
    pred, true = _mask_valid(pred, true)

    sim = _get_fdc(pred)
    obs = _get_fdc(true)
    sim[sim <= 0] = 1e-6
    obs[obs <= 0] = 1e-6

    qsm_lower = np.log(sim[np.round(lower * len(sim)).astype(int)])
    qsm_upper = np.log(sim[np.round(upper * len(sim)).astype(int)])
    qom_lower = np.log(obs[np.round(lower * len(obs)).astype(int)])
    qom_upper = np.log(obs[np.round(upper * len(obs)).astype(int)])

    fms = ((qsm_lower - qsm_upper) - (qom_lower - qom_upper)) / (qom_lower - qom_upper + 1e-6)

    return fms * 100

def fdc_fhv(pred, true, h = 0.02):
    pred, true = _mask_valid(pred, true)

    sim = _get_fdc(pred)
    obs = _get_fdc(true)

    obs = obs[:np.round(h * len(obs)).astype(int)]
    sim = sim[:np.round(h * len(sim)).astype(int)]

    fhv = np.sum(sim - obs) / np.sum(obs)

    return fhv * 100

def fdc_flv(pred, true, l = 0.3):
    pred, true = _mask_valid(pred, true)

    sim = _get_fdc(pred)
    obs = _get_fdc(true)
    sim[sim <= 0] = 1e-6
    obs[obs <= 0] = 1e-6

    obs = obs[-np.round(l * len(obs)).astype(int):]
    sim = sim[-np.round(l * len(sim)).astype(int):]

    # transform values to log scale
    obs = np.log(obs)
    sim = np.log(sim)

    # calculate flv part by part
    qsl = np.sum(sim - sim.min())
    qol = np.sum(obs - obs.min())

    flv = -1 * (qsl - qol) / (qol + 1e-6)

    return flv * 100

def mean_peak_timing(pred, true, window = 3):
    pred, true = _mask_valid(pred, true)

    peaks, _ = signal.find_peaks(true, distance=2*window, prominence=np.std(true))

    # pred_idx_lst = []
    timing_error_lst = []
    for idx in peaks:
        if (pred[idx] > pred[idx - 1]) and (pred[idx] > pred[idx + 1]):
            peak_pred = pred[idx]
            peak_pred_idx = idx
        else:
            peak_pred_idx = np.argmax(pred[max(idx - window,0):idx + window + 1]) + max(idx - window,0)
            peak_pred = pred[peak_pred_idx]
        # pred_idx_lst.append(peak_pred_idx)
    
        peak_true = true[idx]
        timing_error = np.abs(peak_pred_idx - idx) 
        timing_error_lst.append(timing_error)
    
    mean_timing_error = np.mean(timing_error_lst) if len(timing_error_lst) > 0 else np.nan

    return mean_timing_error

def missed_peaks(pred, true, window = 3, threshold = 80):
    pred, true = _mask_valid(pred, true)

    peaks_obs_times, _ = signal.find_peaks(true, distance=2*window, height = np.percentile(true, threshold))
    peaks_sim_times, _ = signal.find_peaks(pred, distance=2*window, height = np.percentile(pred, threshold))
    
    missed_events = 0
    for idx in peaks_obs_times:
        nearby_peak_sim_index = np.where(np.abs(peaks_sim_times - idx) <= window)[0]
        if len(nearby_peak_sim_index) == 0:
            missed_events += 1
            # print(idx)
    
    missed_peak_values = (missed_events / len(peaks_obs_times)) * 100 if len(peaks_obs_times) > 0 else np.nan

    return missed_peak_values

def F1_score_of_capturing_peaks(pred, true, window = 3, threshold = 80):
    pred, true = _mask_valid(pred, true)

    peaks_obs_times, _ = signal.find_peaks(true, distance=2*window, height = np.percentile(true, threshold))
    peaks_sim_times, _ = signal.find_peaks(pred, distance=2*window, height = np.percentile(pred, threshold))
    
    true_positive_peaks = 0 # peak in obs and nearby in sim
    true_negative_peaks = 0 # no peak in obs and sim
    false_positive_peaks = 0 # peak in sim but not nearby in obs
    false_negative_peaks = 0 # peak in obs but not nearby in sim

    for idx in peaks_obs_times:
        nearby_peak_sim_index = np.where(np.abs(peaks_sim_times - idx) <= window)[0]
        if len(nearby_peak_sim_index) > 0:
            true_positive_peaks += 1
        else:
            false_negative_peaks += 1
    
    for idx in peaks_sim_times:
        nearby_peak_obs_index = np.where(np.abs(peaks_obs_times - idx) <= window)[0]
        if len(nearby_peak_obs_index) == 0:
            false_positive_peaks += 1

    precision = true_positive_peaks / (true_positive_peaks + false_positive_peaks) if (true_positive_peaks + false_positive_peaks) > 0 else np.nan
    recall = true_positive_peaks / (true_positive_peaks + false_negative_peaks) if (true_positive_peaks + false_negative_peaks) > 0 else np.nan
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else np.nan

    return f1_score

In [None]:
import numpy as np
import xarray as xr

def compute_metrics_ds(y_pred, y_true):
    """
    Compute a set of metrics for each lead time and catchment.
    y_pred and y_true: numpy arrays of shape (time_idx, lead_time, catmt_idx)
    
    Returns:
      xr.Dataset with dims ("time_idx", "lead_time", "catmt_idx") and variables:
          - RMSE, PearsonR, NSE, KGE, PBIAS, alpha_NSE, beta_NSE,
            FDC_FMS, FDC_FHV, FDC_FLV, mean_peak_timing, missed_peaks, F1_score
          - Also includes raw y_pred and y_true.
    """

    time_steps, num_leadtimes, num_catmt = y_pred.shape

    # Prepare arrays to hold computed metrics [lead_time, catmt_idx]
    rmse_arr            = np.empty((num_leadtimes, num_catmt))
    pearson_arr         = np.empty((num_leadtimes, num_catmt))
    nse_arr             = np.empty((num_leadtimes, num_catmt))
    kge_arr             = np.empty((num_leadtimes, num_catmt))
    pbias_arr           = np.empty((num_leadtimes, num_catmt))
    alpha_nse_arr       = np.empty((num_leadtimes, num_catmt))
    beta_nse_arr        = np.empty((num_leadtimes, num_catmt))
    fdc_fms_arr         = np.empty((num_leadtimes, num_catmt))
    fdc_fhv_arr         = np.empty((num_leadtimes, num_catmt))
    fdc_flv_arr         = np.empty((num_leadtimes, num_catmt))
    mean_peak_timing_arr= np.empty((num_leadtimes, num_catmt))
    missed_peaks_arr    = np.empty((num_leadtimes, num_catmt))
    f1_score_arr        = np.empty((num_leadtimes, num_catmt))

    # Loop over lead times and catchments, computing metrics from the time series
    for lt in range(num_leadtimes):
        for cat in range(num_catmt):
            pred = y_pred[:, lt, cat]
            true = y_true[:, lt, cat]
            rmse_arr[lt, cat]             = RMSE(pred, true)
            pearson_arr[lt, cat]          = pearsonr(pred, true)
            nse_arr[lt, cat]              = NSE(pred, true)
            kge_arr[lt, cat]              = KGE(pred, true)
            pbias_arr[lt, cat]            = PBIAS(pred, true)
            alpha_nse_arr[lt, cat]        = alpha_NSE(pred, true)
            beta_nse_arr[lt, cat]         = beta_NSE(pred, true)
            fdc_fms_arr[lt, cat]          = fdc_fms(pred, true)
            fdc_fhv_arr[lt, cat]          = fdc_fhv(pred, true)
            fdc_flv_arr[lt, cat]          = fdc_flv(pred, true)
            mean_peak_timing_arr[lt, cat] = mean_peak_timing(pred, true)
            missed_peaks_arr[lt, cat]     = missed_peaks(pred, true)
            f1_score_arr[lt, cat]         = F1_score_of_capturing_peaks(pred, true)

    # Create coordinates
    lead_times = np.arange(num_leadtimes)
    catmt_idx  = np.arange(num_catmt)
    time_idx   = np.arange(time_steps)

    ds = xr.Dataset(
        {
            "RMSE":             (("lead_time", "catmt_idx"), rmse_arr),
            "PearsonR":         (("lead_time", "catmt_idx"), pearson_arr),
            "NSE":              (("lead_time", "catmt_idx"), nse_arr),
            "KGE":              (("lead_time", "catmt_idx"), kge_arr),
            "PBIAS":            (("lead_time", "catmt_idx"), pbias_arr),
            "alpha_NSE":        (("lead_time", "catmt_idx"), alpha_nse_arr),
            "beta_NSE":         (("lead_time", "catmt_idx"), beta_nse_arr),
            "FDC_FMS":          (("lead_time", "catmt_idx"), fdc_fms_arr),
            "FDC_FHV":          (("lead_time", "catmt_idx"), fdc_fhv_arr),
            "FDC_FLV":          (("lead_time", "catmt_idx"), fdc_flv_arr),
            "mean_peak_timing": (("lead_time", "catmt_idx"), mean_peak_timing_arr),
            "missed_peaks":     (("lead_time", "catmt_idx"), missed_peaks_arr),
            "F1_score":         (("lead_time", "catmt_idx"), f1_score_arr),
            "y_pred":           (("time_idx", "lead_time", "catmt_idx"), y_pred),
            "y_true":           (("time_idx", "lead_time", "catmt_idx"), y_true)
        },
        coords={
            "time_idx": time_idx,
            "lead_time": lead_times,
            "catmt_idx": catmt_idx
        }
    )
    
    return ds

# # Example usage:
# metrics_ds = compute_metrics_ds(y_pred, y_true)
# metrics_ds

In [None]:
filepath_camelsus = lambda x: os.path.join('/home/sarth/rootdir/workdir/projects/Paper_Data_Latency/Figure04', f'US_v{x}.nc')

filepath_camelsind = lambda x: os.path.join('/home/sarth/rootdir/workdir/projects/Paper_Data_Latency/Figure04', f'IND_v{x}.nc')

filepath_hysets = lambda x: os.path.join('/home/sarth/rootdir/workdir/projects/Paper_Data_Latency/hysets', f'US_hysets_v{x}.nc')

In [None]:
V = 3
ds_camelsus = xr.open_dataset(filepath_camelsus(V))
ds_camelsind = xr.open_dataset(filepath_camelsind(V))
ds_hysets = xr.open_dataset(filepath_hysets(V))

In [None]:
# ds_hysets

In [None]:
# def plot_cdf(ds, varname, lead_times=None, clip_min_max=[-1, 1], cmap_name = 'viridis'):
#     grid = np.linspace(clip_min_max[0], clip_min_max[1], 100)
#     all_cdfs = {}
#     if lead_times is None:
#         lead_times = ds['lead_time'].values
#         lead_times += 1
    
#     if 'F1' in varname:
#         varname_label = 'F1 score of peaks captured'
#         varname_legend = 'F1'
#     else:
#         varname_label = varname
#         varname_legend = varname

#     for lt in lead_times:
#         data = ds[varname].sel(lead_time=lt-1).values.flatten()
#         data = data[~np.isnan(data)]
#         cdf_values = np.array([np.mean(data <= val) for val in grid])
#         all_cdfs[lt] = cdf_values

#     cmap = cm.get_cmap(cmap_name, len(lead_times))

#     fig, ax = plt.subplots(figsize=(8, 6))
#     ax.set_facecolor('whitesmoke')
#     for i, lt in enumerate(lead_times):
#         median_lt = np.nanmedian(ds[varname].sel(lead_time=lt-1).values.flatten())
#         plt.plot(grid, all_cdfs[lt], label=f't+{lt}: {median_lt:.2f}', color=cmap(i))

#     plt.axhline(0.5, color='black', linestyle='--', linewidth=1)

#     # Create a secondary y-axis for PDF
#     ax2 = ax.twinx()
#     ax2.set_ylabel('PDF')

#     bar_width = grid[1] - grid[0]
#     all_pdfs = {}
#     for i, lt in enumerate(lead_times):
#         pdf_values = np.gradient(all_cdfs[lt], grid)
#         all_pdfs[lt] = pdf_values
#         ax2.bar(grid, pdf_values, width=bar_width, alpha=0.1, color=cmap(i), align='center')

#     ax.set_xlabel(varname_label)
#     ax.set_ylabel('CDF')
#     ax.legend(loc = 'upper left', title=f'Lead Time (Median {varname_legend})', fontsize=10)
#     ax.grid()
#     plt.show()


In [None]:
from scipy import stats
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def plot_cdf(ds, varname, lead_times=None, clip_min_max=[-1, 1], cmap_name='viridis', high_res = False):
    grid = np.linspace(clip_min_max[0], clip_min_max[1], 100)
    all_cdfs = {}
    if lead_times is None:
        lead_times = ds['lead_time'].values
        lead_times += 1

    if 'F1' in varname:
        varname_label = 'F1 score of peaks captured'
        varname_legend = 'F1'
    else:
        varname_label = varname
        varname_legend = varname

    # Extract the baseline data (first lead time) for KS comparison
    baseline_data = ds[varname].sel(lead_time=lead_times[0]-1).values.flatten()
    baseline_data = baseline_data[~np.isnan(baseline_data)]
    
    # Compute CDFs for each lead time
    for lt in lead_times:
        data = ds[varname].sel(lead_time=lt-1).values.flatten()
        data = data[~np.isnan(data)]
        cdf_values = np.array([np.mean(data <= val) for val in grid])
        all_cdfs[lt] = cdf_values

    cmap = cm.get_cmap(cmap_name, len(lead_times))
    if high_res:
        fig, ax = plt.subplots(figsize=(16, 12), dpi = 600)
    else:
        fig, ax = plt.subplots(figsize=(8, 6))
    # ax.set_facecolor('whitesmoke')
    for i, lt in enumerate(lead_times):
        data = ds[varname].sel(lead_time=lt-1).values.flatten()
        data = data[~np.isnan(data)]
        # Compute the KS statistic comparing current and baseline data
        ks_stat, _ = stats.ks_2samp(data, baseline_data)
        median_lt = np.nanmedian(data)
        plt.plot(grid, all_cdfs[lt],
                 label=f't+{lt:02d} day: {median_lt:.2f} ({ks_stat:.2f})',
                 color=cmap(i))
    
    plt.axhline(0.5, color='black', linestyle='--', linewidth=1, alpha=0.33)

    # Create a secondary y-axis for PDF
    ax2 = ax.twinx()
    ax2.set_ylabel('PDF')

    bar_width = grid[1] - grid[0]
    for i, lt in enumerate(lead_times):
        pdf_values = np.gradient(all_cdfs[lt], grid)
        ax2.bar(grid, pdf_values, width=bar_width, alpha=0.1, color=cmap(i), align='center')


    # Compute summary statistics for each lead time using the chosen metric
    lead_time_vals = ds['lead_time'].values + 1
    medians = []
    q25 = []
    q75 = []
    for lt in lead_time_vals:
        data = ds['NSE'].sel(lead_time=lt-1).values.flatten()
        data = data[~np.isnan(data)]
        medians.append(np.nanmedian(data))
        q25.append(np.nanpercentile(data, 25))
        q75.append(np.nanpercentile(data, 75))
        
    # Create an inset axes and plot line with error bars (25th-75th percentile)
    axins = inset_axes(
        ax,
        width="40%",
        height="20%",
        bbox_to_anchor=(0.05, -0.05, 0.8, 0.75),  # left, bottom, width, height
        bbox_transform=ax.transAxes,
        loc=2
    )
    axins.set_facecolor('whitesmoke')
    yerr_lower = np.array(medians) - np.array(q25)
    yerr_upper = np.array(q75) - np.array(medians)
    axins.errorbar(lead_time_vals, medians, yerr=[yerr_lower, yerr_upper],
                   fmt='-o', color='black', capsize=3)
    # axins.set_xlabel("Lead Time", fontsize=8)
    axins.set_ylabel('NSE', fontsize=12, labelpad=0)
    # axins.set_title(varname_label, fontsize=8)
    axins.tick_params(axis='both', which='major', labelsize=8)
    # Rotate y-ticks for better readability
    axins.yaxis.set_tick_params(rotation=60)
    axins.set_xticks(lead_time_vals)
    axins.set_xticklabels([f't+{int(lt)}' if lt in [1, 3, 5, 7, 10] else "" for lt in lead_time_vals], fontsize=8)

    axins.grid(True)

    ax.set_xlabel(varname_label)
    ax.set_ylabel('CDF')
    ax.legend(loc='upper left', title=f'Lead Time: Median {varname_legend} (KS statistic)', fontsize=10)
    ax.grid(alpha=0.5)
    plt.show()

In [None]:
# plot_cdf(ds_camelsus, 'NSE', lead_times=[1,3,5,7,10], clip_min_max=[-1, 1],cmap_name='plasma', high_res = True)

In [None]:
plot_cdf(ds_camelsus, 'F1_score', lead_times=[1,3,5,7,10], clip_min_max=[0, 1],cmap_name='plasma', high_res = True)

In [None]:
# plot_cdf(ds_hysets, 'NSE', lead_times=[1,3,5,7,10], clip_min_max=[-1, 1],cmap_name='plasma', high_res = True)

In [None]:
plot_cdf(ds_hysets, 'F1_score', lead_times=[1,3,5,7,10], clip_min_max=[0, 1],cmap_name='plasma', high_res = True)

In [None]:
# plot_cdf(ds_camelsind, 'NSE', lead_times=[1,3,5,7,10], clip_min_max=[-1, 1],cmap_name='plasma', high_res = True)

In [None]:
plot_cdf(ds_camelsind, 'F1_score', lead_times=[1,3,5,7,10], clip_min_max=[0, 1],cmap_name='plasma', high_res = True)