# Ocean Color

## Import packages

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, utils

plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["axes.prop_cycle"] = plt.cycler(
    color=["tab:blue", "tab:red"], ls=["-", "--"]
)

## Define Parameters

In [None]:
# Time period
start = "1997-09"
stop = "2023-03"

# Variable to analyse
variable = "chlor_a"
assert variable in (
    "chlor_a",
    "Rrs_412",
    "Rrs_443",
    "Rrs_490",
    "Rrs_510",
    "Rrs_560",
    "Rrs_665",
)

# Regions to plot
regions = {
    "Global": {"lon_slice": slice(-180, 180), "lat_slice": slice(90, -90)},
    "50S-50N": {"lon_slice": slice(-180, 180), "lat_slice": slice(50, -50)},
    "15-30N 40-55W": {"lon_slice": slice(-55, -40), "lat_slice": slice(30, 15)},
    "NASTG": {"lon_slice": slice(-80, 0), "lat_slice": slice(50, 0)},
}
for region, slices in regions.items():
    # Enforce sorting as original data
    for k, v in slices.items():
        assert v.start >= v.stop if k == "lat_slice" else v.start <= v.stop, (region, k)

## Define request

In [None]:
collection_id = "satellite-ocean-colour"

request = {
    "variable": "remote_sensing_reflectance"
    if variable.startswith("Rrs")
    else "mass_concentration_of_chlorophyll_a",
    "projection": "regular_latitude_longitude_grid",
    "version": "6_0",
    "format": "zip",
}

requests = download.update_request_date(
    request, start=start, stop=stop, stringify_dates=True
)

## Functions to cache

In [None]:
def regionalised_spatial_weighted_mean(ds, variable, lon_slice, lat_slice):
    da = ds[variable]
    da = utils.regionalise(da, lon_slice=lon_slice, lat_slice=lat_slice)
    if variable == "chlor_a":
        da = da.where((da > 1.0e-3) & (da < 1.0e2))
    with xr.set_options(keep_attrs=True):
        da = 10 ** diagnostics.spatial_weighted_mean(np.log10(da))
    da.attrs["long_name"] = da.attrs["long_name"].replace(" (not log-transformed)", "")
    return da.to_dataset(name=variable)

## Download and transform data

In [None]:
datasets = []
for region, slices in regions.items():
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_func=regionalised_spatial_weighted_mean,
        transform_func_kwargs={"variable": variable} | slices,
        chunks={"year": 1, "month": 1},
    )
    datasets.append(ds.expand_dims(latitudes=[region]))
ds = xr.concat(datasets, "latitudes")

# Extract DataArray
da = ds[variable]
da_global = da.sel(latitudes=["Global"])
da_regional = da.drop_sel(latitudes="Global")

## Define plotting functions

In [None]:
def plot_timeseries(da, freq, window, labels, **kwargs):
    # Create DataArray
    da = da.resample(time=freq).mean()
    dataarrays = [da.expand_dims(timeseries=[labels[0]])]
    dataarrays.append(
        da.rolling(time=window, center=True, min_periods=1)
        .mean()
        .expand_dims(timeseries=[labels[1]])
    )
    da = xr.concat(dataarrays, "timeseries")

    # Plot
    default_kwargs = {"hue": "timeseries"}
    if da.sizes.get("latitudes") > 1:
        default_kwargs |= {"col": "latitudes"}
    kwargs = default_kwargs | kwargs
    plot_obj = da.plot(**kwargs)
    if kwargs.get("col") or kwargs.get("row"):
        for ax in plot_obj.axs.flatten():
            ax.grid()
    else:
        plt.grid()
    return plot_obj


def plot_daily(da, window=180, **plot_kwargs):
    freq = "D"
    labels = ["daily", f"{window}-day running"]
    return plot_timeseries(da, freq, window, labels, **plot_kwargs)


def plot_monthly(da, window=3, **plot_kwargs):
    freq = "MS"
    labels = ["monthly", f"{window}-month running"]
    return plot_timeseries(da, freq, window, labels, **plot_kwargs)

## Plot timeseries

In [None]:
for da_to_plot in [da_global, da_regional]:
    for plot_func in plot_daily, plot_monthly:
        plot_func(da_to_plot)
        plt.show()