# Seasonal forecast/hindcast: Scale Awareness

## Import libraries

In [None]:
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")

## Set parameters

In [None]:
# Variable
var_api = "2m_temperature"

# Time range
year_start = 1993
year_stop = 2024

## Define requests

In [None]:
collection_id_reanalysis = "reanalysis-era5-single-levels-monthly-means"
collection_id_seasonal = "seasonal-monthly-single-levels"

common_request = {
    "format": "grib",
    "area": [89.5, -179.5, -89.5, 179.5],
    "variable": var_api,
    "grid": "1/1",
    "year": [f"{year}" for year in range(year_start, year_stop + 1)],
    "month": [f"{month:02d}" for month in range(1, 13)],
}

request_reanalysis = common_request | {
    "product_type": "monthly_averaged_reanalysis",
    "time": "00:00",
}

request_seasonal = common_request | {
    "product_type": "monthly_mean",
    "system": 51,
    "originating_centre": "ecmwf",
}

## Define functions to cache

In [None]:
def compute_anomaly(obj):
    climatology = diagnostics.time_weighted_mean(obj, weights=False)
    climatology = climatology.mean({"realization"} & set(obj.dims))
    return obj - climatology


def detrend(obj):
    trend = xr.polyval(obj["time"], obj.polyfit("time", deg=1).polyfit_coefficients)
    return obj - trend


def compute_monthly_anomalies(ds):
    (da,) = ds.data_vars.values()
    with xr.set_options(keep_attrs=True):
        da = da.groupby("time.month").map(compute_anomaly)
        da_detrend = da.groupby("time.year").map(detrend)
    da = xr.concat(
        [da.expand_dims(detrend=[False]), da_detrend.expand_dims(detrend=[True])],
        "detrend",
    )
    da.encoding["chunksizes"] = tuple(
        1 if dim in ("realization", "detrend") else size
        for dim, size in da.sizes.items()
    )
    return da.to_dataset()

## Download and transform

In [None]:
kwargs = {
    "chunks": {"year": 1},
    "n_jobs": 1,
    "backend_kwargs": {"time_dims": ["valid_time"]},
    "transform_func": compute_monthly_anomalies,
    "transform_chunks": False,
}

# Reanalysis
(da_reanalysis,) = download.download_and_transform(
    collection_id_reanalysis,
    request_reanalysis,
    **kwargs,
).data_vars.values()

# Seasonal forecast
dataarrays = []
for leadtime_month in range(1, 7):
    print(f"{leadtime_month = }")
    (da,) = download.download_and_transform(
        collection_id_seasonal,
        request_seasonal | {"leadtime_month": leadtime_month},
        **kwargs,
    ).data_vars.values()
    dataarrays.append(da.expand_dims(leadtime_month=[leadtime_month]))
da_seasonal = xr.concat(dataarrays, "leadtime_month")

## Plot ERA5

In [None]:
da_reanalysis_mean = diagnostics.spatial_weighted_mean(da_reanalysis)
da_reanalysis_mean.plot(hue="detrend")
plt.title("ERA5")
plt.grid()

## Plot seasonal forecast/hindcast

In [None]:
da_seasonal_mean = diagnostics.spatial_weighted_mean(
    da_seasonal.mean("realization", keep_attrs=True).compute()
)
facet = da_seasonal_mean.plot(col="leadtime_month", col_wrap=3, hue="detrend")
for ax in facet.axs.flatten():
    ax.grid()