# Setting Up

In [1]:
#%% Imports: Python Libraries
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import geopandas as gpd
import xarray as xr
import rioxarray as rxr
import random

import os
import gc
import time
import json
import tqdm
import sys
import argparse
from joblib import Parallel, delayed

import torch
from torch import nn

import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.patches import ConnectionPatch
from matplotlib.collections import PolyCollection
import seaborn as sns

PATHS = {}
PATHS['root'] = '/data/sarth/operational-hydrologic-emulators'
PATHS['datasets'] = os.path.join(PATHS['root'], 'datasets')
PATHS['assets'] = os.path.join(PATHS['root'], 'assets')

PATHS['experiments'] = os.path.join(PATHS['root'], 'experiments')
PATHS['watershed-boundary-dataset'] = '/data/sarth/rootdir/datadir/data/raw/watershed-boundary-dataset'

dataset_names = ['hysets_HUCAll_lumped', 'CAMELS-US_HUCAll_lumped']

In [2]:
events_dict = {
    'hurricane_katrina_2005': { # A
        'event_start_date': '2005-08-23',
        'event_end_date': '2005-08-31',
        'start_view_date': '2005-07-23',
        'end_view_date': '2005-09-30',
        'event_days': 9,
        'label': 'A',
        'new_label': 'D'
    },
    'hurricane_gustav_2008': { # B
        'event_start_date': '2008-08-25',
        'event_end_date': '2008-09-07',
        'start_view_date': '2008-08-01',
        'end_view_date': '2008-09-30',
        'event_days': 14,
        'label': 'B',
        'new_label': 'E'
    },
    'hurricane_isaac_2012': { # C
        'event_start_date': '2012-08-21',
        'event_end_date': '2012-09-03',
        'start_view_date': '2012-08-01',
        'end_view_date': '2012-09-30',
        'event_days': 14,
        'label': 'C',
        'new_label': 'F'
    },
    'mississippi_river_floods_2011': { # D
        'event_start_date': '2011-04-01',
        'event_end_date': '2011-05-31',
        'start_view_date': '2011-03-01',
        'end_view_date': '2011-06-30',
        'event_days': 61,
        'label': 'D',
        'new_label': 'A'
    },
    'mississippi_winter_floods_2015': { # E
        'event_start_date': '2015-12-22',
        'event_end_date': '2016-01-10',
        'start_view_date': '2015-12-01',
        'end_view_date': '2016-01-31',
        'event_days': 20,
        'label': 'E',
        'new_label': 'B'
    },
    'mississippi_spring_summer_floods_2019': { # F
        'event_start_date': '2019-03-01',
        'event_end_date': '2019-07-31',
        'start_view_date': '2019-02-01',
        'end_view_date': '2019-08-31',
        'event_days': 153,
        'label': 'F',
        'new_label': 'C'
    }
}

event_names = list(events_dict.keys())
event_names

['hurricane_katrina_2005',
 'hurricane_gustav_2008',
 'hurricane_isaac_2012',
 'mississippi_river_floods_2011',
 'mississippi_winter_floods_2015',
 'mississippi_spring_summer_floods_2019']

In [3]:
#%% Load Catmt Info
catmt_info = None

# HYSETS
attributes = pd.read_csv(os.path.join(PATHS['datasets'], dataset_names[0], 'attributes_with_nesting.csv'))
attributes['huc_02'] = attributes['huc_02'].astype(str).str.zfill(2)
attributes = attributes[['gauge_id', 'huc_02', 'uparea_snapped', 'snapped_lon', 'snapped_lat', 'nesting', 'num_nodes']]
attributes = attributes.rename(columns={'gauge_id': 'gauge_id', 'huc_02': 'huc', 'uparea_snapped': 'uparea', 'snapped_lon': 'lon', 'snapped_lat': 'lat', 'nesting': 'nesting', 'num_nodes': 'num_nodes'})
attributes['sample_idx'] = attributes.index.values
attributes['source'] = 'HYSETS'

if catmt_info is None:
    catmt_info = attributes.copy()
else:
    catmt_info = pd.concat([catmt_info, attributes], axis=0, ignore_index=True)
del attributes
gc.collect()

# # CAMELS-US
# attributes = pd.read_csv(os.path.join(PATHS['datasets'], dataset_names[1], 'attributes_with_nesting.csv'))
# attributes['huc_02'] = attributes['huc_02'].astype(str).str.zfill(2)
# attributes['gauge_id'] = attributes['gauge_id'].astype(str).str.zfill(8)
# attributes = attributes[['gauge_id', 'huc_02', 'snapped_uparea', 'snapped_lon', 'snapped_lat', 'nesting', 'num_nodes']]
# attributes = attributes.rename(columns={'gauge_id': 'gauge_id', 'huc_02': 'huc', 'snapped_uparea': 'uparea', 'snapped_lon': 'lon', 'snapped_lat': 'lat', 'nesting': 'nesting', 'num_nodes': 'num_nodes'})
# attributes['sample_idx'] = attributes.index.values
# attributes['source'] = 'CAMELS-US'

# if catmt_info is None:
#     catmt_info = attributes.copy()
# else:
#     catmt_info = pd.concat([catmt_info, attributes], axis=0, ignore_index=True)
# del attributes
# gc.collect()

catmt_info = catmt_info.sort_values(by=['huc', 'num_nodes']).reset_index(drop=True)

basin_info = catmt_info[catmt_info['huc'].isin(['05', '06', '07', '08', '10', '11'])].reset_index(drop=True)
print(f"Total basins in HUCs 05, 06, 07, 08, 10, 11: {len(basin_info)}")

del catmt_info

Total basins in HUCs 05, 06, 07, 08, 10, 11: 1961


In [4]:
huc_to_include = ['05', '06', '07', '08', '10', '11']

In [5]:
region_shp = gpd.read_file(os.path.join(PATHS['watershed-boundary-dataset'], 'huc02', 'shapefile.shp'), crs = 'epsg:4326')
all_watersheds = region_shp.copy()
all_watersheds = all_watersheds.rename(columns={'huc2': 'watershed'})
all_watersheds['huc'] = all_watersheds['watershed'].map(lambda x: x.split('_')[0])

In [6]:
upa = rxr.open_rasterio('/data/sarth/rootdir/datadir/data/raw/GloFAS/upstream_area.nc')
upa = upa.sortby(['x', 'y'], ascending=[True, True])
minx, miny, maxx, maxy = all_watersheds.total_bounds
upa = upa.sel(x=slice(minx, maxx), y=slice(miny, maxy))
upa.rio.write_crs("epsg:4326", inplace=True)

In [7]:
basin_info

Unnamed: 0,gauge_id,huc,uparea,lon,lat,nesting,num_nodes,sample_idx,source
0,hysets_03293000,05,48.581585,-85.625,38.225,not_nested,2.0,428,HYSETS
1,hysets_03271300,05,47.663597,-84.175,39.625,not_nested,2.0,429,HYSETS
2,hysets_03353637,05,47.629690,-86.125,39.675,nested_upstream,2.0,2597,HYSETS
3,hysets_03302300,05,48.532246,-85.925,38.325,nested_upstream,2.0,2598,HYSETS
4,hysets_03267700,05,47.374184,-83.775,40.025,nested_upstream,2.0,2599,HYSETS
...,...,...,...,...,...,...,...,...,...
1956,hysets_07074850,11,52846.145000,-91.375,35.275,nested_downstream,2127.0,1024,HYSETS
1957,hysets_07138062,11,65540.790000,-101.425,37.875,nested_upstream,2696.0,4040,HYSETS
1958,hysets_07138065,11,65638.360000,-101.375,37.875,nested_upstream,2700.0,4041,HYSETS
1959,hysets_07144300,11,108111.810000,-97.325,37.675,nested_upstream,4445.0,4042,HYSETS


In [8]:
basin_info_detailed = pd.read_csv('basin_info_filtered.csv')
basin_info_detailed

Unnamed: 0,gauge_id,huc,uparea,lon,lat,nesting,num_nodes,sample_idx,source,Q_p60,...,B_I_P3day,B_I_Q,C_I_P3day,C_I_Q,D_I_P3day,D_I_Q,E_I_P3day,E_I_Q,F_I_P3day,F_I_Q
0,hysets_03426800,5,100.277855,-86.075,35.825,nested_upstream,4.0,2637,HYSETS,1.135690,...,0.394296,0.007801,0.729647,0.025631,1.682081,2.952364,0.798662,0.898933,1.213032,1.441265
1,hysets_03413200,5,123.900430,-84.875,36.775,nested_upstream,5.0,2642,HYSETS,1.295482,...,1.307456,0.482748,0.364556,0.080704,1.202820,1.660834,0.940712,1.082143,1.012639,0.826109
2,hysets_03297845,5,121.314120,-85.375,38.325,nested_upstream,5.0,2643,HYSETS,1.305879,...,1.229431,0.377634,0.360778,0.086903,1.270236,1.767428,0.949035,1.100679,1.001570,0.799746
3,hysets_03364200,5,119.699360,-85.875,39.275,nested_upstream,5.0,2644,HYSETS,1.293977,...,1.326680,0.555569,0.383287,0.082698,1.126530,1.648573,0.899403,1.086682,1.044722,0.945121
4,hysets_03239000,5,118.785280,-83.725,39.825,nested_upstream,5.0,2645,HYSETS,1.324559,...,1.175449,0.285107,0.372050,0.096910,1.303012,1.858604,0.936653,1.113172,1.025514,0.865272
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1507,hysets_07182400,11,7511.928700,-95.875,38.275,nested_downstream,310.0,1018,HYSETS,0.190425,...,0.347492,3.691518,1.436320,0.110728,0.979480,0.712115,1.000189,0.160246,1.905621,3.960491
1508,hysets_07341301,11,10709.796000,-93.975,33.775,nested_downstream,419.0,1021,HYSETS,1.095752,...,1.595366,0.632766,0.241259,0.042725,1.829782,1.917225,2.813681,2.754994,2.009165,1.841624
1509,hysets_07074420,11,21948.215000,-91.275,35.775,nested_upstream,886.0,4038,HYSETS,0.071333,...,1.406051,1.162514,0.219562,0.037676,1.162638,1.532822,0.381277,0.004412,1.336960,1.724255
1510,hysets_07074850,11,52846.145000,-91.375,35.275,nested_downstream,2127.0,1024,HYSETS,0.930595,...,2.007156,1.226941,1.076819,0.194474,2.365629,2.805220,2.615668,2.947934,1.431998,1.712559


In [9]:
# True is A_I_Q and A_I_P3day are greater than 1
for label in ['A', 'B', 'C', 'D', 'E', 'F']:
    basin_info_detailed[f'affected_by_{label}'] = (basin_info_detailed[f'{label}_I_Q'] > 1) & (basin_info_detailed[f'{label}_I_P3day'] > 1)
# A, B, C -> hurricanes
# D, E, F -> floods
basin_info_detailed['affected_by_hurricanes'] = basin_info_detailed[['affected_by_A', 'affected_by_B', 'affected_by_C']].any(axis=1)
basin_info_detailed['affected_by_floods'] = basin_info_detailed[['affected_by_D', 'affected_by_E', 'affected_by_F']].any(axis=1)
basin_flood_affected = basin_info_detailed[basin_info_detailed['affected_by_floods']].reset_index(drop=True)
basin_flood_affected

Unnamed: 0,gauge_id,huc,uparea,lon,lat,nesting,num_nodes,sample_idx,source,Q_p60,...,F_I_P3day,F_I_Q,affected_by_A,affected_by_B,affected_by_C,affected_by_D,affected_by_E,affected_by_F,affected_by_hurricanes,affected_by_floods
0,hysets_03426800,5,100.277855,-86.075,35.825,nested_upstream,4.0,2637,HYSETS,1.135690,...,1.213032,1.441265,True,False,False,True,False,True,True,True
1,hysets_03413200,5,123.900430,-84.875,36.775,nested_upstream,5.0,2642,HYSETS,1.295482,...,1.012639,0.826109,False,False,False,True,False,False,False,True
2,hysets_03297845,5,121.314120,-85.375,38.325,nested_upstream,5.0,2643,HYSETS,1.305879,...,1.001570,0.799746,False,False,False,True,False,False,False,True
3,hysets_03364200,5,119.699360,-85.875,39.275,nested_upstream,5.0,2644,HYSETS,1.293977,...,1.044722,0.945121,False,False,False,True,False,False,False,True
4,hysets_03239000,5,118.785280,-83.725,39.825,nested_upstream,5.0,2645,HYSETS,1.324559,...,1.025514,0.865272,False,False,False,True,False,False,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1459,hysets_07182400,11,7511.928700,-95.875,38.275,nested_downstream,310.0,1018,HYSETS,0.190425,...,1.905621,3.960491,True,False,False,False,False,True,True,True
1460,hysets_07341301,11,10709.796000,-93.975,33.775,nested_downstream,419.0,1021,HYSETS,1.095752,...,2.009165,1.841624,False,False,False,True,True,True,False,True
1461,hysets_07074420,11,21948.215000,-91.275,35.775,nested_upstream,886.0,4038,HYSETS,0.071333,...,1.336960,1.724255,False,True,False,True,False,True,True,True
1462,hysets_07074850,11,52846.145000,-91.375,35.275,nested_downstream,2127.0,1024,HYSETS,0.930595,...,1.431998,1.712559,False,True,False,True,True,True,True,True


In [10]:
sample_flood_catchments = pd.read_csv('best_performance_floods.csv')
sample_flood_catchments

Unnamed: 0,gauge_id,huc,uparea,lon,lat,sample_idx,source,nse_hist,nse_FilteredERA5,nse_GPMFinal,nse_NoMeteo,nse_NoMeteo_ShortOutage,nse_FilteredERA5_ShortOutage,nse_GPMFinal_ShortOutage,nse_FilteredERA5_Shutdown,nse_GPMFinal_Shutdown,nse_NoMeteo_Shutdown,nse_mean
0,hysets_03435000,5,36796.223,-87.225,36.325,524,HYSETS,0.795282,0.828919,0.735448,0.545422,0.417551,0.425469,0.534743,0.113994,0.350286,0.35176,0.509887
1,hysets_03543005,6,44931.195,-84.775,35.675,560,HYSETS,0.728494,0.81446,0.807224,0.639414,0.452191,0.697365,0.631245,0.528834,0.562937,0.323582,0.618575
2,hysets_05583000,7,13159.676,-89.975,40.125,662,HYSETS,0.410998,0.737882,0.694564,0.403129,0.286209,0.460923,0.436748,0.195877,0.189343,0.20101,0.401668
3,hysets_07364100,8,28179.87,-92.125,33.075,3375,HYSETS,0.777503,0.819762,0.614326,0.103747,-0.002187,-0.063091,0.405479,-0.367185,0.243351,-0.057425,0.247428
4,hysets_06295000,10,103129.5,-106.725,46.275,3869,HYSETS,0.677805,0.831644,0.311248,0.450003,0.382935,0.809838,0.318155,0.745174,0.269499,0.324178,0.512048
5,hysets_07144300,11,108111.81,-97.325,37.675,4042,HYSETS,0.435249,0.577309,0.171422,0.176414,0.23334,0.703707,0.111464,0.665241,0.021302,0.206966,0.330241


In [11]:
sample_flood_catchments_idx = basin_info_detailed[basin_info_detailed['gauge_id'].isin(sample_flood_catchments['gauge_id'].values)].index.values
print(sample_flood_catchments_idx)

[ 385  506  867 1013 1318 1511]


# Helper Plot Functions

In [12]:
def plot_mississippi(ax):
    # HUCs not important for this figure
    all_watersheds[~all_watersheds['huc'].isin(huc_to_include)].plot(ax=ax, color='none', edgecolor='lightgray', linewidth=0.5)
    all_watersheds[~all_watersheds['huc'].isin(huc_to_include)].plot(ax=ax, color='lightgrey', edgecolor='none', linewidth=0.5, alpha=0.33)

    # Mississippi River Basin outline
    geom_convex_hull = all_watersheds[all_watersheds['huc'].isin(huc_to_include)].unary_union#.convex_hull
    gpd.GeoSeries(geom_convex_hull).plot(ax=ax, color='none', edgecolor='k', linewidth=1)

    upa_masked = upa.rio.clip([geom_convex_hull], all_watersheds.crs, drop=False, invert=False)

    # convert to km² and mask non-positive / nodata values
    da_km2 = upa_masked / 1e6
    da_km2 = da_km2.where(da_km2 > 0)

    # compute vmin/vmax ignoring NaNs
    vals = da_km2.values
    if np.isfinite(vals).any():
        vmin = float(np.nanmin(vals))
        vmax = float(np.nanmax(vals))
    else:
        vmin, vmax = 1.0, 10.0

    vmin = max(vmin, 1000)

    # set da_km2 less than vmin to NaN
    da_km2 = da_km2.where(da_km2 >= vmin)

    norm = colors.LogNorm(vmin=max(vmin, 1e-6), vmax=vmax)

    # plot with logarithmic color scale
    da_km2.plot(ax=ax, cmap='Blues', alpha=1.0, norm=norm, add_colorbar=False)

    # catchment outlets
    # ax.scatter(basin_info['lon'], basin_info['lat'], color='lightgray', s=1, marker='^')
    ax.scatter(basin_flood_affected['lon'], basin_flood_affected['lat'], color='orange', s=10, marker='^')

    # select catchments
    ax.scatter(basin_info_detailed.iloc[sample_flood_catchments_idx]['lon'], basin_info_detailed.iloc[sample_flood_catchments_idx]['lat'], color='maroon', s=25, marker='^')

    # HUCs within Mississippi River Basin
    all_watersheds[all_watersheds['huc'].isin(huc_to_include)].plot(ax=ax, color='none', edgecolor='k', linewidth=0.5, alpha=0.7)

    ax.axis('off')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_title('')

    return ax



In [13]:
def plot_timeseries_from_csv(
    row,
    timeseries_dir,
    ax_q
    ):
    gauge = str(row.get('gauge_id', f"idx_{row.name}"))
    sample_idx = int(row.get('sample_idx', row.name))
    fname = f"{gauge}_{sample_idx}.csv"
    path = os.path.join(timeseries_dir, fname)

    ts = pd.read_csv(path, index_col=0, parse_dates=True)
    ts.index = pd.to_datetime(ts.index)
    ts = ts.sort_index()
    ts = ts.loc['2005-01-01':]

    q_color = "#8c510a"
    p_color = "#01665e"
    series_colors = {
        'hist': '#d8b365',
        'NoMeteo': '#f6e8c3',
        'FilteredERA5': '#5ab4ac',
        'GPMFinal': '#c7eae5'
    }
    shade_color = "#f0ad4e"
    
    ax_p = ax_q.twinx()

    ax_p.bar(ts.index, ts['P3day'].values, color=p_color, width=1.0, alpha=0.33, label='3-day Precipitation (P3day)', align='center')
    ax_p.invert_yaxis()
    # ax_p.set_ylabel('3-day Precipitation (mm/3 days)', color=p_color)
    ax_p.tick_params(axis='y', colors=p_color)

    ax_q.plot(ts.index, ts['Q'].values, color=q_color, lw=0.9, label='Discharge (Q)')

    plot_keys = ['hist', 'FilteredERA5', 'GPMFinal', 'NoMeteo']
    for k in plot_keys:
        if k in ts.columns:
            ax_q.plot(ts.index, ts[k].clip(lower=0).values, color=series_colors.get(k, 'gray'), lw=0.5, alpha=0.9, label=k.replace('_', ' '))

    ax_q.axhline(float(row['Q_p98']), color='gray', linestyle='--', lw=1.5, label='Q p98')

    ax_p.axhline(float(row['P3day_p98']), color=p_color, linestyle='--', lw=1.5, label='P3day p98')

    # ax_q.set_xlabel('Time')
    # ax_q.set_ylabel('Discharge (mm/day)', color=q_color)
    ax_q.tick_params(axis='y', colors=q_color)
    ax_q.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.6)

    for evt_key, evt in events_dict.items():
        sv = pd.to_datetime(evt['start_view_date'])
        ev = pd.to_datetime(evt['end_view_date'])
        ax_q.axvspan(sv, ev, color=shade_color, alpha=0.33, lw=0)
        mid = sv + (ev - sv) / 2
        ymin, ymax = ax_q.get_ylim()
        # y_text = ymin - 0.01 * (ymax - ymin)  # slightly below the bottom
        y_text = ymax + 0.01 * (ymax - ymin) # slightly above the top
        ax_q.text(mid, y_text, evt.get('new_label', ''), ha='center', va='bottom', fontsize=8, color='red', clip_on=False)

    return ax_p, ax_q

In [14]:
def plot_event_metrics_violin(input_dir, metric, min_max, ax, metric_label, font_size=12, no_title=False):
    # input_dir='performance_metrics_floods'
    csv_files = sorted([f for f in os.listdir(input_dir) if f.lower().endswith('.csv')])
    model_paths = {os.path.splitext(f)[0]: os.path.join(input_dir, f) for f in csv_files}

    metrics = ['nse', 'kge', 'rmse', 'pbias', 'pearsonr', 'f1_score_peaks']
    # metric = 'nse'  # 'f1_score_peaks'
    # min_max = (-1, 1)

    collected = {}
    for model, path in model_paths.items():
        try:
            df = pd.read_csv(path)
            if metric in df.columns:
                # drop NaNs for plotting (boxplot will handle NA, but keep series shape)
                series = df[metric].replace([np.inf, -np.inf], np.nan).astype(float)
                # if all NaN then leave as NaN series (will be ignored by seaborn)
                collected[model] = series.values
            else:
                # metric not present -> create array of NaNs with length equal to df rows
                collected[model] = np.full(len(df), np.nan)
        except Exception:
            # on failure create an empty nan series
            collected[model] = np.array([np.nan])

    # convert to DataFrame (wide form) for seaborn
    df_box = pd.DataFrame({k: pd.Series(v) for k, v in collected.items()})

    if min_max != (None, None):
        df_box = df_box.clip(lower=min_max[0], upper=min_max[1])

    desired_order = ['hist', 'FilteredERA5', 'GPMFinal', 'NoMeteo',
                        'FilteredERA5_ShortOutage', 'GPMFinal_ShortOutage', 'NoMeteo_ShortOutage',
                        'FilteredERA5_Shutdown', 'GPMFinal_Shutdown', 'NoMeteo_Shutdown']
    existing_models = [m for m in desired_order if m in df_box.columns]
    df_box = df_box[existing_models]

    series_colors = {
        'hist': '#d8b365',
        'NoMeteo': '#f6e8c3',
        'FilteredERA5': '#5ab4ac',
        'GPMFinal': '#c7eae5'
    }
    ordered_colors = [series_colors.get(m.split('_')[0], 'gray') for m in df_box.columns]

    sns.violinplot(
        data=df_box,
        inner='quartile',
        orient='v',
        cut=0,
        bw=0.3,
        linewidth=0.5,
        ax=ax,
        palette=ordered_colors,
        order=df_box.columns
    )

    # For a vertical violin, the left half is clipped.
    for item in ax.collections:
        x0, y0, width, height = item.get_paths()[0].get_extents().bounds
        if metric == 'nse':
            y0 = -1
            height = 2
        item.set_clip_path(plt.Rectangle((x0, y0), width/2, height, transform=ax.transData))

    num_items = len(ax.collections)
    sns.stripplot(
        data=df_box,
        color='k',
        size=3,
        ax=ax,
        alpha=0.05,
        jitter=True,
        order=df_box.columns
    )
    for item in ax.collections[num_items:]:
        item.set_offsets(item.get_offsets() + np.array([0.15, 0]))

    medians = df_box.median()
    xs = np.arange(len(medians))
    ax.scatter(xs, medians.values, facecolor='white', edgecolor='k', zorder=10, s=25)

    line_positions = []
    for model in ['hist', 'NoMeteo', 'NoMeteo_ShortOutage']:
        if model in df_box.columns:
            pos = df_box.columns.get_loc(model) + 0.5
            line_positions.append(pos)
    for pos in line_positions:
        ax.axvline(pos, color='gray', linestyle='--', lw=1)

    new_labels = ['Historical', 
                  'Filtered-ERA5', 'GPM-Final', 'No-Meteorological',
                  'Filtered-ERA5', 'GPM-Final', 'No-Meteorological',
                  'Filtered-ERA5', 'GPM-Final', 'No-Meteorological']
    ax.set_xticklabels(new_labels, rotation=15, ha='center', fontsize=font_size)

    ax.set_ylim(min_max[0], min_max[1])

    ax.set_ylabel(metric_label, fontsize=font_size)

    # Add text at top of the plot to show operational, ShortOutage, Shutdown
    if not no_title:
        ymax = min_max[1]
        y_text = ymax + 0.025 * (ymax - min_max[0]) # slightly above the top
        x_positions = []
        x_positions.append((1 + 3) / 2)  # Filtered-ERA5 to No-Meteorological
        x_positions.append((4 + 6) / 2)  # Filtered-ERA5_ShortOutage to No-Meteorological_ShortOutage
        x_positions.append((7 + 9) / 2)  # Filtered-ERA5_Shutdown to No-Meteorological_Shutdown
        labels = ['Operational', 'Short Outage', 'Shutdown']
        for x_pos, label in zip(x_positions, labels):
            ax.text(x_pos, y_text, label, ha='center', va='bottom', fontsize=font_size, color='black', clip_on=False)

    ax.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.6)

    return ax
# Example usage
# fig, ax = plt.subplots(figsize=(20, 4))
# ax = plot_event_metrics_violin(
#     input_dir='performance_metrics_floods',
#     metric='f1_score_peaks',
#     min_max=(0, 1),
#     ax=ax,
#     metric_label='F1 score of peaks captured'
# )

# Final

In [15]:
sample_flood_catchments

Unnamed: 0,gauge_id,huc,uparea,lon,lat,sample_idx,source,nse_hist,nse_FilteredERA5,nse_GPMFinal,nse_NoMeteo,nse_NoMeteo_ShortOutage,nse_FilteredERA5_ShortOutage,nse_GPMFinal_ShortOutage,nse_FilteredERA5_Shutdown,nse_GPMFinal_Shutdown,nse_NoMeteo_Shutdown,nse_mean
0,hysets_03435000,5,36796.223,-87.225,36.325,524,HYSETS,0.795282,0.828919,0.735448,0.545422,0.417551,0.425469,0.534743,0.113994,0.350286,0.35176,0.509887
1,hysets_03543005,6,44931.195,-84.775,35.675,560,HYSETS,0.728494,0.81446,0.807224,0.639414,0.452191,0.697365,0.631245,0.528834,0.562937,0.323582,0.618575
2,hysets_05583000,7,13159.676,-89.975,40.125,662,HYSETS,0.410998,0.737882,0.694564,0.403129,0.286209,0.460923,0.436748,0.195877,0.189343,0.20101,0.401668
3,hysets_07364100,8,28179.87,-92.125,33.075,3375,HYSETS,0.777503,0.819762,0.614326,0.103747,-0.002187,-0.063091,0.405479,-0.367185,0.243351,-0.057425,0.247428
4,hysets_06295000,10,103129.5,-106.725,46.275,3869,HYSETS,0.677805,0.831644,0.311248,0.450003,0.382935,0.809838,0.318155,0.745174,0.269499,0.324178,0.512048
5,hysets_07144300,11,108111.81,-97.325,37.675,4042,HYSETS,0.435249,0.577309,0.171422,0.176414,0.23334,0.703707,0.111464,0.665241,0.021302,0.206966,0.330241


In [15]:
fig = plt.figure(constrained_layout=True, figsize=(21, 7), dpi=1500)
font_size = 12
gs = fig.add_gridspec(nrows=3, ncols=7, wspace=0.0, hspace=0.01)

ax_map = fig.add_subplot(gs[0:3, 2:5])

ax_left1 = fig.add_subplot(gs[0, 0:2])
ax_left2 = fig.add_subplot(gs[1, 0:2])
ax_left3 = fig.add_subplot(gs[2, 0:2])

ax_right1 = fig.add_subplot(gs[0, 5:7])
ax_right2 = fig.add_subplot(gs[1, 5:7])
ax_right3 = fig.add_subplot(gs[2, 5:7])

# ax_legend = fig.add_subplot(gs[3, :])

# ax_bottom1 = fig.add_subplot(gs[4:7, :])
# ax_bottom2 = fig.add_subplot(gs[7:10, :])

for ax, name in zip(
    [ax_map, ax_left1, ax_left2, ax_left3, ax_right1, ax_right2, ax_right3],
    ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(g)']
    ):
    ax.set_title(name, loc='left', fontsize=font_size, fontweight='bold')
    # ax.set_xticks([])
    # ax.set_yticks([])

ax_map = plot_mississippi(ax_map)
# 7, 5, 6
ax_p1, ax_q1 = plot_timeseries_from_csv(basin_info_detailed.iloc[sample_flood_catchments_idx[2]], 'timeseries', ax_left1) # HUC07 Upper Mississippi (2)
ax_p2, ax_q2 = plot_timeseries_from_csv(basin_info_detailed.iloc[sample_flood_catchments_idx[0]], 'timeseries', ax_left2) # HUC05 Ohio (0)
ax_p3, ax_q3 = plot_timeseries_from_csv(basin_info_detailed.iloc[sample_flood_catchments_idx[1]], 'timeseries', ax_left3) # HUC06 Tennessee (1)
# 10, 11, 8
ax_p4, ax_q4 = plot_timeseries_from_csv(basin_info_detailed.iloc[sample_flood_catchments_idx[4]], 'timeseries', ax_right1) # HUC10 Missouri (4)
ax_p5, ax_q5 = plot_timeseries_from_csv(basin_info_detailed.iloc[sample_flood_catchments_idx[5]], 'timeseries', ax_right2) # HUC11 Arkansas-White-Red (5)
ax_p6, ax_q6 = plot_timeseries_from_csv(basin_info_detailed.iloc[sample_flood_catchments_idx[3]], 'timeseries', ax_right3) # HUC08 Lower Mississippi (3)

ax_q2.set_ylabel(r'Discharge ($\frac{mm}{day}$) $\longrightarrow$', color="#8c510a")
ax_p5.set_ylabel(r'$\longleftarrow$ 3-day Precipitation ($\frac{mm}{3-days}$)', color="#01665e")

# Sharex for left plots
ax_q2.sharex(ax_q1)
ax_q3.sharex(ax_q1)
# Sharex for right plots
ax_q5.sharex(ax_q4)
ax_q6.sharex(ax_q4)

for ax in [ax_q1, ax_q2, ax_q3, ax_q4, ax_q5, ax_q6, ax_p1, ax_p2, ax_p3, ax_p4, ax_p5, ax_p6]:
    plt.sca(ax)
    plt.yticks(fontsize=font_size)

# fig.autofmt_xdate(): Do this manually for better control
for ax in [ax_q3, ax_q6]:
    plt.sca(ax)
    plt.xticks(rotation=15, ha='right', fontsize=font_size)
    xticklabels = ax.get_xticklabels()
    for i, label in enumerate(xticklabels):
        if i % 2 != 0:
            label.set_visible(False)
for ax in [ax_q1, ax_q2, ax_q4, ax_q5]:
    plt.sca(ax)
    # plt.xticks(rotation=15, ha='right')
    xticklabels = ax.get_xticklabels()
    for i, label in enumerate(xticklabels):
        label.set_visible(False)

arrow_map = [ # 7,5,6. 2,0,1
    (sample_flood_catchments_idx[2], ax_right1),
    (sample_flood_catchments_idx[0], ax_right2),
    (sample_flood_catchments_idx[1], ax_right3),
]
arrow_labels = [
    'HUC07: Upper Mississippi',
    'HUC05: Ohio',
    'HUC06: Tennessee',
]
# for idx, (target_ax, label) in enumerate(zip(arrow_map, arrow_labels)):
for (idx, target_ax), label in zip(arrow_map, arrow_labels):
    try:
        row = basin_info_detailed.iloc[idx]
        lon, lat = float(row['lon']), float(row['lat'])
    except Exception:
        continue

    con = ConnectionPatch(
        xyA=(lon+0.05, lat), coordsA='data', axesA=ax_map,
        xyB=(-0.05, 0.5), coordsB='axes fraction', axesB=target_ax,
        arrowstyle='->', #shrinkA=5, shrinkB=10,
        mutation_scale=12, lw=1.0, color='maroon', alpha=0.8
    )

    fig.add_artist(con)

    target_ax.text(0.5, 1.15, label, transform=target_ax.transAxes, fontsize=font_size, ha='center', va='center', color='maroon')

arrow_map = [ # 10,11,8. 4,5,3
    (sample_flood_catchments_idx[4], ax_left1),
    (sample_flood_catchments_idx[5], ax_left2),
    (sample_flood_catchments_idx[3], ax_left3),
]
arrow_labels = [
    'HUC10: Missouri',
    'HUC11: Arkansas-White-Red',
    'HUC08: Lower Mississippi',
]
for (idx, target_ax), label in zip(arrow_map, arrow_labels):
    try:
        row = basin_info_detailed.iloc[idx]
        lon, lat = float(row['lon']), float(row['lat'])
    except Exception:
        continue

    con = ConnectionPatch(
        xyA=(lon-0.05, lat), coordsA='data', axesA=ax_map,
        xyB=(1.07, 0.5), coordsB='axes fraction', axesB=target_ax,
        arrowstyle='->', #shrinkA=5, shrinkB=10,
        mutation_scale=12, lw=1.0, color='maroon', alpha=0.8
    )

    fig.add_artist(con)

    target_ax.text(0.5, 1.15, label, transform=target_ax.transAxes, fontsize=font_size, ha='center', va='center', color='maroon')

# Reduce spacing between subplots
plt.subplots_adjust(wspace=-10, hspace=0.1)

# plt.show()
plt.savefig('F6_top.png')
plt.close()

In [215]:
fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(21, 7), sharex=True, dpi=1500)
font_size = 12
ax_bottom1 = axs[0]
ax_bottom2 = axs[1]

ax_bottom1 = plot_event_metrics_violin(
    input_dir='performance_metrics_floods',
    metric='nse',
    min_max=(-1, 1),
    ax=ax_bottom1,
    metric_label='NSE',
    font_size=font_size
)
ax_bottom1.set_title('(h)', loc='left', fontsize=font_size, fontweight='bold')

ax_bottom2 = plot_event_metrics_violin(
    input_dir='performance_metrics_floods',
    metric='f1_score_peaks',
    min_max=(0, 1),
    ax=ax_bottom2,
    metric_label='F1 score of peaks captured',
    font_size=font_size,
    no_title=True
)
ax_bottom2.set_title('(i)', loc='left', fontsize=font_size, fontweight='bold')
plt.tight_layout()

# plt.show()
plt.savefig('F6_middle.png')
plt.close()

In [18]:
# Legend handles in a separate figure
fig, ax_legend = plt.subplots(figsize=(21, 5), dpi=1500)
font_size = 12

# First (Left-most): Discharge (Q) (Line), 3-day Precipitation (P_{3-day}) (Bar), 98 Percentile Q (Dashed Line), 98 Percentile P_{3day} (Dashed Line), Event Periods (Shaded), Affected Catchments (Marker), Example Catchments (Marker with different color).
legend_handles = [
    plt.Line2D([0], [0], color='#8c510a', lw=2, label='Discharge (Q)'),
    plt.Line2D([0], [0], color='#01665e', lw=2, label=r'3-day Precipitation ($P_{3 days}$)'),
    plt.Line2D([0], [0], color='#8c510a', lw=2, linestyle='--', label=r'98 Percentile Q'),
    plt.Line2D([0], [0], color='#01665e', lw=2, linestyle='--', label=r'98 Percentile $P_{3 days}$'),
    plt.Rectangle((0, 0), 1, 1, color='#f0ad4e', alpha=0.33, label='Event Periods'),
    plt.Line2D([0], [0], marker='^', color='w', markerfacecolor='orange', markersize=10, label='Affected Catchments'),
    plt.Line2D([0], [0], marker='^', color='w', markerfacecolor='maroon', markersize=10, label='Example Catchments')
]
l1 = ax_legend.legend(
    handles=legend_handles,
    loc='lower left',
    bbox_to_anchor=(0.1, 0.0),
    fontsize=font_size,
    frameon=False,
    # ncol=len(legend_handles)
)

# Second: Historical (Line and Patch), Filtered-ERA5 (Line and Patch), GPM-Final (Line and Patch), No-Meteorological (Line and Patch). Text at center; Line and Patch on either side (can be done as no text on the Patch column).
legend_handles_2 = [
    plt.Line2D([0], [0], color='#d8b365', lw=2, label='Historical'),
    plt.Line2D([0], [0], color='#5ab4ac', lw=2, label='Filtered-ERA5'),
    plt.Line2D([0], [0], color='#c7eae5', lw=2, label='GPM-Final'),
    plt.Line2D([0], [0], color='#f6e8c3', lw=2, label='No-Meteorological')
]
l2 = ax_legend.legend(
    handles=legend_handles_2,
    loc='lower right',
    bbox_to_anchor=(0.5, 0.0),
    fontsize=font_size,
    frameon=False,
    # ncol=len(legend_handles_2)
)

ax_legend.add_artist(l1)

legend_handles_3 = [
    plt.Rectangle((0, 0), 1, 1, color='#d8b365', alpha=1, label=''),
    plt.Rectangle((0, 0), 1, 1, color='#5ab4ac', alpha=1, label=''),
    plt.Rectangle((0, 0), 1, 1, color='#c7eae5', alpha=1, label=''),
    plt.Rectangle((0, 0), 1, 1, color='#f6e8c3', alpha=1, label='')
]
l3 = ax_legend.legend(
    handles=legend_handles_3,
    # labels=['Operational', 'Short Outage', 'Shutdown', ''],
    loc='lower left',
    bbox_to_anchor=(0.5, 0.0),
    fontsize=font_size,
    frameon=False,
    # ncol=len(legend_handles_3)
)

ax_legend.add_artist(l2)

# Nomenclature Texts: Event A-F names.
event_texts = [
    'A: Mississippi River Floods (2011)',
    'B: Mississippi Winter Floods (2015-2016)',
    'C: Mississippi Spring-Summer Floods (2019)',
    'D: Hurricane Katrina (2005)',
    'E: Hurricane Gustav (2008)',
    'F: Hurricane Isaac (2012)'
]
x_start = 0.66
y_start = 0.05
y_gap = 0.075

for i, text in enumerate(event_texts[::-1]):
    ax_legend.text(x_start, y_start + i * y_gap, text, transform=ax_legend.transAxes, fontsize=font_size, ha='left', va='center')
ax_legend.text(x_start, y_start + len(event_texts)*y_gap, 'Events:', transform=ax_legend.transAxes, fontsize=font_size, ha='left', va='center', fontweight='bold')

plt.axis('off')
# plt.show()

fig.savefig('F6_legend.png')
plt.close()

# Devp

In [None]:
# metrics = ['nse', 'kge', 'rmse', 'pbias', 'pearsonr', 'f1_score_peaks']
# metrics = ['kge', 'rmse', 'pbias', 'pearsonr']

In [27]:
fig, axs = plt.subplots(nrows=4, ncols=1, figsize=(15, 15), sharex=True, dpi=1500)
font_size = 12
ax_bottom1 = axs[0]
ax_bottom2 = axs[1]
ax_bottom3 = axs[2]
ax_bottom4 = axs[3]

ax_bottom1 = plot_event_metrics_violin(
    input_dir='performance_metrics_floods',
    metric='kge',
    min_max=(-1, 1),
    ax=ax_bottom1,
    metric_label='KGE',
    font_size=font_size
)
ax_bottom1.set_title('(a)', loc='left', fontsize=font_size, fontweight='bold')

ax_bottom2 = plot_event_metrics_violin(
    input_dir='performance_metrics_floods',
    metric='rmse',
    min_max=(None, None),#(0, 1),
    ax=ax_bottom2,
    metric_label='RMSE',
    font_size=font_size,
    no_title=True
)
ax_bottom2.set_title('(b)', loc='left', fontsize=font_size, fontweight='bold')

ax_bottom3 = plot_event_metrics_violin(
    input_dir='performance_metrics_floods',
    metric='pbias',
    min_max=(-100, 100),
    ax=ax_bottom3,
    metric_label='PBIAS',
    font_size=font_size,
    no_title=True
)
ax_bottom3.set_title('(c)', loc='left', fontsize=font_size, fontweight='bold')

ax_bottom4 = plot_event_metrics_violin(
    input_dir='performance_metrics_floods',
    metric='pearsonr',
    min_max=(None, None),#(-1, 1),
    ax=ax_bottom4,
    metric_label='r',
    font_size=font_size,
    no_title=True
)
ax_bottom4.set_title('(d)', loc='left', fontsize=font_size, fontweight='bold')

plt.tight_layout()

# plt.show()
plt.savefig('F6_other_metrics.png')
plt.close()