# Assessment of the SST climatology and variability

## Import packages

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pymannkendall as mk
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
year_start = 1982
year_stop = 2011

# Regions
regions = {
    "northern hemisphere": {"lon_slice": slice(-180, 180), "lat_slice": slice(0, 90)},
    "southern hemisphere": {"lon_slice": slice(-180, 180), "lat_slice": slice(-90, 0)},
}

## Define Parameters

In [None]:
# Requests
request_dicts = {
    "ESACCI": {
        "collection_id": "satellite-sea-surface-temperature",
        "request": {
            "processinglevel": "level_4",
            "format": "zip",
            "variable": "all",
            "sensor_on_satellite": "combined_product",
            "version": "2_1",
        },
        "chunks": {"year": 1, "month": 1},
    },
    "GMPE": {
        "collection_id": "satellite-sea-surface-temperature-ensemble-product",
        "request": {
            "format": "zip",
            "variable": "all",
        },
        "chunks": {"year": 1, "month": 1, "day": 12},  # CDS limit is 12
    },
}

# Parameters to speed up I/O
open_mfdataset_kwargs = {
    "concat_dim": "time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
    "parallel": True,
}

## Functions to cache

In [None]:
def get_masked_sst(ds):
    da = ds["analysed_sst"]
    if "mask" in ds:
        da = da.where(ds["mask"] == 1)
    with xr.set_options(keep_attrs=True):
        da -= 273.15
    da.attrs["units"] = "°C"
    return da


def add_chunksizes(da):
    da.encoding["chunksizes"] = tuple(map(max, da.chunks))
    return da


def rechunk(obj):
    """Use NetCDF chunks."""
    chunks = {"time": 1, "year": 1, "season": 1, "latitude": 1_200, "longitude": 2_400}
    obj = obj.chunk(
        **{dim: chunksize for dim, chunksize in chunks.items() if dim in obj.dims}
    )
    if isinstance(obj, xr.DataArray):
        return add_chunksizes(obj)
    for da in obj.data_vars.values():
        add_chunksizes(da)
    return obj


def compute_regionalised_spatial_weighted_mean(ds, lon_slice, lat_slice):
    ds = rechunk(ds)
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    da = get_masked_sst(ds)
    return diagnostics.spatial_weighted_mean(da, weights=True).to_dataset()

## Download and transform

In [None]:
# Initialize variables
datasets = []
for product, request_dict in request_dicts.items():
    requests = download.update_request_date(
        request_dict["request"],
        start=f"{year_start}-01",
        stop=f"{year_stop}-12",
        stringify_dates=True,
    )
    for region, slices in regions.items():
        print(f"{product=} {region=}")
        ds = download.download_and_transform(
            collection_id=request_dict["collection_id"],
            requests=requests,
            chunks=request_dict["chunks"],
            transform_chunks=True,
            transform_func=compute_regionalised_spatial_weighted_mean,
            transform_func_kwargs=slices,
            cached_open_mfdataset_kwargs=open_mfdataset_kwargs,
            **open_mfdataset_kwargs,
        )
        datasets.append(ds.expand_dims(product=[product], region=[region]))
da = xr.merge(datasets)["analysed_sst"].compute()

## Plot timeseries and trend

In [None]:
colors = ["green", "blue"]
for region, da_region in da.groupby("region"):
    fig, ax = plt.subplots()
    for color, (product, da_product) in zip(colors, da_region.groupby("product")):
        da_product.plot(ax=ax, label=product, color=color)
        trend, h, p, z, tau, s, var_s, slope, intercept = mk.original_test(da_product)
        ax.plot(
            da_product["time"],
            np.arange(da_product.sizes["time"]) * slope + intercept,
            label=f"{product} trend {p=}",
            color=color,
            ls="--",
        )
        ax.set_title(f"{region.title()} ({year_start}-{year_stop})")
    ax.legend()
    ax.grid()
    plt.show()

## Plot annual cycle

In [None]:
for region, da_region in da.groupby("region"):
    fig, ax = plt.subplots()
    for color, (product, da_product) in zip(colors, da_region.groupby("product")):
        grouped = da_product.groupby("time.dayofyear")
        mean = grouped.mean()
        std = grouped.std()
        mean.plot(ax=ax, label=product, color=color)
        ax.fill_between(
            mean["dayofyear"],
            mean - std,
            mean + std,
            alpha=0.5,
            label=f"{product.upper()} ± std",
            color=color,
        )
        ax.set_title(f"{region.title()} ({year_start}-{year_stop})")
    ax.legend()
    ax.grid()
    plt.show()