In [1]:
import numpy as np
import xarray as xr

### Preliminaries

In [2]:
###############################
# Set paths
# UPDATE THIS FOR REPRODUCTION
###############################
nex_in = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/cmip6/nex-gddp/'
cil_in = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/cmip6/cil-gdpcir/'

out = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/uc_results/'

In [3]:
###################
# Models
###################

# nex models with all SSPs and variables (tas, pr)
complete_nex_models = ['ACCESS-CM2', 'ACCESS-ESM1-5', 'CanESM5', 'CMCC-ESM2', 
                       'CNRM-CM6-1', 'CNRM-ESM2-1', 'EC-Earth3',
                       'EC-Earth3-Veg-LR', 'FGOALS-g3', 'GFDL-CM4', 'GFDL-ESM4', 
                       'GISS-E2-1-G', 'INM-CM4-8', 'INM-CM5-0',
                       'IPSL-CM6A-LR', 'KACE-1-0-G', 'MIROC-ES2L', 'MIROC6',
                       'MPI-ESM1-2-HR', 'MPI-ESM1-2-LR', 'MRI-ESM2-0', 'NorESM2-LM',
                       'NorESM2-MM', 'TaiESM1', 'UKESM1-0-LL']

# cil models with all SSPs and variables
complete_cil_models = ["INM-CM4-8", "INM-CM5-0", "BCC-CSM2-MR", "CMCC-CM2-SR5",
              "CMCC-ESM2", "MIROC-ES2L", "MIROC6", "UKESM1-0-LL", "MPI-ESM1-2-LR",
              "NorESM2-LM", "NorESM2-MM", "GFDL-ESM4", "EC-Earth3", 
              "EC-Earth3-Veg-LR", "EC-Earth3-Veg", "CanESM5"]

# intersection of models
models = np.intersect1d(complete_cil_models, complete_nex_models)

In [4]:
######################################
# Prep each ensemble for merge
######################################
def preprocess_nex(ds):
    ds['lon'] = np.where(ds['lon'] > 180, ds['lon'] - 360, ds['lon'])
    ds = ds.sortby('lon')
    ds = ds.sortby('ssp')
    ds = ds.assign_coords(ensemble = 'NEX')
    ds = ds.assign_coords(model = ds.encoding['source'][93:-3])
    ds['time'] = ds.indexes['time'].year
    return ds

def preprocess_cil(ds):
    if 'tasavg' in ds.data_vars:
        ds = ds.rename({'tasavg':'tas'})
    ds = ds.sel(lat=slice(-60, 90))
    ds = ds.assign_coords(ensemble = 'CIL')
    ds = ds.sortby('ssp')
    ds = ds.assign_coords(model = ds.encoding['source'][95:-3])
    ds['time'] = ds.indexes['time'].year
    return ds

In [5]:
######################################
# Read all outputs for a given metric
######################################
def read_all(nex_in, cil_in, metric):
    # read NEX
    ds_nex = xr.open_mfdataset(nex_in + '/' + metric + '/*', 
                               parallel=True, preprocess=preprocess_nex,
                               combine='nested', concat_dim='model',
                               chunks={'model':1, 'ssp':1, 'time':83, 'lat':600, 'lon':1440},
                               compat='identical')
    
    # read CIL
    ds_cil = xr.open_mfdataset(cil_in + '/' + metric + '/*',
                               parallel=True, preprocess=preprocess_cil,
                               combine='nested', concat_dim='model',
                               chunks={'model':1, 'ssp':1, 'time':83, 'lat':600, 'lon':1440},
                               compat='identical')

    if metric != 'precip_inds':
        ds_cil = convert_K_to_C(ds_cil)
    
    # we did not calculate tasavg annual maxs/mins in CIL
    if metric == 'annual_maxs' or metric == 'annual_mins':
        ds_nex = ds_nex.drop('tas')
    
    # merge all
    ds = xr.concat([ds_nex, ds_cil], dim='ensemble', compat='equals')
    
    # mask out ocean points (NEX is only available over land)
    ds_mask = ds.isel(ensemble=0, ssp=0, time=0, model=0)[list(ds.keys())[0]].isnull()
    ds = xr.where(ds_mask, np.nan, ds)
    
    return ds

In [6]:
# ###########################
# ####### K19 method ########
# ###########################
# def uc_k19(ds):
#     # Scenario uncertainty
#     U_scen = ds.var(dim='ssp').mean(dim=['model', 'ensemble']).compute()

#     # Model uncertainty
#     U_scen_model = ds.var(dim=['ssp', 'model']).mean(dim='ensemble').compute()
#     U_model = U_scen_model - U_scen

#     # Downscaling uncertainy
#     U_scen_model_ens = ds.var(dim=['ensemble', 'ssp', 'model']).compute()
#     U_ens = U_scen_model_ens - U_scen_model
    
#     # Merge and return
#     U_model = U_model.assign_coords(uncertainty = 'model')
#     U_scen = U_scen.assign_coords(uncertainty = 'scenario')
#     U_ens = U_ens.assign_coords(uncertainty = 'ensemble')
#     U_scen_model_ens = U_scen_model_ens.assign_coords(uncertainty = 'total')
    
#     return xr.concat([U_model, U_scen, U_ens, U_scen_model_ens], dim='uncertainty')

# ####################################
# ####### K19 method with IAV ########
# ####################################
# def uc_k19_iav(ds_in):
#     # Get rolling mean
#     ds_rolling = ds_in.rolling(time=10, center=True).mean().dropna('time')
#     ds_rolling = ds_rolling.assign_coords(iav = 'No')
    
#     ds = xr.concat([ds_rolling, ds_in.assign_coords(iav = 'Yes')], dim='iav').dropna('time')

#     # Scenario uncertainty
#     U_scen = ds.var(dim='model').mean(dim=['ssp', 'ensemble', 'iav']).compute()

#     # Model uncertainty
#     U_scen_model = ds.var(dim=['ssp', 'model']).mean(dim=['ensemble', 'iav']).compute()
#     U_model = U_scen_model - U_scen

#     # Downscaling uncertainy
#     U_scen_model_ens = ds.var(dim=['iav', 'ssp', 'model']).mean(dim='ensemble').compute()
#     U_ens = U_scen_model_ens - U_scen_model
                                                          
#     # Interannual variability
#     U_scen_model_ens_iav = ds.var(dim=['ensemble', 'ssp', 'model', 'iav'])
#     U_iav = U_scen_model_ens_iav - U_scen_model_ens
    
#     # Merge and return
#     U_model = U_model.assign_coords(uncertainty = 'scenario')
#     U_scen = U_scen.assign_coords(uncertainty = 'model')
#     U_ens = U_ens.assign_coords(uncertainty = 'variability')
#     U_iav = U_iav.assign_coords(uncertainty = 'ensemble')
#     U_scen_model_ens_iav = U_scen_model_ens_iav.assign_coords(uncertainty = 'total')
    
#     return xr.concat([U_model, U_scen, U_ens, U_iav, U_scen_model_ens_iav], dim='uncertainty')

In [7]:
###########################
####### HS09 method #######
###########################
def uc_hs09(ds):
    # Model uncertainty: variance across models, averaged over scenarios and ensembles
    U_model = ds.var(dim='model').mean(dim=['ssp', 'ensemble']).compute()

    # Scenario uncertainty STANDARD (HS09): variance across multi-model means
    U_scen = ds.mean(dim=['model', 'ensemble']).var(dim='ssp').compute()

    # # Scenario uncertainty LEE
    # Uscen_lee = ds.var(dim=['model', 'ensemble']).mean(dim='ssp').compute()

    # Downscaling uncertainy: variance across ensembles, averaged over models and scenarios
    U_ens = ds.var(dim='ensemble').mean(dim=['ssp', 'model']).compute()

    # Total uncertainty
    U_total = ds.var(dim=['ensemble', 'ssp', 'model']).compute()
    
    # Merge and return
    U_model = U_model.assign_coords(uncertainty = 'model')
    U_scen = U_scen.assign_coords(uncertainty = 'scenario')
    U_ens = U_ens.assign_coords(uncertainty = 'ensemble')
    U_total = U_total.assign_coords(uncertainty = 'total')
    
    return xr.concat([U_model, U_scen, U_ens, U_total], dim='uncertainty')

In [15]:
####################################
####### HS09 method with IAV #######
####################################
def uc_hs09_iav(ds_in, ds_rolling):
    # Total uncertainty including inter-annual variability
    U_total = ds_in.var(dim=['ensemble', 'ssp', 'model']).sel(time=slice(2020,2096)).compute()
        
    # Interannual variability (single value for all years)
    U_iav = (ds_in - ds_rolling).var(dim='time').mean(dim=['ensemble', 'ssp', 'model']).assign_coords(time = np.arange(2020,2097)).compute()
    
    ####### Work with forced response from here ########

    # Model uncertainty: variance across models, averaged over scenarios and ensembles
    U_model = ds_rolling.var(dim='model').mean(dim=['ssp', 'ensemble']).compute()

    # Scenario uncertainty: variance across multi-model means
    U_scen = ds_rolling.mean(dim=['model', 'ensemble']).var(dim='ssp').compute()

    # Downscaling uncertainy: variance across ensembles, averaged over models and scenarios
    U_ens = ds_rolling.var(dim='ensemble').mean(dim=['ssp', 'model']).compute()
    
    # Merge and return
    U_model = U_model.assign_coords(uncertainty = 'model')
    U_scen = U_scen.assign_coords(uncertainty = 'scenario')
    U_ens = U_ens.assign_coords(uncertainty = 'ensemble')
    U_iav = U_iav.assign_coords(uncertainty = 'variability')
    U_total = U_total.assign_coords(uncertainty = 'total')
    
    return xr.concat([U_model, U_scen, U_ens, U_iav, U_total], dim='uncertainty')

In [9]:
# Needed for CIL
def convert_K_to_C(ds):
    ds['tasmax'] = ds['tasmax'] - 273.15 # K -> C
    ds['tasmin'] = ds['tasmin'] - 273.15 # K -> C
    if 'tas' in ds.data_vars:
        ds['tas'] = ds['tas'] - 273.15 # K -> C
    return ds

In [10]:
############
# Dask
############
from dask_jobqueue import PBSCluster

cluster = PBSCluster(cores=1, memory='50GB', resource_spec='pmem=50GB',
                     project='open',
                     env_extra=['#PBS -l feature=rhel7'], 
                     walltime='00:30:00')

cluster.scale(jobs=30)  # ask for jobs

from dask.distributed import Client
client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.102.201.236:33146,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Not including interannual variability

## Annual averages

In [10]:
# Read
ds = read_all(nex_in, cil_in, 'annual_avgs')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [11]:
%%time
# HS09 method
ds_out = uc_hs09(ds)
ds_out.to_netcdf(out + 'cmip6/annual_avgs_HS09.nc')

# # K19
# ds_out = uc_k19(ds)
# ds_out.to_netcdf(out + 'cmip6/annual_avgs_K19.nc')

CPU times: user 3min 48s, sys: 18.4 s, total: 4min 6s
Wall time: 7min 39s


In [12]:
# clear
client.cancel(ds)

## Annual maxs

In [13]:
# Read
ds = read_all(nex_in, cil_in, 'annual_maxs')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [14]:
%%time
# HS09 method
ds_out = uc_hs09(ds)
ds_out.to_netcdf(out + 'cmip6/annual_maxs_HS09.nc')

# # K19
# ds_out = uc_k19(ds)
# ds_out.to_netcdf(out + 'cmip6/annual_maxs_K19.nc')

CPU times: user 2min 54s, sys: 16.5 s, total: 3min 10s
Wall time: 6min 34s


In [15]:
# clear
client.cancel(ds)

## Annual mins

In [14]:
# Read
ds = read_all(nex_in, cil_in, 'annual_mins')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [16]:
%%time
# HS09 method
ds_out = uc_hs09(ds)
ds_out.to_netcdf(out + 'cmip6/annual_mins_HS09.nc')

# # K19
# ds_out = uc_k19(ds)
# ds_out.to_netcdf(out + 'cmip6/annual_mins_K19.nc')

CPU times: user 1min 30s, sys: 7.32 s, total: 1min 37s
Wall time: 3min 14s


In [17]:
# clear
client.cancel(ds)

## Precip inds

In [40]:
# Read
ds = read_all(nex_in, cil_in, 'precip_inds')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [41]:
%%time
# HS09 method
ds_out = uc_hs09(ds)
ds_out.to_netcdf(out + 'cmip6/precip_inds_HS09.nc')

# # K19
# ds_out = uc_k19(ds)
# ds_out.to_netcdf(out + 'cmip6/precip_inds_K19.nc')

CPU times: user 2min 43s, sys: 18.8 s, total: 3min 2s
Wall time: 7min 16s


In [42]:
# clear
client.cancel(ds)

# Including interannual variability

## Annual averages

In [11]:
# Read
ds = read_all(nex_in, cil_in, 'annual_avgs')

# Persist for faster computation
ds = ds.persist()
ds_rolling = ds.rolling(time=10, center=True).mean().sel(time=slice(2020,2096)).persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})
ds_rolling = ds_rolling.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [12]:
%%time
# HS09 method
ds_out = uc_hs09_iav(ds, ds_rolling)
ds_out.to_netcdf(out + 'cmip6/annual_avgs_HS09_iav.nc')

CPU times: user 4min 41s, sys: 28.8 s, total: 5min 10s
Wall time: 10min 32s


In [None]:
# clear
client.cancel(ds)
client.cancel(ds_rolling)

## Annual maxs

In [14]:
# Read
ds = read_all(nex_in, cil_in, 'annual_maxs')

# Persist for faster computation
ds = ds.persist()
ds_rolling = ds.rolling(time=10, center=True).mean().sel(time=slice(2020,2096)).persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})
ds_rolling = ds_rolling.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [None]:
%%time
# HS09 method
ds_out = uc_hs09_iav(ds, ds_rolling)
ds_out.to_netcdf(out + 'cmip6/annual_maxs_HS09_iav.nc')

In [None]:
# clear
client.cancel(ds)
client.cancel(ds_rolling)

## Annual mins

In [None]:
# Read
ds = read_all(nex_in, cil_in, 'annual_mins')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [None]:
%%time
# HS09 method
ds_out = uc_hs09_iav(ds)
ds_out.to_netcdf(out + 'cmip6/annual_mins_HS09_iav.nc')

In [None]:
# clear
client.cancel(ds)

## Precip inds

In [18]:
# Read
ds = read_all(nex_in, cil_in, 'precip_inds')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [None]:
%%time
# HS09 method
ds_out = uc_hs09_iav(ds)
ds_out.to_netcdf(out + 'cmip6/precip_inds_HS09_iav.nc')

# # K19
# ds_out = uc_k19(ds)
# ds_out.to_netcdf(out + 'cmip6/precip_inds_K19.nc')

In [None]:
# clear
client.cancel(ds)