# Seasonal forecast monthly averages of ocean variables

## Imports

In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot

plt.style.use("seaborn-v0_8-notebook")

In [None]:
import os

import cacholote
import dask.distributed

os.environ["CDSAPI_RC"] = os.path.expanduser("~/almansi_mattia/.cdsapirc")

cacholote.config.set(create_engine_kwargs=dict(connect_args={"timeout": 30}))
download.N_JOBS = 10

n_workers = 2
memory_limit = f"{28 // n_workers}GiB"
dask.distributed.Client(n_workers=n_workers, memory_limit=memory_limit)

## Define parameters

In [None]:
# Time
start = "2023-03"
stop = "2025-01"

## Define requests

In [None]:
collection_id_reanalysis = "reanalysis-oras5"

requests_reanalysis = [
    {
        "product_type": ["operational"],
        "vertical_resolution": "single_level",
        "variable": ["ocean_heat_content_for_the_upper_300m"],
    }
    | {"year": date.strftime("%Y"), "month": date.strftime("%m")}
    for date in pd.date_range(start, stop, freq="1MS")
]

collection_id_seasonal = "seasonal-monthly-ocean"
requests_seasonal = [
    {
        "originating_centre": "cmcc",
        "system": "35",
        "variable": ["depth_average_potential_temperature_of_upper_300m"],
        "forecast_type": ["forecast"],
    }
    | {"year": date.strftime("%Y"), "month": date.strftime("%m")}
    for date in pd.date_range(start, stop, freq="1MS")
]

## Functions to cache

In [None]:
def add_bounds(ds):
    # From https://github.com/COSIMA/ocean-regrid/blob/master/nemo_grid.py
    dg = xr.open_dataset(
        (
            "https://icdc.cen.uni-hamburg.de/thredds/dodsC/ftpthredds/"
            "EASYInit/oras5/ORCA025/mesh/mesh_mask.nc"
        ),
        chunks={},
    ).isel(t=0, z=0)

    # These are the top righ-hand corner of t cells.
    glamf = dg.glamf
    gphif = dg.gphif

    # Extend south so that Southern most cells can have bottom corners.
    gphif_new = np.ndarray((gphif.shape[0] + 1, gphif.shape[1] + 1))
    gphif_new[1:, 1:] = gphif[:]
    gphif_new[0, 1:] = gphif[0, :] - abs(gphif[1, :] - gphif[0, :])

    glamf_new = np.ndarray((glamf.shape[0] + 1, glamf.shape[1] + 1))
    glamf_new[1:, 1:] = glamf[:]
    glamf_new[0, 1:] = glamf[0, :]

    # Repeat first longitude so that Western most cells have left corners.
    gphif_new[:, 0] = gphif_new[:, -1]
    glamf_new[:, 0] = glamf_new[:, -1]

    gphif = gphif_new
    glamf = glamf_new

    # Corners of t points. Index 0 is bottom left and then
    # anti-clockwise.
    clon = np.empty((dg.tmask.shape[0], dg.tmask.shape[1], 4))
    clon[:] = np.nan
    clon[:, :, 0] = glamf[0:-1, 0:-1]
    clon[:, :, 1] = glamf[0:-1, 1:]
    clon[:, :, 2] = glamf[1:, 1:]
    clon[:, :, 3] = glamf[1:, 0:-1]
    assert not np.isnan(np.sum(clon))

    clat = np.empty((dg.tmask.shape[0], dg.tmask.shape[1], 4))
    clat[:] = np.nan
    clat[:, :, 0] = gphif[0:-1, 0:-1]
    clat[:, :, 1] = gphif[0:-1, 1:]
    clat[:, :, 2] = gphif[1:, 1:]
    clat[:, :, 3] = gphif[1:, 0:-1]
    assert not np.isnan(np.sum(clat))

    ds["latitude"].attrs["bounds"] = "latitude_bounds"
    ds["longitude"].attrs["bounds"] = "longitude_bounds"
    return ds.assign_coords(
        latitude_bounds=(["y", "x", "bound"], clat),
        longitude_bounds=(["y", "x", "bound"], clon),
    )


# Seasonal
def preprocess_seasonal(ds):
    # TODO: How to combine? Use first leadtime only for now
    ds = ds.set_coords([var for var, da in ds.data_vars.items() if "bnds" in da.dims])
    ds = ds.isel(leadtime=[0]).swap_dims(leadtime="time")
    ds["realization"] = ds["realization"].str.replace(" ", "").astype(str)
    return ds.expand_dims("realization")


def regrid_reanalysis(ds, grid_request, **xesmf_kwargs):
    with warnings.catch_warnings():
        warnings.filterwarnings(  # Suppress decode_timedelta warning
            "ignore", message=".*decode_timedelta.*"
        )
        ds_seasonal = download.download_and_transform(
            *grid_request,
            preprocess=preprocess_seasonal,
        )
    mask = ds_seasonal["thetaot300"].isel(time=0, realization=0).reset_coords(drop=True)
    grid_out = ds_seasonal.cf[["latitude", "longitude"]].assign_coords(mask=mask)
    return diagnostics.regrid(add_bounds(ds), grid_out, **xesmf_kwargs)

## Download and transform

In [None]:
with warnings.catch_warnings():
    warnings.filterwarnings(  # Suppress decode_timedelta warning
        "ignore", message=".*decode_timedelta.*"
    )
    ds_seasonal = download.download_and_transform(
        collection_id_seasonal,
        requests_seasonal,
        preprocess=preprocess_seasonal,
        drop_variables="hcrs",
    )

In [None]:
# Reanalysis
ds_reanalysis = download.download_and_transform(
    collection_id_reanalysis,
    requests_reanalysis,
    transform_func=regrid_reanalysis,
    transform_func_kwargs={
        "grid_request": (collection_id_seasonal, requests_seasonal[0]),
        "method": "conservative_normed",
        "periodic": True,
        "ignore_degenerate": True,
    },
)

## Quick and dirty plot

In [None]:
for ds, title in zip((ds_reanalysis, ds_seasonal), ("Reanalysis", "Seasonal Forecast")):
    (da,) = ds.sel(time="2023-03").data_vars.values()
    da = da.mean(set(da.dims) - {"latitude", "longitude"})
    plot.projected_map(da, robust=True)
    plt.title(title)
    plt.show()