# Ocean color reflectance

## Import packages

In [None]:
import tempfile

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time period
year_start = 1998
year_stop = 1999

# Variable to analyse
wavelength = 443
assert wavelength in (412, 443, 490, 510, 560, 665)

## Define request

In [None]:
collection_id = "satellite-ocean-colour"

request = {
    "variable": "remote_sensing_reflectance",
    "projection": "regular_latitude_longitude_grid",
    "version": "6_0",
    "format": "zip",
}

# Parameters to speed up I/O
open_mfdataset_kwargs = {
    "concat_dim": "time",
    "combine": "nested",
    "data_vars": "minimal",
    "coords": "minimal",
    "compat": "override",
    "parallel": True,
}

## Functions to cache

In [None]:
def rechunk(obj):
    chunks = {"time": -1, "year": 1, "longitude": 270, "latitude": 270}
    return obj.chunk(**{k: v for k, v in chunks.items() if k in obj.dims})


def rrs_annual_weighted_log_mean(ds, wavelength):
    name = f"Rrs_{wavelength}"
    da = rechunk(ds[name])
    with tempfile.TemporaryDirectory() as tmpdir:
        da.to_zarr(tmpdir)
        da = xr.open_dataarray(tmpdir, engine="zarr", chunks=dict(da.chunksizes))
        weights = np.abs(np.cos(np.deg2rad(da["latitude"])))
        with xr.set_options(keep_attrs=True):
            da = 10 ** diagnostics.annual_weighted_mean(
                np.log10(da * weights), weights=False
            )
        da = rechunk(da.compute())
        da.encoding["chunksizes"] = tuple(map(max, da.chunks))
        return da.to_dataset(name=name)

## Download and transform data

In [None]:
datasets = []
for year in range(year_start, year_stop + 1):
    print(f"{year=}")
    requests = download.update_request_date(
        request, start=f"{year}-01", stop=f"{year}-12", stringify_dates=True
    )
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_chunks=False,
        transform_func=rrs_annual_weighted_log_mean,
        transform_func_kwargs={"wavelength": wavelength},
        chunks={"year": 1, "month": 1},
        **open_mfdataset_kwargs,
    )
    datasets.append(rechunk(ds))
ds = xr.concat(datasets, "year")

## Plot global maps

In [None]:
da = ds[f"Rrs_{wavelength}"]
plot_kwargs = {"col": "year", "col_wrap": 2, "robust": True}

da = da.coarsen(latitude=5, longitude=5).mean()
facet = plot.projected_map(da, projection=ccrs.Robinson(), **plot_kwargs)

## Plot regional maps

In [None]:
da_region = utils.regionalise(da, lon_slice=slice(-55, -40), lat_slice=slice(30, 15))
facet = plot.projected_map(da_region, projection=ccrs.PlateCarree(), **plot_kwargs)