# EKE

## Import libraries

In [None]:
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils
from scipy import ndimage

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Time
start = "1993-01"
stop = "2022-12"

# Global region
lat_slice = slice(-65, 65)
lon_slice = slice(-180, 180)

# Versions
versions = ["vdt2021", "vdt2024"]

# Regions
regions = {
    "GWSE": {
        "lat_slice": slice(-0.593297, 21.390730),
        "lon_slice": slice(42.156397, 65.139224),
    },
    "AC": {
        "lat_slice": slice(-49.705567, -10.668587),
        "lon_slice": slice(-3.992485, 88.354845),
    },
    "KE": {
        "lat_slice": slice(27.214451, 44.797421),
        "lon_slice": slice(126.707357, 178.966069),
    },
    "GS": {
        "lat_slice": slice(24.655584, 52.309994),
        "lon_slice": slice(-79.479982, -25.198502),
    },
    "LC": {
        "lat_slice": slice(17.631824, 32.110687),
        "lon_slice": slice(-98.185393, -81.340371),
    },
    "BMC": {
        "lat_slice": slice(-53.048704, -32.907111),
        "lon_slice": slice(-59.082424, -27.689903),
    },
    "EAC": {
        "lat_slice": slice(-43.201439, -21.725291),
        "lon_slice": slice(143.651991, 167.912370),
    },
}

## Define request

In [None]:
collection_id = "satellite-sea-level-global"
request = {"variable": "daily"}

## Define functions to cache

In [None]:
def compute_eke(ds, lon_slice, lat_slice):
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    eke = 1 / 2 * 1.0e4 * (ds["ugosa"] ** 2 + ds["vgosa"] ** 2)
    eke = eke.resample(time="1MS").mean()
    eke.attrs = {
        "long_name": "Eddy Kinetic Energy",
        "units": "cm^2 s^{-2}",
        "description": "Eddy Kinetic Energy from geostrophic velocity anomalies (ugosa, vgosa)",
    }
    return eke.to_dataset(name="EKE")


def smooth_regions(da, lim):
    struct = ndimage.generate_binary_structure(2, 2)
    values = ndimage.binary_fill_holes(da, structure=struct).astype(int)
    label_da, num_features = ndimage.label(values)
    sizes = ndimage.sum(values, label_da, range(num_features + 1))
    mask_sizes = sizes > lim
    values = mask_sizes[label_da]
    values = ndimage.binary_closing(values, structure=struct).astype(int)
    return xr.DataArray(values, coords=da.coords, dims=da.dims)

## Download and transform

In [None]:
datasets = []
for version in versions:
    print(f"{version=}")
    requests = download.update_request_date(
        request | {"version": version}, start, stop, stringify_dates=True
    )
    ds = download.download_and_transform(
        collection_id,
        requests,
        chunks={"year": 1, "month": 1},
        transform_func=compute_eke,
        transform_func_kwargs={"lon_slice": lon_slice, "lat_slice": lat_slice},
    )
    datasets.append(ds.expand_dims(version=[version]))
ds = xr.concat(datasets, "version")

## Compute mean EKE

In [None]:
eke = ds["EKE"].mean("time", keep_attrs=True).isel(version=0).compute()
plot.projected_map(eke, robust=True)

## Compute masks

In [None]:
threshold = eke.quantile(0.9, dim=["latitude", "longitude"]).drop_vars("quantile")
global_mask = xr.where(eke > threshold, 1, 0)

masks = []
for label, slices in regions.items():
    da = utils.regionalise(global_mask, **slices)
    da = smooth_regions(da, 100 if label == "LC" else 1_000)
    masks.append(da.expand_dims(region=[label]))
masks = xr.concat(masks, "region", join="outer")
masks = xr.concat(
    [
        masks,
        masks.max("region").expand_dims(region=["ALL"]),
        eke.notnull().expand_dims(region=["no ice"]),
    ],
    "region",
    join="outer",
).fillna(0)
plot.projected_map(masks, col="region", col_wrap=3)

## Compute timeseries

In [None]:
eke_timeseries = diagnostics.spatial_weighted_mean(ds["EKE"].where(masks))
facet = eke_timeseries.plot(col="region", hue="version", col_wrap=3)
for ax in facet.axs.flatten():
    ax.grid()