# Altimetry

## Import packages

In [None]:
import tempfile

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import tqdm
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time period
start = "1993-01"
stop = "2022-12"

# Region of interest
region = "global"
assert region in ("black-sea", "global", "mediterranean")

## Define Request

In [None]:
collection_id = f"satellite-sea-level-{region}"
request = {
    "version": "vDT2021",
    "variable": "daily",
    "format": "zip",
}
requests = download.update_request_date(
    request, start=start, stop=stop, stringify_dates=True
)

## Functions to cache

In [None]:
def rechunk(da, tmpdir):
    da = da.sortby("time")
    da.encoding = {}

    target_store = f"{tmpdir}/temporary.zarr"
    append_dim = None
    for _, da_chunk in tqdm.tqdm(da.resample(time="5D")):
        da_chunk.chunk(time=-1).to_zarr(target_store, append_dim=append_dim)
        append_dim = "time"
    da = xr.open_dataarray(target_store, chunks={}, engine="zarr")
    da.encoding = {}

    target_store = f"{tmpdir}/target.zarr"
    da = da.chunk(time=5_000, latitude=10, longitude=10)
    da.to_zarr(target_store)
    return xr.open_dataarray(target_store, chunks={}, engine="zarr")


def compute_extreme(da):
    da_95 = (
        da.chunk(
            {
                dim: -1 if dim == "time" else chunksize
                for dim, chunksize in da.chunksizes.items()
            }
        )
        .quantile(q=0.95, dim="time")
        .drop_vars("quantile")
    )
    da_extreme = da.where(da > da_95).mean("time")
    da_extreme.attrs = {
        "units": da.attrs["units"],
        "long_name": f"Mean extrema of {da.attrs['long_name']}",
    }
    return da_extreme.rename("extreme")


def compute_time_statistics(ds, reductions):
    with tempfile.TemporaryDirectory() as tmpdir:
        print(f"{tmpdir=}")
        da = rechunk(ds["sla"], tmpdir)
        dataarrays = []
        for reduction in reductions:
            print(f"{reduction=}")
            if reduction == "extreme":
                da_reduced = compute_extreme(da)
            else:
                da_reduced = getattr(diagnostics, f"time_weighted_{reduction}")(
                    da, weights=False
                )
                da_reduced.attrs = {
                    "units": da_reduced.attrs["units"],
                    "long_name": f"{reduction.capitalize().replace('_', ' ')} of {da.attrs['long_name']}",
                }
            dataarrays.append(da_reduced.rename(reduction).compute())
    return xr.merge(dataarrays)


def compute_spatial_weighted_statistics(ds):
    da = ds["sla"].chunk(latitude=-1, longitude=-1)
    ds_diag = diagnostics.spatial_weighted_statistics(da).to_dataset(dim="diagnostic")
    for reduction, da_reduced in ds_diag.data_vars.items():
        da_reduced.attrs = {
            "units": da.attrs["units"],
            "long_name": f"Spatial weighted {reduction} of {da.attrs['long_name']}",
        }
    return ds_diag

## Chunking and I/O keyword arguments

In [None]:
chunks = {"year": 1}
xarray_kwargs = {
    # Speedup IO
    "concat_dim": "time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
}
download_and_transform_kwargs = {"chunks": chunks} | xarray_kwargs

## Compute time reductions

In [None]:
ds_maps = download.download_and_transform(
    collection_id,
    requests,
    transform_func=compute_time_statistics,
    transform_func_kwargs={
        "reductions": ("mean", "std", "coverage", "linear_trend", "extreme")
    },
    transform_chunks=False,
    **download_and_transform_kwargs,
)

# Convert units
with xr.set_options(keep_attrs=True):
    ds_maps["linear_trend"] *= 1.0e3 * 60 * 60 * 24 * 365
ds_maps["linear_trend"].attrs["units"] = "mm/year"

## Compute spatial weighted reductions

In [None]:
ds_timeseries = download.download_and_transform(
    collection_id,
    requests,
    transform_func=compute_spatial_weighted_statistics,
    cached_open_mfdataset_kwargs=True,
    **download_and_transform_kwargs,
).sortby("time")

## Plot maps

In [None]:
central_longitude = ds_maps["longitude"].mean().values
if region == "global":
    projection = ccrs.Robinson(central_longitude=central_longitude)
else:
    projection = ccrs.Mercator(
        central_longitude=central_longitude,
        min_latitude=ds_maps["latitude"].min().values,
        max_latitude=ds_maps["latitude"].max().values,
    )
for var, da in ds_maps.data_vars.items():
    plot.projected_map(
        da,
        projection=projection,
        robust=True,
        center=0 if var in ("mean", "linear_trend") else None,
    )
    plt.title(f"{da.attrs['long_name']} ({start}, {stop})")
    plt.show()

## Plot timeseries

In [None]:
window = 365
ds_to_plot = {
    "Daily": ds_timeseries,
    f"{window}-day rolling mean": ds_timeseries.rolling(time=window, center=True).mean(
        keep_attrs=True
    ),
}

for title, ds in ds_to_plot.items():
    fig, ax = plt.subplots(1, 1)
    for var, da in ds.data_vars.items():
        da.plot(label=var, ax=ax)
    ax.grid()
    ax.set_title("\n".join([da.attrs["long_name"].split(" of ")[-1], title]))
    ax.set_ylabel(f"[{da.attrs['units']}]")
    ax.legend(title="Spatial weighted", bbox_to_anchor=(1, 1))