## Generating Timeseries netCDF Files

This notebook uses `intake-esm` to read in the CESM1 runs (from CMIP5) and CESM2 runs (from CMIP6).
For each CMIP experiment and each variable of interest, a netCDF file is constructed
containing the time series of annual global means for the entire time period of the experiment.

### This notebook uses several python packages

The watermark package shows the version number used to help others recreate this environment.

In [1]:
import os
import time # Want finer control than %time allows

import xarray as xr
import numpy as np
import esmlab

import intake
import intake_esm
import ncar_jobqueue
from dask.distributed import Client

import xpersist as xp
# Set up xperist cache
cache_dir = os.path.join(os.path.sep, 'glade', 'p', 'cgd', 'oce', 'projects', 'cesm2-marbl', 'xpersist_cache')
if (os.path.isdir(cache_dir)):
    xp.settings['cache_dir'] = cache_dir
os.makedirs(os.path.join(xp.settings['cache_dir'], 'with_marginal_seas'), exist_ok=True)
os.makedirs(os.path.join(xp.settings['cache_dir'], 'no_marginal_seas'), exist_ok=True)

import ann_avg_utils as aau
units, _ = aau.get_pint_units()

%load_ext watermark
%watermark -a "Mike Levy" -d -iv -m -g -h

Author: Mike Levy

Compiler    : GCC 9.3.0
OS          : Linux
Release     : 3.10.0-1127.18.2.el7.x86_64
Machine     : x86_64
Processor   : x86_64
CPU cores   : 72
Architecture: 64bit

Hostname: crhtc47

Git hash: 9ea3ad60e83b9a610032d9a3d502b83a1168dc24

intake       : 0.6.2
xarray       : 0.18.0
esmlab       : 2019.4.27.post56
numpy        : 1.20.2
ncar_jobqueue: 2021.4.14
xpersist     : 2020.4.30.post8
intake_esm   : 2021.1.15.post12



#### Spin up a dask cluster

Some of these computations take a while

In [2]:
cluster = ncar_jobqueue.NCARCluster(project='P93300606')
cluster.adapt(minimum_jobs=0, maximum_jobs=24)
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://10.12.206.48:42526  Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/mlevy/proxy/8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


### Read the intake_esm datastores

The `intake_esm` package is used to help identify which files belong in each experiment.
The `get_var_from_catalog()` function is a wrapper to read specific files.

In [3]:
catalogs = dict()
catalogs['cesm2'] = intake.open_esm_datastore('data/campaign-cesm2-cmip6-timeseries.json')

#cesm1 = intake.open_esm_datastore('/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cmip5_NOT_CMORIZED.json')
catalogs['cesm1'] = intake.open_esm_datastore('data/glade-cesm1-cmip5-timeseries.json')

### Define our experiments

In [4]:
global_vars = aau.global_vars()

xp_dir = global_vars['xp_dir']
vars = global_vars['vars']
experiments = global_vars['experiments']
experiment_longnames = global_vars['experiment_longnames']
experiment_dict = global_vars['experiment_dict']
time_slices = global_vars['time_slices']
include_marg_seas = global_vars['include_marg_seas']

In [5]:
def get_var_from_catalog(catalog, variable, exp):

    print(f'Reading {variable} from {exp}\n----\n')
    start = time.time()
    cesm1_exp = None
    cesm1_var = None
    cesm2_exp = None
    if experiment_dict[exp][0] == 'cesm1':
        cesm1_exp = experiment_dict[exp][1]
    elif experiment_dict[exp][0] == 'cesm2':
        cesm2_exp = experiment_dict[exp][1]
    else:
        print(f'WARNING: can not determine model version from {exp}')
        return(None)
    chunk_key = experiment_dict[exp][0]
    
    # Note some variable rename shenanigans to account for changes between CESM1 and CESM2
    chunks = dict()
    chunk_2d = {'time' : 180}
    chunk_3d = {'time' : 4}
    chunk_100m = {'time' : 20}
    chunk_150m = {'time' : 15}
    # default is 2D chunk size
    chunks['cesm1'] = chunk_2d
    chunks['cesm2'] = chunk_2d
    if cesm1_exp:
        depth_100m = False
        if variable == 'CaCO3_FLUX_100m':
            cesm1_var = 'CaCO3_FLUX_IN'
        elif variable == 'POC_FLUX_100m':
            cesm1_var = 'POC_FLUX_IN'
        elif variable in ['photoC_diat_zint_100m', 'photoC_diat_zint']:
            cesm1_var = 'photoC_diat'
            depth_100m = (variable == 'photoC_diat_zint_100m') # Will use 150m for "full depth" integral
            chunks['cesm1'] = chunk_100m if depth_100m else chunk_150m
        elif variable in ['photoC_TOT_zint_100m', 'photoC_TOT_zint']:
            cesm1_var = ['photoC_sp', 'photoC_diat', 'photoC_diaz']
            depth_100m = (variable == 'photoC_TOT_zint_100m') # Will use 150m for "full depth" integral
            chunks['cesm1'] = chunk_100m if depth_100m else chunk_150m
        elif variable not in ['SiO2_FLUX_100m', 'NHx_SURFACE_EMIS', 'SedDenitrif', 'ponToSed', 'NO3_RIV_FLUX', 'DON_RIV_FLUX', 'DONr_RIV_FLUX']:
            # for variables in above list, cesm1_var is None
            cesm1_var = variable
    # Hardcode chunks for 3D vars (not ideal)
    if variable == 'diaz_Nfix':
        chunks['cesm1'] = chunk_150m
        chunks['cesm2'] = chunk_150m
    if variable in ['DENITRIF', 'O2']:
        chunks['cesm1'] = chunk_3d
        chunks['cesm2'] = chunk_3d

    if cesm1_exp and (cesm1_var is None):
        print(f'{variable} is not available in {exp}')
        return(None)

    if cesm1_exp and (cesm1_var is not None):
        if type(cesm1_var) == list:
            tmp_dataset = dict()
            for var_from_list in cesm1_var:
                print(f'Reading {var_from_list} to compute {variable}')
                dq = catalog.search(experiment=cesm1_exp, variable=var_from_list).to_dataset_dict(cdf_kwargs={'chunks':chunks[chunk_key]})
                tmp_dataset[var_from_list] = _read_var_from_exp(dq, exp, f'ocn.{cesm1_exp}.pop.h', variable, var_from_list)
                if depth_100m and 'z_t_150m' in tmp_dataset[var_from_list][variable].dims:
                    tmp_da = tmp_dataset[var_from_list][variable].isel(z_t_150m=slice(0,10)).rename({'z_t_150m' : 'z_t_100m'})
                    tmp_dataset[var_from_list][variable] = tmp_da
            dataset = tmp_dataset[cesm1_var[0]]
            for var_from_list in cesm1_var[1:]:
                dataset[variable].data = dataset[variable].data + tmp_dataset[var_from_list][variable].data
        else:
            dq = catalog.search(experiment=cesm1_exp, variable=cesm1_var).to_dataset_dict(cdf_kwargs={'chunks':chunks[chunk_key]})
            dataset = _read_var_from_exp(dq, exp, f'ocn.{cesm1_exp}.pop.h', variable, cesm1_var)
            if depth_100m and 'z_t_150m' in dataset[variable].dims:
                tmp_da = dataset[variable].isel(z_t_150m=slice(0,10)).rename({'z_t_150m' : 'z_t_100m'})
                dataset[variable] = tmp_da
    if cesm2_exp:
        dq = catalog.search(experiment=cesm2_exp, variable=variable).to_dataset_dict(cdf_kwargs={'chunks':chunks[chunk_key]})
        dataset = _read_var_from_exp(dq, exp, f'ocn.{cesm2_exp}.pop.h', variable, cesm1_var)
    
    end = time.time()
    print(f'\nDone reading {variable} from {exp} in {np.round(end - start, 1)}s\n')
    return(dataset)

def _read_var_from_exp(dq, exp, stream, variable, cesm1_var):
    # Define dataset
    dataset_full = dq[stream]

    # Initialize dataset with only time-invariant fields
    keep_vars_no_time = ['REGION_MASK', 'z_t', 'z_t_150m', 'dz', 'TAREA', 'TLONG', 'TLAT', 'member_id', 'ctrl_member_id']
    dataset = dataset_full.isel(time=0).drop([v for v in dataset_full.variables if v not in keep_vars_no_time])

    # Then add variable / cesm1_var with full time dimension
    keep_vars_with_time = ['time', 'time_bound'] + [variable, cesm1_var]
    dataset_full = dataset_full
    for var in keep_vars_with_time:
        if var in dataset_full and var not in dataset:
            dataset[var] = dataset_full[var]
    del(dataset_full)
    if variable not in dataset:
        dataset = dataset.rename({cesm1_var : variable})
        if variable in ['POC_FLUX_100m', 'CaCO3_FLUX_100m', 'SiO2_FLUX_100m']:
            dataset = dataset.isel(z_t=10)  # 100m is top of 11th level, or z_t = 10 counting from 0
    if (cesm1_var != variable) and (cesm1_var in dataset):
        dataset = dataset.drop(cesm1_var)

    # Include marginal seas in computation?
    if include_marg_seas:
        dataset[variable] = dataset[variable].where(dataset['REGION_MASK'] != 0)
    else:
        dataset[variable] = dataset[variable].where(dataset['REGION_MASK'] > 0)

    return(dataset)

In [6]:
def _get_TAREA_and_dz(catalog, exp):
    # Read in any 3D variable to get dataset containing TAREA and dz
    var = 'DENITRIF'

    full_ds = get_var_from_catalog(catalog, var, exp)
    ds = full_ds['dz'].to_dataset(name='dz')
    ds['TAREA'] = full_ds['TAREA']
    ds[var] = full_ds[var].isel(time=0, member_id=0)
    ds['active'] = ds[var].copy()
    ds['active'].data = np.logical_not(np.isnan(ds[var].data))
    del(ds[var])

    return(ds.compute())

In [7]:
%%time

# Set up units
integral_units = dict()
integral_units['area'] = dict()
integral_units['volume'] = dict()

total_volume = dict()
for model_version in experiments:
    for exp in experiments[model_version]:
        xp_func = xp.persist_ds(_get_TAREA_and_dz, name=f'{xp_dir}/{exp}_active', trust_cache=True)
        ds = xp_func(catalogs[model_version], exp)

        integral_units['area'][exp] = units[ds['TAREA'].attrs['units']]
        integral_units['volume'][exp] = integral_units['area'][exp] * units[ds['dz'].attrs['units']]

        # Compute total volume of ocean
        # Sum wgt over active ocean cells
        total_volume[exp] = (ds['active'] * ds['TAREA'] * ds['dz']).sum().values

print('\n----\n')

# Estimate total volume in ocean
# (surface of earth is 71% water, avg depth is 3.7 km)
est_depth = 4*np.pi*0.71*((6371.22*units['km'])**2)*(3.7*units['km'])
for exp in total_volume:
    print(f'Ocean volume in {exp}: {(total_volume[exp] * integral_units["volume"][exp]).to("L")}')
print(f'Estimated ocean volume: {est_depth.to("L")}')

assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm1_PI_active.nc
assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm1_PI_esm_active.nc
assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm1_hist_active.nc
assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm1_hist_esm_active.nc
assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm1_RCP85_active.nc
assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm2_PI_active.nc
assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm2_hist_active.nc
assuming cache is correct
readi

### Individual Table Computations

In this section, we compute each of the requested values for each dataset

#### Net primary production (PgC/yr)

CESM1 doesn't have `photoC_TOT_zint`

#### Diatom primary production (%)

CESM1 doesn't have `photoC_diat_zint`

#### Sinking POC at 100 m (PgC/yr)

CESM1 doesn't have `POC_FLUX_100m`

#### Sinking CaCO3 at 100 m (PgC/yr)

CESM1 doesn't have `CaCO3_FLUX_100m`

#### Sinking SiO2 at 100 m (PgC/yr)

CESM1 doesn't have `SiO2_FLUX_100m`

#### Rain ratio (CaCO3/POC) 100 m

Missing necessary vars to compute

#### Nitrogen deposition (TgN/yr)

#### Denitrification (TgN/yr)

In [8]:
import datetime
def _debug_print(message):
    print(f'{str(datetime.datetime.now())}: {message}')

def resample_ann(ds):
    """compute the annual mean of a DataSet"""

    ds = ds.copy()

    # compute temporal weights using time_bound attr
    assert 'bounds' in ds.time.attrs, 'missing "bounds" attr on time'
    tb_name = ds.time.attrs['bounds']
    dim = ds[tb_name].dims[-1]
    ds['time'] = ds[tb_name].compute().mean(dim).squeeze()

    # compute weigths from diff of time_bound
    weights = ds[tb_name].compute().diff(dim).squeeze()
    weights = weights.groupby('time.year') / weights.groupby('time.year').sum()

    # ensure they all add to one
    # TODO: build support for situations when they don't, i.e. define min coverage threshold
    nyr = len(weights.groupby('time.year'))
    np.testing.assert_allclose(weights.groupby('time.year').sum().values, np.ones(nyr))

    # ascertain which variables have time and which don't
    time_vars = [v for v in ds.data_vars if 'time' in ds[v].dims and v != tb_name]
    other_vars = list(set(ds.variables) - set(time_vars) - {tb_name, 'time'} )

    # compute
    with xr.set_options(keep_attrs=True):
        return xr.merge((
            ds[other_vars],
            (ds[time_vars] * weights).groupby('time.year').sum(dim='time'),
        )).rename({'year': 'time'})

def _compute_global_average_and_resample(dataset, exp, variable, integral_units):
    unit_key = 'volume' if any(zdim in dataset[variable].dims for zdim in ['z_t_100m', 'z_t_150m', 'z_t']) else 'area'

    # 1) Compute global averages
    _debug_print('Calling esmlab.weighted_sum')
    wgts = dataset['TAREA']
    dims = ['nlat', 'nlon']
    if 'z_t_100m' in dataset[variable].dims:
        wgts = wgts * dataset['dz'].isel(z_t=slice(0,10))
        wgts = wgts.rename({'z_t' : 'z_t_100m'})
        dims.append('z_t_100m')
    elif 'z_t_150m' in dataset[variable].dims:
        wgts = wgts * dataset['dz'].isel(z_t=slice(0,15))
        wgts = wgts.rename({'z_t' : 'z_t_150m'})
        dims.append('z_t_150m')
    elif 'z_t' in dataset[variable].dims:
        wgts = wgts * dataset['dz']
        dims.append('z_t')
    normalize = (variable == 'O2')
    if not normalize:
        glb_avg = esmlab.weighted_sum(dataset[variable], dim=dims, weights=wgts).to_dataset(name=variable)
    else:
        glb_avg = esmlab.weighted_mean(dataset[variable], dim=dims, weights=wgts).to_dataset(name=variable)

    # 2) Resample to annual means
    _debug_print('Calling esmlab.resample')
    print(f'   ... computing for {exp}')
    glb_avg['time_bound'] = dataset['time_bound']
#     ann_avg = esmlab.resample(glb_avg, freq='ann').compute()
    ann_avg = resample_ann(glb_avg)
    
    # store some unit metadata
    _debug_print('Determining units and returning')
    if not normalize:
        new_units = (units[dataset[variable].attrs['units']] * integral_units[unit_key][exp]).units
    else:
        new_units = dataset[variable].attrs['units']
    ann_avg[variable].attrs['units'] = str(new_units)
    return ann_avg

def _compute_global_time_series_single_exp(catalog, exp, variable, integral_units):

    _debug_print('Getting data via intake-esm')
    if variable == 'O2_under_thres':
        o2_thres = [5, 20, 60, 80]
        dataset = get_var_from_catalog(catalog, 'O2', exp)
        tmp_ann_avg = []
        for threshold in o2_thres:
            dataset[variable] = xr.where(dataset['O2'] < threshold, 1, 0)
            dataset[variable].attrs['units'] = ''
            tmp_ann_avg.append(_compute_global_average_and_resample(dataset, exp, variable, integral_units))
        ann_avg = xr.concat(tmp_ann_avg, dim='o2_thres')
        ann_avg['o2_thres'] = o2_thres
    else:
        dataset = get_var_from_catalog(catalog, variable, exp)
        ann_avg = _compute_global_average_and_resample(dataset, exp, variable, integral_units)

    return ann_avg

def compute_global_time_series(integral_units, variable, experiments, catalogs):
    ann_avg = dict()

    print(f'Computing global average of {variable}...')
    for model_version in experiments:
        for exp in experiments[model_version]:
            # Compute global average (use xpersist to read from disk if available)
            if (variable in ['SiO2_FLUX_100m', 'NHx_SURFACE_EMIS', 'SedDenitrif', 'ponToSed', 'NO3_RIV_FLUX', 'DON_RIV_FLUX', 'DONr_RIV_FLUX']) and (model_version == 'cesm1'):
                continue
            xp_func = xp.persist_ds(_compute_global_time_series_single_exp, name=f'{xp_dir}/{exp}_{variable}', trust_cache=True)
            ann_avg[exp] = xp_func(catalogs[model_version], exp, variable, integral_units)

            print('')
    return ann_avg

In [11]:
%%time

ann_avg = dict()

for variable in vars:
    # I think this is ~60 minutes for 3D vars and 45 min for 2D vars?
    # (when using all 10 datasets)
    ann_avg[variable] = compute_global_time_series(integral_units, variable, experiments, catalogs)

Computing global average of photoC_TOT_zint_100m...
assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm1_PI_photoC_TOT_zint_100m.nc

assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm1_PI_esm_photoC_TOT_zint_100m.nc

assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm1_hist_photoC_TOT_zint_100m.nc

assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm1_hist_esm_photoC_TOT_zint_100m.nc

assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm1_RCP85_photoC_TOT_zint_100m.nc

assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm2_PI_photoC_TOT_zint_100m.nc

assuming cache is correct
r


Done reading DENITRIF from cesm2_hist in 13.7s

2021-05-25 17:22:03.025551: Calling esmlab.weighted_sum
2021-05-25 17:22:03.332744: Calling esmlab.resample
   ... computing for cesm2_hist
2021-05-25 17:22:06.919938: Determining units and returning
writing cache file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm2_hist_DENITRIF.nc

assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm2_SSP1-2.6_DENITRIF.nc

assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm2_SSP2-4.5_DENITRIF.nc

assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm2_SSP3-7.0_DENITRIF.nc

assuming cache is correct
reading cached file: /glade/p/cgd/oce/projects/cesm2-marbl/xpersist_cache/no_marginal_seas/cesm2_SSP5-8.5_DENITRIF.nc

Computing global average of SedDenitrif...
assuming cache