# Satellite SST spectrum

## Import libraries

In [None]:
import collections

import matplotlib.pyplot as plt
import scipy.signal
import xarray as xr
from c3s_eqc_automatic_quality_control import download, utils

plt.style.use("seaborn-v0_8-notebook")

## Define parameters

In [None]:
# Time
year_start = 1982
year_stop = 1982

# Space
regions = {
    "Gulf Stream": {"lon_slice": slice(-64, -34), "lat_slice": slice(34, 42)},
    "Kuroshio": {"lon_slice": slice(150, 180), "lat_slice": slice(34, 44)},
    "Agulhas": {"lon_slice": slice(5, 35), "lat_slice": slice(-45, -35)},
    "Brazilian-Malvinas": {"lon_slice": slice(-56, -26), "lat_slice": slice(-47, -37)},
}

## Define request

In [None]:
request_dicts = {
    "ESACCI": {
        "collection_id": "satellite-sea-surface-temperature",
        "request": {
            "processinglevel": "level_4",
            "format": "zip",
            "variable": "all",
            "sensor_on_satellite": "combined_product",
            "version": "2_1",
        },
        "chunks": {"year": 1, "month": 1, "day": 7},
    },
    "GMPE": {
        "collection_id": "satellite-sea-surface-temperature-ensemble-product",
        "request": {
            "format": "zip",
            "variable": "all",
        },
        "chunks": {"year": 1, "month": 1, "day": 12},
    },
}

## Define functions to cache

In [None]:
def _welch(x, **kwargs):
    freqs, psd = scipy.signal.welch(x, **kwargs)
    return xr.DataArray(psd, coords={"wavenumber": freqs}, dims=["wavenumber"])


def welch(da, dim, **kwargs):
    return xr.apply_ufunc(
        _welch,
        da.compute(),
        input_core_dims=[[dim]],
        output_core_dims=[["wavenumber"]],
        vectorize=True,
        kwargs=kwargs,
    )


def compute_spectrum(ds, lon_slice, lat_slice):
    dim = "longitude"
    da = utils.regionalise(ds["analysed_sst"], lon_slice, lat_slice)
    nperseg = 1 << (da.sizes[dim].bit_length() - 1)  # lower power of 2
    da = welch(
        da,
        dim,
        window="blackmanharris",
        detrend="linear",
        fs=1 / da[dim].diff(dim).mean().item(),
        nperseg=nperseg,
        nfft=2 * nperseg,
    )
    return da.to_dataset(name="spectrum")

## Download and transform data

In [None]:
datasets = collections.defaultdict(dict)
for product, request_kwargs in request_dicts.items():
    requests = []
    for year in range(year_start, year_stop + 1):
        requests.extend(
            download.update_request_date(
                request_kwargs["request"],
                start=f"{year - 1}-12",
                stop=f"{year}-11",
                stringify_dates=True,
            )
        )
    for region, region_kwargs in regions.items():
        print(f"{product=} {region=}")
        datasets[product][region] = download.download_and_transform(
            request_kwargs["collection_id"],
            requests,
            chunks=request_kwargs["chunks"],
            transform_func=compute_spectrum,
            transform_func_kwargs=region_kwargs,
        )
datasets = dict(datasets)

## Quick and dirty plot

In [None]:
for product, values in datasets.items():
    ds_mean = xr.combine_by_coords(
        [
            ds.mean("latitude").expand_dims(region=[region])
            for region, ds in values.items()
        ]
    )
    ds_mean["spectrum"].plot(col="region", x="time", robust=True)
    plt.suptitle(f"{product=}", y=1.05)
    plt.show()