In [1]:
# %load_ext autoreload
# %autoreload 2
%matplotlib inline
import numpy as np
import time
import shutil

import warnings
import intake
import pathlib
import xarray as xr
import pandas as pd
import cf_xarray
import dask
# dask.config.set({"array.slicing.split_large_chunks": True}) # avoid large chunks to be created.

import matplotlib.pyplot as plt
from fastjmd95 import rho

from dask.diagnostics import ProgressBar
import matplotlib.pyplot as plt

from fastprogress.fastprogress import progress_bar

from xarrayutils.file_handling import (
    write,
    maybe_create_folder,
    file_exist_check,
    temp_write_split,
)
from xarrayutils.utils import (
    remove_bottom_values,
    #mask_mixedlayer
)
from cmip6_preprocessing.preprocessing import (
    combined_preprocessing
)
from cmip6_preprocessing.drift_removal import (
    remove_trend,
    match_and_remove_trend
)
from cmip6_preprocessing.utils import (
    cmip6_dataset_id
)

from cmip6_preprocessing.postprocessing import (
    combine_datasets,
    concat_experiments,
    match_metrics,
    merge_variables,
    interpolate_grid_label,
)
from cmip6_preprocessing.drift_removal import match_and_remove_trend

import sys
sys.path.append("../../")
from cmip6_omz.upstream_stash import (
    transform_wrapper,
    pick_first_member,
    construct_static_dz
)
from cmip6_omz.omz_tools import (
    omz_thickness,
    sigma_bins,
    align_missing,
    preprocessing_wrapper,
    vol_consistency_check_wrapper
)

from cmip6_omz.utils import (
    cmip6_collection,
    o2_models,
)

from cmip6_omz.plotting import plot_omz_results

## What adds to the tasks 

- detrending...might want to save out temp after?

In [2]:
import dask
from multiprocessing.pool import ThreadPool
dask.config.set(pool=ThreadPool(12))

<dask.config.set at 0x2b638163cb20>

## What I have done:
- Remove all old refs to the other repos
- Refactoring of the metrics matching
- Using only the regridding to combine variables 
    - Need to patch in Norwegian models
- Single cell for filtering/checking all datasets for required vars/metrics
    - This also logs all the problems in one place


## TODO:

- Test with netcdf archive (or at least update the zarr?
- Test performance with strip encoding?
- [x] Try with new trends
- Gotta fix the logic in the interpolate function to just merge variables that all have the same grid label (Nor ESM)
- [x] **The damn norwegian models have no area...**
- [x] Can I check each variable for the giant chunks after concatting?
- [ ] CM4 age is chunked badly...
- [ ] Figure out how to deal with the access data properly (thickness concat fails)...

# Develop functions here

In [3]:
#This could go upstream in a more general form
## but for now let's keep it here and readable


In [4]:
#this function should go to upstream_stash
# def load_trend_dict(ds_dict, verbose = False):
    
#     path_jb = '/tigress/GEOCLIM/LRGROUP/jbusecke/projects/aguadv_omz_busecke_2021/data/processed/linear_regression_time_zarr_multimember'
#     trendfolder = pathlib.Path(path_jb)
#     trend_models = np.unique([ds.attrs['source_id'] for ds in ds_dict.values()])
#     flist = []
#     for tm in trend_models:
#         flist = flist + list(trendfolder.glob(f'*{tm}*_trend.nc'))
    
#     total = len(flist)
#     progress = progress_bar(range(total))
    
#     trend_dict = {}
#     for i,path in enumerate(flist):
#         key = path.stem
#         ds = xr.open_mfdataset([path])
#         # write the filename in the dataset
#         ds.attrs.update({'filepath':str(path)})
#         # exclude all nan slopes (why are these there in the first place?)
#         if not np.isnan(ds.slope).all():
#             trend_dict[key] = ds
#         else:
#             if verbose:
#                 print(f"found all nan regression data for {path}")
#         progress.update(i)
#     progress.update(total)
    
#     return trend_dict


# #These are fixes so that the trend data works with cmip6_pp match_and_remove_trend
# #these issues should be addressed in the next iteration of trend file production
# def fix_trend_metadata(trend_dict):
#     for name, ds in trend_dict.items():
#         #restore attributes to trend datasets using file names
#         fn = (ds.attrs['filepath']).rsplit("/")[-1]
#         fn_parse = fn.split('_')
#         ds.attrs['source_id'] = fn_parse[2]
#         ds.attrs['grid_label'] = fn_parse[5]
#         ds.attrs['experiment_id'] = fn_parse[3]
#         ds.attrs['table_id'] = fn_parse[4]
#         ds.attrs['variant_label'] = fn_parse[7]
#         ds.attrs['variable_id'] = fn_parse[8]
        
#         #rename 'slope' variable to variable_id
#         if "slope" in ds.variables:
#             ds = ds.rename({"slope":ds.attrs["variable_id"]})
        
#         #error was triggered in line 350 of cmip6_preprocessing.drift_removal
#         ##this is a temporary workaround, and the one part of this function that might
#         ##require an upstream fix (though it might just be an environment issue)
#         ds = ds.drop('trend_time_range')
        
#         trend_dict[name] = ds
        
#     return trend_dict

In [5]:
from fastprogress.fastprogress import progress_bar
from zarr.convenience import consolidate_metadata

def append_write_zarr(ds, store, split_chunks, split_dim='time', consolidate=True):
    """Save a dataset with a loop to avoid blowing up complicated dask graphs"""
    splits = list(range(0, len(ds[split_dim]), split_chunks))
    splits.append(None)
    datasets = []
    for ii in range(len(splits)-1):
        datasets.append(ds.isel({split_dim:slice(splits[ii], splits[ii+1])}))
    
    # .to_zarr needs that we write the first datasets without appending
    datasets[0].to_zarr(store, mode='w')
    for ds_sub in progress_bar(datasets[1:]):
        ds_sub.to_zarr(store, mode='a', append_dim=split_dim)
    
    if consolidate:
        consolidate_metadata(str(store))

In [6]:
def resample_yearly(ds_in, freq="1AS"):
    # this drops some coordinates, so i need to convert them to data_vars and then reconvert
    time_coords = [
        co
        for co in list(ds_in.coords)
        if "time" in ds_in[co].dims and co not in ["time", "time_bounds"]
    ]
    ds_out = ds_in.reset_coords(time_coords).coarsen(time=12).mean()
    ds_out = ds_out.assign_coords({co: ds_out[co] for co in time_coords})
    ds_out.attrs.update({k: v for k, v in ds_in.attrs.items() if k not in ["table_id"]})
    return ds_out

In [7]:
def is_zarr(fn):
    extension = fn.split('.')[-1]
    if extension == 'nc':
        is_zarr = False
    elif extension == 'zarr':
        is_zarr = True
    else:
        raise RuntimeError('Unrecognized File Extension')
    return is_zarr

def reload_preexisting(filename, overwrite = True):
    print("Skipping. File exists already.")
    if is_zarr(filename):
        ds_sigma_reloaded = xr.open_zarr(
            filename, use_cftime=True, consolidated=True
        )
    else:
        ds_sigma_reloaded = xr.open_dataset(
            filename, use_cftime = True
        )
        try:
            plot_omz_results(ds_sigma_reloaded)
        except Exception as e:
            print(f"Plotting failed with: {e}")
    return ds_sigma_reloaded
    
def strip_encoding(ds):
    """Strips the encoding from xr.dataset... This seems like a bug to me."""
    for var in ds.variables:
        ds[var].encoding = {}
    ds.encoding = {}
    return ds

### Local convenience functions for final cell

# Start pipeline here

In [8]:
foldername = "fine_density_tests_combined_2"
# ofolder = maybe_create_folder(f"../../data/external/{foldername}")
ofolder = maybe_create_folder(f"../../data/processed/{foldername}")
tempfolder = maybe_create_folder(f"../../data/temp/scratch_temp/{foldername}")

# global parameters
o2_bins = np.array([10, 40, 60, 80, 100, 120])
fine_sigma_bins = sigma_bins()



In [9]:
col = intake.open_esm_datastore(cmip6_collection(zarr=True)) #TODO: Check with nc files

In [10]:
o2_models()

['CanESM5-CanOE',
 'CanESM5',
 'CNRM-ESM2-1',
 'ACCESS-ESM1-5',
 'MPI-ESM-1-2-HAM',
 'IPSL-CM6A-LR',
 'MIROC-ES2L',
 'UKESM1-0-LL',
 'MPI-ESM1-2-HR',
 'MPI-ESM1-2-LR',
 'MRI-ESM2-0',
 'NorCPM1',
 'NorESM1-F',
 'NorESM2-LM',
 'NorESM2-MM',
 'GFDL-CM4',
 'GFDL-ESM4']

In [11]:
# if this does not work on jupyter.rc, we can add some logic to 
col = intake.open_esm_datastore(cmip6_collection(zarr=True)) #TODO: Check with nc files

z_kwargs={"decode_times": True, "use_cftime": True, "consolidated": True}
n_kwargs={"decode_times": True, "use_cftime": True, 'chunks':{'time':3}}

variable_ids = ["thetao", "so", "o2", "agessc"] #"mlotst"
metric_variable_ids = ["thkcello", "areacello"] #"mlotst"

# models = o2_models()
# models = ['GFDL-ESM4', 'GFDL-CM4', 'ACCESS-ESM1-5']#`,# # shorter test run....,
# models = [m for m in o2_models() if 'GFDL-ESM4' in m or 'Nor' in m]
# models = [m for m in o2_models() if ('ACCESS' not in m and 'GFDL' not in m and 'HR' not in m)]
models = [
#     'MPI-ESM1-2-HR',
    'MRI-ESM2-0',
    'NorESM2-LM',
#     'GFDL-CM4',
    'GFDL-ESM4',
]

cat = col.search(
    source_id = models,
    grid_label=["gr", "gn"],
    experiment_id=["historical", "ssp585"],
    table_id=["Omon"],
    variable_id=variable_ids,
)
ds_dict = cat.to_dataset_dict(
        aggregate=False,
        zarr_kwargs=z_kwargs,
        cdf_kwargs=n_kwargs,
        preprocess=combined_preprocessing,
    )

# make a separate metric dict to catch all possible metrics!
cat_metrics = col.search(source_id=models,variable_id=metric_variable_ids)
ds_metric_dict = cat_metrics.to_dataset_dict(
        aggregate=False,
        zarr_kwargs=z_kwargs,
        cdf_kwargs=n_kwargs,
        preprocess=combined_preprocessing,
    )


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.version.zstore'



--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.member_id.table_id.variable_id.grid_label.version.zstore'


In [12]:
col.df['source_id'].unique()

array(['CNRM-ESM2-1', 'UKESM1-0-LL', 'GFDL-ESM4', 'GFDL-CM4', 'CanESM5',
       'CanESM5-CanOE', 'MPI-ESM1-2-HR', 'ACCESS-ESM1-5', 'MRI-ESM2-0',
       'MIROC-ES2L', 'IPSL-CM6A-LR', 'NorESM2-LM', 'NorESM2-MM',
       'MPI-ESM1-2-LR', 'MPI-ESM-1-2-HAM', 'NorESM1-F', 'NorCPM1'],
      dtype=object)

## Rechunk the data

In [13]:
def rechunk(ds):
    if 'time' in ds.dims:
        return ds.chunk({'time':1})
    else:
        return ds

ds_dict = {k: rechunk(ds) for k,ds in ds_dict.items()}
ds_metric_dict = {k: rechunk(ds) for k,ds in ds_metric_dict.items()}

In [14]:
# new files (change in later and get rid of `load_trend_dict` (or refactor?) and `fix_trend_metadata`)
# Load all trend files
flist = list(pathlib.Path('../../data/external/cmip6_control_drifts/').absolute().glob('*.nc'))
flist = [f for f in flist if any([v in str(f) for v in variable_ids])]
trend_dict = {}
for f in progress_bar(flist):
    trend_dict[f.stem] = xr.open_mfdataset([f])
#     trend_dict[f.stem] = xr.open_dataset(f)

In [15]:
# these ones are messed up...need a better way to deal with that in the previous step
# see https://github.com/jbusecke/cmip6_preprocessing/issues/175
incomplete_keys = ['CMIP.IPSL.IPSL-CM6A-LR.historical.r3i1p1f1.Omon.gn.none.area_o2']
trend_dict = {k:ds for k,ds in trend_dict.items() if k not in incomplete_keys}

In [None]:
ddict_tracers_detrended = match_and_remove_trend(
    ds_dict,
    trend_dict,
#     check_mask=False
)
# print('THIS IS DANGEROUS. CHECK THE MASKS!')



## Match metrics (there are still quite a few missing).

In [None]:
# this one causes problems because the time is not as long as the full data...
problem_keys = [
    #shorter run? Missing beginning?
    'CMIP.CNRM-CERFACS.CNRM-ESM2-1.historical.r6i1p1f2.Omon.so.gn.v20200117./projects/GEOCLIM/LRGROUP/jbusecke/projects/cmip_data_management_princeton/builder/../zarr_conversion/CMIP6/CMIP/CNRM-CERFACS/CNRM-ESM2-1/historical/r6i1p1f2/Omon/so/gn/v20200117/CMIP.CNRM-CERFACS.CNRM-ESM2-1.historical.r6i1p1f2.Omon.so.gn.v20200117.zarr'
]
# ddict_tracers_detrended_filtered = {k:ds.squeeze() for k, ds in ddict_tracers_detrended.items() if k not in problem_keys}
ddict_tracers_detrended_filtered = {k:ds.squeeze() for k, ds in ddict_tracers_detrended.items() if not ("CNRM-ESM2-1" in k and "r6i1p1f2" in k)}

In [None]:
print('matching metrics\n')
ddict_matched = match_metrics(ddict_tracers_detrended_filtered, ds_metric_dict, ['areacello', 'thkcello'], print_statistics=True)

Do I need to rechunk here for the high res models? I am currently doing this for CM4 and ESM4, but I might have to adjust the source data...

In [None]:
print('interpolate grids\n')
ddict_matched_regrid = interpolate_grid_label(ddict_matched, merge_kwargs={'compat':'override'}) # This should be a default soon

In [None]:
#patch the norwegian model in manually
ddict_patch = merge_variables(ddict_matched)
for name, ds in ddict_patch.items():
    if 'Nor' in name and 'gr' in name:
        patch_name = name.replace('.gr','')
        ddict_matched_regrid[patch_name] = ds

In [None]:
np.sort(list(ddict_matched_regrid.keys()))

## Concatenate experiments and pick the first full one

In [None]:
# somehow xarray cannot deal with comparing list/int attrs (Occurs in CM4)
# I should raise that, but lets fix it quickly here
def clean_attrs(ds):
    for a, attr in ds.attrs.items():
        if isinstance(attr, int):
            ds.attrs[a] = [attr]
    return ds

ddict_matched_regrid = {k:clean_attrs(ds) for k, ds in ddict_matched_regrid.items()}

ddict_ex_combined = concat_experiments(
    ddict_matched_regrid,
    concat_kwargs={
        'combine_attrs': 'drop_conflicts',
        'compat': 'override',
        'coords': 'minimal'
    }
)

## Quick fix for inhomogenous metrics
I have to think about this more. So basically some of the models (ACCESS) have time variables thickness for ssp585 and static for the historical.
This leads to huge dask chunks. For now I am taking those out, which will lead to a static recompute later...

In [None]:
def check_chunks(ds):
    trigger_vars = []
    for var in ds.variables:
        if isinstance(ds[var].data, dask.array.Array):
            for di, ch in zip(ds[var].dims, ds[var].data.chunks):
                if di == 'time':
                    if any([c>10 for c in list(ch)]):
                        trigger_vars.append(var)
                    
    return trigger_vars

# drop the variables in question
ddict_ex_combined_filtered = {}
for name,ds in ddict_ex_combined.items():
    check = check_chunks(ds)
    if len(check)>0:
        print(name)
        print(check)
    ds = ds.drop(check)
    ddict_ex_combined_filtered[name] = ds

## Outstanding issue ACCESS cant combine with some having no thkness
So basically in this example:
```python
ds1 = ddict_matched_regrid['ACCESS-ESM1-5.historical.Omon.r1i1p1f1']
ds2 = ddict_matched_regrid['ACCESS-ESM1-5.ssp585.Omon.r1i1p1f1']
ds2

ds_combined = xr.concat([ds1.drop('thkcello'), ds2], 'time', **{'combine_attrs': 'drop_conflicts', 'compat': 'override', 'coords': 'minimal'})
ds_combined
```
I figured that the thkcello should be dropped, but xarray fails. Raise an issue about that. Otherwise Ill have to check in the combination function...


In [None]:
# only pick full runs (historical and ssp585)
ddict_ex_combined_full = {k:ds for k,ds in ddict_ex_combined_filtered.items() if len(ds.time)>3000}

In [None]:
ddict_ex_combined_full.keys()

## Check datasets for completeness and log the ones with problems

In [None]:
from cmip6_preprocessing.grids import combine_staggered_grid
problems = {'missing_variables':[], 'missing_area':[], 'missing_thickness':[], 'reconstructed_area':[], 'reconstructed_thickness':[]}
ddict_filtered = {}
for name, ds in ddict_ex_combined_full.items():
    flag = False
    # Check that all necessary variables are given
    missing_variables = [va for va in ["thetao", "so", "o2"] if va not in ds.variables]
    if len(missing_variables)>0:
        flag = True
        problems['missing_variables'].append((name, missing_variables))
        
    # Check for area
    if not 'areacello' in ds.coords:
        if ds.attrs['grid_label'] == 'gr': # only reconstruct for regular grids
            grid, ds = combine_staggered_grid(ds, recalculate_metrics=True)
            # I am dropping dz_t here so it can be uniformly reconstructed
            ds = ds.drop('dz_t')
            ds = ds.assign_coords(areacello = (ds.dx_t * ds.dy_t).reset_coords(drop=True))
            problems['reconstructed_area'].append(name)
            assert 'areacello' in ds.coords
        else:
            flag = True
            problems['missing_area'].append(name)
    
    # Check for thickness (and rename) TODO: We should probably not rename and just refactor to use `thkcello`
    if "thkcello" in ds.coords:
        ds = ds.rename({'thkcello': 'dz_t'})
    else:
        # try to reconstruct the thickness from static info
        try:
#             lev_vertices = cf_xarray.bounds_to_vertices(ds.lev_bounds, 'bnds').load()
#             dz_t = lev_vertices.diff('lev_vertices')
#             ds = ds.assign_coords(dz_t=('lev', dz_t.data))
            ds = construct_static_dz(ds).rename({'thkcello': 'dz_t'})
            problems['reconstructed_thickness'].append(name)
        except Exception as e:
            print(f'{name} thickness reconstruction failed with {e}')
            print(ds)
            problems['missing_thickness'].append(name)
            flag=True
            
    if not flag:
        ddict_filtered[name] = ds

In [None]:
list(np.sort(list(ddict_filtered.keys())))

In [None]:
problems

In [None]:
ddict_final = pick_first_member(ddict_filtered)#
list(np.sort(list(ddict_final.keys())))

## Hacking time 😎

Not sure if this actually improved things...but it reduces the number of tasks...which is generally good.

Bring this over to xarrayutils (more info/test in `dev_efficient_bottom_removal`)

In [None]:
# just code that shit in numba
from numba import float64, guvectorize
import numpy as np
import xarray as xr

@guvectorize(
    [
        (float64[:], float64[:]),
    ],
    "(n)->(n)",
    nopython=True,
)
def _remove_last_value(data, output):
    # initialize output
    output[:] = data[:]
    for i in range(len(data)-1):
        if np.isnan(output[i+1]):
            output[i] = np.nan
    # take care of boundaries
    if not np.isnan(output[-1]):
        output[-1] = np.nan

def remove_bottom_values_numba(da, dim='lev'):
    
    out = xr.apply_ufunc(
        _remove_last_value,
        da,
        input_core_dims=[[dim]],
        output_core_dims=[[dim]],
        dask="parallelized",
        output_dtypes=[da.dtype],
    )
    return out

def remove_bottom_values_recoded(ds, dim="lev", fill_val=-1e10):
    """Remove the deepest values that are not nan along the dimension `dim`"""
    # for now assume that values of `dim` increase along the dimension
    if ds[dim][0] > ds[dim][-1]:
        raise ValueError(
            f"It seems like `{dim}` has decreasing values. This is not supported yet. Please sort before."
        )
    else:
        ds_masked = xr.Dataset({va:remove_bottom_values_numba(ds[va]) for va in ds.data_vars})
        ds_masked = ds_masked.transpose(*tuple([di for di in ds.dims if di in ds_masked]))
        ds_masked = ds_masked.assign_coords({co:ds[co].transpose(*[di for di in ds.dims if di in ds[co]]) for co in ds.coords})
        ds_masked.attrs = ds.attrs
        ds_masked = ds_masked
        return ds_masked

## The final loop to vertiLocalClustery transform to sigma-space and save output

In [None]:
from cmip6_omz.omz_tools import omz_thickness_efficient

In [None]:
from IPython.core.display import display, HTML
def print_html(ds):
    display(HTML(ds._repr_html_()))

In [None]:
# I will have to process the control runs seperately
#         if ds.attrs["experiment_id"] == "piControl":
#             ds = ds.isel(time=slice(-300 * 12, None))



# overwrite = True
overwrite = False
with warnings.catch_warnings():
    warnings.filterwarnings('ignore')# might need to remove later...
    for synthetic in [True, False]:
        for mi, (name, ds) in enumerate(ddict_final.items()):
            
            t0 = time.time()
            synthetic_string = 'synthetic example' if synthetic else ' '
            print(f"######################{name} {synthetic_string} ({mi+1}/{len(ddict_filtered)}) ###############")

            dataset_id = f"{cmip6_dataset_id(ds)}"

            if synthetic:
                filename = ofolder.joinpath(f"{dataset_id}_synthetic.zarr")
            else:
                filename = ofolder.joinpath(f"{dataset_id}.zarr")


            if file_exist_check(filename) and not overwrite:
                ds_sigma_reloaded = reload_preexisting(str(filename))
            else:
                print(f"Writing to {filename}")
                tempfilelist = []
                
                ds = preprocessing_wrapper(ds)
                
                # clean up the chunk encoding (can probably be dropped in newer xarray versions but leave for now)
                ds = strip_encoding(ds)
                
                # I need to align.mask the thickness aswell!
                ds = ds.reset_coords(["dz_t"])
                #perform nan-masking functions
                ds = align_missing(ds)
    #             ds = remove_bottom_values(ds)
                ds = remove_bottom_values_recoded(ds)
                ds = ds.set_coords("dz_t")

                # reconstruct the potential density
                ds["sigma_0"] = (rho(ds.so, ds.thetao, 0) - 1000)
                
                # If active create synthetic control dataset with constant historical o2

                if synthetic:
                    with ProgressBar():
                        o2_hist = ds.o2.sel(time=slice('1850', '1900')).mean('time').load()
                    o2_hist_broadcasted = xr.ones_like(ds.sigma_0) * o2_hist
                    ds = ds.assign(o2=o2_hist_broadcasted)

                    assert np.allclose(ds.o2.isel(time=0).load(), ds.o2.isel(time=-100).load(), equal_nan=True)
                    assert not np.allclose(ds.sigma_0.isel(time=0), ds.sigma_0.isel(time=-100), equal_nan=True)
                

                o2_bin_chunks=-1
                
                if 'GFDL' in name or 'HR' in name:
                    #################################################################
                    # rechunk the high res models here, they always crash otherwise #
                    #################################################################
#                     # age is messed up in CM4/ESM4 drop that for now
#                     if 'agessc' in ds.data_vars:
#                         ds = ds.drop(['agessc'])
                    
                    # also set some other parameters
                    o2_bin_chunks = 1
                    
                    print(f"Temp saving to")
                    with ProgressBar():
                        ds_reloaded, tempfilelist_var = temp_write_split(
                            ds,
                            tempfolder.joinpath(f"{name}_rechunked"),
                            verbose=False,
    #                         method='variables',
                            method='dimension',
                            split_interval=24,
                        )
                        tempfilelist.extend(tempfilelist_var)
                    print_html(ds_reloaded)
                    ds = ds_reloaded
                

                    
                ds["omz_thickness"] = omz_thickness_efficient(
                    ds, o2_bins=o2_bins, bin_chunks=o2_bin_chunks
                )
        
#                 print_html(ds)

                ds_sigma_monthly = transform_wrapper(ds, sigma_bins=fine_sigma_bins)
                
#                 print_html(ds_sigma_monthly)

                # Check that the total ocean volume has not changed in the transformation
                assert vol_consistency_check_wrapper(ds, ds_sigma_monthly)

                # average yearly (otherwise the outputs become huuuuge)
                ds_sigma_yearly = resample_yearly(ds_sigma_monthly)
                    
#                     ds_sigma_yearly_reloaded, tempfilelist_var = temp_write_split(
#                         ds_sigma_yearly,
#                         tempfolder,
#                         verbose=False,
#                         method='dimension',
#                         split_interval=1 if len(ds.x)>400 else 10,
#                     )
#                 tempfilelist.extend(tempfilelist_var)

                #################### write out results ########################
#                 ds_sigma_reloaded = write(
#                     ds_sigma_yearly_reloaded,
#                     filename,
#                     overwrite=False,
#                     force_load=False,
#                     check_zarr_complete=True,
#                 )

                dim_split = 5
                if len(ds.x)> 400:
                    dim_split = 2
                if len(ds.x)>1000:
                    dim_split = 1
                
                with ProgressBar():
                    append_write_zarr(ds_sigma_yearly, filename, 10)
                
                ds_sigma_reloaded = xr.open_zarr(
                    filename,
                    consolidated=True,
                    use_cftime=True
                )

                ###### delete temps ######
                print('removing temps')
                for tf in tempfilelist:
                    if tf.exists():
                        shutil.rmtree(tf)

                        
                # Check metadata
                for ma in ['source_id', 'grid_label', 'table_id', 'variant_label']:
                    assert ds.attrs[ma] == ds_sigma_reloaded.attrs[ma]

            ##################### Verification plotting ##########################
            print('plotting results')
            try:
                plot_omz_results(ds_sigma_reloaded)
            except Exception as e:
                print(f"Plotting failed with: {e}")
            plt.show()
            t1 = time.time()
            print(f"Time passed: {(t1-t0)/60 } minutes")

In [None]:
# shutil.rmtree('/home/jbusecke/projects/cmip6_omz/data/processed/fine_density_tests_combined_2/none.none.MPI-ESM1-2-HR.none.r1i1p1f1.Omon.gn.none.none.zarr')

- ~~CanESM5 crapped out (only for the variable o2 case)~~

## Can I save the output?

import pathlib

In [None]:
dont execute this....

In [None]:
# import dask
# from multiprocessing.pool import ThreadPool
# dask.config.set(pool=ThreadPool(6))

In [None]:
import pathlib
import xarray as xr
import numpy as np
foldername = "fine_density_tests_combined_2"
tempfolder = pathlib.Path(f"../../data/temp/scratch_temp/{foldername}")

In [None]:
datasets = [
    xr.open_zarr(
        tempfolder.joinpath(f"temp_write_split_{str(i)}.zarr"),
        consolidated=False,
        use_cftime=True
    ) for i in range(251)
]

In [None]:
ds = xr.concat(datasets, 'time', compat='override', coords='minimal')

In [None]:
foldername = "fine_density_tests_combined_2"
# ofolder = maybe_create_folder(f"../../data/external/{foldername}")
ofolder = pathlib.Path(f"../../data/processed/{foldername}")

In [None]:
from cmip6_preprocessing.utils import cmip6_dataset_id

In [None]:
manual_store = ofolder.joinpath(f"{cmip6_dataset_id(ds)}_synthetic.zarr")

In [None]:
import shutil
shutil.rmtree(manual_store)

In [None]:
append_write_zarr(ds, manual_store, 20)

In [None]:
xr.open_zarr(manual_store).thetao.isel(sigma_0=20).mean(['x', 'y']).plot()

I have to loop to write from one zarr to another? WTF is wrong with this machine?