# Hit-rate of seasonal forecasts

## Import packages

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import regionmask
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
year_forecast = 2023
year_start_hindcast = 1993
year_stop_hindcast = 2016

# Variable
variables = [
    "2m_temperature",
    # "total_precipitation",
    # "surface_solar_radiation_downwards",
    # "10m_wind_speed",
    # "2m_dewpoint_temperature",
]

# Define centres with missing variables
missing_variables = {"eccc": ["2m_dewpoint_temperature"]}

# Centres
centres = {
    "cmcc": {"system": "35"},
    # "dwd": {"system": "21"},
    # "eccc": {"system": "3"},
    # "ecmwf": {"system": "51"},
    # "jma": {"system": "3"},
    # "meteo_france": {"system": "8"},
    # "ncep": {"system": "2"},
    # "ukmo": {"system": "601"},
}
assert all("system" in v for v in centres.values())
assert set(centres) <= {
    "cmcc",
    "dwd",
    "eccc",
    "ecmwf",
    "jma",
    "meteo_france",
    "ncep",
    "ukmo",
}

# Regions
regions = [
    "EAF",
    # "ENA",
    # "MED",
    # "NEB",
    # "SAS",
    # "SEA",
    # "WNA",
    # "WSA",
]
assert set(regions) <= set(regionmask.defined_regions.srex.abbrevs)

# Ploting settings
plot_kwargs = {
    "total_precipitation": {"cmap": "BrBG"},
}

## Plot selected regions

In [None]:
ax = regionmask.defined_regions.srex[regions].plot(
    projection=ccrs.Robinson(),
    label="abbrev",
    add_ocean=True,
)
ax.set_global()

## Define requests

In [None]:
chunks = {"year": 1, "variable": 1}
common_request = {
    "format": "grib",
    "area": [89.5, -179.5, -89.5, 179.5],
    "year": [str(year) for year in range(year_start_hindcast, year_stop_hindcast + 1)],
}

collection_id_reanalysis = "reanalysis-era5-single-levels-monthly-means"
request_reanalysis = common_request | {
    "product_type": "monthly_averaged_reanalysis",
    "month": [f"{month:02d}" for month in range(1, 13)],
    "time": "00:00",
}

collection_id_seasonal = "seasonal-monthly-single-levels"
request_seasonal = common_request | {
    "product_type": "monthly_mean",
    "leadtime_month": list(map(str, range(1, 7))),
    "month": [f"{month:02d}" for month in range(1, 13)],
    "grid": "1/1",
}

## Functions to cache

In [None]:
def compute_tercile_occupation(ds, region):
    # Anomaly
    ds = ds - diagnostics.time_weighted_mean(ds)

    # Reindex using year/month
    time = ds["forecast_reference_time"]
    ds = ds.assign_coords(
        year=(time.name, time.dt.year.data),
        month=(time.name, time.dt.month.data),
    )
    ds = ds.set_index({time.name: ("year", "month")}).unstack(time.name)

    # Mask region
    mask = regionmask.defined_regions.srex.mask(ds)
    index = regionmask.defined_regions.srex.map_keys(region)
    ds = ds.where((mask == index).compute(), drop=True)

    # Spatial mean
    ds = diagnostics.spatial_weighted_mean(ds, weights=False)

    # Get quantiles
    quantiles = ds.chunk(year=-1).quantile([1 / 3, 2 / 3], "year")
    mask = xr.zeros_like(ds, None)
    mask = xr.where(ds < quantiles.sel(quantile=1 / 3), -1, mask)
    mask = xr.where(ds > quantiles.sel(quantile=2 / 3), 1, mask)

    return mask

## Download and transform ERA5

In [None]:
# Get the reanalysis data
datasets = []
for region in regions:
    dataarrays = []
    for variable in variables:
        print(f"{region=} {variable=}")
        ds = download.download_and_transform(
            collection_id_reanalysis,
            request_reanalysis | {"variable": variable},
            chunks=chunks,
            transform_chunks=False,
            transform_func=compute_tercile_occupation,
            transform_func_kwargs={"region": region},
        )
        (da,) = ds.data_vars.values()
        dataarrays.append(da.rename(variable))
    ds = xr.merge(dataarrays)
    datasets.append(ds.expand_dims(region=[region]).compute())
ds_reanalysis = xr.concat(datasets, "region")
del datasets

In [None]:
ds_reanalysis.reset_coords(drop=True)["2m_temperature"].plot()