# Preprocessing for analyzing CMIP6 OMZ data in oxygen coordinates

## What is blocking this?

- the new cmip6_pp masking would make this a lot easier (with the labels)


## General notes
- CM4 is not working due to the wonky chunks. Ill try to squeeze it through anyways, because I do not want to invest much more work here. This will all work better in the cloud.

In [1]:
import cf_xarray
import intake
import xarray as xr
import numpy as np

from cmip6_preprocessing.utils import cmip6_dataset_id
from cmip6_preprocessing.preprocessing import combined_preprocessing
from cmip6_preprocessing.postprocessing import (
    match_metrics,
    interpolate_grid_label,
    merge_variables,
    concat_experiments,
)
from cmip6_preprocessing.drift_removal import match_and_remove_trend
from fastprogress.fastprogress import progress_bar

from xhistogram.xarray import histogram

from cmip6_omz.utils import cmip6_collection, o2_models
from cmip6_omz.upstream_stash import (
    pick_first_member,
    construct_static_dz,
    concat_time,
    zarr_exists
)
from cmip6_omz.units import convert_mol_m3_mymol_kg

from xarrayutils.file_handling import maybe_create_folder

### needs cleaning
from cmip6_omz.omz_tools import omz_thickness_efficient
import matplotlib.pyplot as plt


In [2]:
from cmip6_omz.upstream_stash import append_write_zarr

In [3]:
# set up dask
from dask.distributed import LocalCluster, Client
mem_total = 384
workers = 5
threads = 5 # 4*6 seemed to work quite well, but I would like this to perform a bit better
cluster = LocalCluster(
    memory_limit=f"{int(mem_total/workers)}GiB",
    dashboard_address=9999,
    threads_per_worker=threads,
    n_workers = workers,
                      )
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:9999/status,

0,1
Dashboard: http://127.0.0.1:9999/status,Workers: 5
Total threads: 25,Total memory: 380.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:43116,Workers: 5
Dashboard: http://127.0.0.1:9999/status,Total threads: 25
Started: Just now,Total memory: 380.00 GiB

0,1
Comm: tcp://127.0.0.1:44582,Total threads: 5
Dashboard: http://127.0.0.1:33773/status,Memory: 76.00 GiB
Nanny: tcp://127.0.0.1:41543,
Local directory: /projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip6_omz/notebooks/processing/dask-worker-space/worker-jfmbjfwd,Local directory: /projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip6_omz/notebooks/processing/dask-worker-space/worker-jfmbjfwd

0,1
Comm: tcp://127.0.0.1:38507,Total threads: 5
Dashboard: http://127.0.0.1:38501/status,Memory: 76.00 GiB
Nanny: tcp://127.0.0.1:32902,
Local directory: /projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip6_omz/notebooks/processing/dask-worker-space/worker-wa2qohw1,Local directory: /projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip6_omz/notebooks/processing/dask-worker-space/worker-wa2qohw1

0,1
Comm: tcp://127.0.0.1:40549,Total threads: 5
Dashboard: http://127.0.0.1:39107/status,Memory: 76.00 GiB
Nanny: tcp://127.0.0.1:36719,
Local directory: /projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip6_omz/notebooks/processing/dask-worker-space/worker-5b3e3svl,Local directory: /projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip6_omz/notebooks/processing/dask-worker-space/worker-5b3e3svl

0,1
Comm: tcp://127.0.0.1:37838,Total threads: 5
Dashboard: http://127.0.0.1:39409/status,Memory: 76.00 GiB
Nanny: tcp://127.0.0.1:33747,
Local directory: /projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip6_omz/notebooks/processing/dask-worker-space/worker-bym5__cl,Local directory: /projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip6_omz/notebooks/processing/dask-worker-space/worker-bym5__cl

0,1
Comm: tcp://127.0.0.1:40044,Total threads: 5
Dashboard: http://127.0.0.1:46199/status,Memory: 76.00 GiB
Nanny: tcp://127.0.0.1:42952,
Local directory: /projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip6_omz/notebooks/processing/dask-worker-space/worker-uuac89mn,Local directory: /projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip6_omz/notebooks/processing/dask-worker-space/worker-uuac89mn


In [4]:
ofolder = maybe_create_folder('/projects/GEOCLIM/LRGROUP/jbusecke/projects_data/cmip6_depth_histogram_v2.1')

In [5]:
o2_models()

['ACCESS-ESM1-5',
 'CESM2',
 'CESM2-WACCM',
 'CMCC-ESM2',
 'CNRM-ESM2-1',
 'CanESM5',
 'CanESM5-CanOE',
 'EC-Earth3-CC',
 'GFDL-CM4',
 'GFDL-ESM4',
 'IPSL-CM5A2-INCA',
 'IPSL-CM6A-LR',
 'KIOST-ESM',
 'MIROC-ES2L',
 'MPI-ESM-1-2-HAM',
 'MPI-ESM1-2-HR',
 'MPI-ESM1-2-LR',
 'MRI-ESM2-0',
 'NorESM2-LM',
 'NorESM2-MM',
 'UKESM1-0-LL']

In [6]:
col = intake.open_esm_datastore(
    cmip6_collection(zarr=False)
) 

z_kwargs = {"decode_times": True, "use_cftime": True, "consolidated": True}
n_kwargs = {"decode_times": True, "use_cftime": True, "chunks": {"time": 1}}

variable_ids = ["thetao", "so", "o2", "agessc"]  # "mlotst"
metric_variable_ids = ["thkcello", "areacello"]  # "mlotst"

# models = o2_models()
models = [
#      'CanESM5-CanOE',
#      'CanESM5',
#      'CNRM-ESM2-1',
#      'ACCESS-ESM1-5',
#      'MPI-ESM-1-2-HAM',
#      'IPSL-CM6A-LR',
#      'MIROC-ES2L',
#      'UKESM1-0-LL',
     'MPI-ESM1-2-HR',
#      'MPI-ESM1-2-LR',
#      'MRI-ESM2-0',
#      'NorCPM1',
#      'NorESM1-F',
#      'NorESM2-LM',
#      'NorESM2-MM',
     'GFDL-CM4',
     'GFDL-ESM4'
]

cat = col.search(
    source_id=models,
    grid_label=["gr", "gn"],
    experiment_id=["historical", "ssp585"],
    table_id=["Omon"],
    variable_id=variable_ids,
)
ds_dict = cat.to_dataset_dict(
    aggregate=False,
    zarr_kwargs=z_kwargs,
    cdf_kwargs=n_kwargs,
    preprocess=combined_preprocessing,
)

# make a separate metric dict to catch all possible metrics!
cat_metrics = col.search(source_id=models, variable_id=metric_variable_ids)
ds_metric_dict = cat_metrics.to_dataset_dict(
    aggregate=False,
    zarr_kwargs=z_kwargs,
    cdf_kwargs=n_kwargs,
    preprocess=combined_preprocessing,
)


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.dcpp_init_year.version.time_range.path'



--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.dcpp_init_year.version.time_range.path'


In [7]:
# clean out the variables
ds_dict_filtered = {k:ds.drop([v for v in ds.data_vars if v != ds.variable_id]) for k, ds in ds_dict.items()}

In [8]:
# TODO: still some errors in the concat_time

In [None]:
# combine in time (only needed for netcdf collection)
ds_dict = concat_time(ds_dict_filtered)
ds_metric_dict = concat_time(ds_metric_dict)

In [None]:
import pathlib
# new files (change in later and get rid of `load_trend_dict` (or refactor?) and `fix_trend_metadata`)
# Load all trend files
flist = list(pathlib.Path('../../data/external/cmip6_control_drifts/').absolute().glob('*.nc'))
flist = [f for f in flist if any([v in str(f) for v in variable_ids])]
trend_dict = {}
for f in progress_bar(flist):
    trend_dict[f.stem] = xr.open_mfdataset([f])

In [None]:
# these ones are messed up...need a better way to deal with that in the previous step
# see https://github.com/jbusecke/cmip6_preprocessing/issues/175
incomplete_keys = ['CMIP.IPSL.IPSL-CM6A-LR.historical.r3i1p1f1.Omon.gn.none.area_o2']
trend_dict = {k:ds for k,ds in trend_dict.items() if k not in incomplete_keys}

ddict_tracers_detrended = match_and_remove_trend(
    ds_dict,
    trend_dict,
)

## Match metrics

In [None]:
# this one causes problems because the time is not as long as the full data...
ddict_tracers_detrended_filtered = {
    k: ds.squeeze()
    for k, ds in ddict_tracers_detrended.items()
    if not ("CNRM-ESM2-1" in k and "r6i1p1f2" in k)
}

ddict_matched = match_metrics(
    ddict_tracers_detrended_filtered,
    ds_metric_dict,
    ["areacello", "thkcello"],
    print_statistics=True,
)

## Interpolate Grids and merge variables

- handle the Norwegian Models inside `interpolate_grid_label`

In [None]:
print("interpolate grids\n")
ddict_matched_regrid = interpolate_grid_label(
    ddict_matched, merge_kwargs={"compat": "override"}
)  # This should be a default soon

In [None]:
#patch the norwegian model in manually
ddict_patch = merge_variables(ddict_matched)
for name, ds in ddict_patch.items():
    if 'Nor' in name and 'gr' in name:
        patch_name = name.replace('.gr','')
        ddict_matched_regrid[patch_name] = ds

In [None]:
np.sort(list(ddict_matched_regrid.keys()))

## Concatenate experiments and pick the first full one

In [None]:
# somehow xarray cannot deal with comparing list/int attrs (Occurs in CM4)

# def _clean(obj):
#     for a, attr in obj.attrs.items():
#         if isinstance(attr, np.integer):
#             obj.attrs[a] = [int(attr)]
# #             print('converted to int', a, attr)
#         elif isinstance(attr, np.floating):
#             obj.attrs[a] = [float(attr)]
# #         elif isinstance(attr, list):
# #             print([type(i) for i in attr])
#     return obj
#             # I should raise that, but lets fix it quickly here

# def clean_attrs(ds):
#     ds = _clean(ds)
#     for va in ds.variables:
#         ds[va] = _clean(ds[va])    
#     return ds

ddict_ex_combined = concat_experiments(
    ddict_matched_regrid,
    concat_kwargs={
        'combine_attrs': 'drop_conflicts',
        'compat': 'override',
        'coords': 'minimal'
    }
)

In [None]:
ddict_ex_combined.keys()

Still need to deal with the access stuff here...

In [None]:
# only pick full runs (historical and ssp585)
ddict_ex_combined_full = {k:ds for k,ds in ddict_ex_combined.items() if len(ds.time)>3000}

## Check for problems and fix missing area/thickness manually

This should be wrapped and brought upstream

In [None]:
from cmip6_preprocessing.grids import combine_staggered_grid
problems = {'missing_variables':[], 'missing_area':[], 'missing_thickness':[], 'reconstructed_area':[], 'reconstructed_thickness':[]}
ddict_filtered = {}
for name, ds in ddict_ex_combined_full.items():
    flag = False
    # Check that all necessary variables are given
    missing_variables = [va for va in ["thetao", "so", "o2"] if va not in ds.variables]
    if len(missing_variables)>0:
        flag = True
        problems['missing_variables'].append((name, missing_variables))
        
    # Check for area
    if not 'areacello' in ds.coords:
        if ds.attrs['grid_label'] == 'gr': # only reconstruct for regular grids
            grid, ds = combine_staggered_grid(ds, recalculate_metrics=True)
            # I am dropping dz_t here so it can be uniformly reconstructed
            ds = ds.drop('dz_t')
            ds = ds.assign_coords(areacello = (ds.dx_t * ds.dy_t).reset_coords(drop=True))
            problems['reconstructed_area'].append(name)
            assert 'areacello' in ds.coords
        else:
            flag = True
            problems['missing_area'].append(name)
    
    # Check for thickness (and rename) TODO: We should probably not rename and just refactor to use `thkcello`
    if "thkcello" in ds.coords:
        ds = ds.rename({'thkcello': 'dz_t'})
    else:
        # try to reconstruct the thickness from static info
        try:
#             lev_vertices = cf_xarray.bounds_to_vertices(ds.lev_bounds, 'bnds').load()
#             dz_t = lev_vertices.diff('lev_vertices')
#             ds = ds.assign_coords(dz_t=('lev', dz_t.data))
            ds = construct_static_dz(ds).rename({'thkcello': 'dz_t'})
            problems['reconstructed_thickness'].append(name)
        except Exception as e:
            print(f'{name} thickness reconstruction failed with {e}')
            print(ds)
            problems['missing_thickness'].append(name)
            flag=True
            
    if not flag:
        ddict_filtered[name] = ds
problems

In [None]:
ddict_final = pick_first_member(ddict_filtered)#
list(np.sort(list(ddict_final.keys())))

## Prep Basin mask
- Needs a separated Indian Ocean (Sam uses: 78E)
- Refactor with the new masking using cf-xarray

In [None]:
# load ocean basin data
import regionmask
basins = regionmask.defined_regions.natural_earth.ocean_basins_50
from cmip6_preprocessing.regionmask import merged_mask
from cmip6_preprocessing.regionmask import _default_merge_dict
mask_labels_raw = {label:mi for mi, label in enumerate(_default_merge_dict().keys())}

In [None]:
def mask_split_indian_labels(mask_labels):
    # modify the mask labels
    modified_mask_labels = {}
    for label, i in mask_labels.items():
        if i>5:
            i += 1

        if i!=5:
            modified_mask_labels[label] = i
        else:
            modified_mask_labels['Indian_AS'] = 5
            modified_mask_labels['Indian_BOB'] = 6
    return modified_mask_labels

def mask_split_indian(mask, mask_labels, plot=True):
    # first move all labels (including Indian) one up
    mask_modified = mask.where(mask<5, mask+1)

    # Now move the Arabian sea one down again
    mask_modified = mask_modified.where(~np.logical_and(mask_modified==6, ds.lon<78), 5)
    
    mask_labels_modified = mask_split_indian_labels(mask_labels)
    
    if plot:
        fig, axarr = plt.subplots(ncols=4, nrows=4, figsize=[15,10])
        for ax,(label,i) in zip(axarr.flat, mask_labels_modified.items()):
            mask_modified.where(mask_modified==i).plot(ax=ax)
            ax.set_title(label)
        fig.subplots_adjust(hspace=0.7, wspace=0.7)
        plt.show()
    return mask_modified, mask_labels_modified

## Define bins

In [None]:
o2_bins = np.hstack([-100, np.arange(0, 160, 5), 1e5])  # in mymol/kg
o2_bins_converted = bins_converted = (
    o2_bins / convert_mol_m3_mymol_kg(xr.DataArray([1])).data
)
print(o2_bins)

# define mask bins
# 
mask_bins = np.arange(-0.5, 7.0, 1) 
# this does not include all basins
# but it will save some space and we (for now) dont care about OMZs
# in the Caspian Sea
mask_bin_labels = [label for label,i in mask_split_indian_labels(mask_labels_raw).items() if i< mask_bins.max()]
assert len(mask_bin_labels) == len(mask_bins)-1

lat_bins = np.arange(-60, 61, 10)
lat_bins

## Lets check how to split the Indian Ocean

## This needs to go within the loop later

- [x] Make sure to not add 'other variables' (MIROC has a ton of ðŸ’© in there).
- [x] 10 deg lat bins
- [x] Build in check that zarr file has been written completely
- [x] Refine the o2_bins
    - [x] seperate the negative bin
    - [x] outer bin to 1e5 mymol/kg to get the full ocean volume  
- [x]  Masking
    - [x] Split Arabian/Bob
    - [x] Give flag values (basin names) as attrs(actually as dimension label)
    - [x] Only put out the major basins
- [x] `o2_bin` data in `mumol/kg`
    - [x] check the attrs
    - [x] convert them to integers
- [x] Add bounds as a coordinate

- [x] Make sure to carry the original lev_bounds or dz
- [x] Rename 'count' to 'bin_count'
- [x] Make sure the variables are properly masked with regards to nans in o2 (needs to be a perfect overlay)
- [ ] Maybe resave as nc, but for sure rechunk to larger chunks.
- [x] Make sure to reprocess from netcdf
- [ ] Dont combine experiments
- [ ] Include the other members

In [None]:
def unify_nan(ds, sub_slices={'time':slice(0,6)}):
    """adjusts all data_variables to have nans in the same spots"""
    # extract subsets
    datasets = [ds[va] for va in ds.data_vars]
    datasets = [da.isel(**{di:v for di,v in sub_slices.items() if di in da.dims}) for da in datasets]
    datasets = [np.isnan(da) for da in datasets]
#     print(datasets)
    mask = sum(datasets).astype(bool).any([di for di in sub_slices.keys()])
    return ds.where(~mask)

In [None]:
for name, ds in ddict_final.items():
    ds['vol'] = ds.dz_t * ds.areacello
    
    ds = unify_nan(ds)

    # Masking
    mask = merged_mask(basins, ds)
    mask, _ = mask_split_indian(mask, mask_labels_raw)
    mask.name = "basin_mask"

    # Create a dataset
    vol = ds.vol

    count = histogram(
        ds.o2,
        ds.lat,
        mask,
        bins=[o2_bins_converted, lat_bins, mask_bins],
        dim=["x", "y"],
    )
    volume = histogram(
        ds.o2,
        ds.lat,
        mask,
        bins=[o2_bins_converted, lat_bins, mask_bins],
        weights=vol,
        dim=["x", "y"],
    )
    tracers = {}
    for tr in ds.data_vars:
        tracers[tr] = histogram(
            ds.o2,
            ds.lat,
            mask,
            bins=[o2_bins_converted, lat_bins, mask_bins],
            weights=ds[tr] * vol,
            dim=["x", "y"],
        )
    ds_hist = xr.Dataset(dict(bin_count=count, volume=volume, **tracers))
    ds_hist.attrs = {
        k: v for k, v in ds.attrs.items() if k not in ["intake_esm_varname"]
    }

    # Add more coordinates etc to the output
    ds_hist = ds_hist.assign_coords(o2_bin=convert_mol_m3_mymol_kg(ds_hist.o2_bin).data)
    ds_hist['o2_bin'].attrs['units'] = '$\mu mol$/kg'
    ds_hist = ds_hist.assign_coords(lev_bounds=ds.lev_bounds)
    ds_hist = ds_hist.assign_coords(basin_mask_bin=mask_bin_labels)
    ds_hist = ds_hist.assign_coords(
        o2_bin_bounds=cf_xarray.vertices_to_bounds(
            xr.DataArray(o2_bins, dims="o2_bin"),
            [
                "bnds",
                "o2_bin",
            ],
        )
    )
    ds_hist = ds_hist.assign_coords(
        lat_bounds=cf_xarray.vertices_to_bounds(
            xr.DataArray(lat_bins, dims="lat_bin"),
            [
                "bnds",
                "lat_bin",
            ],
        )
    )

    # activate later (takes a lot of time now)
#     print('Checking for nans')
#     for va in ds_hist.data_vars:
#         test = ~np.isnan(ds_hist[va].isel(time=slice(0,3))).any().load()
#         assert test
        
    
    print(f"{ds_hist.nbytes/1e9} GB")
    path = ofolder.joinpath(f"{cmip6_dataset_id(ds_hist)}.zarr")
    if not zarr_exists(path):
        print(path)
        # just for testing
        #         append_write_zarr(ds_hist, path, 60)
        append_write_zarr(ds_hist, path, 120)  # just for CM4

In [None]:
stop here

## Test the output against something I know

This seems fine. The bin count is not exact, but that is probably due to some changes in the numerical precision in one of the methods.

Nice the mean tracer values also line up.

In [None]:
# might need to rename a few here so that it works and the dataset should probably be synth 
# Then I can move this to the tests.

# define mask bins
mask_bins = np.arange(-0.5, 13.0, 1)# for now manual, but maybe there is a clever way to do this?
mask_bins

lat_bins = np.arange(-90, 91, 20)
lat_bins

lev_bins = np.arange(0, 7000, 500)
lev_bins
vol = (ds.dz_t*ds.areacello) 

count = histogram(ds.o2, ds.lat, mask, bins=[o2_bins_converted, lat_bins, mask_bins], dim=['x','y'])
volume = histogram(ds.o2, ds.lat, mask, bins=[o2_bins_converted, lat_bins, mask_bins], weights=vol, dim=['x','y'])
tracers = {}
for tr in ds.data_vars:
    tracers[tr] = histogram(ds.o2, ds.lat, mask, bins=[o2_bins_converted, lat_bins, mask_bins], weights=ds[tr]*vol, dim=['x','y'])

In [None]:
from cmip6_omz.omz_tools import mask_basin

In [None]:
# expected_threshold = 0.082
cutoff = o2_bins_converted[-1]
test_full_pacific = test
# expected_full_pacific = mask_basin(ds.o2.isel(time=0), drop=False)
expected_full_pacific = ds.o2
expected_full_pacific = xr.ones_like(expected_full_pacific).where(expected_full_pacific<=cutoff)

In [None]:
test

In [None]:
expected_full_pacific

In [None]:
ds.o2.shape

In [None]:
expected_full_pacific.shape

In [None]:
test_full_pacific.sum().data/expected_full_pacific.sum().load().data

I suspect this is due to the numerial precision going wrong somewhere...

In [None]:
test

In [None]:
# example with masked basin
cutoff = o2_bins_converted[-3]
test_full_pacific = test.sel(basin_mask_bin=slice(1.5, 3.5))
expected_full_pacific = mask_basin(ds.o2, drop=False)
expected_full_pacific = xr.ones_like(expected_full_pacific).where(expected_full_pacific<=cutoff)

In [None]:
test_full_pacific.sum().data/expected_full_pacific.sum().load().data

In [None]:
# example with masked basin
cutoff = o2_bins_converted[-3]
test_full_pacific = test3.sel(basin_mask_bin=slice(1.5, 3.5)).sum()/test2.sel(basin_mask_bin=slice(1.5, 3.5)).sum()
expected_full_pacific = mask_basin(ds, drop=False)
expected_full_pacific = expected_full_pacific.where(expected_full_pacific<=cutoff)
expected_full_pacific = expected_full_pacific.o2.weighted((expected_full_pacific.areacello*expected_full_pacific.dz_t).fillna(0)).mean()

In [None]:
expected_full_pacific.load()

In [None]:
test_full_pacific

## synthetic example for xhistogram

I somehow cannot bin over 'lev'...ok for now. Problem is described in detail [here](https://github.com/xgcm/xhistogram/issues/16)

In [None]:
# I cant get the count numbers to line up. What am I doing wrong here?
da = xr.DataArray(np.random.rand(400,76), name='test')
cutoff = 0.4
hist = histogram(da, bins=np.array([-1e3, cutoff]))
hist

In [None]:
test = xr.ones_like(da).where(da<=cutoff).sum()
test

In [None]:
da.plot()