# Bias of seasonal forecast

## Import packages

In [None]:
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
year_forecast = 2023
year_start_hindcast = 1993
year_stop_hindcast = 2016
reference_month = 9

# Centres
centres = {
    "cmcc": {"system": "35"},
    "dwd": {"system": "21"},
    "eccc": {"system": "3"},
    "ecmwf": {"system": "51"},
    "jma": {"system": "3"},
    "ncep": {"system": "2"},
    "ukmo": {"system": "601"},
}
assert all("system" in v for v in centres.values())
assert set(cetres) <= {
    "cmcc",
    "dwd",
    "eccc",
    "ecmwf",
    "jma",
    "meteo_france",
    "ncep",
    "ukmo",
}

# Regions
regions = {
    "Australia": {"lon_slice": slice(101, 179), "lat_slice": slice(5, -52)},
    "East Asia": {"lon_slice": slice(67, 150), "lat_slice": slice(61, 5)},
    "Europe": {"lon_slice": slice(-22, 44), "lat_slice": slice(72, 27)},
    "North America": {"lon_slice": slice(-142, -53), "lat_slice": slice(76, 17)},
    "Sahel": {"lon_slice": slice(-22, 54), "lat_slice": slice(27, 0)},
    "South America": {"lon_slice": slice(-95, -27), "lat_slice": slice(17, -57)},
    "Southern Africa": {"lon_slice": slice(0, 54), "lat_slice": slice(0, -45)},
}
assert all(
    v["lon_slice"].start <= v["lon_slice"].stop  # slice(W, E) for best performance
    and v["lat_slice"].start >= v["lat_slice"].stop  # slice(N, S) for best performance
    for v in regions.values()
)

## Define requests

In [None]:
chunks = {"year": 1}
common_request = {
    "format": "grib",
    "variable": "2m_temperature",
    "area": [89.5, -179.5, -89.5, 179.5],
    "year": [str(year) for year in range(year_start_hindcast, year_stop_hindcast + 1)],
}

collection_id_reanalysis = "reanalysis-era5-single-levels-monthly-means"
request_reanalysis = common_request | {
    "product_type": "monthly_averaged_reanalysis",
    "month": [f"{month:02d}" for month in range(1, 13)],
    "time": "00:00",
}

collection_id_seasonal = "seasonal-monthly-single-levels"
request_seasonal = common_request | {
    "product_type": "monthly_mean",
    "leadtime_month": list(map(str, range(1, 7))),
    "month": f"{reference_month:02d}",
    "grid": "1/1",
}

## Functions to cache

In [None]:
def regionalised_spatial_weighted_mean(ds, lon_slice, lat_slice, weights):
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    return diagnostics.spatial_weighted_mean(ds, weights=weights)

## Download and transform ERA5

In [None]:
# Get the reanalysis data
datasets = []
for region, transform_func_kwargs in regions.items():
    print(f"{region=}")
    ds = download.download_and_transform(
        collection_id_reanalysis,
        request_reanalysis,
        chunks=chunks,
        transform_func=regionalised_spatial_weighted_mean,
        transform_func_kwargs=transform_func_kwargs | {"weights": False},
        combine="nested",
    )
    datasets.append(ds.expand_dims(region=[region]).compute())
ds_reanalysis = xr.concat(datasets, "region")
ds_reanalysis = ds_reanalysis.groupby("forecast_reference_time.month").mean()

## Download and transform seasonal forecast

In [None]:
# Get the reanalysis data
datasets = []
for centre, request_kwargs in centres.items():
    for region, transform_func_kwargs in regions.items():
        print(f"{centre=} {region=}")
        ds = download.download_and_transform(
            collection_id_seasonal,
            request_seasonal | {"originating_centre": centre} | request_kwargs,
            chunks=chunks,
            transform_func=regionalised_spatial_weighted_mean,
            transform_func_kwargs=transform_func_kwargs | {"weights": False},
            backend_kwargs={
                "time_dims": (
                    "forecastMonth",
                    "indexing_time" if centre in ["ukmo", "jma", "ncep"] else "time",
                )
            },
            cached_open_mfdataset_kwargs={
                "combine": "nested",
                "concat_dim": "forecast_reference_time",
            },
        )
        ds = ds.mean((["realization", "forecast_reference_time"]), keep_attrs=True)
        datasets.append(ds.expand_dims(centre=[centre], region=[region]).compute())
ds_seasonal = xr.merge(datasets)

## Compute bias

In [None]:
with xr.set_options(keep_attrs=True):
    bias = (ds_seasonal - ds_reanalysis)["t2m"]
bias.attrs["long_name"] = "Bias of " + bias.attrs["long_name"]
# Avoid ticks interpolation
bias = bias.assign_coords(
    {name: coord.astype(str) for name, coord in bias.coords.items()}
)

## Plot all biases together

In [None]:
facet = bias.plot(col="region", row="centre")
_ = facet.fig.suptitle(f"Seasonal forecast: {bias.attrs['long_name']}", y=1.01)

## Plot regional biases

In [None]:
for region, da in bias.groupby("region"):
    facet = da.plot(col="centre", col_wrap=4)
    facet.fig.suptitle(
        f"Seasonal forecast: {bias.attrs['long_name']}\nArea: {region}", y=1.05
    )