In [1]:
import numpy as np
import xarray as xr

### Preliminaries

In [2]:
###############################
# Set paths
# UPDATE THIS FOR REPRODUCTION
###############################
nex_in = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/cmip6/nex-gddp/'
cil_in = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/cmip6/cil-gdpcir/'

In [3]:
###################
# Models
###################

# nex models with all SSPs and variables (tas, pr)
complete_nex_models = ['ACCESS-CM2', 'ACCESS-ESM1-5', 'CanESM5', 'CMCC-ESM2', 
                       'CNRM-CM6-1', 'CNRM-ESM2-1', 'EC-Earth3',
                       'EC-Earth3-Veg-LR', 'FGOALS-g3', 'GFDL-CM4', 'GFDL-ESM4', 
                       'GISS-E2-1-G', 'INM-CM4-8', 'INM-CM5-0',
                       'IPSL-CM6A-LR', 'KACE-1-0-G', 'MIROC-ES2L', 'MIROC6',
                       'MPI-ESM1-2-HR', 'MPI-ESM1-2-LR', 'MRI-ESM2-0', 'NorESM2-LM',
                       'NorESM2-MM', 'TaiESM1', 'UKESM1-0-LL']

# cil models with all SSPs and variables
complete_cil_models = ["INM-CM4-8", "INM-CM5-0", "BCC-CSM2-MR", "CMCC-CM2-SR5",
              "CMCC-ESM2", "MIROC-ES2L", "MIROC6", "UKESM1-0-LL", "MPI-ESM1-2-LR",
              "NorESM2-LM", "NorESM2-MM", "GFDL-ESM4", "EC-Earth3", 
              "EC-Earth3-Veg-LR", "EC-Earth3-Veg", "CanESM5"]

# intersection of models
models = np.intersect1d(complete_cil_models, complete_nex_models)

In [4]:
######################################
# Prep each ensemble for merge
######################################
def preprocess_nex(ds):
    ds['lon'] = np.where(ds['lon'] > 180, ds['lon'] - 360, ds['lon'])
    ds = ds.sortby('lon')
    ds = ds.sortby('ssp')
    ds = ds.assign_coords(ensemble = 'NEX')
    ds = ds.assign_coords(model = ds.encoding['source'][93:-3])
    ds['time'] = ds.indexes['time'].year
    return ds

def preprocess_cil(ds):
    ds = ds.rename({'tasavg':'tas'})
    ds = ds.sel(lat=slice(-60, 90))
    ds = ds.assign_coords(ensemble = 'CIL')
    ds = ds.sortby('ssp')
    ds = ds.assign_coords(model = ds.encoding['source'][95:-3])
    ds['time'] = ds.indexes['time'].year
    return ds

In [5]:
############
# Dask
############
from dask_jobqueue import PBSCluster

cluster = PBSCluster(cores=1, memory='20GB', resource_spec = 'pmem=20GB',
                     project='open',
                     env_extra=['#PBS -l feature=rhel7'], 
                     walltime='00:20:00')

cluster.scale(jobs=20)  # ask for jobs

from dask.distributed import Client
client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.102.201.228:43499,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [28]:
###########################
####### HS09 method #######
###########################
def uc_hs09(ds):
    # Faster computation
    ds = ds.persist()
    
    # Model uncertainty: variance across models, averaged over scenarios and ensembles
    U_model = ds.var(dim='model').mean(dim=['ssp', 'ensemble']).compute()

    # Scenario uncertainty STANDARD (HS09): variance across multi-model means
    U_scen = ds.mean(dim=['model', 'ensemble']).var(dim='ssp').compute()

    # # Scenario uncertainty LEE
    # Uscen_lee = ds.var(dim=['model', 'ensemble']).mean(dim='ssp').compute()

    # Downscaling uncertainy: variance across ensembles, averaged over models and scenarios
    U_ens = ds.var(dim='ensemble').mean(dim=['ssp', 'model']).compute()

    # Total uncertainty
    U_total = ds.var(dim=['ensemble', 'ssp', 'model']).compute()
    
    # Merge and return
    U_model = U_model.assign_coords(uncertainty = 'model')
    U_scen = U_scen.assign_coords(uncertainty = 'scenario')
    U_ens = U_ens.assign_coords(uncertainty = 'ensemble')
    U_total = U_total.assign_coords(uncertainty = 'total')
    
    return xr.concat([U_model, U_scen, U_ens, U_total], dim='uncertainty')

In [27]:
###########################
####### K19 method ########
###########################
def uc_k19(ds):
    # Faster computation
    ds = ds.persist()
    
    # Scenario uncertainty
    U_scen = ds_test.var(dim='ssp').mean(dim=['model', 'ensemble']).compute()

    # Model uncertainty
    U_scen_model = ds_test.var(dim=['ssp', 'model']).mean(dim='ensemble').compute()
    U_model = U_scen_model - U_scen

    # Downscaling uncertainy
    U_scen_model_ens = ds_test.var(dim=['ensemble', 'ssp', 'model']).compute()
    U_ens = U_scen_model_ens - U_scen_model
    
    # Merge and return
    U_model = U_model.assign_coords(uncertainty = 'model')
    U_scen = U_scen.assign_coords(uncertainty = 'scenario')
    U_ens = U_ens.assign_coords(uncertainty = 'ensemble')
    U_scen_model_ens = U_scen_model_ens.assign_coords(uncertainty = 'total')
    
    return xr.concat([U_model, U_scen, U_ens, U_scen_model_ens], dim='uncertainty')

In [26]:
# function for performing uncertainty analysis
def perform_uc(nex_in, cil_in, metric, generation, method, store):
    # read NEX
    ds_nex = xr.open_mfdataset(nex_in + '/' + metric + '/*', preprocess=preprocess_nex,
                           combine='nested', concat_dim='model',
                           chunks='auto', compat='identical')
    
    # read CIL
    ds_cil = xr.open_mfdataset(cil_in + '/' + metric + '/*', preprocess=preprocess_cil,
                           combine='nested', concat_dim='model',
                           chunks='auto', compat='identical')

    ds_cil['tas'] = ds_cil['tas'] - 273.15 # K -> C
    ds_cil['tasmax'] = ds_cil['tasmax'] - 273.15 # K -> C
    ds_cil['tasmin'] = ds_cil['tasmin'] - 273.15 # K -> C
    
    # merge all
    ds = xr.concat([ds_nex, ds_cil], dim='ensemble', compat='equals')
    
    # calculate UC metrics
    if method == 'HS09':
        # Hawkins Sutton 2009 method
        ds_out = uc_hs09(ds)
    elif method == 'K19':
        # Kim et al. 2019 method
        ds_out = uc_k19(ds)
    
    # store or return
    if store:
        ds_out.to_netcdf('../data/' + generation + '_results/' + metric + '_' + method + '.nc')
    else:
        return ds_out

## Analysis

In [None]:
%%time
ds_out = perform_uc(nex_in, cil_in, 'annual_avgs', 'cmip6', 'HS09', False)