In [1]:
import numpy as np
import xarray as xr
import dask
dask.config.set(**{'array.slicing.split_large_chunks': False});

### Preliminaries

In [2]:
###############################
# Set paths
# UPDATE THIS FOR REPRODUCTION
###############################
nex_in = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/metrics/nex-gddp/'
cil_in = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/metrics/cil-gdpcir/'
isi_in = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/metrics/isimip3b/regridded/conservative/'
cbp_in = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/metrics/carbonplan/regridded/conservative/'

out = '/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/uc_results/'

In [3]:
######################################
# Prep each ensemble for merge
######################################
def preprocess_nex(ds):
    ds['lon'] = np.where(ds['lon'] > 180, ds['lon'] - 360, ds['lon'])
    ds = ds.sortby('lon')
    ds = ds.sortby('ssp')
    ds = ds.assign_coords(ensemble = 'NEX')
    ds = ds.assign_coords(model = ds.encoding['source'].replace(nex_in, '').split('/')[-1][:-3])
    ds['time'] = ds.indexes['time'].year
    return ds

def preprocess_cil(ds):
    ds = ds.sel(lat=slice(-60, 90))
    ds = ds.assign_coords(ensemble = 'CIL')
    ds = ds.sortby('ssp')
    ds = ds.assign_coords(model = ds.encoding['source'].replace(cil_in, '').split('/')[-1])
    ds['time'] = ds.indexes['time'].year
    return ds

def preprocess_isi(ds):
    ds['lon'] = np.where(ds['lon'] > 180, ds['lon'] - 360, ds['lon'])
    ds = ds.sortby('lon')
    ds = ds.sortby('ssp')
    ds = ds.assign_coords(ensemble = 'ISIMIP')
    ds = ds.assign_coords(model = ds.encoding['source'].replace(isi_in, '').split('/')[-1][:-3])
    ds['time'] = ds.indexes['time'].year
    return ds

def preprocess_cbp_gard(ds):
    ds = ds.sel(lat=slice(-60, 90))
    ds = ds.sortby('ssp')
    ds = ds.assign_coords(ensemble = 'GARD-SV')
    ds = ds.assign_coords(model = ds.encoding['source'].replace(cbp_in, '').split('/')[-1][:-3])
    ds['time'] = ds.indexes['time'].year
    # for some models/methods we are missing 
    # precip so need to fill with NaNs
    if 'pr' not in ds.data_vars:
        ds['pr'] = xr.full_like(ds['tas'], np.NaN)
    return ds

def preprocess_cbp_deep(ds):
    ds = ds.sel(lat=slice(-60, 90))
    ds = ds.sortby('ssp')
    ds = ds.assign_coords(ensemble = 'DeepSD-BC')
    ds = ds.assign_coords(model = ds.encoding['source'].replace(cbp_in, '').split('/')[-1][:-3])
    ds['time'] = ds.indexes['time'].year
    # for some models/methods we are missing 
    # precip so need to fill with NaNs
    if 'pr' not in ds.data_vars:
        ds['pr'] = xr.full_like(ds['tas'], np.NaN)
    return ds

In [4]:
######################################
# Read all outputs for a given metric
######################################
def read_all(metric):
    chunking = {'model':1, 'ssp':1, 'time':86, 'lat':300, 'lon':720}
    # read NEX
    ds_nex = xr.open_mfdataset(nex_in + metric + '/*.nc', 
                               parallel=True, preprocess=preprocess_nex,
                               combine='nested', concat_dim='model',
                               chunks=chunking,
                               compat='identical')#.sel(lat=slice(40,50), lon=slice(0,10))
    
    # read CIL
    ds_cil = xr.open_mfdataset(cil_in + metric + '/*',
                               engine='zarr',
                               parallel=True, preprocess=preprocess_cil,
                               combine='nested', concat_dim='model',
                               compat='identical')
    ds_cil = ds_cil.chunk(chunking)
    
    # read ISIMIP
    ds_isi = xr.open_mfdataset(isi_in + metric + '/*.nc',
                               parallel=True, preprocess=preprocess_isi,
                               combine='nested', concat_dim='model',
                               chunks=chunking,
                               compat='identical')
    
    # read carbonplan GARD-SV
    ds_cbp_gard = xr.open_mfdataset(cbp_in + 'GARD-SV/' + metric + '/*.nc',
                                    parallel=True, preprocess = preprocess_cbp_gard, 
                                    combine='nested', concat_dim='model',
                                    compat='identical')
    ds_cbp_gard = ds_cbp_gard.chunk(chunking)
    
    # read carbonplan DeepSD-BC
    ds_cbp_deep = xr.open_mfdataset(cbp_in + 'DeepSD-BC/' + metric + '/*.nc',
                                    parallel=True, preprocess = preprocess_cbp_deep, 
                                    combine='nested', concat_dim='model',
                                    compat='identical')
    ds_cbp_deep = ds_cbp_deep.chunk(chunking)
    
    # merge all
    ds = xr.concat([ds_nex, ds_cil, ds_isi, ds_cbp_gard, ds_cbp_deep],
                   dim='ensemble', fill_value=np.nan)
    
    # mask out ocean points (NEX is only available over land)
    ds_mask = ds.isel(ensemble=0, ssp=0, time=0, model=0)[list(ds.keys())[0]].isnull()
    ds = xr.where(ds_mask, np.nan, ds)
    
    return ds

In [31]:
###########################
####### HS09 method #######
###########################
def uc_hs09(ds):
    #####  Model uncertainty ##### 
    # Variance across models, averaged over scenarios and ensembles
    U_model = ds.var(dim='model')
    # weights (choose point over land)
    weights = ds.isel(time=0, lat=300, lon=800)[list(ds.data_vars)[0]].count(dim='model').rename('weights')
    weights = xr.where(weights == 1, 0, weights) # remove combinations where variance was calculated over 1 entry
    U_model = U_model.weighted(weights).mean(dim=['ssp', 'ensemble'])

    ##### Scenario uncertainty #####
    # Variance across multi-model means (HS09 approach)
    U_scen = ds.mean(dim=['model', 'ensemble']).var(dim='ssp')

    ##### Downscaling uncertainy ##### 
    # Variance across ensembles, averaged over models and scenarios
    U_ens = ds.var(dim='ensemble')
    weights = ds.isel(time=0, lat=300, lon=800)[list(ds.data_vars)[0]].count(dim='ensemble').rename('weights') # weights
    weights = xr.where(weights == 1, 0, weights) # remove combinations where variance was calculated over 1 entry
    U_ens = U_ens.weighted(weights).mean(dim=['ssp', 'model'])

    ##### Total uncertainty #####
    # Variance across everything
    # Note this in general will not equal 
    # the sum of individual uncertainties
    U_total_true = ds.var(dim=['ensemble', 'ssp', 'model'])
    
    # Our 'simulated' total uncertainty
    U_total_sim = U_model + U_scen + U_end

    ##### Merge and return #####
    U_model = U_model.assign_coords(uncertainty = 'model')
    U_scen = U_scen.assign_coords(uncertainty = 'scenario')
    U_ens = U_ens.assign_coords(uncertainty = 'ensemble')
    U_total_true = U_total_true.assign_coords(uncertainty = 'total_true')
    U_total_sim = U_total_sim.assign_coords(uncertainty = 'total_sim')
    
    return xr.concat([U_model, U_scen, U_ens, U_total_true, U_total_sim], dim='uncertainty')

In [6]:
####################################
####### HS09 method with IAV #######
####################################
def uc_hs09_iav(ds_in, ds_rolling):
    # Total uncertainty including inter-annual variability
    U_total = ds_in.var(dim=['ensemble', 'ssp', 'model']).sel(time=slice(2020,2096)).compute()
        
    # Interannual variability (single value for all years)
    U_iav = (ds_in - ds_rolling).var(dim='time').mean(dim=['ensemble', 'ssp', 'model']).assign_coords(time = np.arange(2020,2097)).compute()
    
    ####### Work with forced response from here ########

    # Model uncertainty: variance across models, averaged over scenarios and ensembles
    U_model = ds_rolling.var(dim='model').mean(dim=['ssp', 'ensemble']).compute()

    # Scenario uncertainty: variance across multi-model means
    U_scen = ds_rolling.mean(dim=['model', 'ensemble']).var(dim='ssp').compute()

    # Downscaling uncertainy: variance across ensembles, averaged over models and scenarios
    U_ens = ds_rolling.var(dim='ensemble').mean(dim=['ssp', 'model']).compute()
    
    # Merge and return
    U_model = U_model.assign_coords(uncertainty = 'model')
    U_scen = U_scen.assign_coords(uncertainty = 'scenario')
    U_ens = U_ens.assign_coords(uncertainty = 'ensemble')
    U_iav = U_iav.assign_coords(uncertainty = 'variability')
    U_total = U_total.assign_coords(uncertainty = 'total')
    
    return xr.concat([U_model, U_scen, U_ens, U_iav, U_total], dim='uncertainty')

In [7]:
############
# Dask
############
from dask_jobqueue import PBSCluster

cluster = PBSCluster(cores=1, memory='50GB', resource_spec='pmem=50GB',
                     project='open',
                     env_extra=['#PBS -l feature=rhel7'], 
                     walltime='00:30:00')

cluster.scale(jobs=40)  # ask for jobs

from dask.distributed import Client
client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.102.201.237:42725,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Not including interannual variability

## Annual averages

In [28]:
%%time
# Read
ds = read_all('annual_avgs')

CPU times: user 4.64 s, sys: 144 ms, total: 4.78 s
Wall time: 12.3 s


In [29]:
ds.nbytes/1e9

653.875218628

In [30]:
ds.tas

Unnamed: 0,Array,Chunk
Bytes,121.79 GiB,70.86 MiB
Shape,"(600, 1440, 5, 22, 4, 86)","(300, 720, 1, 1, 1, 86)"
Count,389 Graph Layers,1760 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 121.79 GiB 70.86 MiB Shape (600, 1440, 5, 22, 4, 86) (300, 720, 1, 1, 1, 86) Count 389 Graph Layers 1760 Chunks Type float32 numpy.ndarray",5  1440  600  86  4  22,

Unnamed: 0,Array,Chunk
Bytes,121.79 GiB,70.86 MiB
Shape,"(600, 1440, 5, 22, 4, 86)","(300, 720, 1, 1, 1, 86)"
Count,389 Graph Layers,1760 Chunks
Type,float32,numpy.ndarray


In [19]:
# # Persist for faster computation
# ds = ds.persist()

In [25]:
# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':22, 'ssp':4, 'ensemble':5, 'time':86, 'lat':30, 'lon':30})

In [26]:
ds.tas

Unnamed: 0,Array,Chunk
Bytes,121.79 GiB,129.91 MiB
Shape,"(600, 1440, 5, 22, 4, 86)","(30, 30, 5, 22, 4, 86)"
Count,396 Graph Layers,960 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 121.79 GiB 129.91 MiB Shape (600, 1440, 5, 22, 4, 86) (30, 30, 5, 22, 4, 86) Count 396 Graph Layers 960 Chunks Type float32 numpy.ndarray",5  1440  600  86  4  22,

Unnamed: 0,Array,Chunk
Bytes,121.79 GiB,129.91 MiB
Shape,"(600, 1440, 5, 22, 4, 86)","(30, 30, 5, 22, 4, 86)"
Count,396 Graph Layers,960 Chunks
Type,float32,numpy.ndarray


In [4]:
import dask.distributed

In [6]:
dask.distributed.__version__

'2022.8.1'

In [None]:
%%time
# HS09 method
ds_out = uc_hs09(ds)
ds_out.to_netcdf(out + 'cmip6/annual_avgs_HS09.nc')

Task exception was never retrieved
future: <Task finished name='Task-694272' coro=<Client._gather.<locals>.wait() done, defined at /storage/home/d/dcl5300/work/ENVS/climate-stack/lib/python3.10/site-packages/distributed/client.py:2038> exception=AllExit()>
Traceback (most recent call last):
  File "/storage/home/d/dcl5300/work/ENVS/climate-stack/lib/python3.10/site-packages/distributed/client.py", line 2047, in wait
    raise AllExit()
distributed.client.AllExit
Task exception was never retrieved
future: <Task finished name='Task-694167' coro=<Client._gather.<locals>.wait() done, defined at /storage/home/d/dcl5300/work/ENVS/climate-stack/lib/python3.10/site-packages/distributed/client.py:2038> exception=AllExit()>
Traceback (most recent call last):
  File "/storage/home/d/dcl5300/work/ENVS/climate-stack/lib/python3.10/site-packages/distributed/client.py", line 2047, in wait
    raise AllExit()
distributed.client.AllExit
Task exception was never retrieved
future: <Task finished name='Ta

In [12]:
# clear
client.cancel(ds)

In [None]:
%%time
# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':22, 'ssp':4, 'ensemble':5, 'time':86, 'lat':40, 'lon':40})

In [None]:
%%time
# HS09 method
ds_out = uc_hs09(ds)
ds_out.to_netcdf(out + 'cmip6/annual_avgs_HS09_rechunk.nc')

## Annual maxs

In [13]:
# Read
ds = read_all(nex_in, cil_in, 'annual_maxs')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [14]:
%%time
# HS09 method
ds_out = uc_hs09(ds)
ds_out.to_netcdf(out + 'cmip6/annual_maxs_HS09.nc')

# # K19
# ds_out = uc_k19(ds)
# ds_out.to_netcdf(out + 'cmip6/annual_maxs_K19.nc')

CPU times: user 2min 54s, sys: 16.5 s, total: 3min 10s
Wall time: 6min 34s


In [15]:
# clear
client.cancel(ds)

## Annual mins

In [14]:
# Read
ds = read_all(nex_in, cil_in, 'annual_mins')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [16]:
%%time
# HS09 method
ds_out = uc_hs09(ds)
ds_out.to_netcdf(out + 'cmip6/annual_mins_HS09.nc')

# # K19
# ds_out = uc_k19(ds)
# ds_out.to_netcdf(out + 'cmip6/annual_mins_K19.nc')

CPU times: user 1min 30s, sys: 7.32 s, total: 1min 37s
Wall time: 3min 14s


In [17]:
# clear
client.cancel(ds)

## Precip inds

In [40]:
# Read
ds = read_all(nex_in, cil_in, 'precip_inds')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [41]:
%%time
# HS09 method
ds_out = uc_hs09(ds)
ds_out.to_netcdf(out + 'cmip6/precip_inds_HS09.nc')

# # K19
# ds_out = uc_k19(ds)
# ds_out.to_netcdf(out + 'cmip6/precip_inds_K19.nc')

CPU times: user 2min 43s, sys: 18.8 s, total: 3min 2s
Wall time: 7min 16s


In [42]:
# clear
client.cancel(ds)

# Including interannual variability

## Annual averages

In [11]:
# Read
ds = read_all(nex_in, cil_in, 'annual_avgs')

# Persist for faster computation
ds = ds.persist()
ds_rolling = ds.rolling(time=10, center=True).mean().sel(time=slice(2020,2096)).persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})
ds_rolling = ds_rolling.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [12]:
%%time
# HS09 method
ds_out = uc_hs09_iav(ds, ds_rolling)
ds_out.to_netcdf(out + 'cmip6/annual_avgs_HS09_iav.nc')

CPU times: user 4min 41s, sys: 28.8 s, total: 5min 10s
Wall time: 10min 32s


In [None]:
# clear
client.cancel(ds)
client.cancel(ds_rolling)

## Annual maxs

In [14]:
# Read
ds = read_all(nex_in, cil_in, 'annual_maxs')

# Persist for faster computation
ds = ds.persist()
ds_rolling = ds.rolling(time=10, center=True).mean().sel(time=slice(2020,2096)).persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})
ds_rolling = ds_rolling.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [None]:
%%time
# HS09 method
ds_out = uc_hs09_iav(ds, ds_rolling)
ds_out.to_netcdf(out + 'cmip6/annual_maxs_HS09_iav.nc')

In [None]:
# clear
client.cancel(ds)
client.cancel(ds_rolling)

## Annual mins

In [None]:
# Read
ds = read_all(nex_in, cil_in, 'annual_mins')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [None]:
%%time
# HS09 method
ds_out = uc_hs09_iav(ds)
ds_out.to_netcdf(out + 'cmip6/annual_mins_HS09_iav.nc')

In [None]:
# clear
client.cancel(ds)

## Precip inds

In [18]:
# Read
ds = read_all(nex_in, cil_in, 'precip_inds')

# Persist for faster computation
ds = ds.persist()

# rechunk for SA calculations which involve acting along model, spp, ensemble dimensions
ds = ds.chunk({'model':13, 'ssp':4, 'ensemble':2, 'time':43, 'lat':60, 'lon':144})

In [None]:
%%time
# HS09 method
ds_out = uc_hs09_iav(ds)
ds_out.to_netcdf(out + 'cmip6/precip_inds_HS09_iav.nc')

# # K19
# ds_out = uc_k19(ds)
# ds_out.to_netcdf(out + 'cmip6/precip_inds_K19.nc')

In [None]:
# clear
client.cancel(ds)