# Altimetry

## Import packages

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time period
start = "1993-01"
stop = "2019-12"

## Define request

In [None]:
collection_id = "satellite-sea-level-global"
request = {
    "version": "vDT2021",
    "variable": "daily",
    "format": "zip",
}
requests = download.update_request_date(
    request, start=start, stop=stop, stringify_dates=True
)

## Functions to cache

In [None]:
def compute_extreme(da):
    da_95 = da.chunk({"time": -1}).quantile(q=0.95, dim="time").drop_vars("quantile")
    da_extreme = da.where(da > da_95).mean("time")
    da_extreme.attrs["long_name"] = f"Mean extreme values of {da.attrs['long_name']}"
    da_extreme.attrs["units"] = da.attrs["units"]
    return da_extreme.rename("extreme")


def compute_time_statistics(ds, reduction):
    da = ds["sla"]
    if reduction == "extreme":
        da_reduced = compute_extreme(da)
    else:
        da_reduced = getattr(diagnostics, f"time_weighted_{reduction}")(
            da, weights=False
        )
        da_reduced.attrs = {
            "units": da_reduced.attrs["units"],
            "long_name": f"{reduction.capitalize().replace('_', ' ')} of {da.attrs['long_name']}",
        }
    return da_reduced.to_dataset(name=reduction)


def compute_spatial_weighted_statistics(ds):
    da = ds["sla"]
    ds_diag = diagnostics.spatial_weighted_statistics(da).to_dataset(dim="diagnostic")
    for reduction, da_reduced in ds_diag.data_vars.items():
        da_reduced.attrs = {
            "units": da.attrs["units"],
            "long_name": f"Spatial weighted {reduction} of {da.attrs['long_name']}",
        }
    return ds_diag

In [None]:
chunks = {"year": 1}
xarray_kwargs = {
    # Speedup IO
    "concat_dim": "time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
    "parallel": True,
}
download_and_transform_kwargs = {"chunks": chunks} | xarray_kwargs

## Compute time reductions

In [None]:
datasets = []
for reduction in ("mean", "std", "coverage", "extreme", "linear_trend"):
    if reduction in ["extreme", "linear_trend"]:
        # TODO: Work in progress
        # Do not compute extreme and linear trend, it crashes on the VM.
        continue
    print(f"Computing {reduction=}")
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_func=compute_time_statistics,
        transform_func_kwargs={"reduction": reduction},
        transform_chunks=False,
        **download_and_transform_kwargs,
    )
    datasets.append(ds)
ds_maps = xr.merge(datasets)

## Compute spatial weighted reductions

In [None]:
ds_timeseries = download.download_and_transform(
    collection_id,
    requests,
    transform_func=compute_spatial_weighted_statistics,
    cached_open_mfdataset_kwargs=True,
    **download_and_transform_kwargs,
).sortby("time")

## Plot maps

In [None]:
for var, da in ds_maps.data_vars.items():
    center = 0 if var == "mean" else None
    plot.projected_map(da, projection=ccrs.Robinson(), robust=True, center=center)
    plt.title(f"{da.attrs['long_name']} ({start}, {stop})")
    plt.show()

## Plot timeseries

In [None]:
window = 365
ds_to_plot = {
    "Daily": ds_timeseries,
    f"{window}-day rolling mean": ds_timeseries.rolling(time=window).mean(
        keep_attrs=True
    ),
}

for title, ds in ds_to_plot.items():
    fig, ax = plt.subplots(1, 1)
    for var, da in ds.data_vars.items():
        da.plot(label=var, ax=ax)
    ax.grid()
    ax.set_title("\n".join([da.attrs["long_name"].split(" of ")[-1], title]))
    ax.set_ylabel(f"[{da.attrs['units']}]")
    ax.legend(title="Spatial weighted")