# Preprocessing for analyzing CMIP6 OMZ data in oxygen coordinates

## What is blocking this?

- the new cmip6_pp masking would make this a lot easier (with the labels)


## General notes
- CM4 is not working due to the wonky chunks. Ill try to squeeze it through anyways, because I do not want to invest much more work here. This will all work better in the cloud.

- [x] Make sure to not add 'other variables' (MIROC has a ton of 💩 in there).
- [x] 10 deg lat bins
- [x] Build in check that zarr file has been written completely
- [x] Refine the o2_bins
    - [x] seperate the negative bin
    - [x] outer bin to 1e5 mymol/kg to get the full ocean volume  
- [x]  Masking
    - [x] Split Arabian/Bob
    - [x] Give flag values (basin names) as attrs(actually as dimension label)
    - [x] Only put out the major basins
- [x] `o2_bin` data in `mumol/kg`
    - [x] check the attrs
    - [x] convert them to integers
- [x] Add bounds as a coordinate

- [x] Make sure to carry the original lev_bounds or dz
- [x] Rename 'count' to 'bin_count'
- [x] Make sure the variables are properly masked with regards to nans in o2 (needs to be a perfect overlay)
- [ ] Maybe resave as nc, but for sure rechunk to larger chunks.
- [x] Make sure to reprocess from netcdf
- [ ] Dont combine experiments
- [ ] Include the other members

In [1]:
%load_ext autoreload
%autoreload 2

import cf_xarray
import intake
import xarray as xr
import numpy as np

from cmip6_preprocessing.utils import cmip6_dataset_id
from cmip6_preprocessing.preprocessing import combined_preprocessing
from cmip6_preprocessing.postprocessing import (
    match_metrics,
    interpolate_grid_label,
    merge_variables,
    concat_experiments,
)

from cmip6_preprocessing.drift_removal import match_and_remove_trend
from fastprogress.fastprogress import progress_bar

from xhistogram.xarray import histogram

from cmip6_omz.utils import cmip6_collection, o2_models
from cmip6_omz.upstream_stash import (
    pick_first_member,
    construct_static_dz,
    concat_time,
    zarr_exists,
    pick_first_member,
)
from cmip6_omz.units import convert_mol_m3_mymol_kg

from xarrayutils.file_handling import maybe_create_folder

### needs cleaning
from cmip6_omz.omz_tools import omz_thickness_efficient
import matplotlib.pyplot as plt
from cmip6_omz.upstream_stash import append_write_zarr

## Start the processing

In [5]:
version = 'v2.3'
ofolder = maybe_create_folder(f'/projects/GEOCLIM/LRGROUP/jbusecke/projects_data/cmip6_depth_histogram_{version}')
ofolder_control =  maybe_create_folder(f'/projects/GEOCLIM/LRGROUP/jbusecke/projects_data/cmip6_depth_histogram_control_{version}')

In [3]:
# o2_models()

In [4]:
col = cmip6_collection(zarr=False)
kwargs = dict(
    aggregate=False,
    zarr_kwargs={"decode_times": True, "use_cftime": True, "consolidated": True},
    cdf_kwargs={"decode_times": True, "use_cftime": True, "chunks": {"time": 1}},
    preprocess=combined_preprocessing,
)

variable_ids = ["thetao", "so", "o2", "agessc"]  # "mlotst"
# variable_ids = ["o2", "agessc"]
metric_variable_ids = ["thkcello", "areacello"]  # "mlotst"

# models = o2_models()
models = [
#         'ACCESS-ESM1-5', # all members done
#         'CanESM5',
#         'CanESM5-CanOE',
#         'CNRM-ESM2-1',
#         'IPSL-CM6A-LR',
#         'MIROC-ES2L',
        'UKESM1-0-LL',
        'MPI-ESM1-2-HR',
        'MPI-ESM1-2-LR',
        'MRI-ESM2-0',
        'NorESM2-LM',
        'NorESM2-MM',
#         "GFDL-CM4",
#         "GFDL-ESM4",
]

cat = col.search(
    source_id=models,
    grid_label=["gr", "gn"],
    experiment_id=["historical", "ssp585"],
    table_id=["Omon"],
    variable_id=variable_ids,
)
ds_dict = cat.to_dataset_dict(**kwargs)

# # Trying to get the control runs going here, so do as little as possible?
# variable_ids = ["o2"]  # "mlotst"

# cat_control = col.search(
#     source_id=models,
#     grid_label=["gr", "gn"],
#     experiment_id=["piControl"],
#     table_id=["Omon"],
#     variable_id=variable_ids,
# )
# ds_dict_control = cat_control.to_dataset_dict(**kwargs)


# make a separate metric dict to catch all possible metrics!
cat_metrics = col.search(source_id=models, variable_id=metric_variable_ids)
ds_metric_dict = cat_metrics.to_dataset_dict(**kwargs)

Dataframe size before picking latest version: 2280
Getting latest version...

Dataframe size after picking latest version: 2258

Done....


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.dcpp_init_year.version.time_range.path'



--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.dcpp_init_year.version.time_range.path'


In [5]:
# combine in time (only needed for netcdf collection)
ds_dict = concat_time(ds_dict)
ds_metric_dict = concat_time(ds_metric_dict)

## brute force the GFDL age in the dict

In [6]:
# # Brute Force add the GFDL age
# # TODO: Make this nicer with the original netcdf files (not tonight though)
# col_gfdl = cmip6_collection(zarr=True)
# # BUG: There is something weird going on in the reading process here
# # Just drop everything that is not GFDL
# df = col_gfdl.df
# df = df.iloc[ ['GFDL' in i for i in df['source_id']], :]
# df = df.iloc[ ['agessc' in i for i in df['variable_id']], :]
# col_gfdl.df = df
# cat_gfdl = col_gfdl.search(
#     source_id=[
#         "GFDL-CM4",
#         "GFDL-ESM4", 
#     ],
#     variable_id=["agessc"],
#     experiment_id=["historical", "ssp585"],
# )

# ddict_gfdl_age = cat_gfdl.to_dataset_dict(**kwargs)
# ddict_gfdl_age = {k:ds for k, ds in ddict_gfdl_age.items()}
# ds_dict.update(ddict_gfdl_age)

# # rechunk the GFDL models in depth
# def maybe_rechunk(ds):
#     if ds.source_id in ["GFDL-CM4","GFDL-ESM4"]:
#         ds = ds.chunk({'lev':5})
#     return ds
# ds_dict = {k:maybe_rechunk(ds) for k,ds in ds_dict.items()}

## Cleanup datasets early

I am currently: 
- dropping all variables except for the one specified in ds.variable_id
- Checking if datasets have the expected length (otherwise drop)

I am currently allowing longer ssp585 runs, but could cut them here!

In [7]:
import warnings

def _expected_length(ds):
    if ds.experiment_id == "historical":
        if ds.table_id == "Omon":
            return 1980
        else:
            warnings.warn(
                f"unknown table_id [{ds.table_id}] for {cmip6_dataset_id(ds)}"
            )
            return 1

    elif "ssp" in ds.experiment_id:
        if ds.table_id == "Omon":
            return 1032
        else:
            warnings.warn(
                f"unknown table_id [{ds.table_id}] for {cmip6_dataset_id(ds)}"
            )
            return 1

    elif "Control" in ds.experiment_id:
        if ds.table_id == "Omon":
            return (
                12 * 50
            )  # just give a low number here so none of the controls are dropped
        else:
            warnings.warn(
                f"unknown table_id [{ds.table_id}] for {cmip6_dataset_id(ds)}"
            )
            return 1
    else:
        warnings.warn(
            f"unknown experiment_id [{ds.experiment_id}] for {cmip6_dataset_id(ds)}"
        )
        return 1


def filter_ddict(ddict):
    ddict_filtered = {}
    for name, ds in ddict.items():
        # drop everything but main variable
        ds = ds.drop([v for v in ds.data_vars if v != ds.variable_id])

        # filter out too short runs
        if "time" not in ds.dims:
            ddict_filtered[name] = ds
        else:
            if len(ds.time) < _expected_length(ds):
                print("---------DROPPED--------")
                print(name)
                print(_expected_length(ds))
                print(len(ds.time))
                print("---------DROPPED--------")
            else:
                ddict_filtered[name] = ds
    return ddict_filtered


# apply to data and metrics
ds_dict_filtered = filter_ddict(ds_dict)
ds_metric_dict_filtered = filter_ddict(ds_metric_dict)

In [8]:
import pathlib
# new files (change in later and get rid of `load_trend_dict` (or refactor?) and `fix_trend_metadata`)
# Load all trend files
flist = list(pathlib.Path('../../data/external/cmip6_control_drifts/').absolute().glob('*.nc'))
flist = [f for f in flist if any([v in str(f) for v in variable_ids])]
trend_dict = {}
for f in progress_bar(flist):
    trend_dict[f.stem] = xr.open_mfdataset([f])

In [9]:
# these ones are messed up...need a better way to deal with that in the previous step
# see https://github.com/jbusecke/cmip6_preprocessing/issues/175
incomplete_keys = ['CMIP.IPSL.IPSL-CM6A-LR.historical.r3i1p1f1.Omon.gn.none.area_o2']
trend_dict = {k:ds for k,ds in trend_dict.items() if k not in incomplete_keys}
# i think this should be taken care of in the filtering step above...TODO check at a later point

ddict_tracers_detrended = match_and_remove_trend(
    ds_dict_filtered,
    trend_dict,
)



## Match metrics

In [10]:
# these cause trouble
problem_metrics = [
    'ACCESS-ESM1-5.gn.ssp585.Omon.r3i1p1f1.thkcello', # metric too short
    'ACCESS-ESM1-5.gn.ssp585.Omon.r2i1p1f1.thkcello', # metric too short
    'ACCESS-ESM1-5.gn.ssp585.Omon.r1i1p1f1.thkcello', # metric too long (I guess I could fix that with a join='inner', but probably not worth it now
    'ACCESS-ESM1-5.gn.ssp585.Omon.r9i1p1f1.thkcello', # metric too short
    'ACCESS-ESM1-5.gn.ssp585.Omon.r6i1p1f1.thkcello', # metric too long (I guess I could fix that with a join='inner', but probably not worth it now
    'ACCESS-ESM1-5.gn.ssp585.Omon.r4i1p1f1.thkcello', # metric too long (I guess I could fix that with a join='inner', but probably not worth it now
    'ACCESS-ESM1-5.gn.ssp585.Omon.r8i1p1f1.thkcello', # metric too long (I guess I could fix that with a join='inner', but probably not worth it now
    'ACCESS-ESM1-5.gn.ssp585.Omon.r10i1p1f1.thkcello', # metric too long (I guess I could fix that with a join='inner', but probably not worth it now
    'ACCESS-ESM1-5.gn.ssp585.Omon.r5i1p1f1.thkcello', # metric too long (I guess I could fix that with a join='inner', but probably not worth it now
]
for key in problem_metrics:
    if key in ds_metric_dict_filtered.keys():
        del ds_metric_dict_filtered[key]

In [11]:
# this one causes problems because the time is not as long as the full data...apparently they stopped writing the thickness
# ddict_tracers_detrended_filtered = {
#     k: ds.squeeze()
#     for k, ds in ddict_tracers_detrended.items()
#     if not ("ACCESS-ESM1-5" in k and "r3i1p1f1" in k)
# }

ddict_matched = match_metrics(
    ddict_tracers_detrended,
    ds_metric_dict_filtered,
    ["areacello", "thkcello"],
    print_statistics=True,
)



Processed 225 datasets.
Exact matches:{'areacello': 0, 'thkcello': 148}
Other matches:{'areacello': 174, 'thkcello': 27}
No match found:{'areacello': 51, 'thkcello': 50}


## Interpolate Grids and merge variables

- handle the Norwegian Models inside `interpolate_grid_label` (TODO: Check if this works)

In [12]:
import dask
with dask.config.set(**{'array.slicing.split_large_chunks': True}): # only necessary for ACCESS, they are all different lengths?

    print("interpolate grids\n")
    ddict_matched_regrid = interpolate_grid_label(
        ddict_matched, merge_kwargs={"compat": "override"}
    )  # This should be a default soon

# #patch the norwegian model in manually
# ddict_patch = merge_variables(ddict_matched)
# for name, ds in ddict_patch.items():
#     if 'Nor' in name and 'gr' in name:
#         patch_name = name.replace('.gr','')
#         ddict_matched_regrid[patch_name] = ds

interpolate grids



In [13]:
np.sort(list(ddict_matched_regrid.keys()))

array(['MPI-ESM1-2-HR.historical.Omon.r1i1p1f1',
       'MPI-ESM1-2-HR.ssp585.Omon.r1i1p1f1',
       'MPI-ESM1-2-LR.historical.Omon.r10i1p1f1',
       'MPI-ESM1-2-LR.historical.Omon.r1i1p1f1',
       'MPI-ESM1-2-LR.historical.Omon.r2i1p1f1',
       'MPI-ESM1-2-LR.historical.Omon.r3i1p1f1',
       'MPI-ESM1-2-LR.historical.Omon.r4i1p1f1',
       'MPI-ESM1-2-LR.historical.Omon.r5i1p1f1',
       'MPI-ESM1-2-LR.historical.Omon.r6i1p1f1',
       'MPI-ESM1-2-LR.historical.Omon.r7i1p1f1',
       'MPI-ESM1-2-LR.historical.Omon.r8i1p1f1',
       'MPI-ESM1-2-LR.historical.Omon.r9i1p1f1',
       'MPI-ESM1-2-LR.ssp585.Omon.r10i1p1f1',
       'MPI-ESM1-2-LR.ssp585.Omon.r1i1p1f1',
       'MPI-ESM1-2-LR.ssp585.Omon.r2i1p1f1',
       'MPI-ESM1-2-LR.ssp585.Omon.r3i1p1f1',
       'MPI-ESM1-2-LR.ssp585.Omon.r4i1p1f1',
       'MPI-ESM1-2-LR.ssp585.Omon.r5i1p1f1',
       'MPI-ESM1-2-LR.ssp585.Omon.r6i1p1f1',
       'MPI-ESM1-2-LR.ssp585.Omon.r7i1p1f1',
       'MPI-ESM1-2-LR.ssp585.Omon.r8i1p1f1',
       'M

In [14]:
# # manually remove GFDL (delete again, only necessary because of the brute force gFDL age)
# ddict_matched_regrid = {k:ds for k,ds in ddict_matched_regrid.items() if 'GFDL' not in k}

## Concatenate experiments and pick only 'full (both hist and ssp)' runs

In [15]:
# somehow xarray cannot deal with comparing list/int attrs (Occurs in CM4)

# def _clean(obj):
#     for a, attr in obj.attrs.items():
#         if isinstance(attr, np.integer):
#             obj.attrs[a] = [int(attr)]
# #             print('converted to int', a, attr)
#         elif isinstance(attr, np.floating):
#             obj.attrs[a] = [float(attr)]
# #         elif isinstance(attr, list):
# #             print([type(i) for i in attr])
#     return obj
#             # I should raise that, but lets fix it quickly here

# def clean_attrs(ds):
#     ds = _clean(ds)
#     for va in ds.variables:
#         ds[va] = _clean(ds[va])    
#     return ds

ddict_ex_combined = concat_experiments(
    ddict_matched_regrid,
    concat_kwargs={
        'combine_attrs': 'drop_conflicts',
        'compat': 'override',
        'coords': 'minimal'
    }
)



Still need to deal with the access stuff here...

In [16]:
# only pick full runs (historical and ssp585)
ddict_ex_combined_full = {k:ds for k,ds in ddict_ex_combined.items() if len(ds.time)>3000}

## Check for problems and fix missing area/thickness manually

This should be wrapped and brought upstream

In [17]:
from cmip6_preprocessing.grids import combine_staggered_grid
problems = {'missing_variables':[], 'missing_area':[], 'missing_thickness':[], 'reconstructed_area':[], 'reconstructed_thickness':[]}
ddict_filtered = {}
for name, ds in ddict_ex_combined_full.items():
    flag = False
    # Check that all necessary variables are given
    missing_variables = [va for va in ["thetao", "so", "o2"] if va not in ds.variables]
    if len(missing_variables)>0:
        flag = True
        problems['missing_variables'].append((name, missing_variables))
        
    # Check for area
    if not 'areacello' in ds.coords:
        if ds.attrs['grid_label'] == 'gr': # only reconstruct for regular grids
            grid, ds = combine_staggered_grid(ds, recalculate_metrics=True)
            # I am dropping dz_t here so it can be uniformly reconstructed
            ds = ds.drop('dz_t')
            ds = ds.assign_coords(areacello = (ds.dx_t * ds.dy_t).reset_coords(drop=True))
            problems['reconstructed_area'].append(name)
            assert 'areacello' in ds.coords
        else:
            flag = True
            problems['missing_area'].append(name)
    
    # Check for thickness (and rename) TODO: We should probably not rename and just refactor to use `thkcello`
    if "thkcello" in ds.coords:
        ds = ds.rename({'thkcello': 'dz_t'})
    else:
        # try to reconstruct the thickness from static info
        try:
#             lev_vertices = cf_xarray.bounds_to_vertices(ds.lev_bounds, 'bnds').load()
#             dz_t = lev_vertices.diff('lev_vertices')
#             ds = ds.assign_coords(dz_t=('lev', dz_t.data))
            ds = construct_static_dz(ds).rename({'thkcello': 'dz_t'})
            problems['reconstructed_thickness'].append(name)
        except Exception as e:
            print(f'{name} thickness reconstruction failed with {e}')
            print(ds)
            problems['missing_thickness'].append(name)
            flag=True
            
    if not flag:
        ddict_filtered[name] = ds
problems

{'missing_variables': [('MRI-ESM2-0.gn.Omon.r1i1p1f1', ['o2'])],
 'missing_area': [],
 'missing_thickness': [],
 'reconstructed_area': ['NorESM2-MM.gr.Omon.r1i1p1f1',
  'NorESM2-LM.gr.Omon.r1i1p1f1'],
 'reconstructed_thickness': ['MRI-ESM2-0.gn.Omon.r1i1p1f1',
  'MRI-ESM2-0.gn.Omon.r1i2p1f1']}

In [18]:
# ddict_final = pick_first_member(ddict_filtered)#

# Final version: Put out all full memmbers
ddict_final = ddict_filtered

list(np.sort(list(ddict_final.keys())))


['MPI-ESM1-2-HR.gn.Omon.r1i1p1f1',
 'MPI-ESM1-2-LR.gn.Omon.r10i1p1f1',
 'MPI-ESM1-2-LR.gn.Omon.r1i1p1f1',
 'MPI-ESM1-2-LR.gn.Omon.r2i1p1f1',
 'MPI-ESM1-2-LR.gn.Omon.r3i1p1f1',
 'MPI-ESM1-2-LR.gn.Omon.r4i1p1f1',
 'MPI-ESM1-2-LR.gn.Omon.r5i1p1f1',
 'MPI-ESM1-2-LR.gn.Omon.r6i1p1f1',
 'MPI-ESM1-2-LR.gn.Omon.r7i1p1f1',
 'MPI-ESM1-2-LR.gn.Omon.r8i1p1f1',
 'MPI-ESM1-2-LR.gn.Omon.r9i1p1f1',
 'MRI-ESM2-0.gn.Omon.r1i2p1f1',
 'NorESM2-LM.gr.Omon.r1i1p1f1',
 'NorESM2-MM.gr.Omon.r1i1p1f1',
 'UKESM1-0-LL.gn.Omon.r1i1p1f2',
 'UKESM1-0-LL.gn.Omon.r2i1p1f2',
 'UKESM1-0-LL.gn.Omon.r3i1p1f2',
 'UKESM1-0-LL.gn.Omon.r4i1p1f2',
 'UKESM1-0-LL.gn.Omon.r8i1p1f2']

## Prep Basin mask
- Needs a separated Indian Ocean (Sam uses: 78E)
- Refactor with the new masking using cf-xarray

In [19]:
# load ocean basin data
import regionmask
basins = regionmask.defined_regions.natural_earth.ocean_basins_50
from cmip6_preprocessing.regionmask import merged_mask
from cmip6_preprocessing.regionmask import _default_merge_dict
mask_labels_raw = {label:mi for mi, label in enumerate(_default_merge_dict().keys())}

In [20]:
def mask_split_indian_labels(mask_labels):
    # modify the mask labels
    modified_mask_labels = {}
    for label, i in mask_labels.items():
        if i>5:
            i += 1

        if i!=5:
            modified_mask_labels[label] = i
        else:
            modified_mask_labels['Indian_AS'] = 5
            modified_mask_labels['Indian_BOB'] = 6
    return modified_mask_labels

def mask_split_indian(ds, mask, mask_labels, plot=True):
    # first move all labels (including Indian) one up
    mask_modified = mask.where(mask<5, mask+1)

    # Now move the Arabian sea one down again
    mask_modified = mask_modified.where(~np.logical_and(mask_modified==6, ds.lon<78), 5)
    
    mask_labels_modified = mask_split_indian_labels(mask_labels)
    
    if plot:
        fig, axarr = plt.subplots(ncols=4, nrows=4, figsize=[15,10])
        for ax,(label,i) in zip(axarr.flat, mask_labels_modified.items()):
            mask_modified.where(mask_modified==i).plot(ax=ax)
            ax.set_title(label)
        fig.subplots_adjust(hspace=0.7, wspace=0.7)
        plt.show()
    return mask_modified, mask_labels_modified

## Lets check how to split the Indian Ocean

## This needs to go within the loop later

In [21]:
def unify_nan(ds, sub_slices={"time": slice(12, 24)}):
    """adjusts all data_variables to have nans in the same spots"""
    # TODO: The GFDL age has all nan slices at the beginning and end due to interpolation
    # extract subsets
    datasets = [ds[va] for va in ds.data_vars]
    datasets = [
        da.isel(**{di: v for di, v in sub_slices.items() if di in da.dims})
        for da in datasets
    ]
    datasets = [np.isnan(da) for da in datasets]
    mask = (
        sum(datasets)
        .astype(bool)
        .any([di for di in sub_slices.keys() if di in ds.dims])
    )
    return ds.where(~mask)


def full_wrapper(ds):

    ## Define bins

    o2_bins = np.hstack(
        [-100, np.arange(0, 25, 5), np.arange(30, 180, 10), 1e5]
    )  # in mymol/kg
    o2_bins_converted = bins_converted = (
        o2_bins / convert_mol_m3_mymol_kg(xr.DataArray([1])).data
    )
    # define mask bins
    #
    mask_bins = np.arange(-0.5, 7.0, 1)
    # this does not include all basins
    # but it will save some space and we (for now) dont care about OMZs
    # in the Caspian Sea
    mask_bin_labels = [
        label
        for label, i in mask_split_indian_labels(mask_labels_raw).items()
        if i < mask_bins.max()
    ]
    assert len(mask_bin_labels) == len(mask_bins) - 1

    lat_bins = np.arange(-60, 61, 10)

    ##

    ds["vol"] = ds.dz_t * ds.areacello

    ds = unify_nan(ds)

    # Masking
    mask = merged_mask(basins, ds)
    mask, _ = mask_split_indian(ds, mask, mask_labels_raw, plot=False)
    mask.name = "basin_mask"

    
    vol = ds["vol"]
    
    
    count = histogram(
        ds.o2,
        ds.lat,
        mask,
        bins=[o2_bins_converted, lat_bins, mask_bins],
        dim=["x", "y"],
    )
    volume = histogram(
        ds.o2,
        ds.lat,
        mask,
        bins=[o2_bins_converted, lat_bins, mask_bins],
        weights=vol,
        dim=["x", "y"],
    )
    # drop volume or it will get combined as tracer.
    ds = ds.drop(["vol"])
    tracers = {}
    for tr in ds.data_vars:
        tracers[tr] = histogram(
            ds.o2,
            ds.lat,
            mask,
            bins=[o2_bins_converted, lat_bins, mask_bins],
            weights=ds[tr] * vol,
            dim=["x", "y"],
        )
    ds_hist = xr.Dataset(dict(bin_count=count, volume=volume, **tracers))
    ds_hist.attrs = {
        k: v for k, v in ds.attrs.items() if k not in ["intake_esm_varname"]
    }

    # Add more coordinates etc to the output
    ds_hist = ds_hist.assign_coords(o2_bin=convert_mol_m3_mymol_kg(ds_hist.o2_bin).data)
    ds_hist["o2_bin"].attrs["units"] = "$\mu mol$/kg"
    ds_hist = ds_hist.assign_coords(lev_bounds=ds.lev_bounds)
    ds_hist = ds_hist.assign_coords(basin_mask_bin=mask_bin_labels)
    ds_hist = ds_hist.assign_coords(
        o2_bin_bounds=cf_xarray.vertices_to_bounds(
            xr.DataArray(o2_bins, dims="o2_bin"),
            [
                "bnds",
                "o2_bin",
            ],
        )
    )
    ds_hist = ds_hist.assign_coords(
        lat_bounds=cf_xarray.vertices_to_bounds(
            xr.DataArray(lat_bins, dims="lat_bin"),
            [
                "bnds",
                "lat_bin",
            ],
        )
    )
    return ds_hist

## Process the models

In [22]:
overwrite = False
output_checks = False

# TODO, rerun these all for v3 (probably want to forego the experiment concat)
for name, ds in ddict_final.items():
    print(name)
    ds_hist = full_wrapper(ds)
    # output checks
    if output_checks:
        for va in ds_hist.data_vars:
            test = ~np.isnan(ds_hist[va].isel(time=slice(6, 9))).any().load()
            assert test

        bin_timeseries = ds_hist.bin_count.sum([di for di in ds_hist.bin_count.dims if di != 'time'])
        bin_timeseries = bin_timeseries.isel(time=slice(0,5)).load()
        assert (bin_timeseries > 10).all('time')

    # only save annual summed values
    ds_hist = ds_hist.coarsen(time=12).mean()

    path = ofolder.joinpath(f"{cmip6_dataset_id(ds_hist)}.zarr")
    print(path)

    if not zarr_exists(path) or overwrite or (len(xr.open_zarr(path).time) != len(ds_hist.time)):
        print(f"{ds_hist.nbytes/1e9} GB")
        split = 5 if 'GFDL' in ds.source_id else 10
        append_write_zarr(ds_hist, path, split)  # I got CM4 to run with 2.
    else:
        print("Exists already")

MPI-ESM1-2-HR.gn.Omon.r1i1p1f1
/projects/GEOCLIM/LRGROUP/jbusecke/projects_data/cmip6_depth_histogram_v2.3/none.none.MPI-ESM1-2-HR.none.r1i1p1f1.Omon.gn.none.none.zarr
Exists already
NorESM2-MM.gr.Omon.r1i1p1f1
/projects/GEOCLIM/LRGROUP/jbusecke/projects_data/cmip6_depth_histogram_v2.3/none.NCC.NorESM2-MM.none.r1i1p1f1.Omon.gr.none.none.zarr
Exists already
MPI-ESM1-2-LR.gn.Omon.r4i1p1f1
/projects/GEOCLIM/LRGROUP/jbusecke/projects_data/cmip6_depth_histogram_v2.3/none.MPI-M.MPI-ESM1-2-LR.none.r4i1p1f1.Omon.gn.none.none.zarr
Exists already
MPI-ESM1-2-LR.gn.Omon.r1i1p1f1
/projects/GEOCLIM/LRGROUP/jbusecke/projects_data/cmip6_depth_histogram_v2.3/none.MPI-M.MPI-ESM1-2-LR.none.r1i1p1f1.Omon.gn.none.none.zarr
Exists already
NorESM2-LM.gr.Omon.r1i1p1f1
/projects/GEOCLIM/LRGROUP/jbusecke/projects_data/cmip6_depth_histogram_v2.3/none.NCC.NorESM2-LM.none.r1i1p1f1.Omon.gr.none.none.zarr
Exists already
MPI-ESM1-2-LR.gn.Omon.r2i1p1f1
/projects/GEOCLIM/LRGROUP/jbusecke/projects_data/cmip6_depth_histo

## Process the observations

In [23]:
from cmip6_omz.datasets import load_bianchi
from cmip6_omz.upstream_stash import construct_static_dz

ds_obs = load_bianchi()#
ds_obs = construct_static_dz(ds_obs, bound_coord='lev_bounds').rename({'thkcello':'dz_t'})# why TF is this not working in the function?

ds_obs_hist = full_wrapper(ds_obs)

path = ofolder.joinpath(f"obs.zarr")
ds_obs_hist.to_zarr(path, consolidated=True, mode='w')

<xarray.Dataset>
Dimensions:     (x: 360, y: 180, lev: 33, bnds: 2)
Coordinates:
  * x           (x) float64 0.5 1.5 2.5 3.5 4.5 ... 356.5 357.5 358.5 359.5
  * y           (y) float64 -89.5 -88.5 -87.5 -86.5 ... 86.5 87.5 88.5 89.5
  * lev         (lev) float64 0.0 10.0 20.0 30.0 ... 4e+03 4.5e+03 5e+03 5.5e+03
    lev_bounds  (lev, bnds) float64 ...
    lon         (x, y) float64 0.5 0.5 0.5 0.5 0.5 ... 359.5 359.5 359.5 359.5
    lat         (x, y) float64 -89.5 -88.5 -87.5 -86.5 ... 86.5 87.5 88.5 89.5
Dimensions without coordinates: bnds
Data variables:
    TIME_bnds   (bnds) float64 165.7 196.0
    o2          (lev, y, x) float32 nan nan nan nan nan ... nan nan nan nan nan




<xarray.backends.zarr.ZarrStore at 0x7f51fd852eb0>

In [24]:
stop here

SyntaxError: invalid syntax (4067800170.py, line 1)

## Test the output against something I know

This seems fine. The bin count is not exact, but that is probably due to some changes in the numerical precision in one of the methods.

Nice the mean tracer values also line up.

In [None]:
# might need to rename a few here so that it works and the dataset should probably be synth 
# Then I can move this to the tests.

# define mask bins
mask_bins = np.arange(-0.5, 13.0, 1)# for now manual, but maybe there is a clever way to do this?
mask_bins

lat_bins = np.arange(-90, 91, 20)
lat_bins

lev_bins = np.arange(0, 7000, 500)
lev_bins
vol = (ds.dz_t*ds.areacello) 

count = histogram(ds.o2, ds.lat, mask, bins=[o2_bins_converted, lat_bins, mask_bins], dim=['x','y'])
volume = histogram(ds.o2, ds.lat, mask, bins=[o2_bins_converted, lat_bins, mask_bins], weights=vol, dim=['x','y'])
tracers = {}
for tr in ds.data_vars:
    tracers[tr] = histogram(ds.o2, ds.lat, mask, bins=[o2_bins_converted, lat_bins, mask_bins], weights=ds[tr]*vol, dim=['x','y'])

In [None]:
from cmip6_omz.omz_tools import mask_basin

In [None]:
# expected_threshold = 0.082
cutoff = o2_bins_converted[-1]
test_full_pacific = test
# expected_full_pacific = mask_basin(ds.o2.isel(time=0), drop=False)
expected_full_pacific = ds.o2
expected_full_pacific = xr.ones_like(expected_full_pacific).where(expected_full_pacific<=cutoff)

In [None]:
test

In [None]:
expected_full_pacific

In [None]:
ds.o2.shape

In [None]:
expected_full_pacific.shape

In [None]:
test_full_pacific.sum().data/expected_full_pacific.sum().load().data

I suspect this is due to the numerial precision going wrong somewhere...

In [None]:
test

In [None]:
# example with masked basin
cutoff = o2_bins_converted[-3]
test_full_pacific = test.sel(basin_mask_bin=slice(1.5, 3.5))
expected_full_pacific = mask_basin(ds.o2, drop=False)
expected_full_pacific = xr.ones_like(expected_full_pacific).where(expected_full_pacific<=cutoff)

In [None]:
test_full_pacific.sum().data/expected_full_pacific.sum().load().data

In [None]:
# example with masked basin
cutoff = o2_bins_converted[-3]
test_full_pacific = test3.sel(basin_mask_bin=slice(1.5, 3.5)).sum()/test2.sel(basin_mask_bin=slice(1.5, 3.5)).sum()
expected_full_pacific = mask_basin(ds, drop=False)
expected_full_pacific = expected_full_pacific.where(expected_full_pacific<=cutoff)
expected_full_pacific = expected_full_pacific.o2.weighted((expected_full_pacific.areacello*expected_full_pacific.dz_t).fillna(0)).mean()

In [None]:
expected_full_pacific.load()

In [None]:
test_full_pacific