# Ocean Color

## Import packages

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, utils

plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["axes.prop_cycle"] = plt.cycler(
    color=["tab:blue", "tab:red"], ls=["-", "--"]
)

## Define Parameters

In [None]:
# Time period
start = "1997-09"
stop = "2023-03"

# Latitude slices
lats0 = [90] + [50] + list(range(50, -1, -50))
lats1 = [-90] + [-50] + list(range(0, -51, -50))
assert len(lats0) == len(lats1)

## Define request

In [None]:
collection_id = "satellite-ocean-colour"

request = {
    "variable": "mass_concentration_of_chlorophyll_a",
    "projection": "regular_latitude_longitude_grid",
    "version": "6_0",
    "format": "zip",
}


requests = download.update_request_date(
    request, start=start, stop=stop, stringify_dates=True
)

## Functions to cache

In [None]:
def regionalised_spatial_weighted_mean(ds, lon_slice, lat_slice):
    da = ds["chlor_a"]
    da = utils.regionalise(da, lon_slice=lon_slice, lat_slice=lat_slice)
    da = da.where((da > 1.0e-3) & (da < 1.0e2))
    with xr.set_options(keep_attrs=True):
        da = 10 ** diagnostics.spatial_weighted_mean(np.log10(da))
    da.attrs["long_name"] = da.attrs["long_name"].replace(" (not log-transformed)", "")
    return da.to_dataset()

## Download and transform data

In [None]:
lon_slice = slice(-180, 180)
datasets = []
for lats in zip(lats0, lats1):
    print(f"{lats=}")
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_func=regionalised_spatial_weighted_mean,
        transform_func_kwargs={"lon_slice": lon_slice, "lat_slice": slice(*lats)},
        chunks={"year": 1, "month": 1},
    )
    datasets.append(ds.expand_dims(latitudes=[str(lats)]))
ds = xr.concat(datasets, "latitudes")

# Extract DataArray
da = ds["chlor_a"]
da.attrs = {"long_name": "Chl-a", "units": "$mg/m^3$"}
da_global = da.sel(latitudes=["(90, -90)"])
da_regional = da.drop_sel(latitudes="(90, -90)")

## Define plotting functions

In [None]:
def plot_timeseries(da, freq, window, labels, **kwargs):
    # Create DataArray
    da = da.resample(time=freq).mean()
    dataarrays = [da.expand_dims(timeseries=[labels[0]])]
    dataarrays.append(
        da.rolling(time=window, center=True, min_periods=1)
        .mean()
        .expand_dims(timeseries=[labels[1]])
    )
    da = xr.concat(dataarrays, "timeseries")

    # Plot
    default_kwargs = {"hue": "timeseries"}
    if da.sizes.get("latitudes") > 1:
        default_kwargs |= {"col": "latitudes"}
    kwargs = default_kwargs | kwargs
    plot_obj = da.plot(**kwargs)
    if kwargs.get("col") or kwargs.get("row"):
        for ax in plot_obj.axs.flatten():
            ax.grid()
    else:
        plt.grid()
    return plot_obj


def plot_daily(da, window=180, **plot_kwargs):
    freq = "D"
    labels = ["daily", f"{window}-day running"]
    return plot_timeseries(da, freq, window, labels, **plot_kwargs)


def plot_monthly(da, window=3, **plot_kwargs):
    freq = "MS"
    labels = ["monthly", f"{window}-month running"]
    return plot_timeseries(da, freq, window, labels, **plot_kwargs)

## Plot timeseries

In [None]:
for da_to_plot in [da_global, da_regional]:
    for plot_func in plot_daily, plot_monthly:
        plot_func(da_to_plot)
        plt.show()