# Compute seasonal net outgassing of O2 and APO

Start with [autoreload](https://ipython.org/ipython-doc/3/config/extensions/autoreload.html) magic; this reloads modules automatically before entering the execution of code and thus enabled development in modules like [util.py](util.py).

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import numpy as np
import xarray as xr
import pandas as pd

import matplotlib.pyplot as plt

import util

import intake

## Connect to catalog

This notebook uses an [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) to describe file locations. This is the basis of [intake-esm](https://intake-esm.readthedocs.io/en/latest/), though here we are not using `intake-esm` directly.

In [None]:
cat = intake.open_esm_datastore(util.catalog_json)
cat

In [None]:
%%time
df = pd.read_csv(util.catalog_csv)
df

## Specify a subset of models

In [None]:
models = [
    'CanESM5', ## gives non-monotonic coord error for fgco2 in combine_by_coords below 
    #'CanESM5-CanOE', ## no fgo2 (somehow PM was plotting)
    'CNRM-ESM2-1', 
    'ACCESS-ESM1-5', 
    'MPI-ESM-1-2-HAM', 
    'IPSL-CM6A-LR',
    'MPI-M.MPI-ESM1-2-HR', 
    'MPI-ESM1-2-LR', 
    'NorCPM1', 
    'NorESM2-LM', 
    'UKESM1-0-LL',
    'MIROC-ES2L',
    #'MRI-ESM2-0', ## missing intpp
]
## others on ESGF showing historical+Omon+fgo2 = EC-Earth3-CC, GFDL-CM4, GFDL-ESM4, IPSL-CM5A2-INCA, IPSL-CM6A-LR-INCA, NorESM2-MM (but of the above ESGF not showing CNRM-ESM2-1, UKESM1-0-LL, MIROC-ES2L, MRI-ESM2-0)


### Test read single model

In [None]:
source_id = 'UKESM1-0-LL'
variable_id = ['fgco2', 'fgo2']
time_slice = slice('2005', '2014')
experiment_id = 'historical' 
nmax_members = 4

dsi = util.open_cmip_dataset(
    source_id=source_id, 
    variable_id=variable_id, 
    experiment_id=experiment_id, 
    time_slice=time_slice, 
    table_id='Omon',
    nmax_members=nmax_members,
)
dsi

## Get grid data from each model

Skip models where the grid data is not available

In [None]:
dsets_fix = {}
grid_variables = ['areacello',]
for model in models:
    dsets_m = [util.get_gridvar(df, model, v) for v in grid_variables]
    dsets_m = [ds for ds in dsets_m if ds is not None]
    if dsets_m:
        dsets_fix[model] = xr.merge(dsets_m)
        dsets_fix[model].attrs['source_id'] = model
        
list(dsets_fix.keys())
## somehow PM was getting areacello for MPI-M.MPI-ESM1-2-HR from Ofx - I tried specifying table_id as Ofx here but that did not help

## Compute a region mask for integration

In [None]:
rmask_definition = 'SET_NET'
#rmask_definition = 'global' ### grid-cell area is maxing at 20 degrees in each hemisphere

rmask_dict = {}
for model in models:    
    if model not in dsets_fix:
        continue
    rmask_dict[model] = util.get_rmask_dict(
        dsets_fix[model], 
        mask_definition=rmask_definition, 
        plot=True
    )    

## Assemble monthly-mean climatology

This code takes the following steps:
- Read a dataset for each model
- Compute the regional integral 
- Compute the mean for each month and average across ensemble members
- Concatenate the resulting timeseries along a `source_id` dimension

Note that the code is set up to cache the resulting dataset; it will optionally read this dataset, rather than recreate it, if it exists.

In [None]:
%%time

#variable_ids = ['fgco2', 'fgo2'] # , 'intpp', 'fgn2:heatflux,sst,sss']
variable_ids = ['intpp']
#variable_ids = ['fgn2:tos,sos']
    
time_slice = slice('2005', '2014') ## for comparison to HIPPO/ORCAS/ATom 2009-2018, pick closest decade
experiment_id = 'historical' 
nmax_members = 4
clobber = True

# specify models for each variable that have reverse sign convention
models_flipsign = {v: [] for v in variable_ids}
models_flipsign['fgo2'] = ['NorESM2-LM',]


ds_list = []
source_id_list = []
for source_id in models:    
    if source_id not in rmask_dict:
        continue
        
    ds_list_variable_ids = []
    for variable_name in variable_ids:
        
        variable_id = variable_name
        derived_var = variable_name
        if ':' in variable_name:
            variable_id = variable_name.split(':')[-1].split(',')
            derived_var = variable_name.split(':')[0]        
                
        cache_file = (
            f'data/cmip6'
            f'.{source_id}'
            f'.{experiment_id}'
            f'.{derived_var}'
            f'.{rmask_definition}'
            f'.monclim_{time_slice.start}-{time_slice.stop}.zarr'
        )
        print(cache_file)
        if os.path.exists(cache_file) and not clobber:
            ds = xr.open_zarr(cache_file)

        else:
        
            dsi = util.open_cmip_dataset(
                source_id=source_id, 
                variable_id=variable_id, 
                experiment_id=experiment_id, 
                table_id='Omon',
                time_slice=time_slice, 
                nmax_members=nmax_members,
            )
            if dsi is None:
                print(f'missing: {source_id}, {experiment_id}, {variable_id}')
                continue
            ### put code below in an elif to allow some models to be missing some variables? as now, crashes if any missing

            # compute derived variables
            ### need to read in source fields before trying this:
            if derived_var == 'fgn2':
                print(dsi)
                dsi = util.compute_fgn2(dsi)
            
            elif derived_var == 'fgo2_thermal':
                dsi = util.compute_fgo2_thermal(dsi)
                
            # compute the regional integrals
            flipsign = True if source_id in models_flipsign[variable_id] else False            
            da = util.compute_regional_integral(
                ds=dsi, 
                variable_id=variable_id,
                rmasks=rmask_dict[source_id],
                flipsign=flipsign,
            )    
            
            with xr.set_options(keep_attrs=True):
                da = da.groupby('time.month').mean().mean('member_id')
            
            try:
                ds = da.to_dataset().drop(['depth']).compute()
            except:
                print('Depth does not exist in dataset')
                
            ds.to_zarr(cache_file, mode='w');            
            
        ds_list_variable_ids.append(ds)
    
    if ds_list_variable_ids:
        source_id_list.append(source_id)

    # merge across variables
    if ds_list_variable_ids:
        ds_list.append(xr.merge(ds_list_variable_ids,))

ds = xr.concat(ds_list, dim=xr.DataArray(source_id_list, dims=('source_id'), name='source_id'))    
ds

## Make some plots

In [None]:
monlabs = np.array(["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"])

In [None]:
field = 'fgo2'

fig, axs = plt.subplots(2, 1, figsize=(6, 7), facecolor='w')
    
for region, ax in zip(ds.region.values, axs.ravel()):
    for source_id in ds.source_id.values:
        ax.plot(
            ds.month-0.5, 
            ds[field].sel(source_id=source_id, region=region), 
            marker='.', 
            linestyle='-',
            label=source_id,
        )

    ax.set_xticks(np.arange(13))    
    ax.set_ylabel(f"{ds[field].attrs['long_name']} [{ds[field].attrs['units']}]")
    ax.set_title(region);
    ax.set_xticklabels([])
ax.set_xticklabels([f'        {m}' for m in monlabs]+[''])
ax.legend(loc=(1.02, 0));

In [None]:
field = 'fgco2'

fig, axs = plt.subplots(2, 1, figsize=(6, 7), facecolor='w')
    
monlabs = np.array(["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"])

for region, ax in zip(ds.region.values, axs.ravel()):
    for source_id in ds.source_id.values:
        ax.plot(
            ds.month-0.5, 
            ds[field].sel(source_id=source_id, region=region), 
            marker='.', 
            linestyle='-',
            label=source_id,
        )

    ax.set_xticks(np.arange(13))    
    ax.set_ylabel(f"{ds[field].attrs['long_name']} [{ds[field].attrs['units']}]")
    ax.set_title(region);
    ax.set_xticklabels([])
    
ax.set_xticklabels([f'        {m}' for m in monlabs]+[''])
ax.legend(loc=(1.02, 0));