# Bias of seasonal forecast

## Import packages

In [None]:
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
year_forecast = 2023
year_start_hindcast = 1993
year_stop_hindcast = 2016

# Variable
variables = ["2m_temperature", "total_precipitation"]

# Centres
centres = {
    "cmcc": {"system": "35"},
    "dwd": {"system": "21"},
    "eccc": {"system": "3"},
    "ecmwf": {"system": "51"},
    "jma": {"system": "3"},
    "meteo_france": {"system": "8"},
    "ncep": {"system": "2"},
    "ukmo": {"system": "601"},
}
assert all("system" in v for v in centres.values())
assert set(centres) <= {
    "cmcc",
    "dwd",
    "eccc",
    "ecmwf",
    "jma",
    "meteo_france",
    "ncep",
    "ukmo",
}

# Regions
regions = {
    "Australia": {"lon_slice": slice(101, 179), "lat_slice": slice(5, -52)},
    "East Asia": {"lon_slice": slice(67, 150), "lat_slice": slice(61, 5)},
    "Europe": {"lon_slice": slice(-22, 44), "lat_slice": slice(72, 27)},
    "North America": {"lon_slice": slice(-142, -53), "lat_slice": slice(76, 17)},
    "Sahel": {"lon_slice": slice(-22, 54), "lat_slice": slice(27, 0)},
    "South America": {"lon_slice": slice(-95, -27), "lat_slice": slice(17, -57)},
    "Southern Africa": {"lon_slice": slice(0, 54), "lat_slice": slice(0, -45)},
}
assert all(
    v["lon_slice"].start <= v["lon_slice"].stop  # slice(W, E) for best performance
    and v["lat_slice"].start >= v["lat_slice"].stop  # slice(N, S) for best performance
    for v in regions.values()
)

## Define requests

In [None]:
chunks = {"year": 1, "variable": 1}
common_request = {
    "format": "grib",
    "variable": variables,
    "area": [89.5, -179.5, -89.5, 179.5],
    "year": [str(year) for year in range(year_start_hindcast, year_stop_hindcast + 1)],
}

collection_id_reanalysis = "reanalysis-era5-single-levels-monthly-means"
request_reanalysis = common_request | {
    "product_type": "monthly_averaged_reanalysis",
    "month": [f"{month:02d}" for month in range(1, 13)],
    "time": "00:00",
}

collection_id_seasonal = "seasonal-monthly-single-levels"
request_seasonal = common_request | {
    "product_type": "monthly_mean",
    "leadtime_month": list(map(str, range(1, 7))),
    "month": [f"{month:02d}" for month in range(1, 13)],
    "grid": "1/1",
}

## Functions to cache

In [None]:
def regionalised_spatial_weighted_mean(
    ds, lon_slice, lat_slice, weights, mean_dims=None
):
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    ds = diagnostics.spatial_weighted_mean(ds, weights=weights)
    if mean_dims:
        return ds.mean(mean_dims, keep_attrs=True)
    return ds

## Download and transform ERA5

In [None]:
def preprocess(ds):
    return ds.expand_dims("leadtime")


# Get the reanalysis data
datasets = []
for region, transform_func_kwargs in regions.items():
    print(f"{region=}")
    ds = download.download_and_transform(
        collection_id_reanalysis,
        request_reanalysis,
        chunks=chunks,
        transform_func=regionalised_spatial_weighted_mean,
        transform_func_kwargs=transform_func_kwargs | {"weights": False},
        combine="nested",
        cached_open_mfdataset_kwargs={"drop_variables": ["leadtime"]},
    )
    datasets.append(ds.expand_dims(region=[region]).compute())
ds_reanalysis = xr.concat(datasets, "region")

## Download and transform seasonal forecast

In [None]:
# Get the reanalysis data
datasets = []
for centre, request_kwargs in centres.items():
    for region, transform_func_kwargs in regions.items():
        print(f"{centre=} {region=}")
        ds = download.download_and_transform(
            collection_id_seasonal,
            request_seasonal | {"originating_centre": centre} | request_kwargs,
            chunks=chunks,
            transform_func=regionalised_spatial_weighted_mean,
            transform_func_kwargs=transform_func_kwargs
            | {"weights": False, "mean_dims": ("realization",)},
            backend_kwargs={
                "time_dims": (
                    "verifying_time",
                    "indexing_time" if centre in ["ukmo", "jma", "ncep"] else "time",
                )
            },
        )
        if "time" in ds.dims:
            ds = ds.rename(time="verifying_time")
        datasets.append(ds.expand_dims(centre=[centre], region=[region]).compute())
ds_seasonal = xr.merge(datasets)

## Convert units

In [None]:
m_to_mm = 1.0e3
day_to_s = 60 * 60 * 24

with xr.set_options(keep_attrs=True):
    if "tp" in ds_reanalysis:
        ds_reanalysis["tp"] *= (
            m_to_mm * ds_reanalysis["forecast_reference_time"].dt.days_in_month
        )
        ds_reanalysis = ds_reanalysis.rename(tp="tp_month")
        ds_reanalysis["tp_month"].attrs["units"] = "mm/month"

    if "tprate" in ds_seasonal:
        ds_seasonal["tprate"] *= (
            m_to_mm * day_to_s * ds_seasonal["forecast_reference_time"].dt.days_in_month
        )
        ds_seasonal = ds_seasonal.rename(tprate="tp_month")
        ds_seasonal["tp_month"].attrs["units"] = "mm/month"

## Monthly data

In [None]:
ds_reanalysis = ds_reanalysis.groupby("forecast_reference_time.month").mean(
    keep_attrs=True
)
ds_reanalysis = ds_reanalysis.rename(month="valid_month")

for dim_in, dim_out in zip(["forecast_reference", "verifying"], ["starting", "valid"]):
    ds_seasonal = (
        ds_seasonal.groupby(f"{dim_in}_time.month")
        .mean(keep_attrs=True)
        .rename(month=f"{dim_out}_month")
    )

## Compute bias

In [None]:
with xr.set_options(keep_attrs=True):
    bias = ds_seasonal - ds_reanalysis
for da in bias.data_vars.values():
    da.attrs["long_name"] = "Bias of " + da.attrs["long_name"]
# Avoid ticks interpolation
bias = bias.assign_coords(
    {name: coord.astype(str) for name, coord in bias.coords.items()}
)

## Plot all biases together

In [None]:
for varname, da in bias.data_vars.items():
    kwargs = {}
    match varname:
        case "t2m":
            kwargs["cmap"] = "RdBu_r"
            da.attrs["units"] = "°C"
        case "tp_month":
            kwargs["cmap"] = "BrBG"
    facet = da.plot(col="centre", row="region", x="valid_month", **kwargs)
    _ = facet.fig.suptitle(f"Seasonal forecast: {da.attrs['long_name']}", y=1.01)