In [150]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [151]:
import warnings

import numpy as np
import xarray as xr
from xgcm import Grid

import matplotlib.pyplot as plt

In [152]:
import xwmt
import xwmb

### Load grids and data

In [153]:
gridname = 'natv' #   choose from ['natv', 'rho2', 'zstr']
dt = 'hourly' #       choose from ['monthly', 'daily', 'hourly']

In [None]:
# Load data on native grid
sim = "wmt_incsurffluxes.natv_rho2_zstr.monthly_daily_hourly.13months"
rootdir = f"/archive/Graeme.Macgilchrist/MOM6-examples/ice_ocean_SIS2/Baltic_OM4_025/{sim}/Baltic_OM4_025/"
prefix = '19000101.ocean_'+dt+'_'
time = "190*"
#time = "1900_02_01"

# Diagnostics were saved into different files
suffixs = ['surf','thck','heat','salt','xtra']
Zprefixes = {'rho2':'rho2_', 'zstr':'z_', 'natv':'z'}
Zprefix = Zprefixes[gridname]
ds = xr.Dataset()
for suffix in suffixs:
    if suffix == "surf":
        filename = prefix+suffix+'_'+time+'.nc'
    else:
        filename = prefix+gridname+'_'+suffix+'_'+time+'.nc'
    dsnow = xr.open_mfdataset(rootdir+filename)
    ds = xr.merge([ds,dsnow])
    
# Load snapshot data (for mass tendency term)
suffix = 'snap'
filename = prefix+gridname+'_'+suffix+'_'+time+'.nc'
snap = xr.open_mfdataset(rootdir+filename)

# Align N+1 snapshots so they bound N averages, and select year-long subset
ds = ds.sel(time=slice(snap.time[0], snap.time[-1]))
ds = ds.sel(time=slice('1900-02-01 00', '1901-02-01 00'))
snap = snap.sel(time=slice('1900-02-01 00', '1901-02-01 00'))

#  Load grid
oceangridname = '19000101.ocean_static.nc'
ocean_grid = xr.open_dataset(rootdir+oceangridname).squeeze()

# Some renaming to match hdrake conventions
ocean_grid = ocean_grid.rename({'depth_ocean':'deptho'})
ds = ds.rename({'temp':'thetao'})
snap = snap.rename({'temp':'thetao'})

# Merge snapshots with time-averages
snap = snap.rename(
    {**{'time':'time_bounds'}, **{v:f"{v}_bounds" for v in snap.data_vars}}
)
ds = xr.merge([ds, snap])

# Add core coordinates of ocean_grid to ds
ds = ds.assign_coords({
    "wet": xr.DataArray(ocean_grid["wet"].values, dims=('yh', 'xh',)),
    "areacello": xr.DataArray(ocean_grid["areacello"].values, dims=('yh', 'xh',)),
    'xq': xr.DataArray(ocean_grid['xq'].values, dims=('xq',)),
    'yq': xr.DataArray(ocean_grid['yq'].values, dims=('yq',)),
    'geolon': xr.DataArray(ocean_grid['geolon'].values, dims=('yh','xh')),
    'geolat': xr.DataArray(ocean_grid['geolat'].values, dims=('yh','xh')),
    'geolon_u': xr.DataArray(ocean_grid['geolon_u'].values, dims=('yh','xq')),
    'geolat_u': xr.DataArray(ocean_grid['geolat_u'].values, dims=('yh','xq')),
    'geolon_v': xr.DataArray(ocean_grid['geolon_v'].values, dims=('yq','xh')),
    'geolat_v': xr.DataArray(ocean_grid['geolat_v'].values, dims=('yq','xh')),
    'geolon_c': xr.DataArray(ocean_grid['geolon_c'].values, dims=('yq','xq')),
    'geolat_c': xr.DataArray(ocean_grid['geolat_c'].values, dims=('yq','xq')),
    'dxCv': xr.DataArray(ocean_grid['dxCv'].values, dims=('yq', 'xh',)),
    'dyCu': xr.DataArray(ocean_grid['dyCu'].values, dims=('yh', 'xq',)),
    'deptho': xr.DataArray(ocean_grid['deptho'].values, dims=('yh', 'xh',)),
})

In [None]:
lam = 'heat' # choose from ['temperature', 'salinity', 'sigma2']
lambda_name = 'thetao'

# Add lambda target grid coordinates
dlam = 0.25
ds = ds.assign_coords({
    f"{lambda_name}_l" : np.arange(-2, 30, dlam),
    f"{lambda_name}_i" : np.arange(-2-dlam/2., 30+dlam/2, dlam)
})

# z-coordinate dataset containing basic state variables
coords = {
    'X': {'center': 'xh', 'right': 'xq'},
    'Y': {'center': 'yh', 'right': 'yq'},
    'Z': {'center': f'{Zprefix}l', 'outer': f'{Zprefix}i'},
    'lam': {'center': f'{lambda_name}_l', 'outer': f'{lambda_name}_i'}
}
metrics = {
    ('X','Y'): "areacello",
}
grid = Grid(ds, coords=coords, metrics=metrics, periodic=None)

In [None]:
import yaml

with open("../conventions/MOM6.yaml", "r") as stream:
    try:
        budgets_dict = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

In [None]:
lons = np.array([15.,   20.,  29., 24.5, 24.5, 26.1, 17.5, 11.5])
lats = np.array([53.5, 53.5, 54.5,  59.,  61.,  63., 64.5,  62.])
region = (lons, lats)

region = xr.ones_like(grid._ds['deptho'])

wmb = xwmb.WaterMassBudget(
    grid,
    budgets_dict,
    region
)

In [None]:
import warnings
with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    wmb.mass_budget("heat")

for v in wmb.wmt:
    wmb.wmt[v].load()
wmb.wmt.to_netcdf(f"/work/hfd/codedev/xwmb/data/budget_{lam}_{dt}_{gridname}.nc", mode="w")