# NB2: Analysis of streamlfow simulations (non-bias corrected) from SUMMA+mizuRoute forced by ICAR downscaled meteorological data

1. Evaluation of Historical simulation compared to retrospective simulation and naturalized flow data
2. Future changes in various flow metrics and time series plots 

!!!! WARNING !!!!

There are cells that produce per-site plots (hundreds of sites) at the bottom of this notebook.

In [None]:
%matplotlib inline

import os,sys
import glob
import xarray as xr
import geopandas as gpd
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import scale as mscale
import matplotlib as mpl
from matplotlib import cm
import cartopy.crs as ccrs
from scipy import stats

import scripts.metrics as metrics
from scripts.utility import PPFScale
from scripts.utility import AutoVivification
from scripts.utility import base_map
import scripts.colors as ccmap
mscale.register_scale(PPFScale)

print("\nThe Python version: %s.%s.%s" % sys.version_info[:3])
print(xr.__name__, xr.__version__)

In [None]:
def reorder_dataset(ds, siteID):
    # extract the reachID orders
    x = ds['site'].values
    # Find the indices of the reordered array
    # From: https://stackoverflow.com/questions/8251541/numpy-for-every-element-in-one-array-find-the-index-in-another-array
    index = np.argsort(x)
    sorted_x = x[index]
    sorted_index = np.searchsorted(sorted_x, siteID)

    remap_order = np.take(index, sorted_index, mode="clip")

    # Reorder pio according to the orginal
    ds_reorder = ds.isel(dict(site=remap_order))
    return ds_reorder

## 1. Setup

In [None]:
# directories
main_path  = '/glade/campaign/ral/hap/mizukami/archive/pnw_hydrology/final_archive_v1' # !!! This is top directory of the dataset.
geo_path   = os.path.join(main_path, 'ancillary_data','geospatial_data')
nrni_path  = os.path.join(main_path, 'ancillary_data')
figure_path = 'NB2_figures'
os.makedirs(figure_path, exist_ok=True)
os.makedirs(os.path.join(figure_path, 'per_site'), exist_ok=True)

In [None]:
# Mete dictironaries
gcm_runs = {
            'CanESM5':             {'scen':['hist', 'ssp245', 'ssp370', 'ssp585'], 'cmip':6},
            'CMCC-CM2-SR5':        {'scen':['hist', 'ssp245', 'ssp370', 'ssp585'], 'cmip':6},    
            'NorESM2-MM':          {'scen':['hist', 'ssp245', 'ssp370', 'ssp585'], 'cmip':6},
            'MIROC-ES2L':          {'scen':['hist', 'ssp245', 'ssp370', 'ssp585'], 'cmip':6},
            'MPI-M.MPI-ESM1-2-LR': {'scen':['hist', 'ssp245', 'ssp370', 'ssp585'], 'cmip':6},
            'CanESM2':             {'scen':['hist', 'rcp45',  'rcp85'], 'cmip':5},
            'CCSM4':               {'scen':['hist', 'rcp85'], 'cmip':5},
            'CMCC-CM':             {'scen':['hist', 'rcp45',  'rcp85'], 'cmip':5},
            'CNRM-CM5':            {'scen':['hist', 'rcp45',  'rcp85'], 'cmip':5},
            'MIROC5':              {'scen':['hist', 'rcp45',  'rcp85'], 'cmip':5},
            'MRI-CGCM3':           {'scen':['hist', 'rcp45',  'rcp85'], 'cmip':5},
           }

retro_runs = {
            'GMET':{'period':'control'}
            }

scens = {
         'hist':   {'time':slice('1950-01-01', '2004-12-31'), 'period':['control']},
         'ssp245': {'time':slice('2005-01-01', '2099-12-31'), 'period':['2040s','2080s']},
         'ssp370': {'time':slice('2005-01-01', '2099-12-31'), 'period':['2040s','2080s']},
         'ssp585': {'time':slice('2005-01-01', '2099-12-31'), 'period':['2040s','2080s']},
         'rcp45':  {'time':slice('2005-01-01', '2099-12-31'), 'period':['2040s','2080s']},
         'rcp85':  {'time':slice('2005-01-01', '2099-12-31'), 'period':['2040s','2080s']},
        }

periods = {
         'control':  {'name':'WY1980-2004', 'time':slice('1980-10-01', '2004-09-30'), 'lc':'xkcd:blue'},
         '2040s':    {'name':'WY2030-2060', 'time':slice('2029-10-01', '2060-09-30'), 'lc':'xkcd:orange'},
         '2080s':    {'name':'WY2070-2099', 'time':slice('2069-10-01', '2099-09-30'), 'lc':'xkcd:magenta'},
          }

ensembles = {
             'cmip-hist':     {'cmip':[5,6], 'scen':['hist']},
             'cmip6-hist':    {'cmip':[6],   'scen':['hist']},
             'cmip5-hist':    {'cmip':[5],   'scen':['hist']},
             'cmip6-ssp370':  {'cmip':[6],   'scen':['hist', 'ssp370']},
             'cmip6-ssp245':  {'cmip':[6],   'scen':['hist', 'ssp245']},
             'cmip6-ssp585':  {'cmip':[6],   'scen':['hist', 'ssp585']},
             'cmip5-rcp85':   {'cmip':[5],   'scen':['hist', 'rcp85']},
             'cmip6':         {'cmip':[6],   'scen':['hist', 'ssp245', 'ssp370', 'ssp585']},
             'cmip5':         {'cmip':[5],   'scen':['hist', 'rcp45','rcp85']},
             'low_high-emission': {'cmip':[5,6], 'scen':['ssp245','rcp45','ssp585','rcp85']},
             'high-emission': {'cmip':[5,6], 'scen':['ssp585','rcp85']},
            }

sims   =  {**retro_runs, **gcm_runs}
gcm_names   = list(gcm_runs.keys())
retro_names = list(retro_runs.keys())
sim_names   = list(sims.keys())

## 2.Load data 

### 2.1 geopackage data

In [None]:
df_site  = gpd.read_file(os.path.join(geo_path, 'PNW_flow_site.gpkg'))
df_reach = gpd.read_file(os.path.join(geo_path, 'rivEndoMERITpfaf_PNW.gpkg'))
df_huc12 = gpd.read_file(os.path.join(geo_path, 'HUC12_MERIT_PNW.gpkg'))
df_huc12['geometry'] = df_huc12.geometry.simplify(0.05) # simplified

### 2.2 Link between river network reach ID and site name
reach id co-located with flow site has less probably because some reach has more than one sites 

In [None]:
df_merit_id = pd.read_csv(os.path.join(geo_path, 'PNW_flow_site.csv'))
df_merit_id.head()

### 2.2 Naturalized flow data

In [None]:
ds_nrni = xr.open_dataset(os.path.join(nrni_path,'PNW_unimpaired_flow_1951-2018_latlon.nc')).sel(time=scens['hist']['time'])
nrni_site = ds_nrni.site.values
print('Number of nrni sites: %d'%len(nrni_site))

### 2.3 Read mizuRoute outputs

Read mizuRoute output into xarray dataset and put it dictionary ds_route[gcm_case][scen]

In [None]:
%%time
# get GCM sim
ds_route = AutoVivification()
for gcm_name, meta in gcm_runs.items():
    for scen in meta['scen']:
        analysis_period = scens[scen]['time']

        if scen=='hist' and meta['cmip']==5: # for cmip5 historical period, use rcp85 data
            case = f'{gcm_name}_rcp85'
        elif scen=='hist' and meta['cmip']==6: # for cmip6 historical period, use ssp585 data
            case = f'{gcm_name}_ssp585'
        else: # for future period
            case = f'{gcm_name}_{scen}'
        
        nclist=glob.glob(os.path.join(main_path, case, f'{case}_mizuRoute_daily_site.nc'))
        ds_tmp = xr.open_mfdataset(nclist, data_vars='minimal').sel(time=analysis_period)
        ds_route[gcm_name][scen] = ds_tmp.load()
        ds_route[gcm_name][scen] = ds_route[gcm_name][scen].assign_coords(seg=ds_route[gcm_name][scen]['site'])
        ds_route[gcm_name][scen] = ds_route[gcm_name][scen].drop_vars('site')
        ds_route[gcm_name][scen] = ds_route[gcm_name][scen].rename({'seg':'site'})
        ds_route[gcm_name][scen] = reorder_dataset(ds_route[gcm_name][scen], df_merit_id['location_name'].values)
        print(f'{case}')

flow_site = ds_route[gcm_name][scen].site.values
route_reachID = ds_route[gcm_name][scen].reachID.values
print('Number of routing sites: %d'%len(route_reachID))

In [None]:
%%time
# get retro sim and add it to ds_route dictionary
for retro_name, meta in retro_runs.items():
    scen='hist'
    analysis_period = scens[scen]['time']
    nclist=glob.glob(os.path.join(main_path, f'{retro_name}_hist','mizuRoute_daily_site.nc'))
    ds_tmp = xr.open_mfdataset(nclist, data_vars='minimal').sel(time=analysis_period)
    ds_route[retro_name][scen] = ds_tmp.load()
    ds_route[retro_name][scen] = ds_route[retro_name][scen].assign_coords(seg=ds_route[retro_name][scen]['site'])
    ds_route[retro_name][scen] = ds_route[retro_name][scen].drop_vars('site')
    ds_route[retro_name][scen] = ds_route[retro_name][scen].rename({'seg':'site'})
    ds_route[retro_name][scen] = reorder_dataset(ds_route[retro_name][scen], df_merit_id['location_name'].values)

## 3. Reformat xarray dataset to panda dataframe
Find site where naturalized flow exist (summa site has all the sites in the meta data)

In [None]:
# Find common site between summa site (include all the sites) and nrni_site (site from naturalized flow data)
common_site = np.asarray(list(set(flow_site).intersection(nrni_site)))
num_site = len(common_site)
common_site

### select common sites only 

In [None]:
# flow site meta dataframe
df_site_selected = df_site.loc[df_site['location_name'].isin(common_site)]

In [None]:
# select common sites from NRNI data
ds_nrni_selected = ds_nrni.where(ds_nrni.site.isin(common_site), drop=True)

In [None]:
# GCM simulations
ds_qsim_selected = AutoVivification()
for gcm_name, meta in gcm_runs.items():
    for scen in meta['scen']:
        ds_qsim_selected[gcm_name][scen] = ds_route[gcm_name][scen].where(ds_route[gcm_name][scen].site.isin(common_site), drop=True)
        # match order of site with order of NRNI data
        ds_qsim_selected[gcm_name][scen] = reorder_dataset(ds_qsim_selected[gcm_name][scen], ds_nrni_selected.site.values)

In [None]:
# retro simulations
for sim_name, meta in retro_runs.items():
    scen = 'hist'
    ds_qsim_selected[retro_name][scen] = ds_route[retro_name][scen].where(ds_route[retro_name][scen].site.isin(common_site), drop=True)
    # match order of site with order of NRNI data
    ds_qsim_selected[retro_name][scen] = reorder_dataset(ds_qsim_selected[retro_name][scen], ds_nrni_selected.site.values)

-----
## 4. Compute flow metrics

- mean annual cycle at daily step (ds_seasona_flow)
- annual maximum flow and day of year per wyr
- annual minimum flow and day of year per wyr
- annual centroid - 
- high flow at 90%
- low flow at 10%

In [None]:
%%time
# compute long term annual cycle at daily step
ds_seasonal_flow = AutoVivification()
for gcm in sim_names:
    for scen in ds_qsim_selected[gcm].keys():
        for period in scens[scen]['period']:
            ds_seasonal_flow[gcm][scen][period] = ds_qsim_selected[gcm][scen]['streamflow'].sel(time=periods[period]['time']).groupby('time.dayofyear').mean()

ds_seasonal_flow['obs'] = ds_nrni_selected['streamflow'].sel(time=periods['control']['time']).groupby('time.dayofyear').mean()

In [None]:
%%time
ds_flow_metrics = AutoVivification()

ds_flow_metrics['annual_max']['obs']      = metrics.annual_max(ds_nrni_selected['streamflow'].sel(time=periods['control']['time']).rolling(time=7, center=True).mean())
ds_flow_metrics['annual_min']['obs']      = metrics.annual_min(ds_nrni_selected['streamflow'].sel(time=periods['control']['time']).rolling(time=7, center=True).mean())
ds_flow_metrics['ctr']['obs']             = metrics.annual_centroid(ds_nrni_selected['streamflow'].sel(time=periods['control']['time']))
ds_flow_metrics['BFI']['obs']             = metrics.BFI(ds_nrni_selected['streamflow'].sel(time=periods['control']['time']))
ds_flow_metrics['FMS']['obs']             = metrics.FMS(ds_nrni_selected['streamflow'].sel(time=periods['control']['time']))
ds_flow_metrics['FHV']['obs']             = metrics.FHV(ds_nrni_selected['streamflow'].sel(time=periods['control']['time']),percent=0.98)
ds_flow_metrics['FLV']['obs']             = metrics.FLV(ds_nrni_selected['streamflow'].sel(time=periods['control']['time']))
ds_flow_metrics['high_q_freq_dur']['obs'] = metrics.high_q_freq_dur(ds_nrni_selected['streamflow'].sel(time=periods['control']['time']))
ds_flow_metrics['low_q_freq_dur']['obs']  = metrics.low_q_freq_dur(ds_nrni_selected['streamflow'].sel(time=periods['control']['time']))
ds_flow_metrics['season_mean']['obs']     = metrics.season_mean(ds_nrni_selected['streamflow'].sel(time=periods['control']['time']))
ds_flow_metrics['annual_mean']['obs']     = ds_nrni_selected['streamflow'].sel(time=periods['control']['time']).resample(time='YS').mean('time').to_dataset(name='annual_mean')
ds_flow_metrics['annual_mean']['obs']     = ds_flow_metrics['annual_mean']['obs'].assign_coords(time=ds_flow_metrics['annual_mean']['obs']['time'].dt.year).rename({'time': 'year'})

for gcm in sim_names:
    for scen in ds_qsim_selected[gcm].keys():
        for period in scens[scen]['period']:
            ds1 = ds_qsim_selected[gcm][scen]['streamflow'].sel(time=periods[period]['time'])
            
            ds_flow_metrics['annual_max'][gcm][scen][period]      = metrics.annual_max(ds1.rolling(time=7, center=True).mean())
            ds_flow_metrics['annual_min'][gcm][scen][period]      = metrics.annual_min(ds1.rolling(time=7, center=True).mean())
            ds_flow_metrics['ctr'][gcm][scen][period]             = metrics.annual_centroid(ds1)
            ds_flow_metrics['BFI'][gcm][scen][period]             = metrics.BFI(ds1)
            ds_flow_metrics['FMS'][gcm][scen][period]             = metrics.FMS(ds1)
            ds_flow_metrics['FHV'][gcm][scen][period]             = metrics.FHV(ds1,percent=0.98)
            ds_flow_metrics['FLV'][gcm][scen][period]             = metrics.FLV(ds1)
            ds_flow_metrics['high_q_freq_dur'][gcm][scen][period] = metrics.high_q_freq_dur(ds1)
            ds_flow_metrics['low_q_freq_dur'][gcm][scen][period]  = metrics.low_q_freq_dur(ds1)
            ds_flow_metrics['season_mean'][gcm][scen][period]     = metrics.season_mean(ds1)
            ds_flow_metrics['annual_mean'][gcm][scen][period]     = ds1.resample(time='YS').mean('time').to_dataset(name='annual_mean')
            ds_flow_metrics['annual_mean'][gcm][scen][period]     = ds_flow_metrics['annual_mean'][gcm][scen][period].assign_coords(time=ds_flow_metrics['annual_mean'][gcm][scen][period]['time'].dt.year).rename({'time': 'year'})
        print(f'{gcm}-{scen}')

In [None]:
%%time
# annual flood frequency for two longer periods
for gcm in sim_names:
    for scen in ds_qsim_selected[gcm].keys():
        if scen=='hist':
            ds1 = ds_qsim_selected[gcm][scen]['streamflow'].sel(time=slice('1954-10-01', '2004-09-30'))
        else:
            ds1 = ds_qsim_selected[gcm][scen]['streamflow'].sel(time=slice('2049-10-01', '2099-09-30'))
        ds_flow_metrics['aff'][gcm][scen] = metrics.lp3_flood(ds1)

---
## 5. Error in metrics during control period

- bais at each site
- seasonal bias at each site

In [None]:
%%time

error_metric = {
                'annual_centroid': np.zeros((num_site, len(sim_names))),
                'annual_min_day' : np.zeros((num_site, len(sim_names))),
                'annual_max_day' : np.zeros((num_site, len(sim_names))),
                'annual_min_flow': np.zeros((num_site, len(sim_names))),
                'annual_max_flow': np.zeros((num_site, len(sim_names))),
                'pbiasFHV'       : np.zeros((num_site, len(sim_names))),
                'pbiasFLV'       : np.zeros((num_site, len(sim_names))),
                'mean_high_q_dur': np.zeros((num_site, len(sim_names))),
                'freq_high_q'    : np.zeros((num_site, len(sim_names))),
                'alpha'          : np.zeros((num_site, len(sim_names))),
                'beta'           : np.zeros((num_site, len(sim_names))),
                'corr_seas'      : np.zeros((num_site, len(sim_names))),
                'corr'           : np.zeros((num_site, len(sim_names))),
                'kge'            : np.zeros((num_site, len(sim_names))),
                'pbias_djf'      : np.zeros((num_site, len(sim_names))),
                'pbias_mam'      : np.zeros((num_site, len(sim_names))),
                'pbias_jja'      : np.zeros((num_site, len(sim_names))),
                'pbias_son'      : np.zeros((num_site, len(sim_names))), 
               }

for r, site in enumerate(common_site):
    # nrni
    sr_obs = ds_nrni_selected.sel(time=periods['control']['time'], site=site)['streamflow'].values
    sr_obs = np.where(sr_obs<0.0, 1.0e-7,sr_obs)

    sr_seas_obs = ds_seasonal_flow['obs'].sel(site=site).values
    
    sr_obs_centroid = ds_flow_metrics['ctr']['obs'].sel(site=site)['ann_centroid_day'].values
    sr_obs_max_day  = ds_flow_metrics['annual_max']['obs'].sel(site=site)['ann_max_day'].values
    sr_obs_max_flow = ds_flow_metrics['annual_max']['obs'].sel(site=site)['ann_max_flow'].values
    sr_obs_min_day  = ds_flow_metrics['annual_min']['obs'].sel(site=site)['ann_min_day'].values
    sr_obs_min_flow = ds_flow_metrics['annual_min']['obs'].sel(site=site)['ann_min_flow'].values
    sr_obs_mean_high_q_dur = ds_flow_metrics['high_q_freq_dur']['obs'].sel(site=site)['mean_high_q_dur'].values
    sr_obs_freq_high_q     = ds_flow_metrics['high_q_freq_dur']['obs'].sel(site=site)['freq_high_q'].values
    sr_obs_FHV      = ds_flow_metrics['FHV']['obs'].sel(site=site)['FHV'].values
    sr_obs_FLV      = ds_flow_metrics['FLV']['obs'].sel(site=site)['FLV'].values
    sr_obs_djf      = ds_flow_metrics['season_mean']['obs'].sel(site=site, season='DJF').values
    sr_obs_mam      = ds_flow_metrics['season_mean']['obs'].sel(site=site, season='MAM').values
    sr_obs_jja      = ds_flow_metrics['season_mean']['obs'].sel(site=site, season='JJA').values
    sr_obs_son      = ds_flow_metrics['season_mean']['obs'].sel(site=site, season='SON').values
    
    for c, sim_name in enumerate(sim_names): # sim_names is merge of retro and gcms
       
        # simulated flow series
        sr_sim = ds_qsim_selected[sim_name]['hist'].sel(time=periods['control']['time'], site=site)['streamflow'].values
        sr_sim = np.where(sr_sim<0.0, 1.0e-7,sr_sim)
        
        sr_seas_sim = ds_seasonal_flow[sim_name]['hist']['control'].sel(site=site).values
        
        sr_centroid = ds_flow_metrics['ctr'][sim_name]['hist']['control'].sel(site=site)['ann_centroid_day'].values
        sr_max_day  = ds_flow_metrics['annual_max'][sim_name]['hist']['control'].sel(site=site)['ann_max_day'].values
        sr_max_flow = ds_flow_metrics['annual_max'][sim_name]['hist']['control'].sel(site=site)['ann_max_flow'].values
        sr_min_day  = ds_flow_metrics['annual_min'][sim_name]['hist']['control'].sel(site=site)['ann_min_day'].values
        sr_min_flow = ds_flow_metrics['annual_min'][sim_name]['hist']['control'].sel(site=site)['ann_min_flow'].values
        sr_mean_high_q_dur = ds_flow_metrics['high_q_freq_dur'][sim_name]['hist']['control'].sel(site=site)['mean_high_q_dur'].values
        sr_freq_high_q     = ds_flow_metrics['high_q_freq_dur'][sim_name]['hist']['control'].sel(site=site)['freq_high_q'].values
        sr_FHV      = ds_flow_metrics['FHV'][sim_name]['hist']['control'].sel(site=site)['FHV'].values
        sr_FLV      = ds_flow_metrics['FLV'][sim_name]['hist']['control'].sel(site=site)['FLV'].values
        sr_djf      = ds_flow_metrics['season_mean'][sim_name]['hist']['control'].sel(site=site, season='DJF').values
        sr_mam      = ds_flow_metrics['season_mean'][sim_name]['hist']['control'].sel(site=site, season='MAM').values
        sr_jja      = ds_flow_metrics['season_mean'][sim_name]['hist']['control'].sel(site=site, season='JJA').values
        sr_son      = ds_flow_metrics['season_mean'][sim_name]['hist']['control'].sel(site=site, season='SON').values
        
        # compute error in flow metrics
        error_metric['annual_centroid'][r,c] = metrics.bias(sr_obs_centroid, sr_centroid)
        error_metric['annual_max_day'][r,c]  = metrics.bias(sr_obs_max_day, sr_max_day)
        error_metric['annual_max_flow'][r,c] = metrics.pbias(sr_obs_max_flow, sr_max_flow)*100
        error_metric['annual_min_day'][r,c]  = metrics.bias(sr_obs_min_day, sr_min_day)
        error_metric['annual_min_flow'][r,c] = metrics.pbias(sr_obs_min_flow, sr_min_flow)*100
        error_metric['mean_high_q_dur'][r,c] = metrics.bias(sr_obs_mean_high_q_dur, sr_mean_high_q_dur)
        error_metric['freq_high_q'][r,c]     = metrics.bias(sr_obs_freq_high_q, sr_freq_high_q)
        error_metric['pbiasFHV'][r,c]        = metrics.pbias(sr_obs_FHV, sr_FHV)*100
        error_metric['pbiasFLV'][r,c]        = metrics.pbias(sr_obs_FLV, sr_FLV)*100
        
        error_metric['alpha'][r,c] = metrics.alpha(sr_obs, sr_sim)
        error_metric['beta'][r,c]  = metrics.beta(sr_obs, sr_sim)
        error_metric['corr'][r,c]  = metrics.corr(sr_obs, sr_sim)
        error_metric['kge'][r,c]   = metrics.kge(sr_obs, sr_sim)

        error_metric['corr_seas'][r,c]  = metrics.corr(sr_seas_obs, sr_seas_sim)
        
        error_metric['pbias_djf'][r,c]  = metrics.pbias(sr_obs_djf, sr_djf)*100
        error_metric['pbias_mam'][r,c]  = metrics.pbias(sr_obs_mam, sr_mam)*100
        error_metric['pbias_jja'][r,c]  = metrics.pbias(sr_obs_jja, sr_jja)*100
        error_metric['pbias_son'][r,c]  = metrics.pbias(sr_obs_son, sr_son)*100

---
## 6. GCM analysis during control period

- bais at each site
- seasonal bias at each site

### Fig1. Maps of seasonal bias

In [None]:
%%time

%matplotlib agg
plt.rcParams.update({'figure.max_open_warning': 0})

mpl.rcParams['xtick.labelsize'] = 7 
mpl.rcParams['ytick.labelsize'] = 7 
mpl.rcParams['axes.labelsize'] = 8 
mpl.rcParams['axes.titlesize'] = 8 

ncols=2
nrows=2

for case in sim_names: # sim_names is merge of retro and gcms
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5.5, 5), subplot_kw={"projection": ccrs.PlateCarree()}, dpi=100,)
    fig.subplots_adjust(left=0.025, bottom=0.025, right=0.965, top=0.90, wspace=0.10, hspace=0.125)

    for ix, season in enumerate(['djf','mam','jja','son']):
        row = ix // ncols
        col = ix % ncols
    
        df_stat = pd.DataFrame(data=error_metric[f'pbias_{season}'], index=common_site, columns=sim_names)
        df_stat.reset_index(level=0, inplace=True)
        df_stat.rename(columns={'index':'location_name'},inplace=True)
        df_stat_final = df_site_selected.merge(df_stat, on="location_name", how = 'inner')
        
        base_map(ax[row,col],df_huc12)
        ax1 = df_stat_final.plot(ax=ax[row,col], column=case, markersize=15, cmap=ccmap.cmap_bias2, norm=ccmap.norm_bias2, legend=True, zorder=2, 
                           legend_kwds={'extend':'both', 'pad':0.02, 'shrink': 0.90});
        ax[row,col].set_title(f'{season}', fontsize=8)
        
    fig.suptitle(f'{case} seasonal %bias [%]', fontsize=9, y=0.975)
    fig.savefig(os.path.join(figure_path, f'./Fig1_seasonal_pbias_map_{case}.png'), dpi=200)
    break

### Fig2. Maps of error - high flow frequency [count] and duration [days]

In [None]:
%%time

%matplotlib agg
plt.rcParams.update({'figure.max_open_warning': 0})

mpl.rcParams['xtick.labelsize'] = 7 
mpl.rcParams['ytick.labelsize'] = 7 
mpl.rcParams['axes.labelsize'] = 8 
mpl.rcParams['axes.titlesize'] = 8

metric_list = {
    'freq_high_q':{'unit':'-','norm':ccmap.norm_freq_high_q, 'cmap_diff':ccmap.cmap_freq_high_q_diff, 'norm_diff':ccmap.norm_freq_high_q_diff}, 
    'mean_high_q_dur':{'unit':'day','norm':ccmap.norm_freq_high_dur,'cmap_diff':ccmap.cmap_freq_high_dur_diff, 'norm_diff':ccmap.norm_freq_high_dur_diff},
}

for case in sim_names: # sim_names is merge of retro and gcms
    fig, axs = plt.subplots(ncols=len(metric_list), nrows=2, figsize=(6.0, 5.0), subplot_kw={"projection": ccrs.PlateCarree()}, dpi=150,)
    fig.subplots_adjust(left=0.025, bottom=0.025, right=0.965, top=0.90, wspace=0.10, hspace=0.125)
    
    for ix, (metric_name, meta) in enumerate(metric_list.items()):
        #obs map - top rows
        df_obs = ds_flow_metrics['high_q_freq_dur']['obs'][metric_name].mean(dim='year').to_dataframe()
        df_obs.index.rename('location_name',inplace=True)
        df_obs_final = df_site_selected.merge(df_obs, on="location_name", how = 'inner')

        base_map(axs[0,ix],df_huc12)
        df_obs_final.plot(ax=axs[0,ix], column=metric_name, markersize=15, cmap='turbo', norm=meta['norm'], zorder=2, 
                          legend=True, legend_kwds={'extend':'max', 'pad':0.02, 'shrink': 0.85});
        axs[0,ix].text(0.5, 1.125, '%s [%s]'%(metric_name, meta['unit']), transform=axs[0,ix].transAxes, ha='center', fontsize=9)
        axs[0,ix].text(0.5, 1.025, 'Reference flow', transform=axs[0,ix].transAxes, ha='center', fontsize=8)

        base_map(axs[1,ix],df_huc12)
        df_err = pd.DataFrame(data=error_metric[metric_name], index=common_site, columns=sim_names)
        df_err.reset_index(level=0, inplace=True)
        df_err.rename(columns={'index':'location_name'},inplace=True)
        df_err_final = df_site_selected.merge(df_err, on="location_name", how = 'inner')
        df_err_final.plot(ax=axs[1,ix], column=case, markersize=10, cmap=meta['cmap_diff'], norm=meta['norm_diff'], zorder=2,
                          legend=True, legend_kwds={'extend':'both', 'pad':0.02, 'shrink': 0.875});
        axs[1,ix].text(0.5, 1.025, 'error %s [%s]'%(case,meta['unit']), transform=axs[1,ix].transAxes, ha='center', fontsize=8)
    fig.savefig(os.path.join(figure_path, f'Fig2_high_q_event_map_{case}.png'), dpi=200)
    break

### Fig 3. Maps of timing error - centroid, annual_max_day, annual_min_day

In [None]:
%%time

%matplotlib agg
plt.rcParams.update({'figure.max_open_warning': 0})

mpl.rcParams['xtick.labelsize'] = 7 
mpl.rcParams['ytick.labelsize'] = 7 
mpl.rcParams['axes.labelsize'] = 8 
mpl.rcParams['axes.titlesize'] = 8

metric_list = {
    'annual_centroid': {'vname':'ann_centroid_day','unit':'day','cmap_diff':ccmap.cmap_centroid_diff, 'norm_diff':ccmap.norm_centroid_diff}, 
    'annual_max_day':  {'vname':'ann_max_day','unit':'day','cmap_diff':ccmap.cmap_max_day_diff, 'norm_diff':ccmap.norm_max_day_diff}, 
    'annual_min_day':  {'vname':'ann_min_day','unit':'day','cmap_diff':ccmap.cmap_min_day_diff, 'norm_diff':ccmap.norm_min_day_diff},
}

for case in sim_names: # sim_names is merge of retro and gcms
    fig, axs = plt.subplots(ncols=len(metric_list), nrows=2, figsize=(6.5, 4.5), subplot_kw={"projection": ccrs.PlateCarree()}, dpi=150,)
    fig.subplots_adjust(left=0.025, bottom=0.025, right=0.965, top=0.90, wspace=0.10, hspace=0.125)
    
    for ix, (metric_name, meta) in enumerate(metric_list.items()):
        if metric_name == 'annual_max_day':
            ds_plot = ds_flow_metrics['annual_max'].copy()
        elif metric_name == 'annual_min_day':
            ds_plot = ds_flow_metrics['annual_min'].copy()
        elif metric_name == 'annual_centroid':
            ds_plot = ds_flow_metrics['ctr'].copy()

        #obs map - top rows
        df_obs = ds_plot['obs'][meta['vname']].mean(dim='year').to_dataframe()
        df_obs.index.rename('location_name',inplace=True)
        df_obs_final = df_site_selected.merge(df_obs, on="location_name", how = 'inner')

        base_map(axs[0,ix],df_huc12)
        df_obs_final.plot(ax=axs[0,ix], column=meta['vname'], markersize=10, cmap='turbo', zorder=2, 
                          legend=True, legend_kwds={'extend':'neither', 'pad':0.02, 'shrink': 0.825});
        axs[0,ix].text(0.5, 1.125, '%s [%s]'%(metric_name, meta['unit']), transform=axs[0,ix].transAxes, ha='center', fontsize=9)
        axs[0,ix].text(0.5, 1.025, 'Reference flow', transform=axs[0,ix].transAxes, ha='center', fontsize=8)

        base_map(axs[1,ix],df_huc12)
        df_err = pd.DataFrame(data=error_metric[metric_name], index=common_site, columns=sim_names)
        df_err.reset_index(level=0, inplace=True)
        df_err.rename(columns={'index':'location_name'},inplace=True)
        df_err_final = df_site_selected.merge(df_err, on="location_name", how = 'inner')
        df_err_final.plot(ax=axs[1,ix], column=case, markersize=10, cmap=meta['cmap_diff'], norm=meta['norm_diff'], zorder=2,
                          legend=True, legend_kwds={'extend':'both', 'pad':0.02, 'shrink': 0.85});
        axs[1,ix].text(0.5, 1.025, 'error %s [%s]'%(case,meta['unit']), transform=axs[1,ix].transAxes, ha='center', fontsize=8)
    fig.savefig(os.path.join(figure_path, f'Fig3_timing_map_{case}.png'), dpi=200)

### Figs 4-6: Boxplots of errors - KGE and its component, and FLV and FHV during control period for all GCMs and gmet

In [None]:
# for paper
%matplotlib inline
removed_sim=['']
sim_included = [s for s in sim_names if s not in removed_sim]

patch_colors=[]
for sim in sims:
    if sim in gcm_runs:
        if gcm_runs[sim]['cmip']==5:
            c='b'
        elif gcm_runs[sim]['cmip']==6:
            c='b'
    if sim in retro_runs:
        c='k'
    patch_colors.append(c)

mpl.rcParams['xtick.labelsize'] = 7 
mpl.rcParams['ytick.labelsize'] = 7 
mpl.rcParams['axes.labelsize'] = 8 
mpl.rcParams['axes.titlesize'] = 8

error_metric_list = {
    'alpha':{'range':[0.5,1,1.5],'header':r'a) flow variability ratio ($\alpha$) [-]'},
    'beta':{'range':[0.5,1,1.5],'header':r'b) mean flow ratio ($\beta$) [-]'},
    'pbiasFHV':{'range':[-50,0,50],'header':'c) %bias Flow >98% in FDC [%]'}, 
    'pbiasFLV':{'range':[-100,0,100],'header':'d) %bias Flow <10% in FDC [%]'},
    'corr_seas':{'range':[0.5,0.5,1],'header':'e) seasonal correlation (r) [-]'},
}

fig, axs = plt.subplots(nrows=len(error_metric_list), ncols=1, figsize=(6.75, 7.5), dpi=100)
fig.subplots_adjust(left=0.075, bottom=0.175, right=0.965, top=0.95, wspace=0.10, hspace=0.3)

df_describe = {}
for ix, (error_name, meta) in enumerate(error_metric_list.items()):
    mask = (~np.isnan(error_metric[error_name]) & ~np.isinf(error_metric[error_name]))
    filtered_data = [d[m] for s, d, m in zip(sim_included, error_metric[error_name].T, mask.T) if s not in removed_sim]
    #axs[ix].violinplot(filtered_data, showextrema=True, showmedians=True)
    bplot = axs[ix].boxplot(filtered_data, showfliers=False, patch_artist=True)
    axs[ix].axhline(y=meta['range'][1], color='k', linestyle='--', lw=0.5)

    # fill with colors
    for patch, color in zip(bplot['boxes'], patch_colors):
        patch.set_facecolor(color)
    
    axs[ix].set_title(meta['header'])
    axs[ix].set_ylim([meta['range'][0], meta['range'][2]])
    axs[ix].set_xlabel('')
    if ix==len(error_metric_list)-1:
        axs[ix].set_xticklabels(sim_included, rotation=90);
    else:
        axs[ix].set_xticklabels('');
    df_describe[error_name] = pd.DataFrame({name:data for name, data in zip(sim_included, filtered_data)}).describe()
fig.savefig(os.path.join(figure_path, f'Fig4_gcm_error_summary_ppt_paper.png'), dpi=300)

In [None]:
# get summary statistics for each metrics
error_name = 'beta' # alpha, beta, pbiasFHV, pbiasFLV
df_describe[error_name]

In [None]:
# errors in high and low flow event metrics

removed_sim=['']
sim_included = [s for s in sim_names if s not in removed_sim]

mpl.rcParams['xtick.labelsize'] = 6.5 
mpl.rcParams['ytick.labelsize'] = 7 
mpl.rcParams['axes.labelsize'] = 8 
mpl.rcParams['axes.titlesize'] = 9 

error_metric_list = {
    'annual_centroid': {'range':[-50,0,50],'header':r'annual centroid [day]'},
    'annual_max_day':{'range':[-50,0,50],'header':r'annual max flow day [day]'},
    'annual_min_day':{'range':[-200,0,160],'header':r'annual min flow day [day]'},
}

fig, axs = plt.subplots(nrows=len(error_metric_list), ncols=1, figsize=(5.0, 6.0), dpi=100)
fig.subplots_adjust(left=0.075, bottom=0.225, right=0.975, top=0.95, wspace=0.10, hspace=0.25)
  
for ix, (error_name, meta) in enumerate(error_metric_list.items()):
    mask = (~np.isnan(error_metric[error_name]) & ~np.isinf(error_metric[error_name]))
    filtered_data = [d[m] for s, d, m in zip(sim_included, error_metric[error_name].T, mask.T) if s not in removed_sim]
    #axs[ix].violinplot(filtered_data, showextrema=True, showmedians=True)
    bplot = axs[ix].boxplot(filtered_data, showfliers=False, patch_artist=True)
    axs[ix].axhline(y=meta['range'][1], color='k', linestyle='--', lw=0.5)

    # fill with colors
    for patch, color in zip(bplot['boxes'], patch_colors):
        patch.set_facecolor(color)    
        
    axs[ix].set_title(meta['header'])
    axs[ix].set_ylim([meta['range'][0], meta['range'][2]])
    axs[ix].set_xlabel('')
    if ix==len(error_metric_list)-1:
        axs[ix].set_xticklabels(sim_included, rotation=90);
    else:
        axs[ix].set_xticklabels('');
fig.savefig(os.path.join(figure_path, f'Fig5_gcm_timing_error_summary.png'), dpi=300)    

In [None]:
# errors in event metrics

removed_sim=['']
sim_included = [s for s in sim_names if s not in removed_sim]

mpl.rcParams['xtick.labelsize'] = 6.5 
mpl.rcParams['ytick.labelsize'] = 7 
mpl.rcParams['axes.labelsize'] = 8 
mpl.rcParams['axes.titlesize'] = 9 

metric_list = {'freq_high_q':    {'range':[-5,0,5],'header':r'annual high flow frequency'}, 
               'mean_high_q_dur':{'range':[-30,0,30],'header':r'mean high flow duration'}
              }

fig, axs = plt.subplots(nrows=len(metric_list), ncols=1, figsize=(6.5, 4.0), dpi=100)
fig.subplots_adjust(left=0.075, bottom=0.225, right=0.975, top=0.95, wspace=0.10, hspace=0.25)

for ix, (error_name, meta) in enumerate(metric_list.items()):
    mask = (~np.isnan(error_metric[error_name]) & ~np.isinf(error_metric[error_name]))
    filtered_data = [d[m] for s, d, m in zip(sim_included, error_metric[error_name].T, mask.T) if s not in removed_sim]
  
    bplot = axs[ix].boxplot(filtered_data, showfliers=False, patch_artist=True)
    axs[ix].axhline(y=meta['range'][1], color='k', linestyle='--', lw=0.5)
    
    # fill with colors
    for patch, color in zip(bplot['boxes'], patch_colors):
        patch.set_facecolor(color)    
    #axs[ix].boxplot(filtered_data, showfliers=False)
    #axs[ix].axhline(y=0.0, color='k', linestyle='--', lw=0.5)
    
    axs[ix].set_title(meta['header'])
    axs[ix].set_ylim([meta['range'][0], meta['range'][2]])
    axs[ix].set_xlabel('')
    if ix==len(metric_list)-1:
        axs[ix].set_xticklabels(sim_included, rotation=90);
    else:
        axs[ix].set_xticklabels('');

fig.savefig(os.path.join(figure_path, f'Fig6_gcm_event_error_summary.png'), dpi=300)    

-----
## 7.Changes in flow metrics for future periods compared to control period

###  Fig 7. Each GCM and scenario

In [None]:
ds_flow_metrics.keys()

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import Normalize
cmap_mean_flow_diff = LinearSegmentedColormap.from_list('custom1', 
                                             [(0.0,      'xkcd:red'),
                                              (100/1600, 'xkcd:white'),
                                              (1.0,      'xkcd:blue')], N=255)
cmap_mean_flow_diff.set_over('xkcd:dark blue')
cmap_mean_flow_diff.set_under('xkcd:dark red')
norm_mean_flow_diff=mpl.colors.Normalize(vmin=-100, vmax=1500)

# percent diff
cmap_mean_flow_pdiff = LinearSegmentedColormap.from_list('custom1', 
                                             [(0.0,      'xkcd:red'),
                                              (40/100, 'xkcd:white'),
                                              (1.0,      'xkcd:blue')], N=255)
cmap_mean_flow_pdiff.set_over('xkcd:dark blue')
cmap_mean_flow_pdiff.set_under('xkcd:dark red')
norm_mean_flow_pdiff=mpl.colors.Normalize(vmin=-40, vmax=60)

# annual mean flow
norm_mean_flow = mpl.colors.LogNorm(vmin=1, vmax=10000)

In [None]:
%%time
# For individual GCM, change due to ssp/rcp and periods

%matplotlib inline
# CMIP6: 'CanESM5' 'CMCC-CM2-SR5' 'MIROC-ES2L' 'NorESM2-MM'
# CMIP5: 'CanESM2' 'CMCC-CM' 'CNRM-CM5' 'MIROC5' 'MRI-CGCM3' 'CCSM4' 'GFDL-CM3'
# retro: 'gmet'

flow_metric  = 'FHV'   #annual_mean ann_centroid_day, ann_max_day, ann_max_flow, ann_min_day, ann_min_flow
gcm          = 'NorESM2-MM' 
percent_change = False

if flow_metric == 'annual_mean':
    ds_plot = ds_flow_metrics['annual_mean'].copy()
    unit='m3/s'
    norm_flow=norm_mean_flow
    if percent_change:
        cmap_diff=cmap_mean_flow_pdiff
        norm_diff=norm_mean_flow_pdiff
    else:
        cmap_diff=cmap_mean_flow_diff
        norm_diff=norm_mean_flow_diff
elif flow_metric == 'ann_max_day':
    ds_plot = ds_flow_metrics['annual_max'].copy()
    unit='day since 10/1'
    cmap_diff=ccmap.cmap_max_day_diff
    norm_diff=ccmap.norm_max_day_diff
elif flow_metric == 'ann_min_day':
    ds_plot = ds_flow_metrics['annual_min'].copy()
    unit='day since 10/1'
    cmap_diff=ccmap.cmap_min_day_diff
    norm_diff=ccmap.norm_min_day_diff
elif flow_metric == 'ann_centroid_day':
    ds_plot = ds_flow_metrics['ctr'].copy()
    unit='day since 10/1'
    cmap_diff=ccmap.cmap_centroid_diff
    norm_diff=ccmap.norm_centroid_diff
elif flow_metric == 'ann_max_flow':
    ds_plot = ds_flow_metrics['annual_max'].copy()
    unit='m3/s'
    norm_flow=ccmap.norm_max_flow
    if percent_change:
        cmap_diff=ccmap.cmap_max_flow_pdiff
        norm_diff=ccmap.norm_max_flow_pdiff
    else:
        cmap_diff=ccmap.cmap_max_flow_diff
        norm_diff=ccmap.norm_max_flow_diff
elif flow_metric == 'ann_min_flow':
    ds_plot = ds_flow_metrics['annual_min'].copy()
    unit='m3/s'
    cmap_diff=ccmap.cmap_min_flow_diff
    norm_diff=ccmap.norm_min_flow_diff
    norm_flow=ccmap.norm_min_flow
elif flow_metric == 'freq_high_q':  # annual high flow frequency
    ds_plot = ds_flow_metrics['high_q_freq_dur'].copy()
    unit='counts/yr'
    cmap_diff=ccmap.cmap_freq_high_q_diff
    norm_diff=ccmap.norm_freq_high_q_diff
    norm_flow=ccmap.norm_freq_high_q
elif flow_metric == 'mean_high_q_dur': # annual mean high flow duration 
    ds_plot = ds_flow_metrics['high_q_freq_dur'].copy()
    unit='days'
    cmap_diff=ccmap.cmap_freq_high_dur_diff
    norm_diff=ccmap.norm_freq_high_dur_diff
    norm_flow=ccmap.norm_freq_high_dur
elif flow_metric == 'FHV': # annual mean high flow duration
    ds_plot = ds_flow_metrics['FHV'].copy()
    unit='m3/s'
    norm_flow=ccmap.norm_max_flow
    if percent_change:
        cmap_diff=ccmap.cmap_max_flow_pdiff
        norm_diff=ccmap.norm_max_flow_pdiff
    else:
        cmap_diff=ccmap.cmap_max_flow_diff
        norm_diff=ccmap.norm_max_flow_diff

ncols = 3
nrows = 3
top=0.925
shrink=0.875
figsize=(6.5, 6.25)
if gcm_runs[gcm]['cmip'] == 5: 
    nrows = 2
    top=0.925
    shrink=0.725
    figsize=(6.5, 4.75)
cbar_kwrgs = {"orientation":"vertical", "shrink":shrink, "pad":0.02, 'extend':'both'}

fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}, dpi=100,)
fig.subplots_adjust(left=0.025, bottom=0.025, right=0.965, top=top, wspace=0.10, hspace=0.125)

for ix, period in enumerate(['control','2040s','2080s']):
    for jx, scen in enumerate([item for item in gcm_runs[gcm]['scen'] if item!='hist']):

        if jx>0 and ix==0:
            continue
            
        base_map(ax[jx,ix],df_huc12)

        if ix==0 and jx==0:
            if 'year' in ds_plot[gcm_name]['hist']['control'].variables:
                df_control = ds_plot[gcm]['hist']['control'].mean(dim='year').to_dataframe()
            else:
                df_control = ds_plot[gcm]['hist']['control'].to_dataframe()
            df_control.index.rename('location_name',inplace=True)
            df_control_final = df_site_selected.merge(df_control, on="location_name", how = 'inner')
 
            if flow_metric in ['ann_min_flow', 'ann_max_flow', 'FHV', 'annual_mean']:
                df_control_final.plot(ax=ax[jx,ix], column=flow_metric, markersize=5, cmap='turbo_r', norm=norm_flow, legend=False, zorder=2,)
            elif flow_metric in ['freq_high_q']:
                cbar_kwrgs['extend']='max'
                df_control_final.plot(ax=ax[jx,ix], column=flow_metric, markersize=5, cmap='turbo_r', norm=norm_flow, legend=False, zorder=2,)
            else:
                df_control_final.plot(ax=ax[jx,ix], column=flow_metric, markersize=5, cmap='turbo', legend=False, zorder=2,)
            ax[jx, ix].set_title(f'{period} [{unit}]', fontsize=8)           
            print(f'{period}')
        else:
            if 'year' in ds_plot[gcm_name]['hist']['control'].variables:
                if percent_change:  # use % change
                    df_change = (100*(ds_plot[gcm][scen][period].mean(dim='year') - ds_plot[gcm]['hist']['control'].mean(dim='year'))/ds_plot[gcm]['hist']['control'].mean(dim='year')).to_dataframe()
                else:
                    df_change = (ds_plot[gcm][scen][period].mean(dim='year') - ds_plot[gcm]['hist']['control'].mean(dim='year')).to_dataframe()
            else:
                if percent_change:  # use % change
                    df_change = (100*(ds_plot[gcm][scen][period] - ds_plot[gcm]['hist']['control'])/ds_plot[gcm]['hist']['control']).to_dataframe()
                else:
                    df_change = (ds_plot[gcm][scen][period] - ds_plot[gcm]['hist']['control']).to_dataframe()
            if flow_metric in ['annual_max_day', 'ann_min_day', 'ann_centroid_day']:
                unit='days'
            if percent_change:  # use % change
                unit='%'
            df_change.index.rename('location_name',inplace=True)
            df_change_final = df_site_selected.merge(df_change, on="location_name", how = 'inner')
            df_change_final.plot(ax=ax[jx,ix], column=flow_metric, markersize=5, cmap=cmap_diff, norm=norm_diff, legend=False, zorder=2,)       
            ax[jx, ix].set_title(f'{scen} {period} [{unit}]', fontsize=8)
            print(f'{scen}-{period}')
            
        points = ax[jx,ix].collections[-1]
        cbar = plt.colorbar(points, ax=ax[jx,ix], **cbar_kwrgs);
        cbar.ax.tick_params(labelsize=6) 

for kx in range(1,nrows):
    fig.delaxes(ax[kx][0])

fig.suptitle(f'{gcm} {flow_metric} change', fontsize=9, y=0.985);
if percent_change:
    fig.savefig(os.path.join(figure_path, f'Fig7_{flow_metric}_{gcm}_pcnt_change.png'), dpi=300)
else:
    fig.savefig(os.path.join(figure_path, f'Fig7_{flow_metric}_{gcm}_change.png'), dpi=300)

### Fig 8. Ensemble of GCMs and scenarios

In [None]:
ensembles

In [None]:
%%time
# For change in flow metrics for ensembles of GCMs/scnearios 

%matplotlib inline

flow_metric   = 'ann_max_flow'   #annual_mean, ann_centroid_day, ann_max_day, ann_max_flow, ann_min_day, ann_min_flow, FHV
ensemble_name = 'high-emission' 
percent_change = True

if flow_metric == 'ann_max_day':
    ds_plot = ds_flow_metrics['annual_mean'].copy()
    unit='m3/s'; metric_name = 'Annual mean flow'
    norm_flow=norm_mean_flow
    if percent_change:
        cmap_diff=cmap_mean_flow_pdiff
        norm_diff=norm_mean_flow_pdiff
    else:
        cmap_diff=cmap_mean_flow_diff
        norm_diff=norm_mean_flow_diff
elif flow_metric == 'ann_max_day':
    ds_plot = ds_flow_metrics['annual_max'].copy()
    unit='day since 10/1'; metric_name = 'Annual maximum flow day'
    cmap_diff=ccmap.cmap_max_day_diff
    norm_diff=ccmap.norm_max_day_diff
elif flow_metric == 'ann_min_day':
    ds_plot = ds_flow_metrics['annual_min'].copy()
    unit='day since 10/1'; metric_name = 'Annual minimum flow day'
    cmap_diff=ccmap.cmap_min_day_diff
    norm_diff=ccmap.norm_min_day_diff
elif flow_metric == 'ann_centroid_day':
    ds_plot = ds_flow_metrics['ctr'].copy()
    unit='day since 10/1'; metric_name = 'Annual centroid'
    cmap_diff=ccmap.cmap_centroid_diff
    norm_diff=ccmap.norm_centroid_diff
elif flow_metric == 'ann_max_flow':
    ds_plot = ds_flow_metrics['annual_max'].copy()
    unit='m3/s'; metric_name = 'Annual maximum flow'
    if percent_change:
        cmap_diff=cmap_mean_flow_pdiff #ccmap.cmap_max_flow_pdiff
        norm_diff=norm_mean_flow_pdiff #ccmap.norm_max_flow_pdiff
        norm_flow=ccmap.norm_max_flow
    else:
        cmap_diff=ccmap.cmap_max_flow_diff
        norm_diff=ccmap.norm_max_flow_diff
        norm_flow=ccmap.norm_max_flow
elif flow_metric == 'ann_min_flow':
    ds_plot = ds_flow_metrics['annual_min'].copy()
    unit='m3/s'; metric_name = 'Annual minimum flow'
    cmap_diff=ccmap.cmap_min_flow_diff
    norm_diff=ccmap.norm_min_flow_diff
    norm_flow=ccmap.norm_min_flow
elif flow_metric == 'freq_high_q':  # annual high flow frequency
    ds_plot = ds_flow_metrics['high_q_freq_dur'].copy()
    unit='counts/yr'; metric_name = 'Annual high flow frequency'
    cmap_diff=ccmap.cmap_freq_high_q_diff
    norm_diff=ccmap.norm_freq_high_q_diff
    norm_flow=ccmap.norm_freq_high_q
elif flow_metric == 'mean_high_q_dur': # annual mean high flow duration 
    ds_plot = ds_flow_metrics['high_q_freq_dur'].copy()
    unit='days'; metric_name = 'Annual mean high flow duration'
    cmap_diff=ccmap.cmap_freq_high_dur_diff
    norm_diff=ccmap.norm_freq_high_dur_diff
    norm_flow=ccmap.norm_freq_high_dur
elif flow_metric == 'FHV': # annual mean high flow duration
    ds_plot = ds_flow_metrics['FHV'].copy() 
    unit='m3/s'; metric_name = 'Flow>90% in FDC'
    if percent_change:
        cmap_diff=ccmap.cmap_max_flow_pdiff
        norm_diff=ccmap.norm_max_flow_pdiff
        norm_flow=ccmap.norm_max_flow
    else:
        cmap_diff=ccmap.cmap_max_flow_diff
        norm_diff=ccmap.norm_max_flow_diff
        norm_flow=ccmap.norm_max_flow        
    
# -------------------
print('Computing ensemble mean.....')
# -------------------
gcm_plots = {}
for gcm, meta in gcm_runs.items():
    if meta['cmip'] in ensembles[ensemble_name]['cmip']:
        gcm_plots[gcm] = [scen for scen in meta['scen'] if scen in ensembles[ensemble_name]['scen']]
print(gcm_plots)

dr_metric_cat = {}
count = 0 
for gcm_name, scen_list in gcm_plots.items():
    if count==0:
        dr_metric_cat['control'] = ds_plot[gcm_name]['hist']['control'][flow_metric]
    else:
        dr_metric_cat['control'] =  xr.concat([dr_metric_cat['control'], ds_plot[gcm_name]['hist']['control'][flow_metric]], "gcm")
    count+=1

count = 0 
for gcm_name, scen_list in gcm_plots.items():
    for scen in scen_list:
        if scen=='hist':
            continue
        if count==0:
            dr_metric_cat['2040s']  = ds_plot[gcm_name][scen]['2040s'][flow_metric]
            dr_metric_cat['2080s']  = ds_plot[gcm_name][scen]['2080s'][flow_metric]
        else:
            dr_metric_cat['2040s']  = xr.concat([dr_metric_cat['2040s'], ds_plot[gcm_name][scen]['2040s'][flow_metric]], "gcm")
            dr_metric_cat['2080s']  = xr.concat([dr_metric_cat['2080s'], ds_plot[gcm_name][scen]['2080s'][flow_metric]], "gcm")
        count+=1

# -------------------
print('plotting.....')
# -------------------
ncols = 3
nrows = 1
shrink=0.65
figsize=(6.25, 2.5)

fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}, dpi=150,)
fig.subplots_adjust(left=0.025, bottom=0.025, right=0.965, top=0.965, wspace=0.10, hspace=0.125)

for ix, period in enumerate(['control','2040s','2080s']):
    
    cbar_kwrgs = {"orientation":"vertical", "shrink":shrink, "pad":0.02, 'extend':'both'}
    
    print(f'{ix}. {period}')
    
    base_map(ax[ix],df_huc12)

    if period=='control':
        if 'year' in dr_metric_cat[period].coords:
            df_control = dr_metric_cat[period].mean(dim=['year','gcm']).to_dataframe()
            #hist_test = ds_metric_cat[period].sel(year=slice('1980','2005'))[flow_metric] # array for student t-test for mean difference
            hist_test = dr_metric_cat[period].mean(dim='year') # array for student t-test for mean difference
        else:
            df_control = dr_metric_cat[period].mean(dim='gcm').to_dataframe()
            hist_test = dr_metric_cat[period]
        df_control.index.rename('location_name',inplace=True)
        df_control_final = df_site_selected.merge(df_control, on="location_name", how = 'inner')
        if flow_metric in ['ann_min_flow', 'ann_max_flow', 'FHV', 'annual_mean']:
            df_control_final.plot(ax=ax[ix], column=flow_metric, markersize=3, cmap='turbo_r', norm=norm_flow, legend=False, zorder=2,)
        elif flow_metric in ['freq_high_q']:
            cbar_kwrgs['extend']='max'
            df_control_final.plot(ax=ax[ix], column=flow_metric, markersize=2, cmap='turbo_r', norm=norm_flow, legend=False, zorder=2,)
        else:
            df_control_final.plot(ax=ax[ix], column=flow_metric, markersize=2, cmap='turbo', legend=False, zorder=2,)
        ax[ix].set_title(f"{periods[period]['name']}", fontsize=8)
        if flow_metric=='ann_centroid_day':
            ax[ix].text(0.5, 1.145, f'{metric_name}', horizontalalignment='center', transform=ax[0].transAxes, fontsize=9.5)
        else:
            ax[ix].text(0.6, 1.145, f'{metric_name} [{unit}]', horizontalalignment='center', transform=ax[0].transAxes, fontsize=9.5)
    else:
        if percent_change:  # use % change
            if 'year' in dr_metric_cat[period].coords:
                dr_change = 100*(dr_metric_cat[period].mean(dim=['year','gcm']) - dr_metric_cat['control'].mean(dim=['year','gcm']))/dr_metric_cat['control'].mean(dim=['year','gcm'])
            else:
                dr_change = 100*(dr_metric_cat[period].mean(dim='gcm') - dr_metric_cat['control'].mean(dim='gcm'))/dr_metric_cat['control'].mean(dim='gcm')
            unit='%'
        else:
            if 'year' in dr_metric_cat[period].coords:
                dr_change = dr_metric_cat[period].mean(dim=['year','gcm']) - dr_metric_cat['control'].mean(dim=['year','gcm'])
            else:
                dr_change = dr_metric_cat[period].mean(dim='gcm') - dr_metric_cat['control'].mean(dim='gcm')
            if flow_metric in ['annual_max_day', 'ann_min_day', 'ann_centroid_day']:
                unit='days'

        # statistical test for mean difference between two periods
        if 'year' in dr_metric_cat[period].coords: 
            fut  = dr_metric_cat[period].mean(dim='year')
        else:
            fut  = dr_metric_cat[period]
        pvalue = stats.ttest_ind(hist_test,fut, axis=0, equal_var=False).pvalue
        
        df_change = dr_change.to_dataframe()
        df_change.index.rename('location_name',inplace=True)
        df_change['pvalue'] = pvalue
        df_change_final = df_site_selected.merge(df_change, on="location_name", how = 'inner')

        if len(df_change_final[df_change_final['pvalue']>0.05])>0:
            df_change_final[df_change_final['pvalue']>0.05].plot(ax=ax[ix], column=flow_metric, marker='^', markersize=4, edgecolor="black", linewidth=0.1, cmap=cmap_diff, norm=norm_diff, legend=False, zorder=2,)
        if len(df_change_final[df_change_final['pvalue']<0.05])>0:  
            df_change_final[df_change_final['pvalue']<0.05].plot(ax=ax[ix], column=flow_metric, markersize=4, edgecolor="black", linewidth=0.1, cmap=cmap_diff, norm=norm_diff, legend=False, zorder=2,)
        ax[ix].set_title(f"{periods[period]['name']}", fontsize=8)
        if ix==1:
            if percent_change:
                ax[1].text(0.5, 1.145, f'%change in {metric_name} [{unit}]', transform=ax[1].transAxes, fontsize=9.5)
            else:
                ax[1].text(0.5, 1.145, f'Change in {metric_name} [{unit}]', transform=ax[1].transAxes, fontsize=9.5)
    points = ax[ix].collections[-1]
    cbar = plt.colorbar(points, ax=ax[ix], **cbar_kwrgs);
    if flow_metric=='ann_centroid_day' and period=='control':
        cbar.set_ticks([124, 152, 183, 213, 244])
        cbar.set_ticklabels(['2/1','3/1','4/1','5/1','6/1'])    
    cbar.ax.tick_params(labelsize=6)
    
if percent_change:
    fig.savefig(os.path.join(figure_path,f'Fig8_{flow_metric}_{ensemble_name}_ensemble_pcnt_change.png'), dpi=300)
else:
    fig.savefig(os.path.join(figure_path,f'Fig8_{flow_metric}_{ensemble_name}_ensemble_change.png'), dpi=300)

### Fig 9. Changes in annual flow magnitude for a return period compared to control period

In [None]:
cmap_max_flow_pdiff = LinearSegmentedColormap.from_list('custom1', 
                                             [(0.0,      'xkcd:red'),
                                              (20/100,   'xkcd:white'),
                                              (1.0,      'xkcd:blue')], N=255)
cmap_max_flow_pdiff.set_over('xkcd:blue')
cmap_max_flow_pdiff.set_under('xkcd:red')
norm_max_flow_pdiff=mpl.colors.Normalize(vmin=-20, vmax=80)

In [None]:
%%time

%matplotlib inline

flow_metric   = 'aff'   # only aff
return_period = 20
ensemble_name = 'high-emission' 
percent_change = True

ds_plot = ds_flow_metrics[flow_metric].copy()
unit='m3/s'; metric_name = f'{return_period}-yr annual peak flow'
if percent_change:
    cmap_diff=ccmap.cmap_max_flow_pdiff
    norm_diff=ccmap.norm_max_flow_pdiff
    norm_flow=ccmap.norm_max_flow
else:
    cmap_diff=ccmap.cmap_max_flow_diff
    norm_diff=ccmap.norm_max_flow_diff
    norm_flow=ccmap.norm_max_flow
    
# -------------------
print('Computing ensemble mean.....')
# -------------------
gcm_plots = {}
for gcm, meta in gcm_runs.items():
    if meta['cmip'] in ensembles[ensemble_name]['cmip']:
        gcm_plots[gcm] = [scen for scen in meta['scen'] if scen in ensembles[ensemble_name]['scen']]
print(gcm_plots)

ds_metric_cat = {}
count = 0 
for gcm_name, scen_list in gcm_plots.items():
    if count==0:
        ds_metric_cat['hist'] = ds_plot[gcm_name]['hist'].sel(return_period=return_period, method='nearest')
    else:
        ds_metric_cat['hist'] = xr.concat([ds_metric_cat['hist'], ds_plot[gcm_name]['hist'].sel(return_period=return_period, method='nearest')], "gcm")
    count+=1

count = 0 
for gcm_name, scen_list in gcm_plots.items():
    for scen in scen_list:
        if scen=='hist':
            continue
        if count==0:
            ds_metric_cat['2050-2099']  = ds_plot[gcm_name][scen].sel(return_period=return_period, method='nearest')
        else:
            ds_metric_cat['2050-2099']  = xr.concat([ds_metric_cat['2050-2099'], ds_plot[gcm_name][scen].sel(return_period=return_period, method='nearest')], "gcm")
        count+=1

# -------------------
print('plotting.....')
# -------------------
ncols = 2
nrows = 1
shrink=0.85
figsize=(6.0, 3.00)

fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}, dpi=150,)
fig.subplots_adjust(left=0.025, bottom=0.025, right=0.965, top=0.90, wspace=0.05, hspace=0.125)

for ix, period in enumerate(['hist','2050-2099']):
    
    cbar_kwrgs = {"orientation":"vertical", "shrink":shrink, "pad":0.02, 'extend':'both'}
    
    print(f'{ix}. {period}')
    
    base_map(ax[ix],df_huc12)

    if period=='hist':
        hist_test = ds_metric_cat[period][flow_metric]
        df_hist = ds_metric_cat[period].mean(dim='gcm').to_dataframe()
        df_hist.index.rename('location_name',inplace=True)
        df_hist_final = df_site_selected.merge(df_hist, on="location_name", how = 'inner')
        df_hist_final.plot(ax=ax[ix], column=flow_metric, markersize=3, cmap='turbo_r', norm=norm_flow, legend=False, zorder=2,)
        cbar_kwrgs['extend']='max'
        ax[ix].set_title(f'{period}', fontsize=7)
        ax[ix].text(0.075, 1.1, f'{metric_name} [{unit}]', transform=ax[0].transAxes, fontsize=8.0)
    else:
        if percent_change:  # use % change
            ds_change = 100*(ds_metric_cat[period].mean(dim='gcm') - ds_metric_cat['hist'].mean(dim='gcm'))/ds_metric_cat['hist'].mean(dim='gcm')
            unit='%'
        else:
            ds_change = ds_metric_cat[period].mean(dim='gcm') - ds_metric_cat['hist'].mean(dim='gcm')

        # statistical test for mean difference between two periods
        fut  = ds_metric_cat[period][flow_metric]
        pvalue = stats.ttest_ind(hist_test, fut, axis=0, equal_var=False).pvalue
         
        df_change = ds_change.to_dataframe()
        df_change.index.rename('location_name',inplace=True)
        df_change['pvalue'] = pvalue
        df_change_final = df_site_selected.merge(df_change, on="location_name", how = 'inner')

        if len(df_change_final[df_change_final['pvalue']>0.05])>0:
            df_change_final[df_change_final['pvalue']>0.05].plot(ax=ax[ix], column=flow_metric, marker='^', markersize=5, edgecolor="black", linewidth=0.1,cmap=cmap_diff, norm=norm_diff, legend=False, zorder=2,)
        if len(df_change_final[df_change_final['pvalue']<0.05])>0:  
            df_change_final[df_change_final['pvalue']<0.05].plot(ax=ax[ix], column=flow_metric, markersize=5, edgecolor="black", linewidth=0.1, cmap=cmap_diff, norm=norm_diff, legend=False, zorder=2,)
        ax[ix].set_title(f'{period}', fontsize=7)
        if ix==1:
            if percent_change:
                ax[1].text(0.0, 1.1, f'%change in {metric_name} [{unit}]', transform=ax[1].transAxes, fontsize=8.0)
            else:
                ax[1].text(0.0, 1.1, f'Change in {metric_name} [{unit}]', transform=ax[1].transAxes, fontsize=8.0)
    points = ax[ix].collections[-1]
    cbar = plt.colorbar(points, ax=ax[ix], **cbar_kwrgs);
    cbar.ax.tick_params(labelsize=6) 
    
if percent_change:
    fig.savefig(os.path.join(figure_path,f'Fig9_{flow_metric}_{ensemble_name}_ensemble_pcnt_change.png'), dpi=300)
else:
    fig.savefig(os.path.join(figure_path,f'Fig9_{flow_metric}_{ensemble_name}_ensemble_change.png'), dpi=300)

### Fig 10. Changes in a future return period of historical flow magnitude

In [None]:
ds_future_aff = AutoVivification()
for gcm in sim_names:
    for scen in ds_qsim_selected[gcm].keys():
        if scen=='hist':
            continue
        else:
            print(f'{gcm}-{scen}')
            ds1 = ds_qsim_selected[gcm][scen]['streamflow'].sel(time=slice('2049-10-01', '2099-09-30'))
            ds_future_aff[gcm][scen] = metrics.lp3_flood_return_period(ds1, ds_flow_metrics['aff'][gcm]['hist'].sel(return_period=20, method='nearest')['aff'].drop_vars('return_period'))
            ds_future_aff[gcm][scen] = 1.0/(1.0-ds_future_aff[gcm][scen])
            ds_future_aff[gcm][scen] = ds_future_aff[gcm][scen].rename({'cdf':'return_period'})

In [None]:
import matplotlib.ticker as ticker
vals0=np.arange(1,17,1)
cmap0 = plt.get_cmap('turbo', (len(vals0)))
#cmap0.set_under('cyan')
#cmap = mpl.colors.ListedColormap(mpl.cm.Spectral_r(np.arange(9)))
norm0 = mpl.colors.BoundaryNorm(vals0, cmap0.N)

ensemble_name = 'high-emission' 

# -------------------
print('Computing ensemble mean.....')
# -------------------
gcm_plots = {}
for gcm, meta in gcm_runs.items():
    if meta['cmip'] in ensembles[ensemble_name]['cmip']:
        gcm_plots[gcm] = [scen for scen in meta['scen'] if scen in ensembles[ensemble_name]['scen']]
print(gcm_plots)

ds_metric_cat = {}
count = 0 
for gcm_name, scen_list in gcm_plots.items():
    for scen in scen_list:
        if scen=='hist':
            continue
        if count==0:
            ds_metric_cat['2050-2099']  = ds_future_aff[gcm_name][scen]
        else:
            ds_metric_cat['2050-2099']  = xr.concat([ds_metric_cat['2050-2099'], ds_future_aff[gcm_name][scen]], "gcm")
        count+=1
        
# -------------------
print('plotting.....')
# -------------------
ncols = 1
nrows = 1
shrink=0.75
figsize=(4.5, 5.00)

fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}, dpi=150,)
fig.subplots_adjust(left=0.025, bottom=0.025, right=0.965, top=0.975, wspace=0.05, hspace=0.125)
cbar_kwrgs = {"orientation":"horizontal", "shrink":shrink, "pad":0.02, 'extend':'neither'}
base_map(ax,df_huc12)
df_metric = ds_metric_cat['2050-2099'].median(dim='gcm').to_dataframe()
df_metric.index.rename('location_name',inplace=True)
df_metric_final = df_site_selected.merge(df_metric, on="location_name", how = 'inner')
df_metric_final.plot(ax=ax, column='return_period', markersize=10, edgecolor="None", linewidth=0.1, cmap=cmap0, norm=norm0, legend=False, zorder=2,) #'YlGnBu_r'
#ax.set_title('Change in return period of 20-yr annual max. flow estimated during 1954-2004', fontsize='xx-small')
points = ax.collections[-1]
cbar = plt.colorbar(points, ax=ax, **cbar_kwrgs);
cbar.ax.tick_params(labelsize=7)

# this is an inset axes over the main axes
ins = ax.inset_axes([0.725,0.775,0.25,0.2])
df_metric.plot.hist(ax=ins, bins=range(1,17), legend=False)
ins.set_xlabel('future return period [yr]', fontsize=6)
ins.set_xlim([1,16])
ins.set_ylabel('# of sites', fontsize=6)
ins.xaxis.set_major_locator(ticker.MultipleLocator(5)) # set BIG ticks
ins.xaxis.set_minor_locator(ticker.MultipleLocator(1)) # set small ticks
ins.tick_params(labelsize=5)

fig.savefig(os.path.join(figure_path,f'Fig10_return_period_{ensemble_name}_change.png'), dpi=300)

### Fig 11. Historical and future empriacal annual flood frequency curves at selected site

In [None]:
site = 'MROW'

p = 1-1/ds_flow_metrics['aff'][gcm_name]['hist']['return_period'].values

obs_flow = ds_nrni_selected['streamflow'].sel(site=site, time=slice('1954-10-01','2004-09-30')).rolling(time=7).mean()
ann_max_obs = obs_flow.resample(time="YS").max().values
ann_max_obs_sort =  ann_max_obs[np.argsort(ann_max_obs)]
obs_prob=np.arange(1,float(len(ann_max_obs)+1))/(1+len(ann_max_obs))#*100  #probability
#plt.plot(obs_prob, ann_max_obs_sort, 'o', markersize=3, c='k')
#aff_obs = metrics.lp3_flood(obs_flow)
#plt.plot(p, aff_obs['aff'].values, ls=':', c='k', label='hist')

hist_ann_max_sort=np.ones((len(gcm_plots),50))*np.nan
fut_ann_max_sort=np.ones((len(gcm_plots),50))*np.nan
hist_sim_prob=np.arange(1,float(hist_ann_max_sort.shape[1]+1))/(1+hist_ann_max_sort.shape[1])#*100  #probability
fut_sim_prob=np.arange(1,float(fut_ann_max_sort.shape[1]+1))/(1+fut_ann_max_sort.shape[1])#*100  #probability

for ix, (gcm_name, scen_list) in enumerate(gcm_plots.items()):
    ds1 = ds_qsim_selected[gcm_name]['hist']['streamflow'].sel(site=site, time=slice('1954-10-01', '2004-09-30')).rolling(time=7).mean()
    ann_max = ds1.resample(time="YS-OCT").max().values
    hist_ann_max_sort[ix,:] =  ann_max[np.argsort(ann_max)]
    
    if 'ssp585' in scen_list:
        ds1 = ds_qsim_selected[gcm_name]['ssp585']['streamflow'].sel(site=site, time=slice('2049-10-01', '2099-09-30')).rolling(time=7).mean()     
        ann_max = ds1.resample(time="YS-OCT").max().values
        fut_ann_max_sort[ix,:] = ann_max[np.argsort(ann_max)]
    else:
        ds1 = ds_qsim_selected[gcm_name]['rcp85']['streamflow'].sel(site=site, time=slice('2049-10-01', '2099-09-30')).rolling(time=7).mean()     
        ann_max = ds1.resample(time="YS-OCT").max().values
        fut_ann_max_sort[ix,:] = ann_max[np.argsort(ann_max)]

fig, ax = plt.subplots(1, figsize=(4.5, 3.5), dpi=100,)
ax.plot(1/(1-hist_sim_prob), np.median(hist_ann_max_sort, axis=0), 'o-', markersize=3, color='b')
ax.plot(1/(1-fut_sim_prob), np.median(fut_ann_max_sort, axis=0), 'o-', markersize=3, color='r')
ax.fill_between(1/(1-hist_sim_prob), np.percentile(hist_ann_max_sort, 25, axis=0), np.percentile(hist_ann_max_sort, 75, axis=0), alpha=0.3, color='b', label='WY1955-2004')
ax.fill_between(1/(1-fut_sim_prob), np.percentile(fut_ann_max_sort, 25, axis=0), np.percentile(fut_ann_max_sort, 75, axis=0), alpha=0.3, color='r', label='WY2050-2099')

ax.set_xscale('log')
#plt.grid(False, linestyle='--', linewidth=0.7, alpha=0.6)
plt.xticks([1,10,50],[1,10,50])
plt.xlim([0.95,55])
ax.set_yscale('linear')
ax.set_xlabel('return period [yr]')
ax.set_ylabel('Annual maximum flow [m3/s]')
plt.legend()
plt.tight_layout()
fig.savefig(os.path.join(figure_path, f'Fig11_return_period_{ensemble_name}_{site}.png'), dpi=300)

## Fig 12. Annual time seriese plots for 150 years at a specified site

In [None]:
%matplotlib inline
# plots of simulations based on different bias corrected CanESM SUMMA forcing (bc1: 
site='DWR' #MROW

fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(6.5, 2.5), dpi=150)
ds_nrni_selected['streamflow'].sel(site=site, time=periods['control']['time']).resample(time="YS-OCT").mean().plot(ax=axs, c='k',ls='--',lw=1.5)

slope_ave = 0
intsect_ave = 0
count = 0
for case, meta in retro_runs.items():
    c='k'
    lw=1
    ds_qsim_selected[case]['hist']['streamflow'].sel(site=site, time=periods['control']['time']).resample(time="YS-OCT").mean().plot(ax=axs, lw=lw, color=c, label=case, zorder=0)
    
for case, meta in gcm_runs.items():
    for scen in meta['scen']:
        if scen=='hist':
            c='xkcd:gray'; lw=0.15            
        elif scen=='ssp245' or scen=='rcp45':
            c='#005AB5'; lw=0.15
        elif scen=='ssp370':
            continue
            c='xkcd:green'; lw=0.15
        elif scen=='ssp585' or scen=='rcp85':
            c='#DC3220'; lw=0.15
            
        if scen == 'hist':
            ds_qsim_selected[case][scen]['streamflow'].sel(site=site, time=periods['control']['time']).resample(time="YS-OCT").mean().plot(ax=axs, lw=lw, color=c, label=case, zorder=0)
        else:
            count +=1
            dr_ann_q = ds_qsim_selected[case][scen]['streamflow'].sel(site=site).resample(time="YS-OCT").mean()
            dr_ann_q.plot(ax=axs, lw=lw, color=c, label=case)
            years = np.arange(len(dr_ann_q.time))
            slope, intsect = np.polyfit(years, dr_ann_q.values, 1)
            slope_ave += slope
            intsect_ave += intsect
slope_ave /= count
intsect_ave /= count
trend = xr.DataArray(
    slope_ave * years + intsect_ave,
    coords={"time": dr_ann_q.time},
    dims=["time"],
    )
trend.plot(ax=axs, lw=1, color='k', ls=':')

axs.set_title('')
axs.set_xlabel('');
axs.set_xlim(pd.Timestamp('1980-01-01'), pd.Timestamp('2100-01-01'));
axs.set_ylabel('Annual mean flow [m3/s]');   

#plt.yscale('log')
#plt.legend();
plt.tight_layout()
plt.savefig(os.path.join(figure_path, f'Fig12_annual_series_{site}.png'), dpi=300)

## Fig 13. annual cycle plots at a specified site

In [None]:
# --- setup
site='MROW'
ensemble_name = "low_high-emission"
skip_period = ['2040s']  # control, 2040s, 2080s, 
# ---

gcm_plots = {gcm: list(set(meta['scen']).intersection(ensembles[ensemble_name]['scen'])) for gcm, meta in gcm_runs.items() if meta['cmip'] in ensembles[ensemble_name]['cmip'] }
num_gcm = len(gcm_plots)


emission_scen = {'low emission': ['ssp245','rcp45'], 'high emission': ['ssp585','rcp85']}
emission_color = {'low emission': 'blue', 'high emission': 'red'}

fig, axs = plt.subplots(1,1, figsize=(6.5, 2.5), dpi=150)

zorder=6
for period, _ in periods.items():
    
    # skip periods
    if period in skip_period:
        continue
        
    if period=='control':

        obs_flow = ds_seasonal_flow['obs'].sel(site=site).roll(dayofyear=92, roll_coords=False)#.plot(c='k',ls='--',lw=1.0)
        plt.plot(np.arange(0,366), obs_flow, linestyle='--', linewidth=1.0, color='black', label='obs', zorder=6)
            
        for ix, retro in enumerate(retro_names):
            sim_flow = ds_seasonal_flow[retro]['hist']['control'].sel(site=site).roll(dayofyear=92, roll_coords=False).values
            plt.plot(np.arange(0,366), sim_flow, linestyle='-', linewidth=1.0, color='k', label='%s'%(retro), zorder=zorder)
        
        for ix, gcm in enumerate(list(gcm_plots.keys())):
                
            sim_flow = ds_seasonal_flow[gcm]['hist']['control'].sel(site=site).roll(dayofyear=92, roll_coords=False).values
            if ix==0:
                sim_flow_min = sim_flow
                sim_flow_max = sim_flow
            else:
                sim_flow_max = np.maximum(sim_flow_max, sim_flow)
                sim_flow_min = np.minimum(sim_flow_min, sim_flow)
        zorder -=1        
        plt.fill_between(np.arange(0,366), sim_flow_min, sim_flow_max, alpha=0.4, color='grey', label=period, zorder=zorder)

    else:
        for emission, scen_list in emission_scen.items():
            ix = 0
            for gcm, scen_list1 in gcm_plots.items():
                for scen in scen_list1:
                    if scen in scen_list:
                        sim_flow = ds_seasonal_flow[gcm][scen][period].sel(site=site).roll(dayofyear=92, roll_coords=False).values
                        if num_gcm==1:
                            plt.plot(np.arange(0,366), sim_flow, linestyle='--', lw=0.75, label='%s'%(scen))
                        else:
                            if ix==0:
                                sim_flow_min = sim_flow
                                sim_flow_max = sim_flow
                            else:
                                sim_flow_max = np.maximum(sim_flow_max, sim_flow)
                                sim_flow_min = np.minimum(sim_flow_min, sim_flow)
                        ix+=1
            plt.fill_between(np.arange(0,366), sim_flow_min, sim_flow_max, alpha=0.3, color=emission_color[emission], label=emission, zorder=zorder)    
        zorder-=1
    
axs.set_title('')
axs.set_xlabel('');
axs.set_ylabel('Daily discharge [m3/s]');
axs.set_xlim([0,367]);
axs.set_xticks([1,32,62,93,124,152,183,213,244,274,305,336]); 
axs.set_xticklabels(['10/1','11/1','12/1','1/1','2/1','3/1','4/1','5/1','6/1','7/1','8/1','9/1'])
plt.legend(fontsize='x-small')
plt.title(f'{site}')
plt.tight_layout()
plt.savefig(os.path.join(figure_path, f'Fig13_seasonal_flow_ensemble_{site}.png'), dpi=300)

## Plot of Flow duration curve at each site
!!! This create hundreds of plot for indiviual sites !!!

#### 1. One ESM with one scenario (3 lines for control, 2040s, 2080s) and Obs (a dash line)

In [None]:
%matplotlib auto
plt.rcParams.update({'figure.max_open_warning': 0})

# --- setup
gcm_plot = 'CanESM5'
scen_plot = 'ssp245'
site_plot = ds_nrni_selected['site'].values # limit the sites of interest
skip_period = []
# ---

for site in site_plot:
    
    fig, maiax = plt.subplots(1, figsize = (7, 5))
    maxflow = 0
    minflow = 10**6
    
    obs_prob=np.arange(1,float(len(ds_nrni_selected.sel(time=periods['control']['time'])['time'])+1))/(1+len(ds_nrni_selected.sel(time=periods['control']['time'])['time']))  #probability 
    obs_flow = ds_nrni_selected['streamflow'].sel(site=site, time=periods['control']['time']).values
    #obs_flow = np.where(obs_flow<0.0, 1.0e-7,obs_flow)
    obs_flow_sort =  obs_flow[np.argsort(obs_flow)[::-1]]
    maiax.plot(obs_prob, obs_flow_sort, ls='--', lw=2.0, color='black', label='obs')

    imax = np.where(obs_prob<0.9995)
    maxflow = max(maxflow, max(obs_flow_sort[imax])) 
    minflow = min(minflow, min(obs_flow_sort[imax]))
    
    for period, _ in periods.items():
        
        if period in skip_period:
                continue
            
        if period=='control':
            scen_list=['hist']
        else:
            scen_list=[scen_plot]
            
        for scen in scen_list:
            if scen=='hist' and period!='control':
                continue
                
            sim_flow = ds_qsim_selected[gcm_plot][scen].sel(time=periods[period]['time'], site=site)['streamflow'].values
            sim_flow_sort =  sim_flow[np.argsort(sim_flow)[::-1]]
                
            sim_prob = np.arange(1,float(len(sim_flow)+1))/(1+len(sim_flow)) #probability        

            imax = np.where(sim_prob<0.9995)
            maxflow = max(maxflow, max(sim_flow_sort[imax])) 
            minflow = min(minflow, min(sim_flow_sort[imax]))            
            maiax.plot(sim_prob, sim_flow_sort, ls='-', lw=2.0, 
                    c=periods[period]['lc'], label='%s-%s'%(period, scen))

    maiax.set_xscale('ppf')
    maiax.set_yscale('log')
    maiax.set_xlim([0.007, 0.9995])
    maiax.set_ylim([minflow*0.95, maxflow*1.05])
   #maiax.set_ylim([-10, maxflow*1.05])    
    maiax.set_ylabel('Discharge [m3/s]'); maiax.set_xlabel('Non exceedance probability [-]')
    maiax.legend(bbox_to_anchor=(0.975,0.975), loc="upper right")
    maiax.set_title(f'{gcm_plot} at {site}')
    
    fig.savefig(os.path.join(figure_path, 'per_site',f'FDC_{gcm_plot}_{scen_plot}_{site}.png'), dpi=100)
    break

#### 2. ESM ensemble - 3 filled bands for control, 2040s, 2080s) and obs (a dash line)

In [None]:
%matplotlib auto
plt.rcParams.update({'figure.max_open_warning': 0})

# --- setup
ensemble_name = "cmip6"
site_plot = ds_nrni_selected['site'].values # limit the sites of interest
skip_period = []
# ---

gcm_plots = {gcm: meta['scen'] for gcm, meta in gcm_runs.items() if meta['cmip'] in ensembles[ensemble_name]['cmip'] }
num_gcm = len(gcm_plots)

for site in site_plot:
    
    fig, maiax = plt.subplots(1, figsize = (7, 5))
    maxflow = 0
    minflow = 10**6
    obs_prob=np.arange(1,float(len(ds_nrni_selected.sel(time=periods['control']['time'])['time'])+1))/(1+len(ds_nrni_selected.sel(time=periods['control']['time'])['time']))  #probability 
    obs_flow = ds_nrni_selected['streamflow'].sel(site=site, time=periods['control']['time']).values
    #obs_flow = np.where(obs_flow<0.0, 1.0e-7,obs_flow)
    obs_flow_sort =  obs_flow[np.argsort(obs_flow)[::-1]]
    
    imax = np.where(obs_prob<0.9995)
    maxflow = max(maxflow, max(obs_flow_sort[imax]))     
    minflow = min(minflow, min(obs_flow_sort[imax]))
    maiax.plot(obs_prob, obs_flow_sort, linestyle='--', linewidth=1.5, color='black', label='obs', zorder=0)

    zorder=6
    for period, _ in periods.items(): # go through period from hist, early, mid, late
        
        # skip periods
        if period in skip_period:
            continue
            
        if period=='control':
            
            for ix, retro in enumerate(retro_names):
                sim_flow = ds_qsim_selected[retro]['hist']['streamflow'].sel(site=site, time=periods['control']['time']).values
                sim_flow_sort = sim_flow[np.argsort(sim_flow)[::-1]]
                sim_prob = np.arange(1,float(len(sim_flow)+1))/(1+len(sim_flow)) #probability
 
                plt.plot(sim_prob, sim_flow_sort, linestyle='-', linewidth=1.0, color='k', label='%s'%(retro), zorder=zorder)
                
            for ix, gcm in enumerate(list(gcm_plots.keys())):
                
                sim_flow = ds_qsim_selected[gcm]['hist'].sel(time=periods[period]['time'], site=site)['streamflow'].values
                sim_flow_sort = sim_flow[np.argsort(sim_flow)[::-1]]
                sim_prob = np.arange(1,float(len(sim_flow)+1))/(1+len(sim_flow)) #probability
                
                imax = np.where(sim_prob<0.9995)
                maxflow = max(maxflow, max(sim_flow_sort[imax])) 
                minflow = min(minflow, min(sim_flow_sort[imax]))
                if ix==0:
                    sim_flow_sort_min = sim_flow_sort
                    sim_flow_sort_max = sim_flow_sort
                else:
                    sim_flow_sort_max = np.maximum(sim_flow_sort_max, sim_flow_sort)
                    sim_flow_sort_min = np.minimum(sim_flow_sort_min, sim_flow_sort)

            plt.fill_between(sim_prob, sim_flow_sort_min, sim_flow_sort_max, alpha=0.4, color=periods[period]['lc'], label=f'{period}', zorder=zorder)
        
        else:
            ix = 0
            for gcm, scen_list in gcm_plots.items():
                for scen in scen_list:
               
                    if scen=='hist': # history period is already plotted
                        continue
                    if scen not in gcm_runs[gcm]['scen']:
                        continue
                
                    sim_flow = ds_qsim_selected[gcm][scen].sel(time=periods[period]['time'], site=site)['streamflow'].values
                    sim_flow_sort =  sim_flow[np.argsort(sim_flow)[::-1]]
                
                    sim_prob = np.arange(1,float(len(sim_flow)+1))/(1+len(sim_flow)) #probability
                    
                    imax = np.where(sim_prob<0.9995)
                    maxflow = max(maxflow, max(sim_flow_sort[imax]))
                    minflow = min(minflow, min(sim_flow_sort[imax]))
                    if ix==0:
                        sim_flow_sort_min = sim_flow_sort
                        sim_flow_sort_max = sim_flow_sort
                    else:
                        sim_flow_sort_max = np.maximum(sim_flow_sort_max, sim_flow_sort)
                        sim_flow_sort_min = np.minimum(sim_flow_sort_min, sim_flow_sort)
                    
                    ix+=1

            plt.fill_between(sim_prob, sim_flow_sort_min, sim_flow_sort_max, alpha=0.4, color=periods[period]['lc'], label=f'{period}', zorder=zorder)             

        zorder-=1
    maiax.tick_params(axis='x', which='major', labelsize=8)    
    maiax.set_xscale('ppf')
    maiax.set_yscale('log')
    maiax.set_xlim([0.003, 0.991])
    maiax.set_ylim([minflow*0.95, maxflow*1.05])
    maiax.set_ylabel('Discharge [m3/s]'); maiax.set_xlabel('Exceedance probability [-]')
    maiax.legend(bbox_to_anchor=(0.975,0.975), loc="upper right")
    maiax.set_title(f'ESM ensembles at {site}')
    
    fig.savefig(os.path.join(figure_path, 'per_site', 'FDC_%s_%s.png'%(ensemble_name, site)), dpi=200)
    break

## Monthly seasonal cycle at each site 

#### ESM ensemble - 3 filled bands for control, 2040s, 2080s) and obs (a dash line)

In [None]:
%matplotlib auto
plt.rcParams.update({'figure.max_open_warning': 0})

# --- setup
ensemble_name = "cmip6-ssp370"
site_plot = ds_nrni_selected['site'].values # limit the sites of interest
skip_period = []  # control, 2040s, 2080s, 
# ---

gcm_plots = {gcm:  set(meta['scen']).intersection(ensembles[ensemble_name]['scen']) for gcm, meta in gcm_runs.items() if meta['cmip'] in ensembles[ensemble_name]['cmip'] }
month = ['Oct','Nov','Dec','Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep']
num_gcm = len(gcm_plots)

ls = {'retro':'-',
      'control':'-',
      '2040s':'--',
      '2080s':'--',
     }

for site in site_plot:
    plt.figure(figsize = (7, 5))
    
    zorder=6
    for period, _ in periods.items():
        
        # skip periods
        if period in skip_period:
            continue
            
        if period=='control':
                
            obs_flow = ds_nrni_selected['streamflow'].sel(site=site, time=periods['control']['time']).groupby("time.month").mean(dim="time").roll(month=3, roll_coords=False).values
            plt.plot(month, obs_flow, linestyle='--', linewidth=1.0, color='black', label='obs', zorder=6)
                
            for ix, retro in enumerate(retro_names):
                sim_flow = ds_qsim_selected[retro]['hist']['streamflow'].sel(site=site, time=periods['control']['time']).groupby("time.month").mean(dim="time").roll(month=3, roll_coords=False).values
                plt.plot(month, sim_flow, linestyle='-', linewidth=1.0, color='k', label='%s'%(retro), zorder=zorder)
            
            for ix, gcm in enumerate(list(gcm_plots.keys())):
                    
                sim_flow = ds_qsim_selected[gcm]['hist']['streamflow'].sel(time=periods[period]['time'], site=site).groupby("time.month").mean(dim="time").roll(month=3, roll_coords=False).values
                if ix==0:
                    sim_flow_min = sim_flow
                    sim_flow_max = sim_flow
                else:
                    sim_flow_max = np.maximum(sim_flow_max, sim_flow)
                    sim_flow_min = np.minimum(sim_flow_min, sim_flow)
            zorder -=1        
            plt.fill_between(month, sim_flow_min, sim_flow_max, alpha=0.5, color=periods[period]['lc'], label=period, zorder=zorder)

        else:
            ix = 0
            for gcm, scen_list in gcm_plots.items():
                for scen in scen_list:
                     
                    if scen=='hist': # history period is already plotted
                        continue
                    if scen not in gcm_runs[gcm]['scen']:
                        continue
                        
                    sim_flow = ds_qsim_selected[gcm][scen]['streamflow'].sel(time=periods[period]['time'], site=site).groupby("time.month").mean(dim="time").roll(month=3, roll_coords=False).values
                    if num_gcm==1:
                        plt.plot(month, sim_flow, linestyle=ls[period], lw=0.75, label='%s'%(scen))
                    else:
                        if ix==0:
                            sim_flow_min = sim_flow
                            sim_flow_max = sim_flow
                        else:
                            sim_flow_max = np.maximum(sim_flow_max, sim_flow)
                            sim_flow_min = np.minimum(sim_flow_min, sim_flow)
                    ix+=1
                            
            plt.fill_between(month, sim_flow_min, sim_flow_max, alpha=0.5, color=periods[period]['lc'], label=period, zorder=zorder)    
        zorder-=1
        
    plt.ylabel('Discharge [m3/s]'); plt.xlabel('Month')
    plt.legend()
    plt.title(f'{site}')
    plt.savefig(os.path.join(figure_path, 'per_site','Monthly_cycle_%s_%s.png'%(ensemble_name, site)), dpi=200)
    break

## Annual flood frequency

#### ESM ensemble - 3 filled bands for control, 2040s, 2080s) and obs (a dash line)

In [None]:
### annual flood frequency
%matplotlib auto
plt.rcParams.update({'figure.max_open_warning': 0})

# --- setup
ensemble_name = "high-emission"
site_plot = ds_nrni_selected['site'].values # limit the sites of interest
skip_period = []
# ---

gcm_plots = {gcm: meta['scen'] for gcm, meta in gcm_runs.items() if meta['cmip'] in ensembles[ensemble_name]['cmip'] }
num_gcm = len(gcm_plots)

for site in site_plot:

    fig, maiax = plt.subplots(1, figsize = (7, 5))
    maxflow = 0
    minflow = 1000000
    
    # skip periods
    if period in skip_period:
        continue
        
    period='control'
    scen = 'hist'
    for ix, gcm in enumerate(list(gcm_plots.keys())):
        if ix==0: #plot obs
            obs_flow = ds_nrni_selected['streamflow'].sel(site=site, time=slice('1954-10-01','2004-09-30'))
            if np.all(~np.isnan(obs_flow)) and np.all(obs_flow!=0):
                obs_flow = obs_flow.rolling(time=7).mean()
                ann_max_obs = obs_flow.resample(time="YS-OCT").max().values
                ann_max_obs_sort =  ann_max_obs[np.argsort(ann_max_obs)]
                obs_prob=np.arange(1,float(len(ann_max_obs)+1))/(1+len(ann_max_obs))#*100  #probability 
                return_period = 1/(1-obs_prob)
                plt.plot(return_period, ann_max_obs_sort, linestyle='--', linewidth=2.0, color='black', label='obs', zorder=0)
                
                imax = np.where((obs_prob<=0.98) & (obs_prob>=0.05))
                maxflow = max(maxflow, max(ann_max_obs_sort[imax]))
                minflow = min(minflow, min(ann_max_obs_sort[imax]))
            
        sim_flow = ds_qsim_selected[gcm][scen]['streamflow'].sel(site=site, time=slice('1954-10-01','2004-09-30')).rolling(time=7).mean()
        ann_max_sim = sim_flow.resample(time="YS-OCT").max().values # annual maximum series
        ann_max_sim_sort = ann_max_sim[np.argsort(ann_max_sim)]
        prob=np.arange(1,float(len(ann_max_sim)+1))/(1+len(ann_max_sim))#*100  #probability 
        return_period = 1/(1-prob)

        imax = np.where((prob<=0.98) & (prob>=0.05))
        maxflow = max(maxflow, max(ann_max_sim_sort[imax]))
        minflow = min(minflow, min(ann_max_sim_sort[imax]))
        
        if ix==0:
            sim_flow_sort_min = ann_max_sim_sort
            sim_flow_sort_max = ann_max_sim_sort
        else:
            sim_flow_sort_max = np.maximum(sim_flow_sort_max, ann_max_sim_sort)
            sim_flow_sort_min = np.minimum(sim_flow_sort_min, ann_max_sim_sort)

    plt.fill_between(return_period, sim_flow_sort_min, sim_flow_sort_max, alpha=0.45, color=periods[period]['lc'], label=f'1955-2004', zorder=5)

    ix = 0
    period='2080s'
    for gcm, scen_list in gcm_plots.items():
        for scen in scen_list:
            
            if scen=='hist': # history period is already plotted
                continue
            if scen not in gcm_runs[gcm]['scen']:
                continue
                
            sim_flow = ds_qsim_selected[gcm][scen]['streamflow'].sel(site=site, time=slice('2049-10-01','2099-09-30')).rolling(time=7).mean()
            ann_max_sim = sim_flow.resample(time="YS-OCT").max().values
            ann_max_sim_sort = ann_max_sim[np.argsort(ann_max_sim)]                
            prob=np.arange(1,float(len(ann_max_sim)+1))/(1+len(ann_max_sim))#*100  #probability 
            return_period = 1/(1-prob)
            
            imax = np.where((prob<=0.98) & (prob>=0.05))
            maxflow = max(maxflow, max(ann_max_sim_sort[imax]))
            minflow = min(minflow, min(ann_max_sim_sort[imax]))
            
            if ix==0:
                sim_flow_sort_min = ann_max_sim_sort
                sim_flow_sort_max = ann_max_sim_sort
            else:
                sim_flow_sort_max = np.maximum(sim_flow_sort_max, ann_max_sim_sort)
                sim_flow_sort_min = np.minimum(sim_flow_sort_min, ann_max_sim_sort)
            
            ix+=1
        
    plt.fill_between(return_period, sim_flow_sort_min, sim_flow_sort_max, alpha=0.45, color=periods[period]['lc'], label=f'2049-2099')

    # Add the patch to the Axes
    maiax.set_xscale('log') # ppf
    maiax.set_yscale('linear')
    plt.xticks([1,10,50],[1,10,50])
    plt.xlim([0.95,55])
    maiax.set_ylabel('Discharge [m3/s]'); maiax.set_xlabel('Return period [yr]')
    maiax.legend(bbox_to_anchor=(0.01,0.975), loc="upper left")
    maiax.set_title(f'ESM ensembles at {site}')
    fig.savefig(os.path.join(figure_path, 'per_site', f'AFFC_{ensemble_name}_{site}.png'), dpi=200)
    break