# SST trends

## Import packages

In [None]:
import tempfile

import cacholote
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import tqdm
import xarray as xr
from c3s_eqc_automatic_quality_control import download, plot
from xarrayMannKendall import Mann_Kendall_test

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
# Note: Time period from December year_start-1 to November year_stop
year_start = 1982
year_stop = 2011

## Define Parameters

In [None]:
# Requests
request_dicts = {
    "ESACCI": {
        "collection_id": "satellite-sea-surface-temperature",
        "request": {
            "processinglevel": "level_4",
            "format": "zip",
            "variable": "all",
            "sensor_on_satellite": "combined_product",
            "version": "2_1",
        },
        "chunks": {"year": 1, "month": 1},
    },
    "GMPE": {
        "collection_id": "satellite-sea-surface-temperature-ensemble-product",
        "request": {
            "format": "zip",
            "variable": "all",
        },
        "chunks": {"year": 1, "month": 1, "day": 12},  # CDS limit is 12
    },
}

# Parameters to speed up I/O
open_mfdataset_kwargs = {
    "concat_dim": "time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
    "parallel": True,
}

## Functions to cache

In [None]:
def get_masked_sst(ds):
    da = ds["analysed_sst"]
    if "mask" in ds:
        da = da.where(ds["mask"] == 1)
    with xr.set_options(keep_attrs=True):
        da -= 273.15
    da.attrs["units"] = "°C"
    return da


def add_chunksizes(da):
    da.encoding["chunksizes"] = tuple(map(max, da.chunks))
    return da


def rechunk(obj):
    """Use NetCDF chunks."""
    chunks = {"time": 1, "year": 1, "season": 1, "latitude": 1_200, "longitude": 2_400}
    obj = obj.chunk(
        **{dim: chunksize for dim, chunksize in chunks.items() if dim in obj.dims}
    )
    if isinstance(obj, xr.DataArray):
        return add_chunksizes(obj)
    for da in obj.data_vars.values():
        add_chunksizes(da)
    return obj


def compute_low_resolution(ds, freq):
    ds = rechunk(ds)
    da = get_masked_sst(ds)
    target_sizes = {"latitude": 720, "longitude": 1440}
    coarsen_dims = {
        dim: ratio
        for dim, target_size in target_sizes.items()
        if (ratio := da.sizes[dim] // target_size) > 1
    }
    with xr.set_options(keep_attrs=True):
        if coarsen_dims:
            da = da.coarsen(coarsen_dims).mean()
        da = da.sortby("time").resample(time=freq).mean()
    for coord in da.coords:
        da[coord].attrs = ds[coord].attrs
    return rechunk(da).to_dataset()


def _mann_kendall(da, **kwargs):
    da = da.chunk({dim: -1 if dim == "time" else "auto" for dim in da.dims})
    coords_name = {"time": "time", "x": "longitude", "y": "latitude"}
    with tempfile.TemporaryDirectory() as tmpdir:
        da.to_zarr(tmpdir)
        da = xr.open_dataarray(tmpdir, engine="zarr", chunks=dict(da.chunksizes))
        ds_trend = Mann_Kendall_test(da, coords_name=coords_name, **kwargs).compute()

    ds_trend = ds_trend.rename({k: v for k, v in coords_name.items() if k != "time"})
    for coord in ds_trend.coords:
        ds_trend[coord].attrs = da[coord].attrs
    ds_trend["trend"].attrs["long_name"] = f"Trend of {da.attrs['long_name']}"
    return ds_trend


@cacholote.cacheable
def compute_mann_kendall_trend(
    collection_id,
    request,
    chunks,
    year_start,
    year_stop,
    seasonal,
    open_mfdataset_kwargs,
    **mann_kendall_kwargs,
):
    dataarrays = []
    for year in tqdm.tqdm(range(year_start, year_stop + 1), desc="annual"):
        requests = download.update_request_date(
            request, start=f"{year-1}-12", stop=f"{year}-11", stringify_dates=True
        )
        ds = download.download_and_transform(
            collection_id=collection_id,
            requests=requests,
            chunks=chunks,
            transform_chunks=False,
            transform_func=compute_low_resolution,
            transform_func_kwargs={"freq": "QE-DEC" if seasonal else "MS"},
            **open_mfdataset_kwargs,
        )
        dataarrays.append(rechunk(ds["analysed_sst"]))
    da = xr.concat(dataarrays, "time")

    if seasonal:
        ds = da.groupby("time.season").map(_mann_kendall, **mann_kendall_kwargs)
        ds["trend"].attrs["units"] = f"{da.attrs['units']}/year"
    else:
        ds = da.groupby("time.year").map(_mann_kendall, **mann_kendall_kwargs)
        ds["trend"].attrs["units"] = f"{da.attrs['units']}/year"
    return rechunk(ds)

## Download and transform

In [None]:
maps_datasets = {}
for seasonal in (True, False):
    datasets = []
    for product, request_dict in request_dicts.items():
        print(f"{seasonal=} {product=}")
        ds = compute_mann_kendall_trend(
            **request_dict,
            year_start=year_start,
            year_stop=year_stop,
            seasonal=seasonal,
            open_mfdataset_kwargs=open_mfdataset_kwargs,
            # Mann Kendall settings
            alpha=0.05,
            method="theilslopes",
        )
        for coord in ("longitude", "latitude"):
            ds[coord] = ds[coord].round(3)
        if not seasonal:
            ds = ds.mean("year", keep_attrs=True)
        ds = ds.expand_dims(product=[product])
        datasets.append(rechunk(ds))
    maps_datasets[f"{seasonal=}"] = xr.concat(datasets, "product")
del datasets

## Plot Maps

In [None]:
projection = ccrs.Robinson()
for ds in maps_datasets.values():
    facet = plot.projected_map(
        ds["trend"],
        projection=projection,
        row="season" if "season" in ds.dims else None,
        col="product",
        robust=True,
        center=0,
    )
    for ax, sel_dict in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
        title = ax.get_title()
        p_value = ds["p"].sel(**sel_dict).reset_coords(drop=True)
        plot.projected_map(
            p_value,
            projection=projection,
            show_stats=False,
            ax=ax,
            cmap="none",
            add_colorbar=False,
            plot_func="contourf",
            levels=[0, 0.05, 1],
            hatches=["", "/" * 5],
        )
        ax.set_title(title)
    plt.show()