# Assessment of the SST climatology and variability

## Import packages

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
start = "1982-01"
stop = "1991-12"

# Regions
regions = {
    "northern hemisphere": {"lon_slice": slice(-180, 180), "lat_slice": slice(0, 90)},
    "southern hemisphere": {"lon_slice": slice(-180, 180), "lat_slice": slice(-90, 0)},
}

## Define Parameters

In [None]:
# Requests
request_dicts = {
    "esacci": {
        "collection_id": "satellite-sea-surface-temperature",
        "request": {
            "processinglevel": "level_4",
            "format": "zip",
            "variable": "all",
            "sensor_on_satellite": "combined_product",
            "version": "2_1",
        },
        "chunks": {"year": 1, "month": 1},
    },
    "gmpe": {
        "collection_id": "satellite-sea-surface-temperature-ensemble-product",
        "request": {
            "format": "zip",
            "variable": "all",
        },
        "chunks": {"year": 1, "month": 1, "day": 12},  # CDS limit is 12
    },
}

# Parameters to speed up I/O
open_mfdataset_kwargs = {
    "concat_dim": "time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
    "parallel": True,
}

## Functions to cache

In [None]:
def get_masked_sst(ds):
    da = ds["analysed_sst"]
    if "mask" in ds:
        da = da.where(ds["mask"] == 1)
    return da


def rechunk(obj):
    """Use NetCDF chunks."""
    chunks = {"season": 1, "latitude": 1_200, "longitude": 2_400}
    return obj.chunk(
        **{dim: chunksize for dim, chunksize in chunks.items() if dim in obj.dims}
    )



def compute_time_reductions(ds, reductions, groups):
    ds = rechunk(ds)
    da = get_masked_sst(ds)
    dataarrays = []
    for group in groups:
        for reduction in reductions:
            name = "_".join([group, reduction])
            long_name = f"{reduction.title()} of {da.attrs['long_name']}"
            func = getattr(diagnostics, f"{group}_weighted_{reduction}")
            da_reduced = rechunk(func(da, weights=False))
            da_reduced.attrs["long_name"] = long_name
            da_reduced.encoding["chunksizes"] = tuple(map(max, da_reduced.chunks))
            dataarrays.append(da_reduced.rename(name))
    return xr.merge(dataarrays)


def compute_spatial_weighted_reductions(ds, reductions, lon_slice, lat_slice):
    ds = rechunk(ds)
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    da = get_masked_sst(ds)
    da = diagnostics.spatial_weighted_mean(da, weights=True)
    grouped = da.groupby("time.dayofyear")
    dataarrays = []
    for reduction in reductions:
        long_name = f"{reduction.title()} of {da.attrs['long_name']}"
        func = getattr(grouped, reduction)
        da_reduced = rechunk(func(da, keep_attrs=True))
        da_reduced.attrs["long_name"] = long_name
        dataarrays.append(da_reduced.rename(reduction))
    return xr.merge(dataarrays)

## Download and transform

In [None]:
datasets_maps = {}
datasets_timeseries = []

reductions = ("mean", "std")
groups = ("time", "seasonal")
for product, request_dict in request_dicts.items():
    # Common kwargs
    kwargs = {
        "collection_id": request_dict["collection_id"],
        "requests": download.update_request_date(
            request_dict["request"], start=start, stop=stop, stringify_dates=True
        ),
        "transform_chunks": False,
        "chunks": request_dict["chunks"],
        **open_mfdataset_kwargs,
    }

    # Time reductions
    ds = download.download_and_transform(
        **kwargs,
        transform_func=compute_time_reductions,
        transform_func_kwargs={"reductions": reductions, "groups": groups},
    )
    datasets_maps[product] = rechunk(ds)

    # Spatial weighted reductions
    for region, slices in regions.items():
        ds = download.download_and_transform(
            **kwargs,
            transform_func=compute_spatial_weighted_reductions,
            transform_func_kwargs={"reductions": reductions} | slices,
        )
        datasets_timeseries.append(ds.expand_dims(region=[region], product=[product]))
ds_timeseries = xr.merge(datasets_timeseries)

## Plot mean and std maps

In [None]:
maps_kwargs = {"projection": ccrs.Robinson(), "plot_func": "contourf", "col_wrap": 2}

for product, ds in datasets_maps.items():
    for var, da in ds.data_vars.items():
        plot.projected_map(
            da if "season" in da.dims else da.compute(),
            cmap="Spectral_r" if var.endswith("mean") else "tab20b",
            col="season" if "season" in da.dims else None,
            **maps_kwargs,
        )
        plt.suptitle(f"{product.upper()} ({start}, {stop})")
        plt.show()

## Plot bias

In [None]:
for var in ("time_mean", "seasonal_mean"):
    da = datasets_maps[list(datasets_maps)[0]][var]
    with xr.set_options(keep_attrs=True):
        da = da - datasets_maps[list(datasets_maps)[1]][var].interp_like(da)
    da.attrs["long_name"] = da.attrs["long_name"].replace("Mean", "Mean bias")
    plot.projected_map(
        da if "season" in da.dims else da.compute(),
        cmap="PRGn",
        col="season" if "season" in da.dims else None,
        **maps_kwargs,
    )
    plt.suptitle(f"{' - '.join(list(datasets_maps)).upper()} ({start}, {stop})")
    plt.show()

## Plot timeseries

In [None]:
for region, ds_region in ds_timeseries.groupby("region"):
    fig, ax = plt.subplots()
    ax.set_prop_cycle(color=["green", "blue"])
    for product, ds_product in ds_region.groupby("product"):
        ds_product["mean"].plot(
            hue="product", ax=ax, label=product.upper(), add_legend=False
        )
        ax.fill_between(
            ds_product["dayofyear"],
            ds_product["mean"] - ds_product["std"],
            ds_product["mean"] + ds_product["std"],
            alpha=0.5,
            label=f"{product.upper()} ± std",
        )
        ax.set_title(f"{region.title()} ({start}, {stop})")
    ax.legend()
    ax.grid()
    plt.show()