# Compute dimension reductions on CESM-LE

In [1]:
%matplotlib inline
import os
import shutil

from glob import glob

import cftime

import numpy as np
import xarray as xr

import matplotlib.pyplot as plt

import cartopy.crs as ccrs
from cartopy.util import add_cyclic_point
import cmocean

import intake
import pop_tools

import util

Cannot write to data cache '/glade/p/cesmdata/cseg'. Will not be able to download remote data files. Use environment variable 'CESMDATAROOT' to specify another directory.


## Notebook parameters

In [2]:
catalog_file = './data/glade-cesm1-le.json'
variables = ['SST', 'Chl_surf', 'IFRAC', 'KGP']
experiments = ['20C', 'RCP85']
stream = 'pop.h'
component = 'ocn'
require_ocn_bgc = True

mask_name = 'krill-ToE'

chunks = {'time': 60}

## Spin up dask cluster

In [3]:
from ncar_jobqueue import NCARCluster
cluster = NCARCluster()
cluster.scale(36)
cluster

Port 8787 is already in use. 
Perhaps you already have a cluster running?
Hosting the diagnostics dashboard on a random port instead.


VBox(children=(HTML(value='<h2>NCARCluster</h2>'), HBox(children=(HTML(value='\n<div>\n  <style scoped>\n    .…

In [4]:
from dask.distributed import Client
client = Client(cluster) # Connect this local process to remote workers
client

0,1
Client  Scheduler: tcp://128.117.181.211:42118  Dashboard: https://jupyterhub.ucar.edu/dav/user/mclong/proxy/34256/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


tornado.application - ERROR - Uncaught exception GET /status/ws (::1)
HTTPServerRequest(protocol='http', host='jupyterhub.ucar.edu', method='GET', uri='/status/ws', version='HTTP/1.1', remote_ip='::1')
Traceback (most recent call last):
  File "/glade/work/mclong/miniconda3/envs/funnel/lib/python3.7/site-packages/tornado/websocket.py", line 956, in _accept_connection
    open_result = handler.open(*handler.open_args, **handler.open_kwargs)
  File "/glade/work/mclong/miniconda3/envs/funnel/lib/python3.7/site-packages/bokeh/server/views/ws.py", line 123, in open
    raise ProtocolError("Subprotocol header is not 'bokeh'")
bokeh.protocol.exceptions.ProtocolError: Subprotocol header is not 'bokeh'
tornado.application - ERROR - Uncaught exception GET /status/ws (::1)
HTTPServerRequest(protocol='http', host='jupyterhub.ucar.edu', method='GET', uri='/status/ws', version='HTTP/1.1', remote_ip='::1')
Traceback (most recent call last):
  File "/glade/work/mclong/miniconda3/envs/funnel/lib/python

## Get a region mask 

In [6]:
grid_name = 'POP_gx1v6'

nb_parameters = dict(
    mask_name=mask_name,
    grid_name=grid_name,
)

# call _pop_region_mask.ipynb(**parameters)
# TODO: make this a return value
zarr_name = f'./data/region-mask-{grid_name}-{mask_name}.zarr'

masked_area = xr.open_zarr(zarr_name)
masked_area

## Read the CESM-LE data 

We will use [`intake-esm`](https://intake-esm.readthedocs.io/en/latest/), which is a data catalog tool.
It enables querying a database for the files we want, then loading those directly as an `xarray.Dataset`.

First step is to set the "collection" for the CESM-LE, which depends on a json file conforming to the [ESM Catalog Specification](https://github.com/NCAR/esm-collection-spec).

In [7]:
col = intake.open_esm_datastore(catalog_file)
col

glade-cesm1-le-ESM Collection with 191066 entries:
	> 7 experiment(s)

	> 108 case(s)

	> 6 component(s)

	> 15 stream(s)

	> 1052 variable(s)

	> 116 date_range(s)

	> 40 member_id(s)

	> 191066 path(s)

	> 6 ctrl_branch_year(s)

	> 4 ctrl_experiment(s)

	> 41 ctrl_member_id(s)

### Select ensemble members

Ocean biogeochemistry was corrupted in some members and the data deleted. 

In [8]:
if require_ocn_bgc:
    query = dict(
        experiment=['20C'],
        variable=['diatChl'],
    )
else:
    query = dict(
        experiment=['20C'],
    )
col_sub = col.search(**query)

member_id = list(col_sub.df.member_id.unique())    
print(member_id)

[1, 2, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 101, 102, 103, 104, 105]


### Process variable list

Define functions for derived variables.

In [9]:
def compute_chl_surf(ds):
    """compute surface chl"""

    ds['Chl_surf'] = (ds.diatChl + ds.spChl + ds.diazChl).isel(z_t_150m=0)
    ds.Chl_surf.attrs = ds.diatChl.attrs
    ds.Chl_surf.attrs['long_name'] = 'Surface chlorophyll'

    for v in ['diatChl', 'spChl', 'diazChl']:
        if v not in variables:
            ds = ds.drop(v)
            
    return ds


def compute_kgp(ds):
    """Compute Krill Growth Potential 
    
    Natural growth rates in Antarctic krill (Euphausia superba): II. Predictive 
    models based on food, temperature, body length, sex, and maturity 
    stage doi: 10.4319/lo.2006.51.2.0973 
    A Atkinson, RS Shreeve, AG Hirst, P Rothery, GA Tarling 
    Limnol Oceanogr, 2006 
    
    Oceanic circumpolar habitats of Antarctic krill 
    doi: 10.3354/meps07498 
    A Atkinson, V Siegel, EA Pakhomov, P Rothery, V Loeb 
    Mar Ecol Prog Ser, 2008
    
    """
    
    # specify coefs
    a = -0.066
    b = 0.002
    c = -0.000061
    d = 0.385
    e = 0.328
    f = 0.0078
    g = 0.0101
    
    # local pointers
    sst = ds.SST
    chl = ds.Chl_surf
    
    # mask chl with lower bound
    chl = chl.where(chl >= 0.5).fillna(0.)
        
    # length coordinate
    length = xr.DataArray(
        [20., 40., 60.], 
        name='length',
        dims=('length'), 
        attrs={
            'units': 'mm', 
            'long_name': 
            'Krill body length'
        }
    )

    # compute terms and sum
    length_term = a + (b * length) + (c * length**2)
    chl_term = d * chl / (e + chl)    
    sst_term = f * sst + g * sst**2    
    kgp = length_term + chl_term + sst_term
    kgp.name = 'KGP'
    
    # mask based on SST range
    kgp = kgp.where((-1. <= sst) & (sst <= 5.)).fillna(0.).where(ds.KMT > 0)
    
    # add coordinates
    kgp = kgp.assign_coords({'length': length})
    kgp = kgp.assign_coords({'TLONG': ds.TLONG, 'TLAT': ds.TLAT})

    # add attrs
    kgp.attrs = {'units': 'mm d$^{-1}$', 'long_name': 'Daily growth rate'}
    ds['KGP'] = kgp

    return ds

Categorize requested variables as derived or directly written out by the model.

In [10]:
defined_model_variables = list(col.df.variable.unique())
print(f'{len(defined_model_variables)} variables in catalog')

1052 variables in catalog


In [11]:
defined_derived_variables = {
    'Chl_surf': {
        'dependencies': ['diatChl', 'spChl', 'diazChl'],
        'function': compute_chl_surf,
    },
    'KGP': {
        'dependencies': ['SST', 'Chl_surf'],
        'function': compute_kgp,

    }
}

def manage_var_dep(var_list):
    """Determine if a variable is written directly by 
       the model (and therefore just read-in)
       or derived from dependencies.
    """
    query_variables = []    
    derived_variables = []
    for v in var_list:    

        if v in defined_model_variables:
            query_variables.append(v)
            
        elif v in defined_derived_variables:
            q, d = manage_var_dep(
                defined_derived_variables[v]['dependencies']
            )
            derived_variables.append(v)
            query_variables.extend(q)
        else:    
            raise ValueError(f'unknown variable {v}')    

    return (
        sorted(list(set(query_variables))), 
        sorted(list(set(derived_variables)))
    )

query_variables, derived_variables = manage_var_dep(variables)
print(f'query_variables: {query_variables}')
print(f'derived_variables: {derived_variables}')

query_variables: ['IFRAC', 'SST', 'diatChl', 'diazChl', 'spChl']
derived_variables: ['Chl_surf', 'KGP']


### Search for the queried data

Specify a list of variables and perform a search. Under the hood, the `search` functionality uses [`pandas`](https://pandas.pydata.org/) data frames. We can view that frame here using the `.df` syntax.

In [12]:
col_sub = col.search(
    experiment=experiments, 
    stream=stream, 
    variable=query_variables,
    member_id=member_id,
    )

print(col_sub)

col_sub.df.head()

glade-cesm1-le-ESM Collection with 475 entries:
	> 2 experiment(s)

	> 68 case(s)

	> 1 component(s)

	> 1 stream(s)

	> 5 variable(s)

	> 5 date_range(s)

	> 34 member_id(s)

	> 475 path(s)

	> 3 ctrl_branch_year(s)

	> 2 ctrl_experiment(s)

	> 34 ctrl_member_id(s)



Unnamed: 0,experiment,case,component,stream,variable,date_range,member_id,path,ctrl_branch_year,ctrl_experiment,ctrl_member_id
0,20C,b.e11.B20TRC5CNBDRD.f09_g16.001,ocn,pop.h,IFRAC,185001-200512,1,/glade/campaign/cesm/collections/cesmLE/CESM-C...,402,CTRL,1
1,20C,b.e11.B20TRC5CNBDRD.f09_g16.002,ocn,pop.h,IFRAC,192001-200512,2,/glade/campaign/cesm/collections/cesmLE/CESM-C...,1920,20C,1
2,20C,b.e11.B20TRC5CNBDRD.f09_g16.009,ocn,pop.h,IFRAC,192001-200512,9,/glade/campaign/cesm/collections/cesmLE/CESM-C...,1920,20C,1
3,20C,b.e11.B20TRC5CNBDRD.f09_g16.010,ocn,pop.h,IFRAC,192001-200512,10,/glade/campaign/cesm/collections/cesmLE/CESM-C...,1920,20C,1
4,20C,b.e11.B20TRC5CNBDRD.f09_g16.011,ocn,pop.h,IFRAC,192001-200512,11,/glade/campaign/cesm/collections/cesmLE/CESM-C...,1920,20C,1


Now use the [`to_dataset_dict`](https://intake-esm.readthedocs.io/en/latest/api.html#intake_esm.core.esm_datastore.to_dataset_dict) method to return a dictionary of `xarray.Dataset`'s. 

`intake_esm` makes groups of these according to rules in the collection spec file.

We can use the `preprocess` parameter to pass in a function that makes some corrections to the dataset. So first we define a function that does the following:
- drop the singleton dimension on SST (which screws up coordinate alignment)
- make the 2D grid vars coordinates
- [add your correction here]

In [13]:
def preprocess(ds):
    """Fix some things in the dataset and subset in space
       This is POP-centric.       
    """   
    
    # TODO: make this component model agnostic
    
    grid_vars = ['KMT', 'TAREA', 'TLAT', 'TLONG', 'z_t', 'dz', 'z_t_150m', 'time', 'time_bound']
    
    if 'SST' in ds:
        ds['SST'] = ds.SST.isel(z_t=0, drop=True)
    return ds
             
    data_vars = list(
        filter(
            lambda v: v in query_variables, 
            ds.data_vars
        )
    )
    
    ds = ds[data_vars+grid_vars]  
    
    # TODO: this could be a good place to apply domain subsetting
    
    # set grid variables to coordinates to ease concatenation in intake-esm
    new_coords = set(grid_vars) - set(ds.coords)

    return ds.set_coords(new_coords)

In [None]:
%%time
dset_orig = col_sub.to_dataset_dict(
    cdf_kwargs={
        'chunks': chunks, 
        'decode_times': False},
    preprocess=preprocess
)
dsets_orig


--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.stream'
                
--> There is/are 2 group(s)


tornado.application - ERROR - Uncaught exception GET /status/ws (::1)
HTTPServerRequest(protocol='http', host='jupyterhub.ucar.edu', method='GET', uri='/status/ws', version='HTTP/1.1', remote_ip='::1')
Traceback (most recent call last):
  File "/glade/work/mclong/miniconda3/envs/funnel/lib/python3.7/site-packages/tornado/websocket.py", line 956, in _accept_connection
    open_result = handler.open(*handler.open_args, **handler.open_kwargs)
  File "/glade/work/mclong/miniconda3/envs/funnel/lib/python3.7/site-packages/bokeh/server/views/ws.py", line 123, in open
    raise ProtocolError("Subprotocol header is not 'bokeh'")
bokeh.protocol.exceptions.ProtocolError: Subprotocol header is not 'bokeh'


Now, apply post-query corrections.

In [None]:
def fix_time(ds):
    ds = ds.copy(deep=True)
    ds['time'] = xr.DataArray(cftime.num2date(ds.time_bound.mean(dim='d2'), units=ds.time.units, 
                                              calendar=ds.time.calendar), dims=('time'))
    return ds 

dsets = {key: fix_time(ds) for key, ds in dsets_orig.items()}

### Compute derived variables

In [None]:
if derived_variables:
    for v in derived_variables:
        func = defined_derived_variables[v]['function']
        dsets = {key: func(ds) for key, ds in dsets.items()}

dsets

Concatenate the datasets in time, i.e. 20C + RCP8.5 experiments.

In [None]:
# TODO: Generalize this, we need the notion of "concatenatable groups"

ordered_dsets_keys = ['ocn.20C.pop.h', 'ocn.RCP85.pop.h']

ds = xr.concat(
    [dsets[exp] for exp in ordered_dsets_keys], 
    dim='time', 
    data_vars='minimal'
)
ds

## Compute summer time (DJF) means 

In [None]:
%%time
ds_djf = util.ann_mean(ds, season='DJF', time_bnds_varname='time_bound')
ds_djf

## Compute regional means

In [None]:
%%time
dim = ['nlat', 'nlon']
area_total = masked_area.sum(dim)
weights = masked_area / area_total
weights_sum = weights.sum(dim)

# ensure that the weights add to 1.
np.testing.assert_allclose(weights_sum.where(weights_sum != 0.).fillna(1.), 1.0, rtol=1e-7)

with xr.set_options(keep_attrs=True):
    ds_djf_regional = (ds_djf[list(ds_djf.data_vars)] * weights).sum(dim).compute()
ds_djf_regional

Quick look plots for each region

In [None]:
for plot_region in masked_area.region.values:

    nvar = len(ds_djf_regional.data_vars)
    ncol = int(np.sqrt(nvar))
    nrow = int(nvar/ncol) + min(1, nvar%ncol)

    fig, ax = plt.subplots(nrow, ncol, figsize=(4*ncol, 3*nrow),
                           constrained_layout=True)

    for i, v in enumerate(ds_djf_regional.data_vars):
        plt.axes(ax.ravel()[i])

        var = ds_djf_regional[v].sel(region=plot_region)
        if 'length' in var.dims:
            var = var.sel(length=40.)            
        for m_id in ds_djf_regional.member_id:
            var_i = var.sel(member_id=m_id)
            var_i.plot(linewidth=0.5)

        with xr.set_options(keep_attrs=True):            
            var.mean('member_id').plot(color='k', linewidth=1)
        plt.title(v)
    plt.suptitle(plot_region, fontsize=16, fontweight='bold')

## Compute temporal means

In [None]:
%%time

# TODO: this could be parameterized in the notebook
with xr.set_options(keep_attrs=True):   
    epoch_list = [
        ('1920-1950', ds_djf.sel(time=slice('1920', '1950')).mean('time')),
        ('2070-2100', ds_djf.sel(time=slice('2070', '2100')).mean('time')),
    ]
     
epoch = xr.DataArray(
    [t[0] for t in epoch_list],
    dims=('epoch'),
    name='epoch',
)    

ds_djf_epoch = xr.concat(
    [t[1] for t in epoch_list], 
    dim=epoch,
)
ds_djf_epoch = ds_djf_epoch.compute()
ds_djf_epoch

In [None]:
for epoch in ds_djf_epoch.epoch.values:

    nvar = len(ds_djf_epoch.data_vars)
    ncol = int(np.sqrt(nvar))
    nrow = int(nvar/ncol) + min(1, nvar%ncol)

    fig, ax = plt.subplots(nrow, ncol, figsize=(4*ncol, 3*nrow),
                           constrained_layout=True)

    for i, v in enumerate(ds_djf_epoch.data_vars):
        plt.axes(ax.ravel()[i])

        with xr.set_options(keep_attrs=True):             
            var = ds_djf_epoch[v].sel(epoch=epoch).mean('member_id')
        if 'length' in var.dims:
            var = var.sel(length=40.)            
        var.plot()
        plt.title(v);
    plt.suptitle(epoch, fontsize=16, fontweight='bold');

In [None]:
with xr.set_options(keep_attrs=True):  
    ds_djf_epoch_diff = ds_djf_epoch.diff('epoch').squeeze('epoch')
ds_djf_epoch_diff

In [None]:
nvar = len(ds_djf_epoch_diff.data_vars)
ncol = int(np.sqrt(nvar))
nrow = int(nvar/ncol) + min(1, nvar%ncol)

fig, ax = plt.subplots(nrow, ncol, figsize=(4*ncol, 3*nrow),
                       constrained_layout=True)

for i, v in enumerate(ds_djf_epoch_diff.data_vars):
    plt.axes(ax.ravel()[i])

    with xr.set_options(keep_attrs=True):             
        var = ds_djf_epoch_diff[v].mean('member_id')
    if 'length' in var.dims:
        var = var.sel(length=40.)            
    var.plot()
    plt.title(v)

epoch = ds_djf_epoch.epoch.data
plt.suptitle(f'({epoch[1]}) - ({epoch[0]})', fontsize=16, fontweight='bold');

In [None]:
dso_map = {
    'data/cesm-le-fields-djf-regional-timeseries.zarr': ds_djf_regional, 
    'data/cesm-le-fields-djf-epoch-mean.zarr': ds_djf_epoch,  
}
for file_out, dso in dso_map.items():
    util.write_ds_out(dso, file_out)