In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings
import matplotlib.pyplot as plt

In [None]:
import numpy as np
import glob
import xarray as xr
import xbudget
import regionate
import xwmt
import xwmb
import xgcm
import cartopy.crs as ccrs
import CM4Xutils #needed to run pip install nc-time-axis
from regionate import MaskRegions, GriddedRegion

In [None]:
print('xgcm version', xgcm.__version__, '\nregionate version', regionate.__version__, '\nxwmt version', xwmt.__version__, '\nxwmb version', xwmb.__version__)

### Request HPC Resources

In [None]:
from dask_jobqueue import SLURMCluster  # setup dask cluster 
from dask.distributed import Client

log_directory="/vortexfs1/home/anthony.meza/scratch/CM4X/CM4XTransientTracers/WaterMassBudgets/logs"

cluster = SLURMCluster(
    cores=36,
    processes=1,
    memory='190GB',
    walltime='02:00:00',
    queue='compute',
    interface='ib0', 
log_directory = log_directory)
print(cluster.job_script())
cluster.scale(jobs=4)

client = Client(cluster)
client

### Load in data

In [None]:
# Key water mass transformation budget terms

budget_terms = ['Eulerian_tendency', 'advection', 'diffusion', 
            'boundary_fluxes', 'convergent_mass_transport', 
           'mass_tendency', 'mass_source', 'spurious_numerical_mixing', 
           "surface_exchange_flux", "bottom_flux", "frazil_ice", 
            "surface_ocean_flux_advective_negative_rhs"]

other_budget_terms = ["surface_ocean_flux_advective_negative_rhs_heat", 
                     "surface_ocean_flux_advective_negative_rhs_salt", 
                     "surface_exchange_flux_heat", 
                     "surface_exchange_flux_salt", 
                     "frazil_ice_heat", 
                     "bottom_flux_heat",
                     "boundary_fluxes", 
                     "mass_tendency", 
                     "diffusion_heat", 
                     "diffusion_salt",
                     "spurious_numerical_mixing",
                     "convergent_mass_transport"]
budget_terms= sorted(list(set(budget_terms) | set(other_budget_terms)))

In [None]:
decomp_budget_terms = ["surface_exchange_flux_advective_evaporation_salt", 
                    "surface_exchange_flux_advective_rain_and_ice_salt",
                    "surface_exchange_flux_advective_snow_salt",
                    "surface_exchange_flux_advective_rivers_salt", 
                    "surface_exchange_flux_advective_icebergs_salt",
                    "surface_exchange_flux_advective_virtual_precip_restoring_salt",
                    "surface_exchange_flux_advective_sea_ice_salt",
                    "surface_exchange_flux_nonadvective_longwave_heat", 
                    "surface_exchange_flux_nonadvective_shortwave_heat",
                    "surface_exchange_flux_nonadvective_sensible_heat",
                    "surface_exchange_flux_nonadvective_latent_heat", 
                    "surface_exchange_flux_advective_mass_transfer_heat"]

In [None]:
datadir = lambda x="" : "/vortexfs1/home/anthony.meza/scratch/CM4XTransientTracers/data/model/budgets_sigma2/" + x
datafiles = glob.glob(datadir("CM4Xp125*"))[20:]
datafiles = sorted(datafiles)

wmts = []

for (t, file) in enumerate(datafiles): 
    print(file)
    ds = xr.open_mfdataset(
        file,
        data_vars="minimal",
        coords="minimal",
        compat="override",
        parallel=True,
        engine="zarr")
    ds = ds.fillna(0.)

    ds['mask'] = (ds['geolat'] <= -40)


    grid = CM4Xutils.ds_to_grid(ds, Zprefix = "sigma2")
    grid._ds = grid._ds.assign_coords({
        "sigma2_l_target": grid._ds['sigma2_l'].rename({"sigma2_l":"sigma2_l_target"}),
        "sigma2_i_target": grid._ds['sigma2_i'].rename({"sigma2_i":"sigma2_i_target"}),
    })
    grid = xwmt.add_gridcoords(
        grid,
        {"Z_target": {"center": "sigma2_l_target", "outer": "sigma2_i_target"}},
        {"Z_target": "extend"}
    )
    regions = MaskRegions(ds.mask, grid).region_dict
    antarctic = regions[0] #there are more in this list if there are multiple contours 
    region = GriddedRegion("antarctic", antarctic.lons, antarctic.lats, grid, 
                           ij=(antarctic.i, antarctic.j))
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=FutureWarning)
    
        # budgets_dict = xbudget.load_preset_budget(model="MOM6_3Donly")
        budgets_dict = xbudget.load_yaml(datadir("../../MOM6_AABW.yaml"))

        xbudget.collect_budgets(grid, budgets_dict)
        
        wmb = xwmb.WaterMassBudget(
            grid,
            budgets_dict, 
            region
        ) #if region not passed, the whole globe is taken
        wmb.mass_budget("sigma2", greater_than=True, default_bins=False, 
                        integrate=True, along_section=False)
        
        wmt = wmb.wmt[budget_terms].compute()
        wmt = wmt.assign_coords({"sigma2_i_target": wmb.wmt["sigma2_i_target"].compute()})

        wmb_decomp = xwmb.WaterMassBudget(
                grid,
                budgets_dict, 
                region, 
                decompose=["surface_exchange_flux", "nonadvective", "advective"]
                ) #if region not passed, the whole globe is taken
        wmb_decomp.mass_budget("sigma2", greater_than=True, default_bins=False, 
                        integrate=True, along_section=False)
        wmt_decomp = wmb_decomp.wmt[decomp_budget_terms].compute()
        wmt_decomp = wmt_decomp.assign_coords({"sigma2_i_target": wmb_decomp.wmt["sigma2_i_target"].compute()})

        
        wmts += [1 * xr.merge(1 * [wmt_decomp, wmt])]

In [None]:
wmts_ds = xr.concat(wmts, dim = "time")
savedir = "/vortexfs1/home/anthony.meza/scratch/CM4XTransientTracers/data/model/"
wmts_ds.to_netcdf(savedir + "Southern_Ocean_WMT_Budget.nc")