# Ocean Color

## Import packages

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pymannkendall as mk
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time period
start = "1997-09"
stop = "2023-03"

# Variable to analyse
variables = ["chlor_a", "Rrs_443", "Rrs_560"]
assert set(variables) <= {
    "chlor_a",
    "Rrs_412",
    "Rrs_443",
    "Rrs_490",
    "Rrs_510",
    "Rrs_560",
    "Rrs_665",
}

# Regions to plot
regions = {
    "Global": {"lon_slice": slice(-180, 180), "lat_slice": slice(90, -90)},
    "50S-50N": {"lon_slice": slice(-180, 180), "lat_slice": slice(50, -50)},
    "NASTG": {"lon_slice": slice(-80, 0), "lat_slice": slice(50, 0)},
}
for region, slices in regions.items():
    # Enforce sorting as original data
    for k, v in slices.items():
        assert v.start >= v.stop if k == "lat_slice" else v.start <= v.stop, (region, k)

## Define request

In [None]:
collection_id = "satellite-ocean-colour"

request = {
    "projection": "regular_latitude_longitude_grid",
    "version": "6_0",
    "format": "zip",
}

## Functions to cache

In [None]:
def regionalised_spatial_weighted_mean(ds, variable, lon_slice, lat_slice):
    da = ds[variable]
    da = utils.regionalise(da, lon_slice=lon_slice, lat_slice=lat_slice)
    if variable == "chlor_a":
        da = da.where((da > 1.0e-3) & (da < 1.0e2))
    with xr.set_options(keep_attrs=True):
        da = 10 ** diagnostics.spatial_weighted_mean(np.log10(da))
    da.attrs["long_name"] = da.attrs["long_name"].replace(" (not log-transformed)", "")
    return da.to_dataset(name=variable)

## Download and transform data

In [None]:
datasets = []
for variable in variables:
    for region, slices in regions.items():
        requests = download.update_request_date(
            request
            | {
                "variable": "remote_sensing_reflectance"
                if variable.startswith("Rrs")
                else "mass_concentration_of_chlorophyll_a"
            },
            start=start,
            stop=stop,
            stringify_dates=True,
        )
        ds = download.download_and_transform(
            collection_id,
            requests,
            transform_func=regionalised_spatial_weighted_mean,
            transform_func_kwargs={"variable": variable} | slices,
            chunks={"year": 1, "month": 1, "variable": 1},
        )
        datasets.append(ds.expand_dims(latitudes=[region]))
ds = xr.merge(datasets).compute()

# Extract global and regional
ds_global = ds.sel(latitudes=["Global"])
ds_regional = ds.drop_sel(latitudes="Global")

## Define plotting functions

In [None]:
def plot_timeseries(da):
    kwargs_daily = {
        "label": "daily",
        "color": "tab:grey",
        "ls": " ",
        "marker": ".",
    }
    kwargs_running = {
        "label": "48-month running",
        "color": "tab:red",
        "ls": "--",
        "marker": " ",
    }
    kwargs_slope = {
        "label": "yearly slope",
        "color": "tab:blue",
        "ls": "-",
        "marker": " ",
    }

    da_daily = da.resample(time="D").mean()
    da_yearly = da.resample(time="Y").mean()
    da_running = (
        da.resample(time="MS")
        .mean()
        .rolling(time=48, center=True, min_periods=1)
        .mean()
    )

    col = "latitudes" if da.sizes["latitudes"] > 1 else None
    plt_obj = da_daily.plot(col=col, **kwargs_daily)

    if not col:
        *_, slope, intercept = mk.original_test(da_yearly.squeeze())

        da_running.plot(add_legend=False, **kwargs_running)
        plt.plot(
            da_yearly["time"],
            np.arange(da_yearly.sizes["time"]) * slope + intercept,
            **kwargs_slope,
        )
        plt.legend()
        plt.grid()
        return plt_obj

    for i, (ax, sel_dict) in enumerate(
        zip(plt_obj.axs.flatten(), plt_obj.name_dicts.flatten())
    ):
        da_running.sel(sel_dict).plot(ax=ax, add_legend=False, **kwargs_running)

        *_, slope, intercept = mk.original_test(da_yearly.sel(sel_dict))
        ax.plot(
            da_yearly["time"],
            np.arange(da_yearly.sizes["time"]) * slope + intercept,
            **kwargs_slope,
        )

        ax.grid()
        if i:
            ax.set_ylabel("")

    ax.legend(bbox_to_anchor=(1, 1), loc="upper left")
    return plt_obj

## Plot timeseries

In [None]:
for variable in variables:
    for ds in [ds_global, ds_regional]:
        da = ds[variable]
        plot_timeseries(da)
        plt.show()