# The final plots

The following cell downloads the preprocessed data from [zenodo](https://sandbox.zenodo.org/record/812722#.YJQ5NC9h2wI).
> During the submission this is still kept in the 'zenodo sandbox' repository, but will be submitted as permanent archive before submission.

In [None]:
!cd ../ && ./scripts/download_zenodo_files.sh
!cd ../data && tar -zxf busecke_etal_2021_aguadv.tar --directory binder

In [None]:
import pathlib
%load_ext autoreload
%autoreload 2

import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'


import cmocean.cm as cmo
import cartopy.crs as ccrs


from dask.diagnostics import ProgressBar

from cmip6_preprocessing.postprocessing import concat_members

from xarrayutils.utils import xr_linregress, linear_trend, sign_agreement
from xarrayutils.filtering import filter_2D
from xarrayutils.plotting import shaded_line_plot, linear_piecewise_scale
from xgcm import Grid

from cmip6_omz.units import convert_mol_m3_mymol_kg
from cmip6_omz.omz_tools import mask_basin
# TODO: This should all be migrated upstream in the final version
from cmip6_omz.upstream_stash import zarr_exists, construct_static_dz

from busecke_etal_2021_aguadv.plotting import (
    o2_model_colors,
    model_color_legend,
    mask_multi_model,
    finish_map_plot,
    ScientificManualFormatter,
)
from busecke_etal_2021_aguadv.utils import (
    load_zarr_directory, fail_age, slope,
    trend_slice, hist_slice, regrid, member_treatment
)

from xhistogram.xarray import histogram

from fastprogress import progress_bar

import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

## Global Parmeters

In [None]:
zkwargs = {"use_cftime": True, "consolidated": True}

core_o2 = 20
outer_o2 = 120

projection = ccrs.Robinson(210)
map_kwargs = dict(x='lon', y='lat', transform=ccrs.PlateCarree())
single_figsize_y = 4
single_figsize_x = 8

pfolder = pathlib.Path('')
dfolder = pathlib.Path('../data/binder/')

## Custom Functions

In [None]:
def make_mask(ds, threshold, var='historical_o2_min_value'):
    # cut all data to a particular threshold
        # convert o2 value units
    o2 = convert_mol_m3_mymol_kg(ds[var])
    if 'time' in o2.dims:
        o2 = o2.mean('time')
    mask = (o2<threshold) # shoulf focus only on the real core
#     mask = (o2.max('time')<threshold) # shoulf focus only on the real core
    return mask

def smooth(ds):
    ds = filter_2D(ds, 2, ["y", 'x'])
    return ds

## Load Data

In [None]:
# # depth histogram data
# dfolder = '/projects/GEOCLIM/LRGROUP/jbusecke/projects_data/cmip6_depth_histogram_v2.3/'
# ddict_histogram = load_zarr_directory(dfolder)
def _mask_fail_age(ds):
    if fail_age(ds):
        ds['agessc'] = xr.ones_like(ds.agessc) * np.nan
    return ds

In [None]:
ddict_histogram_pre = load_zarr_directory(dfolder, pattern='*_summed_histogram.zarr')
# the obs need a bit of special treatment (they have missing attrs)
obs_histgram_pre = ddict_histogram_pre.pop('none.none.none.none.none.none.none.none.none')

In [None]:
# load map data
depth_pattern = '100-3000m'

ds_combo_o2 = xr.open_zarr(dfolder.joinpath(f"combined_o2_{depth_pattern}.zarr")).o2
ds_combo_epc100 = xr.open_zarr(dfolder.joinpath(f"combined_epc100_{depth_pattern}.zarr")).epc100
ds_combo_age = xr.open_zarr(dfolder.joinpath(f"combined_agessc_{depth_pattern}.zarr")).agessc

#Only for SM (just one member per model)
ds_combo_o2sat = xr.open_zarr(dfolder.joinpath(f"combined_o2sat_{depth_pattern}.zarr")).o2sat
ds_combo_aou = -xr.open_zarr(dfolder.joinpath(f"combined_aou_{depth_pattern}.zarr")).aou
ds_combo_o2residual = ds_combo_o2sat+ds_combo_aou

In [None]:
# boundary
ds_boundary_combined = xr.open_zarr(dfolder.joinpath('boundary_combined.zarr'), use_cftime=True)

## Figure 1: OMZ validation

- The masking is really cumbersome, the check for an actual local min should be built into the algo...
- TODO: Add obs to last panel for SI
- TODO: Force extend the colorbar, so that 20 corresponds to the transition between gray and red


In [None]:
obs = regrid(ds_boundary_combined.sel(model='obs_bianchi'))

for o2_threshold in [core_o2, outer_o2]:
    for title, vmin, vmax, cmap, levels in [
        ("OMZ boundary", 0, 1500, cmo.dense, 16),
        ("Minimum o2", 0, 130, cmo.oxy, 14),
#         ("Minimum o2", 0, 100, cmo.oxy, 33),
#         ("Minimum o2 Depth", 0, 1500, cmo.deep, 16),
    ]:
        p_datasets = []
        p_contour_datasets = []
        
        print(f"{title} {o2_threshold}")
        fig_final, (ax_obs, ax_mmm) = plt.subplots(
            ncols=2,
            figsize=[single_figsize_x * 2, single_figsize_y],
            subplot_kw={"projection": projection},
        )
        
        fig, axarr = plt.subplots(
            ncols=3,
            nrows=5,
            figsize=[single_figsize_x * 3, single_figsize_y * 5],
            subplot_kw={"projection": projection},
        )

        ax_counter = 0
        
        for name in ds_boundary_combined.model.data:
            print(name)
            if 'obs' in name:
                ax = ax_obs
            else:
                ax = axarr.flat[ax_counter]
                ax_counter += 1
                
            # specific to the plot

            ds = ds_boundary_combined.sel(model=name)
            averaged_members = ds.averaged_members.data
            # mask land
            ds = ds.where(ds.historical_o2_min_value < 1000)
            
            broad_o2_threshold = outer_o2+40
            # mask broadly for all plots
            ds = ds.where(make_mask(ds, broad_o2_threshold))
            # further mask for areas with near zero thickness (problem with the original algo)
            thickness = (ds.historical_lower_boundary - ds.historical_upper_boundary)
            another_mask = thickness.sel(o2_threshold=outer_o2)>1

            ds = ds.where(another_mask)
            
            map_kwargs = dict(x="lon", y="lat", transform=ccrs.PlateCarree())
            plot_kwargs = dict(
                    cmap=cmap,
                    vmin=vmin,
                    vmax=vmax,
                    levels=levels,
                    cbar_kwargs={'extend':'neither'},
                    rasterized=True,
                    **map_kwargs
                )
            
            if title == 'Minimum o2':
                p = convert_mol_m3_mymol_kg(ds['historical_o2_min_value'])

            elif title == "Minimum o2 Depth":
                p = ds['historical_o2_min_lev']
            
            elif title == 'OMZ boundary':
                ds = ds.where(make_mask(ds, o2_threshold))
                p = ds['historical_upper_boundary'].sel(o2_threshold=o2_threshold)
            else:
                raise RuntimeError('Nope')
            
            # regrid
            p = regrid(p) 
            
            # raise values near zero by a small amount to avoid problems with the colorbar
            p = p.where(np.logical_or(p>1, np.isnan(p)), 1)
            
            p = mask_basin(p)
            if not 'obs' in name:
                p_datasets.append(p)
            
            # main plot
            p.plot.contourf(ax=ax, **plot_kwargs)
            
            if title == 'Minimum o2':
                # plot obs_values
                convert_mol_m3_mymol_kg(obs.historical_o2_min_value).plot.contour(
                    ax=ax,
                    add_cbar=False,
                    add_labels=False,
                    colors="g",
                    levels=[core_o2, outer_o2],
                    linestyles=["-", "--"],
                    **map_kwargs
                )
            
            axtitle = "Observations" if "obs" in name else name
            ax.set_title(f"{axtitle} ({averaged_members.compute()} members)")
            
        # print multi model mean
        da_mmm = xr.concat(p_datasets, 'model')
        assert all(['obs' not in model for model in da_mmm.model.data])
        #filter grid points where only a few models have values
        da_mmm = da_mmm.where(mask_multi_model(da_mmm))
        da_mmm = da_mmm.median('model')
        
        # TODO Check that the smoothing is done in the same way everywhere
        da_mmm = smooth(da_mmm)
        
        da_mmm.plot.contourf(ax=ax_mmm, **plot_kwargs)
        
        ax_mmm.set_title('Multi ESM Median')
        
        # plot obs_values
        if title == 'Minimum o2':
            convert_mol_m3_mymol_kg(obs.historical_o2_min_value).plot.contour(
                ax=ax_mmm,
                add_cbar=False,
                add_labels=False,
                colors="g",
                levels=[core_o2, outer_o2],
                linestyles=["-", "--"],
                **map_kwargs
            )
            
        # Mapping extras for all axes
        for ax in list(axarr.flat)+[ax_obs, ax_mmm]:
            finish_map_plot(ax)
            
        
        title_print = f'{title} (o2_threshold:{o2_threshold})'
        fig.suptitle(title_print, y=0.95, fontsize=20)
        fig.savefig(pfolder.joinpath(f"{title_print}_SI.pdf"))
        fig.suptitle(title_print, y=0.95, fontsize=20)
        fig_final.savefig(pfolder.joinpath(f"{title_print}.pdf"))
        plt.show()

## Fig 3: Relationship to OMZ boundary

In [None]:
from xarrayutils.utils import sign_agreement
ds = ds_boundary_combined.drop_sel(model='obs_bianchi')

o2_th = [core_o2, outer_o2]
fig, axarr_full = plt.subplots(
    ncols=len(o2_th),
    nrows=2,
    subplot_kw={'projection':projection},
    figsize=[single_figsize_x*len(o2_th),single_figsize_y*2],
)

for ai, (var, vmax, levels, cmap) in enumerate(
    [
        ('trend_thickness', 50, 25, 'BrBG_r'),
        ('trend_upper_boundary', 50, 25, 'RdGy')
    ]):
    da = ds[var]
    axarr = axarr_full[:,ai]
    
    # mask out the areas that do not have many models with an OMZ there.
    nan_mask = mask_multi_model(da)
    
    m = da.median('model').where(nan_mask).load()
    m = smooth(m)
    s = sign_agreement(da, m, 'model', count_nans=False, threshold=0.6).where(nan_mask)
    # values shown where at least 10 models have an OMZ
    # areas that are not dotted have at least 7 models with that sign

    for o2, ax in zip(o2_th, axarr):
        m.sel(o2_threshold=o2).plot.contourf(ax=ax, vmax=vmax, cmap=cmap, levels=levels,rasterized=True, **map_kwargs)
        s.sel(o2_threshold=o2).plot.contourf(
            ax=ax,
            levels=[0,0.5],
            hatches=['....', None],
            colors='none',
            add_labels=False, 
            add_colorbar=False,
            **map_kwargs)
        finish_map_plot(ax)
fig.savefig(pfolder.joinpath(f"Thickness_Boundary.pdf"))

In [None]:
ds = ds_boundary_combined.drop_sel(model='obs_bianchi')

o2_th = [core_o2, outer_o2]
for oi, o2 in enumerate(o2_th):
    fig, axarr = plt.subplots(
        ncols=4,
        nrows=len(ds.model.data)//2,
        subplot_kw={'projection':projection},
        figsize=[single_figsize_x*len(o2_th)*2,single_figsize_y*len(ds.model.data)/2],
    )

    for vi, (var, vmax, levels, cmap, axarr_sub) in enumerate(
        [
        ('trend_thickness', 50, 25, 'BrBG_r', list(axarr[:,0].flat)+list(axarr[:,2].flat)),
        ('trend_upper_boundary', 50, 25, 'RdGy',list(axarr[:,1].flat)+list(axarr[:,3].flat))
        ]):
        da = ds[var]
        for mi, (ax, model) in enumerate(zip(axarr_sub, da.model.data)):
            m = da.sel(model=model)
#             m = smooth(m)
            m.sel(o2_threshold=o2).plot(ax=ax, vmax=vmax, cmap=cmap, levels=levels,rasterized=True, **map_kwargs)
            finish_map_plot(ax)
            ax.set_title(f"{model} members:{m.averaged_members.load().data}")
    plt.show()
    fig.savefig(pfolder.joinpath(f"{var}_{o2}_SI.pdf"))

TODO make SI plots of the changes in each model

In [None]:
co = 'k'
slope_dict = {}

ds_scatter = ds_boundary_combined.drop_sel(model=['obs_bianchi'])#'MRI-ESM2-0', 'CNRM-ESM2-1',
# print('TEST')
# ds_scatter = ds_boundary_combined.drop_sel(model=['CNRM-ESM2-1', 'IPSL-CM6A-LR', 'MIROC-ES2L', 'obs_bianchi'])

ds_scatter = mask_basin(ds_scatter)

fig, axarr = plt.subplots(nrows=2, ncols=2, figsize=[10, 10])
# for th, axarr_sub in [(core_o2, axarr[0,:]), (outer_o2, axarr[1,:])]:
for th, axarr_sub in [(core_o2, axarr[0,:]), (outer_o2, axarr[1,:])]:
    for boundary, ax in zip(['upper', 'lower'], axarr_sub.flat):
        if boundary == 'lower':
            y_var = 'trend_lower_boundary'
            diagonal_sign = 1
        else:
            y_var = 'trend_upper_boundary'
            diagonal_sign = -1
        outer_bin = 200
        x_bins = np.linspace(-outer_bin, outer_bin, 60)
        y_bins = np.linspace(-outer_bin, outer_bin, 60)
        x = ds_scatter.sel(o2_threshold=th).trend_thickness
        y = ds_scatter.sel(o2_threshold=th)[y_var]

        hist = histogram(x, y, bins=[x_bins, y_bins])
        hist.plot(
            ax=ax,
            vmin=0,
            vmax=500 if th == outer_o2 else 200,
            add_colorbar=False,
            x='trend_thickness_bin',
            cmap='BuGn',
            rasterized=True,
        )

        # flattern the arrays
#         w_flat = w.load().data.reshape(-1, 1)
        x_flat = x.load().data.reshape(-1, 1)
        y_flat = y.load().data.reshape(-1, 1)
        
        # remove the nans
        nan_mask = np.isnan(x_flat).reshape(-1, 1)
        x_flat = x_flat[~nan_mask]
        y_flat = y_flat[~nan_mask]

        # remove values that show very large thickness increase to prevent outliers
        large_value_mask = abs(x_flat) < outer_bin
        x_flat = x_flat[large_value_mask]
        y_flat = y_flat[large_value_mask]

        reg = xr_linregress(
            xr.DataArray(x_flat, dims=['dummy']),
            xr.DataArray(y_flat, dims=['dummy']),
            'dummy'
        ).load()

        line = x_bins * reg.slope.data + reg.intercept.data
        print('----------')
        print(f'{boundary}-{th}')
        print(reg)
        ax.text(10, 100, f"$r^2$: {reg.r_value.data**2:.2f} \n slope:{reg.slope.data:.2f}", color='0.5')
        ax.plot(x_bins, line, color='k', ls='-.')


        outer_bin_plot = outer_bin
        ax.set_xlim(-outer_bin_plot, outer_bin_plot)
        ax.set_ylim(outer_bin_plot, -outer_bin_plot)
        ax.axhline(0, color='0.5', ls='--')
        ax.axvline(0, color='0.5', ls='--')
        # plot expected line
        ax.plot(x_bins, diagonal_sign * x_bins, ls='--', lw=1, color='0.5')
        ax.set_xlabel('OMZ thickness trend [m/century]')
        ax.set_ylabel(f'OMZ {boundary} boundary trend [m/century]')
        ax.set_title(f'{boundary} vs thickness for {th} mumol')
fig.subplots_adjust(hspace=0.4, wspace=0.4)
fig.savefig(pfolder.joinpath(f"Thickness_Boundary_Scatter.pdf"))

## Fig 4/5: Drivers - age

Now lets look at the age inside the OMZ core and the age outside of the full OMZ

In [None]:
for domain, o2 in [('', outer_o2), ('core', core_o2)]:#
    fig, (ax_scatter, ax_age, ax_vol) = plt.subplots(ncols=3, nrows=1, figsize=np.array([22, 6])*0.8)

    for name, ds_pre in progress_bar(ddict_histogram_pre.items()):
        print(name)
        if (domain=='core' and ds_pre.source_id not in ['IPSL-CM6A-LR', 'MIROC-ES2L', 'CNRM-ESM2-1']) or domain=='':
            ds_inner = ds_pre.sel(o2_bin=slice(None, o2)).sum('o2_bin')
            ds_outer = ds_pre.sel(o2_bin=slice(o2, None)).sum('o2_bin')

            if 'agessc' in ds_pre.data_vars:
                if domain == 'core':
                    age = (ds_inner.agessc/ds_inner.volume)
                    inner_outer_age = '<='
                else:
                    age = (ds_outer.agessc/ds_outer.volume)
                    inner_outer_age = '>'
            else:
                print(f'Insert nans for age in {name}')
                age = xr.ones_like(ds_outer.volume) * np.nan

            age = age.where((age>0).all('time'))

            vol = ds_inner.volume

            # align the members
            age, vol = xr.align(age, vol, join='inner')

            age = age.load()
            vol = vol.load()
            
            # TODO: Fix this in the final version
            # remove GFDL age artifacts (from interpolating yr->month)
            # this should be unnescessary with proper processing

            if "GFDL" in name:
                age = age.where(age>age.quantile(0.03, 'time'))

            age_hist = age.sel(time=hist_slice).mean('time')
            vol_hist = vol.sel(time=hist_slice).mean('time')

            age_anom = (age - age.sel(time=slice('1850', '1860')).mean('time'))
            vol_anom = (vol - vol.sel(time=slice('1850', '1860')).mean('time'))

            age_slope, p_value_age = slope(age)
            vol_slope, p_value_vol = slope(vol)

            age_perc_slope = age_slope/age_hist*100
            vol_perc_slope = vol_slope/vol_hist*100

            if (p_value_age>0.05).any() or (p_value_vol>0.05).any():
                edgecolor = 'w'
                lw = 2
            else:
                edgecolor = 'k'
                lw=1

            # smooth the timeseries
            age_anom = age_anom.rolling(time=10, center=True).mean().isel(time=slice(5,-5))
            vol_anom = vol_anom.rolling(time=10, center=True).mean().isel(time=slice(5,-5))

            # plot timeseries of age
            for da, ax in ([age_anom, ax_age], [vol_anom, ax_vol]):
                shaded_line_plot(
                    da,
                    'member_id',
                    ax=ax,
                    line_kwargs={'color':o2_model_colors()[ds_pre.source_id]},
                )
            color = o2_model_colors()[ds_pre.source_id]
            ax_scatter.scatter(
                age_perc_slope,
                vol_perc_slope,
                c=color,
                alpha=0.3,
                s=25,
                edgecolor='none'
            )
            ax_scatter.scatter(
                age_perc_slope.mean('member_id'),
                vol_perc_slope.mean('member_id'),
                c=color,
                edgecolor=edgecolor,
                linewidths=lw
            )

    ax_scatter.set_ylabel(f'Change in {domain} OMZ Volume [%/century]')
    ax_scatter.set_xlabel(f'Change in Age (o2 {inner_outer_age} {o2}) [%/century]')

    if domain=='core':
        ax_scatter.set_xlim(-40, 40)
        ax_scatter.set_ylim(-40, 40)
    else:
        ax_scatter.set_xlim(-20, 20)
        ax_scatter.set_ylim(-20, 20)

    ax_scatter.axhline(0, color='0.5')
    ax_scatter.axvline(0, color='0.5')
    ax_scatter.set_aspect(1)

    ax_age.set_ylabel(f'Age anomaly to (1950-2000) (o2 {inner_outer_age} {o2}) [yr]')
    ax_age.set_title(f'{domain} Ideal Age')
    ax_vol.set_ylabel(f'{domain} OMZ Volume anomaly to (1950-2000) [m^3]')
    ax_vol.set_title(f'{domain} OMZ Volume')
    fig.savefig(pfolder.joinpath(f"age_{domain}_final.pdf"))
    plt.show()

## Average maps






In [None]:
kwargs = dict(**map_kwargs)

for ds, title_var, vmax in [
    (ds_combo_o2, "Oxygen", 10),
    (ds_combo_epc100, "Particulate Export", 4e-8),
    (ds_combo_age, "Ideal Age", 30),
    (ds_combo_aou, "-AOU", 10),
    (ds_combo_o2sat, "o2sat", 10),
    (ds_combo_o2residual, "o2sat-AOU", 10),
]:
    
    fig, ax = plt.subplots(subplot_kw=dict(projection=projection),figsize=[12,5])
    
    mmm = smooth(ds.median('model'))
    if title_var == 'Oxygen':
        mmm = convert_mol_m3_mymol_kg(mmm)
        mmm.name = 'Oxygen change'
        mmm.attrs['units'] = '$\mu$mol/kg/century'
        full_title = f'{title_var} averaged between {depth_pattern}'
    elif title_var == 'Ideal Age':
        full_title = f'{title_var} averaged between {depth_pattern}'
        mmm.name = 'Ideal Age'
        mmm.attrs['units'] = 'yrs/century'
    elif title_var == '-AOU':
        mmm = convert_mol_m3_mymol_kg(mmm)
        full_title = f'{title_var} averaged between {depth_pattern}'
        mmm.name = 'AOU'
        mmm.attrs['units'] = '$\mu$mol/kg/century'
    elif 'o2sat' in title_var:
        mmm = convert_mol_m3_mymol_kg(mmm)
        full_title = f'{title_var} averaged between {depth_pattern}'
        mmm.name = title_var
        mmm.attrs['units'] = '$\mu$mol/kg/century'
    else:
        full_title = f'{title_var} at 100m'
        mmm.name = 'Particulate Organic Carbon Flux'
        mmm.attrs['units'] = 'mol/$m^2$/s/century'
        
    sign_threshold = 0.6
    n_models = len(ds.model)
    dotted = np.ceil(n_models*sign_threshold)
    print(f'Dots shows where less than {dotted} models agree on the sign')
    
    mmm.plot(
        ax=ax,
        vmax=vmax,
        rasterized=True,
        **kwargs)
    
    # plot an indicator of the historical OMZ?
    omz_ref = convert_mol_m3_mymol_kg(ds_boundary_combined.historical_o2_min_value).median('model')
    omz_ref = omz_ref.where(omz_ref<1e10)#.fillna(0)
    omz_ref.plot.contour(
        ax=ax,
        levels=[core_o2, outer_o2],
        colors='k',
        linewidths=[1,2],
        add_labels=False,
        add_colorbar=False,
        **map_kwargs
    ) # TODO change to 120
    
    sign_agreement(ds, mmm, 'model', threshold=sign_threshold).plot.contourf(
        ax=ax,
        levels=[0,0.5],
        hatches=['...',None],
        colors='none',
        add_labels=False,
        add_colorbar=False,
        **map_kwargs
    )
    ax = plt.gca()
    #plot info about models and stipling in the corner
    ax.text(0.85, 0.9, f'Dotted: \n <{dotted}/{n_models} \n models agree on sign',
            horizontalalignment='right',
            verticalalignment='center',
            transform=ax.transAxes)
    finish_map_plot(ax)
    ax.set_title(full_title)
    fig.savefig(pfolder.joinpath(f"Average_{title_var}_map.pdf"))
    plt.show()

In [None]:
# just a blank legend
plt.figure()
model_color_legend(bbox_to_anchor=[0.5, 1])
plt.savefig(pfolder.joinpath('legend.pdf'))

## Plots from Sam

In [None]:
#models processed with xhistogram, so need cumsum to recover o2_bins
def convert_to_cumulative_volume(hist):
    hist_cumu = hist.cumsum("o2_bin")
    hist_cumu = hist_cumu.assign_coords(o2_bin=hist.o2_bin)
    return hist_cumu

#function to recover attributes of time series data from key name
def recover_attrs_ts(ddict):
    for name, ds in ddict.items():
        attrs = name.split('_')
        
        ds.attrs['source_id'] = attrs[0]
        ds.attrs['variant_label'] = attrs[1]
        ds.attrs['mask_type'] = attrs[2]
        
        ddict[name] = ds
    return ddict

def resample_yearly(ds_in, freq="1AS"):
    # this drops some coordinates, so i need to convert them to data_vars and then reconvert
    time_coords = [
        co
        for co in list(ds_in.coords)
        if "time" in ds_in[co].dims and co not in ["time", "time_bounds"]
    ]
    ds_out = ds_in.reset_coords(time_coords).resample(time=freq).mean()
    #ds_out = ds_in.reset_coords(time_coords).coarsen(time=12).mean()
    ds_out = ds_out.assign_coords({co: ds_out[co] for co in time_coords})
    ds_out.attrs.update({k: v for k, v in ds_in.attrs.items() if k not in ["table_id"]})
    ds_out['time'] = ds_out.time.dt.year
    
    return ds_out

In [None]:
# reload data
from busecke_etal_2021_aguadv.utils import (
    load_zarr_directory
)
ifolder = pathlib.Path('/projects/GEOCLIM/LRGROUP/jbusecke/projects/busecke_etal_2021_aguadv/data/processed/plotting')

def select_dummy(ds):
    out = ds.dummy
    out.attrs = ds.attrs
    return out

results_trend = {k:select_dummy(ds).load() for k,ds in load_zarr_directory(ifolder, pattern='*_sam_results_trend.zarr').items()}
results_hist_volume = {k:select_dummy(ds).load() for k,ds in load_zarr_directory(ifolder, pattern='*_sam_results_hist_volume.zarr').items()}
obs_hist = results_hist_volume.pop('none.none.none.none.none.none.none.none.none')
obs_hist.attrs['source_id'] = 'obs_bianchi'

omz_time_series = {ds.source_id:ds.load() for k,ds in load_zarr_directory(ifolder, pattern='*_sam_omz_time_series.zarr').items()}
CORE = {ds.source_id:ds.load() for k,ds in load_zarr_directory(ifolder, pattern='*_sam_CORE.zarr').items()}
OUTER = {ds.source_id:ds.load() for k,ds in load_zarr_directory(ifolder, pattern='*_sam_OUTER.zarr').items()}

In [None]:
#Define plotting functions

#Plot OMZ Area vs Depth
def validation_depth_plot(ddict, ds_obs, o2_thresh, lat_bins, ax):
    colors = o2_model_colors()
    Median_list = []
    
    ######## PLOT MODEL CURVES ########################
    
    for name, ds in ddict.items():
        source_id = ds.source_id
        ds_ = ds.sel(lat_bin = lat_bins).sum('lat_bin')
        out = (ds_/ds_.dz).sel(o2_bin=o2_thresh, method = 'pad')
        shaded_line_plot(out, 'member_id', line_kwargs=dict(color=colors[source_id]), horizontal=False, ax=ax)
        Median_list.append(out.mean(['member_id']))
    ###################################################    
    
    ######## Calculate and Plot ESM ensemble median ###############
    
    #linearly interpolate omz area for all models onto the same grid for median calculation
    target_lev = np.arange(0., 7000., 50.)
    Median_interp = []
    for da in Median_list:
        grid = Grid(da, coords={'LEV': {'center':'lev'}}, periodic=False)
        omz_area = grid.transform(
                                        da, 
                                          'LEV', 
                                          target_lev, 
                                          method = 'linear', 
                                          target_data=da.lev, 
                                          mask_edges=False)
        Median_interp.append(omz_area)
 
    Median = xr.concat(Median_interp, 'model').median('model')
    ax.plot(Median, Median.lev, color = 'black', linewidth = 3, ls = '--')

    M1 = xr.concat(Median_interp, 'model').quantile(.25, dim = 'model')
    M2 = xr.concat(Median_interp, 'model').quantile(.75, dim = 'model')
    ax.fill_betweenx(Median.lev, M1, M2, alpha = 0.5, color = 'gray', edgecolor = 'none')#, hatch = '-')
    ################################################################
    
    
    ########### PLOT OBSERVATIONS ################################
    
    out_obs = ds_obs.sel(lat_bin = lat_bins).sum('lat_bin')
    out_obs = (out_obs/out_obs.dz).sel(o2_bin = o2_thresh, method = 'pad') #careful here with histogram bins
    ax.plot(out_obs, out_obs.lev, color = 'black', linewidth = 3)
    
    x1 = out_obs*0.5
    x2 = out_obs*1.5
    ax.fill_betweenx(out_obs.lev, x1, x2, alpha = 0.4, color = 'black', edgecolor = 'none')#, hatch = '-')
    ###########################################################
    
    
       
    ############ FORMATING ###################################
    ax_lim_format = 13
    ax.xaxis.set_major_formatter(ScientificManualFormatter(ax_lim_format, "%1.1f"))
    ax.ticklabel_format(
        axis="x", style="sci", scilimits=(-ax_lim_format, ax_lim_format)
    )
    
    ax.grid()
    ax.axvline(0, color='0.5', ls='--')
    ax.set_ylim(7000, 0)
    ax.set_xlim(0,1e14)
    ax.set_xlabel(r'Area [$\rm{m}^2$]', fontsize = 15)
    ax.set_ylabel(r'Depth [m]', fontsize = 15)
    ax.set_title(f'o2_bin = {o2_thresh}', fontsize = 15)
    linear_piecewise_scale(1500, 4, ax=ax)
    ax.set_yticks([0, 100, 200, 300, 400, 500, 1000, 2000, 4000, 6000]);
    #model_color_legend() 
    
#Plot integrated OMZ volume vs O2 threshold
def validation_oxygen_plot(ddict, ds_obs, lat_bins, lev, ax):
    colors = o2_model_colors()
    Median_list = []
    o2_range = slice(0, 160)
    
    ######## PLOT MODEL CURVES ########################
    for name, ds in ddict.items():
        source_id = ds.source_id
        out = ds.sel(lat_bin = lat_bins, o2_bin = o2_range, lev = slice(0,lev)).sum(['lat_bin', 'lev'])
        shaded_line_plot(out, 'member_id', line_kwargs=dict(color=colors[source_id]), ax = ax)
        Median_list.append(out.mean(['member_id']))
    ###################################################
    
    ######## PLOT MEDIAN CURVE ########################
    Median = xr.concat(Median_list, 'model').median('model')
    ax.plot(Median.o2_bin, Median, color = 'black', linewidth = 3, ls = '--')
    
    M1 = xr.concat(Median_list, 'model').quantile(.25, dim = 'model')
    M2 = xr.concat(Median_list, 'model').quantile(.75, dim = 'model')
    ax.fill_between(Median.o2_bin, M1, M2, alpha = 0.5, color = 'gray', edgecolor = 'none')
    ###################################################
    
    
    
    ######### PLOT OBSERVATIONS #######################
    out_obs = ds_obs.sel(lat_bin = lat_bins, o2_bin = o2_range).sum(['lat_bin', 'lev'])
    ax.plot(out_obs.o2_bin, out_obs, color = 'black', linewidth = 3)
    
    y1 = out_obs*1.5
    y2 = out_obs*0.5
    ax.fill_between(out_obs.o2_bin, y1, y2, alpha = 0.4, color = 'black', edgecolor = 'none')
    
    ############ FORMATING ###################################
    ax_lim_format = 16
    ax.yaxis.set_major_formatter(ScientificManualFormatter(ax_lim_format, "%1.1f"))
    ax.ticklabel_format(
        axis="y", style="sci", scilimits=(-ax_lim_format, ax_lim_format)
    )
    
    ax.axhline(0., color = '0.5', ls='--')
    ax.set_xlim(5,160)
    ax.set_xlabel(r'Oxygen Threshold [$\mu$mol/kg]', fontsize = 15)
    ax.set_ylabel(r'Volume [$\rm{m}^3$]', fontsize = 15)
    #model_color_legend()
    
    
#Plot integrated OMZ volume trend vs O2 threshold    
def oxygen_space_trend_plot(ddict,lat_bins, lev, ax):
    colors = o2_model_colors()
    Median_list = []
    o2_range = slice(0, 160)
    ######## PLOT MODEL CURVES ########################
    for name, ds in ddict.items():
        source_id = ds.source_id
        out = ds.sel(lat_bin = lat_bins, o2_bin = o2_range, lev = slice(0,lev)).sum(['lat_bin', 'lev'])
        shaded_line_plot(out, 'member_id', line_kwargs=dict(color=colors[source_id]), ax = ax)
        Median_list.append(out.mean(['member_id']))
    ###################################################
    
    ######## PLOT MEDIAN CURVE ########################
    Median = xr.concat(Median_list, 'model').median('model')
    ax.plot(Median.o2_bin, Median, color = 'black', linewidth = 3, ls = '--')
    
    M1 = xr.concat(Median_list, 'model').quantile(.25, dim = 'model')
    M2 = xr.concat(Median_list, 'model').quantile(.75, dim = 'model')
    ax.fill_between(Median.o2_bin, M1, M2, alpha = 0.5, color = 'gray', edgecolor = 'none')
    ###################################################


    ############ FORMATING ###################################
    ax.axhline(0., color = '0.5', ls='--')
    ax_lim_format = 16
    ax.yaxis.set_major_formatter(ScientificManualFormatter(ax_lim_format, "%1.1f"))
    ax.ticklabel_format(
        axis="y", style="sci", scilimits=(-ax_lim_format, ax_lim_format)
    )
    ax.set_xlim(5,160)
    ax.set_ylabel('Volume Trend\n[$\mathrm{m}^3$/century]', fontsize = 15)
    ax.set_xlabel(r'Oxygen Threshold [$\mu$mol/kg]', fontsize = 15)

In [None]:
#Figure 2

%matplotlib inline
fig, axes = plt.subplots(1,2, figsize = [10,5])#, sharey = True)

lat_bin = slice(-30,30)

validation_depth_plot(results_hist_volume, obs_hist, 120, lat_bin, axes[0])
validation_depth_plot(results_hist_volume, obs_hist, 20, lat_bin, axes[1])
axes[1].set_ylabel('')
axes[0].set_title('OMZ', fontsize = 15)
axes[1].set_title('OMZ core', fontsize = 15)

plt.tight_layout()

In [None]:
#Figure 3

%matplotlib inline
fig, axes = plt.subplots(1,2, figsize = [12,4.5])

lat_bin = slice(-30,30)

validation_oxygen_plot(results_hist_volume, obs_hist, lat_bin, 8000, axes[0])
oxygen_space_trend_plot(results_trend, lat_bin, 8000, axes[1])
plt.tight_layout()

In [None]:
n = 3
m = 3
fig, axes = plt.subplots(n,m, figsize = [20,20])
lev = [2000, 3000, 8000]
lats = [10, 30, 50]

for i in range(m):
    for j in range(n):
        axes[j,i].set_title(f'LAT = (-{lats[i]},{lats[i]}), DEPTH = (0,{lev[j]})')
        validation_oxygen_plot(results_hist_volume, obs_hist, slice(-lats[i], lats[i]), lev[j], axes[j,i])
        
plt.tight_layout()

In [None]:
# Figure S4
n = 3
m = 3
fig, axes = plt.subplots(n,m, figsize = [20,20])
lev = [1000, 3000, 8000]
lats = [10, 30, 50]

for i in range(m):
    for j in range(n):
        axes[j,i].set_title(f'LAT = (-{lats[i]},{lats[i]}), DEPTH = (0,{lev[j]})')
        oxygen_space_trend_plot(results_trend, slice(-lats[i], lats[i]), lev[j], axes[j,i])
        
plt.tight_layout()

## Figures 5,6 (export)

In [None]:
#rolling window size for smoothing
wnd = 10

#convert from s-1 to yr-1
convert_export = (86400)*365.25 #s/day * day/yr

fig, ax = plt.subplots(1,3, figsize = [15, 5])
plt.suptitle('o2_threshold = 20')
colors = o2_model_colors()

for name, ds in CORE.items():
        
    ax[0].set_ylabel('Change in volume [%/century]')
    ax[0].set_xlabel('Change in export [%/century]')
    
    ########SCATTER PLOT#####################
    ax[0].scatter(
        ds.epc_perc_slope*100,
        ds.vol_perc_slope*100,
        c=o2_model_colors()[ds.source_id],
        alpha=0.3,
        s=20
    )
    if (ds.epc_p_value.mean('member_id') < 0.05) and (ds.vol_p_value.mean('member_id') < 0.05):
        edge = 'k'
    else:
        edge = 'r'
        print(name, ds.epc_p_value.mean('member_id').values, ds.vol_p_value.mean('member_id').values)
    ax[0].scatter(
        ds.epc_perc_slope.mean('member_id')*100,
        ds.vol_perc_slope.mean('member_id')*100,
        c=o2_model_colors()[ds.source_id],
        edgecolor=edge,
    )  
    ax[0].axvline(0, color = 'k', ls = '--')
    ax[0].axhline(0, color = 'k', ls = '--')
    
    
    #########EXPORT TIME SERIES ###############
    ax[1].set_ylabel(r'Export anomaly to (1950-2000) [$\rm{mmolC/m^2/yr}$]')
    epc_out = (ds.epc100 - ds.epc100.sel(time=slice('1850','1860')).mean(['time']))*convert_export
    epc_out = epc_out.sel(time = slice('1850', '2100')).rolling(time = wnd, center = True).mean().isel(time=slice(5,-5))
    shaded_line_plot(epc_out, 'member_id', line_kwargs=dict(color=colors[ds.source_id]), ax = ax[1])

    ###########VOLUME TIME SERIES###############
    ax[2].set_ylabel(r'Volume anomaly to (1950-2000) [$\rm{m^3}$]')
    vol_out = (ds.omz_vol - ds.omz_vol.sel(time=slice('1850','1860')).mean(['time']))
    vol_out = vol_out.sel(time = slice('1850', '2100')).rolling(time = wnd, center = True).mean().isel(time=slice(5,-5))
    shaded_line_plot(vol_out, 'member_id',  line_kwargs=dict(color=colors[ds.source_id]), ax = ax[2])

ax[0].set_ylim(-40, 40)
ax[0].set_xlim(-50, 50)

#model_color_legend()
plt.tight_layout()

In [None]:

#rolling window size for smoothing
wnd = 10

#convert from s-1 to yr-1
convert_export = (86400)*365.25 #s/day * day/yr

fig, ax = plt.subplots(1,3, figsize = [15, 5])
plt.suptitle(f'o2_threshold = {outer_o2}')
colors = o2_model_colors()

for name, ds in OUTER.items():
        
    ax[0].set_ylabel('Change in volume [%/century]')
    ax[0].set_xlabel('Change in export [%/century]')
    
    ########SCATTER PLOT#####################
    ax[0].scatter(
        ds.epc_perc_slope*100,
        ds.vol_perc_slope*100,
        c=o2_model_colors()[ds.source_id],
        alpha=0.3,
        s=20
    )
    if (ds.epc_p_value.mean('member_id') < 0.05) and (ds.vol_p_value.mean('member_id') < 0.05):
        edge = 'k'
    else:
        edge = 'r'
        print(name, ds.epc_p_value.mean('member_id').values, ds.vol_p_value.mean('member_id').values)
    ax[0].scatter(
        ds.epc_perc_slope.mean('member_id')*100,
        ds.vol_perc_slope.mean('member_id')*100,
        c=o2_model_colors()[ds.source_id],
        edgecolor=edge,
    )  
    ax[0].axvline(0, color = 'k', ls = '--')
    ax[0].axhline(0, color = 'k', ls = '--')
    
    
    #########EXPORT TIME SERIES ###############
    ax[1].set_ylabel(r'Export anomaly to (1950-2000) [$\rm{mmolC/m^2/yr}$]')
    epc_out = (ds.epc100 - ds.epc100.sel(time=slice('1850','1860')).mean(['time']))*convert_export
    epc_out = epc_out.sel(time = slice('1850', '2100')).rolling(time = wnd, center = True).mean().isel(time=slice(5,-5))
    shaded_line_plot(epc_out, 'member_id', line_kwargs=dict(color=colors[ds.source_id]), ax = ax[1])

    ###########VOLUME TIME SERIES###############
    ax[2].set_ylabel(r'Volume anomaly to (1950-2000) [$\rm{m^3}$]')
    vol_out = (ds.omz_vol - ds.omz_vol.sel(time=slice('1850','1860')).mean(['time']))
    vol_out = vol_out.sel(time = slice('1850', '2100')).rolling(time = wnd, center = True).mean().isel(time=slice(5,-5))
    shaded_line_plot(vol_out, 'member_id',  line_kwargs=dict(color=colors[ds.source_id]), ax = ax[2])

ax[0].set_ylim(-20, 20)
ax[0].set_xlim(-50, 50)

#model_color_legend()
plt.tight_layout()

## Table S1,S2

In [None]:
lat_bins = slice(-30,30)
ds_obs = obs_hist.sel(lat_bin = lat_bins).sum(['lat_bin', 'lev']).to_dataset(name='volume')

for name, ds in omz_time_series.items():
    ds['trend'] = linear_trend(ds.volume.sel(time = slice('2000', '2100')), 'time').slope.load()*101.
    ds['p_value'] = linear_trend(ds.volume.sel(time = slice('2000', '2100')), 'time').p_value.load()
    ds['hist_vol'] = ds.volume.sel(time=slice('1950','2000')).mean(['time']).load()
    omz_time_series[name] = ds

In [None]:
def make_table(o2_thresh):
    
    obs = ds_obs.volume.sel(o2_bin = o2_thresh, method = 'nearest')
    
    table_dict = {'name': [],
                  '# mem.': [],
                 r'volume ($\rm{m}^3$)':[],
                 'vol. std.':[],
                  'volume (% obs.)': [],
                 r'trend ($\rm{m^3/cen}$)':[],
                 'trend std.':[],
                  'trend (% change)': [],
                'trend p-value':[],
                }
    
    for name, ds in omz_time_series.items():
        ds_ = ds.sel(o2_bin = o2_thresh, method = 'nearest').median(['member_id'])
        
        table_dict['name'].append(name)
        table_dict['# mem.'].append(len(ds.member_id))
        table_dict[r'volume ($\rm{m}^3$)'].append('{:.2e}'.format(ds_.hist_vol.values))
        table_dict['volume (% obs.)'].append(int(np.rint((ds_.hist_vol/obs).values*100)))
        
        table_dict[r'trend ($\rm{m^3/cen}$)'].append('{:.2e}'.format(ds_.trend.values))
        table_dict['trend (% change)'].append(int(np.rint(ds_.trend.values/ds_.hist_vol.values*100)))
        if ds.p_value.sel(o2_bin = o2_thresh, method = 'nearest').max(['member_id']).values < 0.05:
            table_dict[r'trend p-value'].append('o')
        else:
            table_dict[r'trend p-value'].append('x')
        
        if len(ds.member_id) > 1:
            table_dict['vol. std.'].append('{:.1e}'.format(ds.hist_vol.sel(
            o2_bin = o2_thresh, method = 'nearest').std(['member_id']).values))
            table_dict['trend std.'].append('{:.1e}'.format(ds.trend.sel(
            o2_bin = o2_thresh, method = 'nearest').std(['member_id']).values))
        else:
            table_dict['vol. std.'].append('--')
            table_dict['trend std.'].append('--')
    
    DF_models = pd.DataFrame.from_dict(table_dict)
    
    #Add Observations
    WOA = {'name': ['WOA'],
                  '# mem.': ['--'],
                 r'volume ($\rm{m}^3$)':['{:.1e}'.format(obs.values)],
                 'vol. std.':['--'],
                  'volume (% obs.)': ['--'],
                 r'trend ($\rm{m^3/cen}$)':['--'],
                 'trend std.':['--'],
                  'trend (% change)': ['--'],
                'trend p-value':['--']
                }
    
    WOA = pd.DataFrame.from_dict(WOA)
    
    
    #Add Multi-model median
    MMM = {'name': ['MMM'],
                  '# mem.': [14],
                 r'volume ($\rm{m}^3$)':[],
                 'vol. std.':[],
                  'volume (% obs.)': [],
                 r'trend ($\rm{m^3/cen}$)':[],
                 'trend std.':[],
                  'trend (% change)': [],
                'trend p-value':['--']
                }
    
    MMM[r'volume ($\rm{m}^3$)'].append('{:.2e}'.format(DF_models[r'volume ($\rm{m}^3$)'].median()))
    MMM['volume (% obs.)'].append(int(np.rint(DF_models['volume (% obs.)'].median())))
    std = np.std(DF_models[r'volume ($\rm{m}^3$)'].values.astype(np.float))
    #MMM['vol. std.'].append('{:.1e}'.format(std))
    MMM['vol. std.'].append('--')
    
    MMM[r'trend ($\rm{m^3/cen}$)'].append('{:.2e}'.format(DF_models[r'trend ($\rm{m^3/cen}$)'].median()))
    
    #Not sure which method is 'right' here
    #MMM['trend (% change)'].append(int(np.rint(DF_models['trend (% change)'].median())))
    MMM['trend (% change)'].append(int(np.rint(
        DF_models[r'trend ($\rm{m^3/cen}$)'].median()/DF_models[r'volume ($\rm{m}^3$)'].median()*100)))
    
    std = np.std(DF_models[r'trend ($\rm{m^3/cen}$)'].values.astype(np.float))
    #MMM['trend std.'].append('{:.1e}'.format(std))
    MMM['trend std.'].append('--')
    
    MMM = pd.DataFrame.from_dict(MMM)
    
    DF = pd.concat([WOA, MMM, DF_models]).reset_index(drop = True)
    
    return DF

#truncate output table display
def trunc_display(df):
    
    for i, row in df.iterrows():
        row[r'volume ($\rm{m}^3$)'] = '{:.1e}'.format(float(row[r'volume ($\rm{m}^3$)']))
        if row[r'trend ($\rm{m^3/cen}$)'] != '--':
            row[r'trend ($\rm{m^3/cen}$)'] = '{:.1e}'.format(float(row[r'trend ($\rm{m^3/cen}$)']))    
    return df

In [None]:
#table for CORE OMZ
trunc_display(make_table(20))

In [None]:
#table for OUTER OMZ
trunc_display(make_table(120))