# Seasonal forecast monthly averages of ocean variables

## Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot

plt.style.use("seaborn-v0_8-notebook")

## Define parameters

In [None]:
# Time
start = "1993-05"
stop = "2025-01"
freq = "12MS"
leadtimes = [1, 2, 3]

## Define requests

In [None]:
collection_id_reanalysis = "reanalysis-oras5"
collection_id_seasonal = "seasonal-monthly-ocean"

requests_reanalysis = {leadtime: [] for leadtime in leadtimes}
requests_seasonal = []
for date in pd.date_range(start, stop, freq=freq):
    if pd.to_datetime("2019-01") <= date <= pd.to_datetime("2023-03"):
        continue

    requests_seasonal.append(
        {
            "originating_centre": "meteo_france",
            "system": "8",
            "variable": ["depth_average_potential_temperature_of_upper_300m"],
            "forecast_type": ["forecast" if date.year > 2018 else "hindcast"],
            "year": date.strftime("%Y"),
            "month": date.strftime("%m"),
        }
    )
    for leadtime in leadtimes:
        date_reanalysis = date + pd.DateOffset(months=leadtime)
        requests_reanalysis[leadtime].append(
            {
                "product_type": ["operational" if date.year > 2014 else "consolidated"],
                "vertical_resolution": "single_level",
                "variable": ["ocean_heat_content_for_the_upper_300m"],
                "year": date_reanalysis.strftime("%Y"),
                "month": date_reanalysis.strftime("%m"),
            }
        )

## Functions to cache

In [None]:
def add_bounds(ds):
    # From https://github.com/COSIMA/ocean-regrid/blob/master/nemo_grid.py
    dg = xr.open_dataset(
        (
            "https://icdc.cen.uni-hamburg.de/thredds/dodsC/ftpthredds/"
            "EASYInit/oras5/ORCA025/mesh/mesh_mask.nc"
        ),
        chunks={},
    ).isel(t=0, z=0)

    # These are the top righ-hand corner of t cells.
    glamf = dg.glamf
    gphif = dg.gphif

    # Extend south so that Southern most cells can have bottom corners.
    gphif_new = np.ndarray((gphif.shape[0] + 1, gphif.shape[1] + 1))
    gphif_new[1:, 1:] = gphif[:]
    gphif_new[0, 1:] = gphif[0, :] - abs(gphif[1, :] - gphif[0, :])

    glamf_new = np.ndarray((glamf.shape[0] + 1, glamf.shape[1] + 1))
    glamf_new[1:, 1:] = glamf[:]
    glamf_new[0, 1:] = glamf[0, :]

    # Repeat first longitude so that Western most cells have left corners.
    gphif_new[:, 0] = gphif_new[:, -1]
    glamf_new[:, 0] = glamf_new[:, -1]

    gphif = gphif_new
    glamf = glamf_new

    # Corners of t points. Index 0 is bottom left and then
    # anti-clockwise.
    clon = np.empty((dg.tmask.shape[0], dg.tmask.shape[1], 4))
    clon[:] = np.nan
    clon[:, :, 0] = glamf[0:-1, 0:-1]
    clon[:, :, 1] = glamf[0:-1, 1:]
    clon[:, :, 2] = glamf[1:, 1:]
    clon[:, :, 3] = glamf[1:, 0:-1]
    assert not np.isnan(np.sum(clon))

    clat = np.empty((dg.tmask.shape[0], dg.tmask.shape[1], 4))
    clat[:] = np.nan
    clat[:, :, 0] = gphif[0:-1, 0:-1]
    clat[:, :, 1] = gphif[0:-1, 1:]
    clat[:, :, 2] = gphif[1:, 1:]
    clat[:, :, 3] = gphif[1:, 0:-1]
    assert not np.isnan(np.sum(clat))

    ds["latitude"].attrs["bounds"] = "latitude_bounds"
    ds["longitude"].attrs["bounds"] = "longitude_bounds"
    return ds.assign_coords(
        latitude_bounds=(["y", "x", "bound"], clat),
        longitude_bounds=(["y", "x", "bound"], clon),
    )


# Seasonal
def preprocess_seasonal(ds):
    # TODO: How to combine? Use first leadtime only for now
    ds = ds.set_coords([var for var, da in ds.data_vars.items() if "bnds" in da.dims])
    ds["realization"] = ds["realization"].str.replace(" ", "").astype(str)
    return ds.expand_dims(["realization", "reftime"])


def regrid_reanalysis(ds, grid_request, **xesmf_kwargs):
    ds_seasonal = download.download_and_transform(
        *grid_request,
        preprocess=preprocess_seasonal,
    )
    mask_out = (
        ds_seasonal["thetaot300"]
        .isel(
            {dim: 0 for dim in ("realization", "forecast_reference_time", "leadtime")}
        )
        .reset_coords(drop=True)
        .notnull()
    )
    grid_out = ds_seasonal.cf[["latitude", "longitude"]].assign_coords(mask=mask_out)

    mask_in = ds["sohtc300"].isel(time=0).reset_coords(drop=True).notnull()
    ds = add_bounds(ds).assign_coords(mask=mask_in)
    return diagnostics.regrid(ds, grid_out, **xesmf_kwargs)


def compute_detrended_anomalies(ds, grid_request, **xesmf_kwargs):
    if grid_request is not None:
        ds = regrid_reanalysis(ds, grid_request, **xesmf_kwargs)
    else:
        assert not xesmf_kwargs

    (da,) = ds.data_vars.values()
    name = da.name
    with xr.set_options(keep_attrs=True):
        # 2.0: Calculating the ensemble mean for each lead time, and year
        if "realization" in da.dims:
            ensemble_mean = da.mean("realization")
        else:
            ensemble_mean = da

        # 3.0: Calculating a trend for each lead time based on 2.0: the ensemble means.
        (time_dim,) = set(ds.dims) & {"time", "forecast_reference_time"}
        linear_fit = ensemble_mean.polyfit(time_dim, deg=1)["polyfit_coefficients"]

        # 3.1: Subtracting the lead-time specific trend for each ensemble member, year, and lead time
        detrended = da - linear_fit.sel(degree=1, drop=True)

        # 4.0: Calculating the climatology based on 2.0: the ensemble mean for each lead time
        climatology = ensemble_mean.mean(time_dim)

        # 4.1: Subtracting the lead-time specific climatology for each ensemble member, year, and lead time
        da = detrended - climatology

    da.encoding["chunksizes"] = tuple(
        1 if dim in ("realization", "leadtime") else size
        for dim, size in da.sizes.items()
    )
    return da.to_dataset(name=name)


def reindex(ds):
    # Reindex using year/month (shift months)
    ds = ds.assign_coords(
        year=("time", ds["time"].dt.year.astype(int).values),
        month=("time", ds["time"].dt.month.astype(int).values),
    )
    ds = ds.set_index(time=("year", "month")).unstack("time")
    return ds

## Download and transform

In [None]:
# Seasonal
ds_seasonal = download.download_and_transform(
    collection_id_seasonal,
    requests_seasonal,
    preprocess=preprocess_seasonal,
    drop_variables="hcrs",
    transform_func=compute_detrended_anomalies,
    transform_func_kwargs={"grid_request": None},
    transform_chunks=False,
)
datasets = []
for leadtime, ds in ds_seasonal.isel(leadtime=leadtimes).groupby("leadtime"):
    ds = ds.squeeze("leadtime").swap_dims(forecast_reference_time="time")
    ds = reindex(ds)
    ds["leadtime"] = ds["leadtime"].expand_dims("month")
    datasets.append(ds)
ds_seasonal = xr.combine_by_coords(datasets)

In [None]:
# Reanalysis
datasets = []
for leadtime, requests in requests_reanalysis.items():
    ds = download.download_and_transform(
        collection_id_reanalysis,
        requests,
        transform_func=compute_detrended_anomalies,
        drop_variables="time_counter_bnds",
        transform_func_kwargs={
            "grid_request": (collection_id_seasonal, requests_seasonal[0]),
            "method": "conservative_normed",
            "periodic": True,
            "ignore_degenerate": True,
        },
        transform_chunks=False,
    )
    datasets.append(reindex(ds))
ds_reanalysis = xr.combine_by_coords(datasets)

## Quick and dirty plots

## Reanalysis

In [None]:
(da,) = ds_reanalysis.data_vars.values()
_ = plot.projected_map(da.mean("year", keep_attrs=True), col="month")
plt.show()

diagnostics.spatial_weighted_mean(da).plot(hue="month")
plt.grid()

## Seasonal Forecast

In [None]:
(da,) = ds_seasonal.mean("realization").data_vars.values()
_ = plot.projected_map(da.mean("year", keep_attrs=True), col="month")
plt.show()

diagnostics.spatial_weighted_mean(da).plot(hue="month", x="year")
plt.grid()