# Ocean color reflectance

## Import packages

In [None]:
import cacholote
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time period
year_start = 1998
year_stop = 2022

# Variable to analyse
variable = "Rrs_443"
assert variable in (
    "chlor_a",
    "Rrs_412",
    "Rrs_443",
    "Rrs_490",
    "Rrs_510",
    "Rrs_560",
    "Rrs_665",
)

# Regions to plot
regions = {
    "Global": {"lon_slice": slice(-180, 180), "lat_slice": slice(90, -90)},
    "50S-50N": {"lon_slice": slice(-180, 180), "lat_slice": slice(50, -50)},
    "15-30N 40-55W": {"lon_slice": slice(-55, -40), "lat_slice": slice(30, 15)},
    "NASTG": {"lon_slice": slice(-80, 0), "lat_slice": slice(50, 0)},
}
for region, slices in regions.items():
    # Enforce sorting as original data
    for k, v in slices.items():
        assert v.start >= v.stop if k == "lat_slice" else v.start <= v.stop, (region, k)

# Save figures
savefig = False
assert isinstance(savefig, bool)

## Define request

In [None]:
collection_id = "satellite-ocean-colour"

request = {
    "variable": "remote_sensing_reflectance"
    if variable.startswith("Rrs")
    else "mass_concentration_of_chlorophyll_a",
    "projection": "regular_latitude_longitude_grid",
    "version": "6_0",
    "format": "zip",
}

# Parameters to speed up I/O
open_mfdataset_kwargs = {
    "concat_dim": "time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
    "parallel": True,
}

## Functions to cache

In [None]:
def rechunk(obj):
    chunks = {"year": 1, "time": -1, "longitude": 270, "latitude": 270}
    return obj.chunk(**{k: v for k, v in chunks.items() if k in obj.dims})


def rrs_monthly_weighted_log_reductions(ds, variable):
    da = rechunk(ds[variable])
    if variable == "chlor_a":
        da = da.where((da > 1.0e-3) & (da < 1.0e2))
    weights = np.abs(np.cos(np.deg2rad(da["latitude"])))
    with xr.set_options(keep_attrs=True):
        da = np.log10(da)
        da = da.groupby("time.year").map(
            diagnostics.monthly_weighted_mean, weights=False
        )
        da = 10 ** (da + np.log10(weights))
        da = da.persist()
        da_mean = da.mean("month").expand_dims(reduction=["mean"])
        da_std = (10 ** np.log10(da).std("month")).expand_dims(reduction=["std"])
    da = xr.concat([da_mean, da_std], "reduction")
    da.attrs["long_name"] = da.attrs["long_name"].replace(" (not log-transformed)", "")
    da.encoding["chunksizes"] = tuple(map(max, da.chunks))
    return da.to_dataset(name=variable)


def get_yearly_mean_and_std(
    collection_id, request, year_start, year_stop, variable, **open_mfdataset_kwargs
):
    datasets = []
    for year in tqdm.tqdm(range(year_start, year_stop + 1), desc="year"):
        requests = download.update_request_date(
            request, start=f"{year}-01", stop=f"{year}-12", stringify_dates=True
        )
        ds = download.download_and_transform(
            collection_id,
            requests,
            transform_chunks=False,
            transform_func=rrs_monthly_weighted_log_reductions,
            transform_func_kwargs={"variable": variable},
            chunks={"year": 1, "month": 1},
            **open_mfdataset_kwargs,
        )
        datasets.append(rechunk(ds))
    return xr.concat(datasets, "year")


@cacholote.cacheable
def get_overall_mean_and_std(
    collection_id, request, year_start, year_stop, variable, **open_mfdataset_kwargs
):
    da = get_yearly_mean_and_std(
        collection_id=collection_id,
        request=request,
        year_start=year_start,
        year_stop=year_stop,
        variable=variable,
        **open_mfdataset_kwargs,
    )[variable]
    with xr.set_options(keep_attrs=True):
        da_mean = da.sel(reduction=["mean"]).mean("year")
        da_std = (da.sel(reduction=["std"]) ** 2).mean("year") ** 0.5
    da = rechunk(xr.concat([da_mean, da_std], "reduction"))
    return da.to_dataset(name=variable)

## Download and transform data

In [None]:
kwargs = {
    "collection_id": collection_id,
    "request": request,
    "year_start": year_start,
    "year_stop": year_stop,
    "variable": variable,
} | open_mfdataset_kwargs

da = rechunk(get_overall_mean_and_std(**kwargs)[variable])
da_year = rechunk(get_yearly_mean_and_std(**kwargs)[variable])
da_year_low_res = da_year.coarsen({"latitude": 5, "longitude": 5}).mean()

## Plot maps

In [None]:
plot_kwargs = {}
for region, regionalise_kwargs in regions.items():
    # Compute all plot kwargs once based on overall stats
    da_region = utils.regionalise(da, **regionalise_kwargs)
    kwargs_reductions = {}
    for reduction, da_reduction in da_region.groupby("reduction"):
        kwargs_reductions[reduction] = xr.plot.utils._determine_cmap_params(
            da_reduction.values, robust=True
        )
    plot_kwargs[region] = kwargs_reductions

for da_to_plot in [da, da_year_low_res]:
    for region, regionalise_kwargs in regions.items():
        da_region = utils.regionalise(da_to_plot, **regionalise_kwargs)
        for reduction, da_reduction in da_region.groupby("reduction"):
            plot.projected_map(
                da_reduction,
                col="year" if "year" in da_to_plot.dims else None,
                col_wrap=5,
                show_stats=False,
                **plot_kwargs[region][reduction],
            )
            title = [reduction.capitalize(), region, f"{year_start}-{year_stop}"]
            if "year" in da_to_plot.dims:
                title.append("Yearly")
            title = " ".join(title)
            plt.suptitle(title) if "year" in da_to_plot.dims else plt.title(title)
            if savefig:
                plt.savefig(title.replace(" ", "_").lower() + ".png")
            plt.show()