In [1]:
import numpy as np
import xarray as xr
import dask
import os
from glob import glob

import xclim
xclim.set_options(cf_compliance="log");

### Preliminaries

In [2]:
###############################
# Set paths
# UPDATE THIS FOR REPRODUCTION
###############################
in_path = '/gpfs/group/kaf26/default/public/NEX-GDDP-CMIP6/models/'
out_path = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/metrics/nex-gddp/'
quantile_path = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/quantiles/'

In [3]:
###################
# Models
###################
from utils import nex_ssp_dict

models = list(nex_ssp_dict.keys())

In [4]:
###################
# Model details
###################
model_info = {}
for model in models:
    tmp = glob(in_path + model + '/ssp126/tasmax/*_2015.nc')
    tmp = tmp[0].replace(in_path + model, '').replace('/ssp126/tasmax/tasmax_day_' + model + '_ssp126', '').replace('2015.nc', '')
    model_info.update({model: tmp})

In [5]:
############
# Dask
############
from dask_jobqueue import PBSCluster
cluster = PBSCluster(cores=1, resource_spec = 'pmem=30GB', memory='30GB',
                     worker_extra_args= ['#PBS -l feature=rhel7'],
                     walltime = '15:00:00')

cluster.scale(jobs=55)  # ask for jobs

from dask.distributed import Client
client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.102.201.235:38223,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


## Simple metrics (no historical quantiles required)

In [None]:
########################################################
# Calculate the metric for a 
# single model-year, including all SSPs and variables
########################################################
def model_year_metric(path, model, model_vers, ssps, var_ids, year, metric):
    # Function for longest consecutive spell if needed
    def n_longest_consecutive(ds, dim='time'):
        ds = ds.cumsum(dim=dim) - ds.cumsum(dim=dim).where(ds == 0).ffill(dim=dim).fillna(0)
        return ds.max(dim=dim)

    # Set up dictionary for all results
    ds_all = {}
    # Loop through SSPs
    for ssp in ssps:
        # Temporary list for each SSP
        ds_list = []
        # Loop through variables
        for var in var_ids:
            ## Temporary file for each variable
            ds_tmp = xr.open_dataset(path + model + '/' + ssp + '/' +
                                     var + '/' + var + '_day_' + model + 
                                     '_' + ssp + model_vers + str(year) + '.nc')
            
            ## Convert units
            # temperature: K -> C
            if var == 'tas' and ds_tmp.tas.attrs['units'] == 'K':
                ds_tmp['tas'] = ds_tmp['tas'] - 273.15
            if var == 'tasmax' and ds_tmp.tasmax.attrs['units'] == 'K':
                ds_tmp['tasmax'] = ds_tmp['tasmax'] - 273.15
            if var == 'tasmin' and ds_tmp.tasmin.attrs['units'] == 'K':
                ds_tmp['tasmin'] = ds_tmp['tasmin'] - 273.15

            # precip: kg m-2 s-1 -> mm day-1
            if var == 'pr' and ds_tmp.pr.attrs['units'] == 'kg m-2 s-1':
                ds_tmp['pr'] = ds_tmp['pr'] * 86400
                ds_tmp.pr.attrs['units'] = 'mm/day'
                
            # Calculate metric
            if metric == 'avg':
                ds_tmp = ds_tmp.resample(time='1Y').mean()
            elif metric == 'max':
                ds_tmp = ds_tmp.resample(time='1Y').max()
            elif metric == 'dry':
                # Number of dry days
                ds_tmp_0 = (ds_tmp == 0.).resample(time='1Y').sum() # 0mm
                ds_tmp_1 = (ds_tmp < 1.).resample(time='1Y').sum() # less than 1mm
                # Longest sonsecutive dry day streak
                ds_tmp_0c = (ds_tmp == 0.).resample(time='1Y').apply(n_longest_consecutive) # 0mm longest consecutive
                ds_tmp_1c = (ds_tmp < 1.).resample(time='1Y').apply(n_longest_consecutive) # less than 1mm longest consecutive
                # Merge
                ds_tmp = xr.merge([ds_tmp_0.rename({'pr':'count_eq_0'}),
                                   ds_tmp_0c.rename({'pr':'streak_eq_0'}),
                                   ds_tmp_1.rename({'pr':'count_lt_1'}),
                                   ds_tmp_1c.rename({'pr':'streak_lt_1'})])
            elif metric == 'max5d':
                ds_tmp = xclim.indicators.icclim.RX5day(ds=ds_tmp, freq='Y')
                ds_tmp = xr.Dataset({'RX5day':ds_tmp})
                
            # Append to list
            ds_list.append(ds_tmp)
            
        # Append to dict
        ds_all.update({ssp: ds_list})

    # Merge and concat along ssp dimension
    for ssp in ssps:
        ds_all[ssp] = xr.merge(ds_all[ssp])
        ds_all[ssp] = ds_all[ssp].assign_coords(ssp = ssp)
    
    # Return
    ds_out = xr.concat([ds_all[ssp] for ssp in ssps], dim='ssp')
    return ds_out

### Annual averages

In [7]:
# Loop through models: RUNTIME IS ~15 MINS PER MODEL WITH 30 DASK WORKERS
metric = 'avg'

# All variables
var_ids = ['tas', 'tasmin', 'tasmax', 'pr']

for model in models:
    # Check if already exists
    if os.path.isfile(out_path + metric + '/' + model + '.nc'):
        print(model + ' already done')
        continue
    
    # Parallelize with dask over years
    delayed_res = []

    for year in range(2015, 2101):
        tmp_res = dask.delayed(model_year_metric)(path = in_path,
                                                  model = model,
                                                  model_vers = model_info[model],
                                                  ssps = nex_ssp_dict[model],
                                                  var_ids = var_ids,
                                                  year = year, 
                                                  metric = metric)
        delayed_res.append(tmp_res)
            
    # Compute
    res = dask.compute(*delayed_res)

    # Store
    df_final = xr.concat(res, dim='time')
    df_final.to_netcdf(out_path + metric + '/' + model + '.nc')

    print(model)

ACCESS-ESM1-5 already done
BCC-CSM2-MR already done
CanESM5 already done
CMCC-ESM2 already done
CNRM-CM6-1 already done
CNRM-ESM2-1 already done
EC-Earth3 already done
EC-Earth3-Veg-LR already done
GFDL-ESM4 already done
HadGEM3-GC31-LL already done
INM-CM4-8 already done
INM-CM5-0 already done
IPSL-CM6A-LR already done
MIROC-ES2L already done
MIROC6 already done
MPI-ESM1-2-HR
MPI-ESM1-2-LR already done
MRI-ESM2-0
NESM3
NorESM2-LM already done
NorESM2-MM already done
UKESM1-0-LL already done


### 1-day max

In [7]:
# Loop through models: RUNTIME IS ~10 MINS PER MODEL WITH 30 DASK WORKERS
metric = 'max'

# All variables
var_ids = ['tas', 'tasmin', 'tasmax', 'pr']

for model in models:
    # Check if already exists
    if os.path.isfile(out_path + metric + '/' + model + '.nc'):
        print(model + ' already done')
        continue
    
    # Parallelize with dask over years
    delayed_res = []

    for year in range(2015, 2101):
        tmp_res = dask.delayed(model_year_metric)(path = in_path,
                                                  model = model,
                                                  model_vers = model_info[model],
                                                  ssps = nex_ssp_dict[model],
                                                  var_ids = var_ids,
                                                  year = year, 
                                                  metric = metric)
        delayed_res.append(tmp_res)
            
    # Compute
    res = dask.compute(*delayed_res)

    # Store
    df_final = xr.concat(res, dim='time')
    df_final.to_netcdf(out_path + metric + '/' + model + '.nc')

    print(model)

ACCESS-ESM1-5 already done
BCC-CSM2-MR already done
CanESM5 already done
CMCC-ESM2 already done
CNRM-CM6-1 already done
CNRM-ESM2-1 already done
EC-Earth3 already done
EC-Earth3-Veg-LR already done
GFDL-ESM4 already done
HadGEM3-GC31-LL already done
INM-CM4-8 already done
INM-CM5-0 already done
IPSL-CM6A-LR already done
MIROC-ES2L already done
MIROC6 already done
MPI-ESM1-2-HR
MPI-ESM1-2-LR already done
MRI-ESM2-0
NESM3
NorESM2-LM already done
NorESM2-MM already done
UKESM1-0-LL already done


### 5-day max (pr)

In [14]:
# Loop through models: RUNTIME IS ~5 MINS PER MODEL WITH 40 DASK WORKERS
metric = 'max5d'

# Precip only
var_ids = ['pr']

for model in models:
    # Check if already exists
    if os.path.isfile(out_path + metric + '/' + model + '.nc'):
        print(model + ' already done')
        continue
    
    # Parallelize with dask over years
    delayed_res = []

    for year in range(2015, 2101):
        tmp_res = dask.delayed(model_year_metric)(path = in_path,
                                                  model = model,
                                                  model_vers = model_info[model],
                                                  ssps = nex_ssp_dict[model],
                                                  var_ids = var_ids,
                                                  year = year, 
                                                  metric = metric)
        delayed_res.append(tmp_res)
            
    # Compute
    res = dask.compute(*delayed_res)

    # Store
    df_final = xr.concat(res, dim='time')
    df_final.to_netcdf(out_path + metric + '/' + model + '.nc')

    print(model)

ACCESS-ESM1-5
BCC-CSM2-MR
CanESM5
CMCC-ESM2
CNRM-CM6-1
CNRM-ESM2-1
EC-Earth3
EC-Earth3-Veg-LR
GFDL-ESM4
HadGEM3-GC31-LL
INM-CM4-8
INM-CM5-0
IPSL-CM6A-LR
MIROC-ES2L
MIROC6
MPI-ESM1-2-HR
MPI-ESM1-2-LR
MRI-ESM2-0
NESM3
NorESM2-LM
NorESM2-MM
UKESM1-0-LL


### Dry days

In [7]:
# Loop through models: RUNTIME IS ~15 MINS PER MODEL WITH 30 DASK WORKERS
metric = 'dry'

# Precip only
var_ids = ['pr']

for model in models:
    # Check if already exists
    if os.path.isfile(out_path + metric + '/' + model + '.nc'):
        print(model + ' already done')
        continue
    
    # Parallelize with dask over years
    delayed_res = []

    for year in range(2015, 2101):
        tmp_res = dask.delayed(model_year_metric)(path = in_path,
                                                  model = model,
                                                  model_vers = model_info[model],
                                                  ssps = nex_ssp_dict[model],
                                                  var_ids = var_ids,
                                                  year = year, 
                                                  metric = metric)
        delayed_res.append(tmp_res)
            
    # Compute
    res = dask.compute(*delayed_res)

    # Store
    df_final = xr.concat(res, dim='time')
    df_final.to_netcdf(out_path + metric + '/' + model + '.nc')

    print(model)

ACCESS-ESM1-5 already done
BCC-CSM2-MR already done
CanESM5 already done
CMCC-ESM2 already done
CNRM-CM6-1 already done
CNRM-ESM2-1 already done
EC-Earth3 already done
EC-Earth3-Veg-LR already done
GFDL-ESM4 already done
HadGEM3-GC31-LL already done
INM-CM4-8 already done
INM-CM5-0 already done
IPSL-CM6A-LR already done
MIROC-ES2L already done
MIROC6 already done
MPI-ESM1-2-HR already done
MPI-ESM1-2-LR already done
MRI-ESM2-0 already done
NESM3 already done
NorESM2-LM already done
NorESM2-MM already done
UKESM1-0-LL already done


## Less simple metrics (historical quantiles required)

In [6]:
def model_year_ssp_metric(model_path, quantile_path, model, model_vers, ssp, var_id, year, obs):
    """
    Reads NEX-GDDP model output for a given ssp-year and calculates the number of hot/wet days 
    and the longest consecutive hot/wet day streak. This function will be wrapped in dask 
    distributed.
    """
    
    # Subfunction to calculate longest consecutive spell
    def n_longest_consecutive(ds, dim='time'):
        ds = ds.cumsum(dim=dim) - ds.cumsum(dim=dim).where(ds == 0).ffill(dim=dim).fillna(0)
        return ds.max(dim=dim)
    
    # Read historical quantiles
    if var_id in ['tasmax', 'tasmin', 'tas']:
        if 'gmfd' in obs:
            ds_q_gmfd = xr.open_dataset(quantile_path + 'gmfd_temperature_quantiles_nex-cil-deepsd.nc')
        if 'era5' in obs:
            ds_q_era5 = xr.open_dataset(quantile_path + 'era5_temperature_quantiles_nex-cil-deepsd', engine='zarr')
    elif var_id == 'pr':
        if 'gmfd' in obs:
            ds_q_gmfd = xr.open_dataset(quantile_path + 'gmfd_precip_quantiles_nex-cil-deepsd.nc')
        if 'era5' in obs:
            ds_q_era5 = xr.open_dataset(quantile_path + 'era5_precip_quantiles_nex-cil-deepsd', engine='zarr')
    
    # Read model file
    ds_tmp = xr.open_dataset(model_path + model + '/' + ssp + '/' +
                             var_id + '/' + var_id + '_day_' + model + 
                             '_' + ssp + model_vers + str(year) + '.nc')
           
    # Temperature: K -> C
    if var_id == 'tas' and ds_tmp.tas.attrs['units'] == 'K':
        ds_tmp['tas'] = ds_tmp['tas'] - 273.15
    if var_id == 'tasmax' and ds_tmp.tasmax.attrs['units'] == 'K':
        ds_tmp['tasmax'] = ds_tmp['tasmax'] - 273.15
    if var_id == 'tasmin' and ds_tmp.tasmin.attrs['units'] == 'K':
        ds_tmp['tasmin'] = ds_tmp['tasmin'] - 273.15

    # Precip: kg m-2 s-1 -> mm day-1
    if var_id == 'pr' and ds_tmp.pr.attrs['units'] == 'kg m-2 s-1':
        ds_tmp['pr'] = ds_tmp['pr'] * 86400

    # Calculate metrics
    ds_tmp_out = []
    for rp in ['q99', 'rp10']:
        # GMFD
        if 'gmfd' in obs:
            # Above/below binary
            ds_tmp_q_gmfd = ds_tmp[var_id] > ds_q_gmfd[var_id + '_' + rp]
            # Count
            ds_tmp_q_gmfd_count = ds_tmp_q_gmfd.resample(time='1Y').sum()
            ds_tmp_out.append(xr.Dataset({var_id + '_' + rp + 'gmfd_count': ds_tmp_q_gmfd_count}))
            # Streak
            ds_tmp_q_gmfd_streak = ds_tmp_q_gmfd.resample(time='1Y').apply(n_longest_consecutive)
            ds_tmp_out.append(xr.Dataset({var_id + '_' + rp + 'gmfd_streak': ds_tmp_q_gmfd_streak}))
            
        # ERA5
        if 'era5' in obs:
            # Above/below binary
            ds_tmp_q_era5 = ds_tmp[var_id] > ds_q_era5[var_id + '_' + rp]
            # Count
            ds_tmp_q_era5_count = ds_tmp_q_era5.resample(time='1Y').sum()
            ds_tmp_out.append(xr.Dataset({var_id + '_' + rp + 'era5_count': ds_tmp_q_era5_count}))
            # Streak
            ds_tmp_q_era5_streak = ds_tmp_q_era5.resample(time='1Y').apply(n_longest_consecutive)
            ds_tmp_out.append(xr.Dataset({var_id + '_' + rp + 'era5_streak': ds_tmp_q_era5_streak}))
    
    # Merge and return
    ds_out = xr.merge(ds_tmp_out)
    ds_out = ds_out.assign_coords(ssp=ssp)
    return ds_out

### Wet days

In [9]:
%%time
# Loop through models: RUNTIME IS ~16 MINS PER MODEL WITH 55 DASK WORKERS
metric = 'wet'

# Precip only
var_id = 'pr'

for model in models:
    # Check if already exists
    if os.path.isfile(out_path + metric + '/' + model + '.nc'):
        print(model + ' ' + var_id + ' already done')
        continue

    # Parallelize with dask over ssp-years
    delayed_res = []

    for ssp in nex_ssp_dict[model]:
        for year in range(2015, 2101):
            tmp_res = dask.delayed(model_year_ssp_metric)(in_path,
                                                          quantile_path,
                                                          model,
                                                          model_info[model],
                                                          ssp,
                                                          var_id,
                                                          year,
                                                          ['gmfd', 'era5'])
            delayed_res.append(tmp_res)
        
    # Compute
    res = dask.compute(*delayed_res)

    # Combine in correct order along ssp, year
    df_final = xr.concat([xr.concat([ds for ds in res if ds.ssp == ssp], dim='time') for ssp in nex_ssp_dict[model]], dim='ssp')
    del res
    
    # Store
    df_final.to_netcdf(out_path + metric + '/' + model + '.nc')
    del df_final
    
    print(model)

ACCESS-ESM1-5 pr already done
BCC-CSM2-MR pr already done
CanESM5
CMCC-ESM2
CNRM-CM6-1
CNRM-ESM2-1
EC-Earth3
EC-Earth3-Veg-LR
GFDL-ESM4
HadGEM3-GC31-LL
INM-CM4-8
INM-CM5-0
IPSL-CM6A-LR
MIROC-ES2L
MIROC6
MPI-ESM1-2-HR
MPI-ESM1-2-LR
MRI-ESM2-0
NESM3
NorESM2-LM
NorESM2-MM
UKESM1-0-LL
CPU times: user 15min 32s, sys: 10min 25s, total: 25min 58s
Wall time: 4h 32min 15s


### Hot days

In [7]:
%%time
# Loop through models: RUNTIME IS ~16 MINS PER MODEL-VARIABLE WITH 55 DASK WORKERS
metric = 'hot'

for model in models:
    for var_id in ['tasmin', 'tasmax', 'tas']:
        # Check if already exists
        if os.path.isfile(out_path + metric + '/' + model + '_' + var_id + '.nc'):
            print(model + ' ' + var_id + ' already done')
            continue
    
        # Parallelize with dask over ssp-years
        delayed_res = []
    
        for ssp in nex_ssp_dict[model]:
            for year in range(2015, 2101):
                tmp_res = dask.delayed(model_year_ssp_metric)(in_path,
                                                              quantile_path,
                                                              model,
                                                              model_info[model],
                                                              ssp,
                                                              var_id,
                                                              year,
                                                              ['gmfd', 'era5'])
                delayed_res.append(tmp_res)
            
        # Compute
        res = dask.compute(*delayed_res)
    
        # Combine in correct order along ssp, year
        df_final = xr.concat([xr.concat([ds for ds in res if ds.ssp == ssp], dim='time') for ssp in nex_ssp_dict[model]], dim='ssp')
        del res

        # Store
        df_final.to_netcdf(out_path + metric + '/' + model + '_' + var_id + '.nc')
        del df_final 
        
        print(model + '_' + var_id)

ACCESS-ESM1-5 tasmin already done
ACCESS-ESM1-5 tasmax already done
ACCESS-ESM1-5 tas already done
BCC-CSM2-MR tasmin already done
BCC-CSM2-MR tasmax already done
BCC-CSM2-MR tas already done
CanESM5 tasmin already done
CanESM5 tasmax already done
CanESM5 tas already done
CMCC-ESM2 tasmin already done
CMCC-ESM2 tasmax already done
CMCC-ESM2 tas already done
CNRM-CM6-1 tasmin already done
CNRM-CM6-1 tasmax already done
CNRM-CM6-1 tas already done
CNRM-ESM2-1 tasmin already done
CNRM-ESM2-1 tasmax already done
CNRM-ESM2-1 tas already done
EC-Earth3_tasmin
EC-Earth3_tasmax
EC-Earth3_tas
EC-Earth3-Veg-LR_tasmin
EC-Earth3-Veg-LR_tasmax
EC-Earth3-Veg-LR_tas
GFDL-ESM4_tasmin
GFDL-ESM4_tasmax
GFDL-ESM4_tas
HadGEM3-GC31-LL_tasmin
HadGEM3-GC31-LL_tasmax
HadGEM3-GC31-LL_tas
INM-CM4-8_tasmin
INM-CM4-8_tasmax
INM-CM4-8_tas
INM-CM5-0_tasmin
INM-CM5-0_tasmax
INM-CM5-0_tas
IPSL-CM6A-LR_tasmin
IPSL-CM6A-LR_tasmax
IPSL-CM6A-LR_tas
MIROC-ES2L_tasmin
MIROC-ES2L_tasmax
MIROC-ES2L_tas
MIROC6_tasmin
MIROC6_t

In [8]:
cluster.close()

2022-12-02 05:14:03,141 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
