# Ocean Colour

## Import libraries

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Time period
start = "1998-01"
stop = "2023-12"

# Variable
variable = "chlor_a"
assert variable in {"chlor_a"} | {f"Rrs_{wl}" for wl in (443, 560)}

# Regions
regions_monthly = {
    "IO_POOZ": {
        "lon_slice": slice(20, 150),
        "lat_slice": slice(-47.5, -63.5),
    },
    "PO_POOZ": {
        "lon_slice": slice(150, 290),
        "lat_slice": slice(-47.5, -63.5),
    },
    "AO_POOZ": {
        "lon_slice": slice(-70, 20),
        "lat_slice": slice(-47.5, -63.5),
    },
}
regions_map = {
    "SO": {
        "lon_slice": slice(-180, 180),
        "lat_slice": slice(-47.5, -63.5),
    }
}

# Define data request
collection_id = "satellite-ocean-colour"
request = {
    "projection": "regular_latitude_longitude_grid",
    "version": "6_0",
    "format": "zip",
}
chunks = {"year": 1, "month": 1, "variable": 1}

## Define transform functions

In [None]:
def monthly_weighted_log_mean(ds, variable, lon_slice, lat_slice):
    da = ds[variable]
    da = utils.regionalise(da, lon_slice=lon_slice, lat_slice=lat_slice)
    if variable == "chlor_a":
        da = da.where((da > 0.01) & (da < 1.0e2))

    valid_pixels = 100 * da.notnull().groupby("time.year").map(
        diagnostics.monthly_weighted_mean, weights=False
    )
    valid_pixels.attrs = {"long_name": "Valid Pixels", "units": "%"}

    with xr.set_options(keep_attrs=True):
        da = np.log10(da * np.cos(da["latitude"] * np.pi / 180))
        da = 10 ** da.groupby("time.year").map(
            diagnostics.monthly_weighted_mean, weights=False
        )

    ds = xr.merge([da.rename("mean"), valid_pixels.rename("valid_pixels")])
    return ds.mean(["latitude", "longitude"], keep_attrs=True)


def weighted_log_map(ds, variable, lon_slice, lat_slice):
    da = ds[variable]
    da = utils.regionalise(da, lon_slice=lon_slice, lat_slice=lat_slice)
    if variable == "chlor_a":
        da = da.where((da > 0.01) & (da < 1.0e2))

    valid_pixels = 100 * da.notnull().mean("time")
    valid_pixels.attrs = {"long_name": "Valid Pixels", "units": "%"}

    with xr.set_options(keep_attrs=True):
        da = np.log10(da * np.cos(da["latitude"] * np.pi / 180))
        da = 10 ** da.mean("time")
    return xr.merge([da.rename("mean"), valid_pixels.rename("valid_pixels")])

## Download and transform

In [None]:
maps = {}
for region, slices in regions_map.items():
    print(f"{region =}")
    requests = download.update_request_date(
        request
        | {
            "variable": "remote_sensing_reflectance"
            if variable.startswith("Rrs")
            else "mass_concentration_of_chlorophyll_a"
        },
        start=start,
        stop=stop,
        stringify_dates=True,
    )
    maps[region] = download.download_and_transform(
        collection_id,
        requests,
        transform_func=weighted_log_map,
        transform_func_kwargs=slices | {"variable": variable},
        chunks=chunks,
        transform_chunks=False,
    )

datasets = []
for region, slices in regions_monthly.items():
    print(f"{region =}")
    requests = download.update_request_date(
        request
        | {
            "variable": "remote_sensing_reflectance"
            if variable.startswith("Rrs")
            else "mass_concentration_of_chlorophyll_a"
        },
        start=start,
        stop=stop,
        stringify_dates=True,
    )
    ds = download.download_and_transform(
        collection_id,
        requests,
        transform_func=monthly_weighted_log_mean,
        transform_func_kwargs=slices | {"variable": variable},
        chunks=chunks,
    )
    datasets.append(ds.expand_dims(region=[region]))
ds_monthly = xr.concat(datasets, "region")

## Plot maps

In [None]:
for region, ds in maps.items():
    for var, da in ds.data_vars.items():
        plot.projected_map(
            da,
            robust=True,
            projection=ccrs.SouthPolarStereo(),
            show_stats=False,
        )
        plt.title(f"{region=}")
        plt.show()

## Plot monthly data

In [None]:
for var, da in ds_monthly.data_vars.items():
    da.plot(col="region", x="year", y="month", yincrease=False, robust=True)
    plt.show()