# Compute a global mean, annual mean timeseries from the CESM Large Ensemble

In [None]:
%matplotlib inline
import os
import socket

from tqdm import tqdm

import dask
import dask.distributed
import ncar_jobqueue

import xarray as xr
import numpy as np
import esmlab

import intake
import intake_esm

import matplotlib.pyplot as plt

## Connect to the `intake-esm` data catalog

An input file `cesm1-le-collection.yml` specifies where to look for files and assembles a database for the CESM-LE. `intake-esm` configuration settings are stored by default in ~/.intake_esm/config.yaml or locally in .intake_esm/config.yaml.  Key things to specify are the `database_directory`, which is where the catalog data file (csv) is written to disk.

In [None]:
col = intake.open_esm_metadatastore(
    collection_input_definition='cesm1-le-collection.yml',
    overwrite_existing=False)
col.df.info()

## Compute grid weights for a global mean

### Load a dataset and read in the grid variables
To compute a properly-weighted spatial mean, we need a cell-volume array. We'll pick out the necessary grid variables from a single file. First, let's get an arbitrary POP history file from the catalog.

In [None]:
arbitrary_pop_file = col.search(experiment='20C', stream='pop.h').query_results.file_fullpath.tolist()[0]
ds = xr.open_dataset(arbitrary_pop_file, decode_times=False, decode_coords=False)
grid_vars = ['KMT', 'z_t', 'TAREA', 'dz']
ds = ds.drop([v for v in ds.variables if v not in grid_vars]).compute()
ds

### Compute a 3D topography mask
Now we'll compute the 3D volume field, masked appropriate by the topography.

First step is to create the land mask.

In [None]:
nk = len(ds.z_t)
nj = ds.KMT.shape[0]
ni = ds.KMT.shape[1]

# make 3D array of 0:km
k_vector_one_to_km = xr.DataArray(np.arange(0, nk), dims=('z_t'), coords={'z_t': ds.z_t})
ONES_3d = xr.DataArray(np.ones((nk, nj, ni)), dims=('z_t', 'nlat', 'nlon'), coords={'z_t': ds.z_t})
MASK = (k_vector_one_to_km * ONES_3d)

# mask out cells where k is below KMT
MASK = MASK.where(MASK <= ds.KMT - 1)
MASK = xr.where(MASK.notnull(), 1., 0.)

plt.figure()
MASK.isel(z_t=0).plot()
plt.title('Surface mask')

plt.figure()
MASK.isel(nlon=200).plot(yincrease=False)
plt.title('Pacific transect')

### Compute the 3D volume field

Now we'll compute the masked volume field by multiplying `z_t` by `TAREA` by the mask created above.

In [None]:
MASKED_VOL = ds.dz * ds.TAREA * MASK
MASKED_VOL.attrs['units'] = 'cm^3'
MASKED_VOL.attrs['long_name'] = 'masked volume'
plt.figure()
MASKED_VOL.isel(z_t=0).plot()
plt.title('Surface mask')

plt.figure()
MASKED_VOL.isel(nlon=200).plot(yincrease=False)
plt.title('Pacific transect')

## Compute global-mean, annual-means across the ensemble

### Find the ensemble members that have ocean biogeochemistry 
(several of the runs had corrupted BGC fields)

In [None]:
member_ids = col.search(experiment=['20C', 'RCP85'], has_ocean_bgc=True).query_results.ensemble.unique().tolist()
print(member_ids)

### Spin up a dask cluster

We are using `ncar_jobqueue.NCARCluster`; this just passes thru to `dask_jobqueue.PBSCluster` or `dask_jobqueue.SLURMCluster` depending on whether you are on Cheyenne or a DAV machine. 

**Note**: `dask_jobqueue.SLURMCluster` does not work on Cheyenne compute nodes, though the cluster jobs will start giving the appearance of functionality.

Default arguments to `ncar_jobqueue.NCARCluster` are set in `~/.config/dask/jobqueue.yaml`; you can over-ride these defaults by passing in arguments directly here.

In [None]:
cluster = ncar_jobqueue.NCARCluster(walltime="00:20:00", cores=36, memory='109GB', processes=9)
client = dask.distributed.Client(cluster)
n_workers = 9 * 10
cluster.scale(n_workers)

After the worker jobs have started, it's possible to view the client attributes.

In [None]:
!qstat -u $USER

Paste the dashboard link into the `DASK DASHBOARD URL` in the `dask-labextension` at right, replacing the part that looks sort of IP-adress-ish with the URL in your browser, excluding the `/lab...` part.

In [None]:
client

### Compute 

We'll loop over the ensemble and compute one at a time. In theory it should be possible to compute all at once, but in practice this doesn't seem to work.

In [None]:
variable = ['O2']
query = dict(ensemble=member_ids, experiment=['20C', 'RCP85'], 
                 stream='pop.h', variable=variable, direct_access=True)


In [None]:
col_subset = col.search(**query)

In [None]:
col_subset.query_results.info()

In [None]:
%time ds = col_subset.to_xarray(decode_times=False, chunks={'time': 10})
ds

In [None]:
len(ds.keys())

In [None]:
ds.keys()

In [None]:
ds['pop.h.ocn.20C']

In [None]:
ds_1 = ds['pop.h.ocn.RCP85'].copy()

In [None]:
%time dso = esmlab.climatology.compute_ann_mean(ds_1)

In [None]:
dso

In [None]:
%time dso = esmlab.statistics.weighted_mean(dso, weights=MASKED_VOL, dim=['z_t', 'nlat', 'nlon'])
dso

In [None]:
#cluster.close()

%%time
variable = ['O2']
dsets = []
for member_id in member_ids:
    print(f'working on ensemble member {member_id}')
    
    query = dict(ensemble=member_id, experiment=['20C', 'RCP85'], 
                 stream='pop.h', variable=variable, direct_access=True)

    col_subset = col.search(**query)

    # get a dataset
    ds = col_subset.to_xarray()

    # compute annual means
    dso = esmlab.climatology.compute_ann_mean(ds)

    # compute global average
    dso = esmlab.statistics.weighted_mean(dso, weights=MASKED_VOL, dim=['z_t', 'nlat', 'nlon'])

    # compute the dataset 
    dso = dso.compute()
    dsets.append(dso)


ensemble_dim = xr.DataArray(member_ids, dims='member_id', name='member_id')    
ds = xr.concat(dsets, dim=ensemble_dim)
ds

cluster.close()

for member_id in member_ids:
    ds.O2.sel(member_id=member_id).plot()

set(ds.coords) - set(ds.dims)