In [None]:
import numpy as np
import xarray as xr
import xgcm
import eddytools as et

In [None]:
path = "/path/to/model/output/" 
eddypath = "/path/to/tracked/eddies/"

### Interpolation of fields for eddy detection
We interpolate some variables and coordinates to F-points on the grid. Having everything on the same grid points facilitates the dtection and tracking. The F-point is chosen because vorticity and the Okubo-Weiss parameter naturally fall on this grid point.

First, we open the dataset and define a grid for the interpolation.

In [None]:
ds = xr.open_zarr(path + "zarr_Diags/output.5d.zarr/")

In [None]:
metrics = {
    ('X'): ['dxC', 'dxG', 'dxF', 'dxV'], # X distances
    ('Y'): ['dyC', 'dyG', 'dyF', 'dyU'], # Y distances
    ('Z'): ['drF', 'drW', 'drS', 'drC'], # Z distances
    ('X', 'Y'): ['rAw', 'rAs', 'rA', 'rAz'] # Areas in x-y plane
}

grid = xgcm.Grid(ds, periodic=["X"], metrics=metrics)

We need some additional masks at points where they are not natively defined.

In [None]:
ds["maskF"] = grid.interp(ds.maskS, "X")
ds["maskF"] = ds.maskF.where(ds.maskF == 0, other=1)
ds['maskZ'] = grid.interp(ds['maskS'], "X", to="left", metric_weighted=["X", "Y"])
ds['maskZ'] = ds['maskZ'].where(ds['maskZ']==1, other=0)
ds = ds.set_coords(["maskZ", "maskF"]).chunk({'XC': 240, 'XG': 240, 'YC': 320, 'YG': 320})

We compute the Okubo-Weiss parameter and add it to the dataset.

In [None]:
data_OW = et.okuboweiss.calc(ds, grid, 'UVEL', 'VVEL')
data_OW["OW"] = data_OW.OW.chunk({"YG": 320, "XG": 240})
data_OW["vort"] = data_OW.vort.chunk({"YG": 320, "XG": 240})
ds_mod = xr.merge([ds, data_OW], compat='override')

We will apply a little workaround to track eddies across the periodic boundary in x-direction later on. We will extend the domain in x direction so that the first 800 km in x are added to the eastern end of the domain.

In [None]:
extension = 800 # extension in km
extend = int(extension / (ds.dxC[0, 0].values / 1000.))

Define which variables to interpolate and which to exclude from interpolation.

In [None]:
variables = ['OW', 'vort']
exclude = ["XC", "XG", "iter", "layer_center"]
all_variables = variables + ["OW_std"]

We loop over every year, first interpolate the data and then artificially extend the domain as mentioned above.

In [None]:
for yy in np.arange(201, 301):
    year = f"{yy:04}"
    print(year)
    time_range_start = year + '-01-01'
    time_range_end = year + '-12-30'
    # interpolation parameters for eddytools
    interpolation_parameters = {'model': 'MITgcm',
                                'grid': 'cartesian',
                                'start_time': time_range_start, # time range start
                                'end_time': time_range_end, # time range end
                                'calendar': '360_day', # calendar, must be either 360_day or standard
                                'lon1': 0.0e6, # minimum longitude of detection region
                                'lon2': 2.4e6, # maximum longitude
                                'lat1': 0.0e6, # minimum latitude
                                'lat2': 3.2e6, # maximum latitude
                                'res': 10., # resolution of the fields in km
                                'vars_to_interpolate': variables,
                                'mask_to_interpolate': ['maskZ', 'maskC', 'maskW', 'maskS', 'Depth'], 
                                                       # masks to interpolate
                                'vars_to_filter': [], # variables to apply spatial filter to
                                'cut_lon': 1500,
                                'cut_lat': 1500
                                }
    # interpolate data
    data_int, _ = et.interp.horizontal(ds_mod, metrics, interpolation_parameters)
    # compute the spatial standard devitation of the Okubo-Weiss parameter at level 15
    # and add that to the interpolated dataset
    print("Computing spatial standard deviation of OW.")
    OW_tmp = data_int['OW']
    OW_tmp = OW_tmp.where(OW_tmp != 0).isel(z=15).persist()
    mean_OW_spatial_std = OW_tmp.rolling(
                          lat=40, lon=(len(OW_tmp.lon) * 2) - 1, center=True, min_periods=1
                          ).std(skipna=True).mean('time')
    data_int = data_int.update({'OW_std': (['lat', 'lon'], 
                   mean_OW_spatial_std.values)}).chunk({'lon': 240, 'lat': 320})
    # Now we extend the domain in x-direction
    print("Extending variables:")
    ds_extended = xr.Dataset({"time": (["time",], data_int.time.data),
                              "z": (["z",], data_int.z.data),
                              "lat": (["lat",], data_int.lat.data),
                              "lon": (["lon",], xr.concat([data_int["lon"], data_int["lon"][0:extend]
                             + (data_int["lon"][-1].values + data_int["lon"][1].values)], dim="lon").data)})
    for var in all_variables:
        if var not in ds_extended.coords:
            print("-", var)
            if "z" in data_int[var].dims:
                if "time" in data_int[var].dims:
                    dims = ["time", "z", "lat", "lon"]
                else:
                    dims = ["z", "lat", "lon"]
            else:
                if "time" in data_int[var].dims:
                    dims = ["time", "lat", "lon"]
                else:
                    dims = ["lat", "lon"]
            if "lon" in data_int[var].dims:
                ds_extended[var] = xr.DataArray(np.concatenate((data_int[var], 
                                                data_int[var].isel(lon=slice(0, extend))), axis=-1),
                                                dims=dims)
            else:
                ds_extended[var] = xr.DataArray(data_int[var], dims=dims)
        if var in variables:
            ds_extended[var].attrs = ds_mod[var].attrs
    ds_extended["time"].attrs = ds["time"].attrs
    ds_extended["z"].attrs = ds["Z"].attrs
    ds_extended["lat"].attrs = ds["YG"].attrs
    ds_extended["lon"].attrs = ds["XG"].attrs
    ds_extended["layer_center"] = xr.DataArray(ds["layer_center"].data, dims=["z",])
    ds_extended["layer_center"].attrs = ds["layer_center"].attrs
    trs = time_range_start.translate({ord('-'): None})
    tre = time_range_end.translate({ord('-'): None})
    print("Saving to disk:", eddypath + 'interp_data.extend800.' + trs + '_' + tre + '.nc')
    ds_extended.sel(time=slice(time_range_start, time_range_end)).to_netcdf(eddypath 
                    + 'interp_data.extend800.' + trs + '_' + tre + '.nc', mode='w')