# Hit-rate of seasonal forecasts

## Import packages

In [None]:
import tempfile

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import regionmask
import scipy.stats
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
year_start_hindcast = 1993
year_stop_hindcast = 2016

# Variable
variables = [
    "2m_temperature",
    # "total_precipitation",
    # "surface_solar_radiation_downwards",
    # "10m_wind_speed",
    # "2m_dewpoint_temperature",
]

# Define centres with missing variables
missing_variables = {"eccc": ["2m_dewpoint_temperature"]}

# Centres
centres = {
    "cmcc": {"system": "35"},
    # "dwd": {"system": "21"},
    # "eccc": {"system": "3"},
    # "ecmwf": {"system": "51"},
    # "jma": {"system": "3"},
    # "meteo_france": {"system": "8"},
    # "ncep": {"system": "2"},
    # "ukmo": {"system": "601"},
}
assert all("system" in v for v in centres.values())
assert set(centres) <= {
    "cmcc",
    "dwd",
    "eccc",
    "ecmwf",
    "jma",
    "meteo_france",
    "ncep",
    "ukmo",
}

# Regions
regions = [
    "EAF",
    # "ENA",
    # "MED",
    # "NEB",
    # "SAS",
    # "SEA",
    # "WNA",
    # "WSA",
]
assert set(regions) <= set(regionmask.defined_regions.srex.abbrevs)

# Ploting settings
plot_kwargs = {
    "total_precipitation": {"cmap": "BrBG"},
}

## Plot selected regions

In [None]:
ax = regionmask.defined_regions.srex[regions].plot(
    projection=ccrs.Robinson(),
    label="abbrev",
    add_ocean=True,
)
ax.set_global()

## Define requests

In [None]:
chunks = {"year": 1, "variable": 1}
common_request = {
    "format": "grib",
    "area": [89.5, -179.5, -89.5, 179.5],
    "year": [str(year) for year in range(year_start_hindcast, year_stop_hindcast + 1)],
}

collection_id_reanalysis = "reanalysis-era5-single-levels-monthly-means"
request_reanalysis = common_request | {
    "product_type": "monthly_averaged_reanalysis",
    "month": [f"{month:02d}" for month in range(1, 13)],
    "time": "00:00",
}

collection_id_seasonal = "seasonal-monthly-single-levels"
request_seasonal = common_request | {
    "product_type": "monthly_mean",
    "leadtime_month": list(map(str, range(1, 7))),
    "month": [f"{month:02d}" for month in range(1, 13)],
    "grid": "1/1",
}

## Functions to cache

In [None]:
def mode(*args, axis=None, **kwargs):
    return scipy.stats.mode(*args, axis=axis, **kwargs).mode


def reindex_seasonal_forecast(ds):
    # Stack starting_time and leading_month
    ds = ds.rename(forecast_reference_time="starting_time")
    ds = ds.stack(
        time=("starting_time", "leadtime_month"),
        create_index=False,
    )

    # Shift valid_time
    ds = ds.set_index(time="starting_time")
    valid_time = ds.indexes["time"]
    for shift in set(ds["leadtime_month"].values):
        shifted = ds.indexes["time"].shift(shift - 1, "MS")
        valid_time = valid_time.where(ds["leadtime_month"] != shift, shifted)

    # Reindex: valid_time and starting_month
    coords = {
        "valid_time": ("time", valid_time),
        "starting_month": ("time", ds["time"].dt.month.data),
    }
    ds = ds.assign_coords(coords)
    ds = ds.set_index({"time": tuple(coords)}).unstack("time")
    return ds


def compute_tercile_occupation(ds, region):
    # Mask region
    mask = regionmask.defined_regions.srex.mask(ds)
    index = regionmask.defined_regions.srex.map_keys(region)
    ds = ds.where((mask == index).compute(), drop=True)

    if "leadtime_month" in ds.dims:
        ds = reindex_seasonal_forecast(ds)
    else:
        ds = ds.rename(forecast_reference_time="valid_time")

    # Compute anomaly
    climatology = diagnostics.time_weighted_mean(ds, time_name="valid_time")
    climatology = climatology.mean(set(climatology.dims) & {"realization"})
    ds -= climatology

    # Reindex using valid year/month month
    time = ds["valid_time"]
    coords = {
        "valid_year": (time.name, time.dt.year.data),
        "valid_month": (time.name, time.dt.month.data),
    }
    ds = ds.assign_coords(coords)
    ds = ds.set_index({time.name: tuple(coords)}).unstack(time.name)

    # Spatial mean
    ds = diagnostics.spatial_weighted_mean(ds, weights=False)

    # Get quantiles
    quantiles = ds.chunk(valid_year=-1).quantile([1 / 3, 2 / 3], "valid_year")
    low = quantiles.sel(quantile=1 / 3)
    high = quantiles.sel(quantile=2 / 3)
    mask = xr.full_like(ds, None)
    mask = xr.where(ds < low, -1, mask)
    mask = xr.where((ds >= low) & (ds <= high), 0, mask)
    mask = xr.where(ds > high, 1, mask)

    if "realization" in mask.dims:
        # Get mode
        mask = mask.reduce(mode, dim="realization")

    return mask.reset_coords(drop=True)

## Download and transform ERA5

In [None]:
# Get the reanalysis data
datasets = []
for region in regions:
    dataarrays = []
    for variable in variables:
        print(f"{region=} {variable=}")
        ds = download.download_and_transform(
            collection_id_reanalysis,
            request_reanalysis | {"variable": variable},
            chunks=chunks,
            transform_chunks=False,
            transform_func=compute_tercile_occupation,
            transform_func_kwargs={"region": region},
        )
        (da,) = ds.data_vars.values()
        dataarrays.append(da.rename(variable))
    ds = xr.merge(dataarrays)
    datasets.append(ds.expand_dims(region=[region]).compute())
ds_reanalysis = xr.concat(datasets, "region")
del datasets

## Download and transform seasonal forecast

In [None]:
# Get the seasonal forecast data
datasets = []
for centre, request_kwargs in centres.items():
    for region in regions:
        dataarrays = []
        for variable in variables:
            print(f"{centre=} {region=} {variable=}")
            if variable in missing_variables.get(centre, []):
                print("SKIP")
                continue

            with tempfile.TemporaryDirectory() as TMPDIR:
                ds = download.download_and_transform(
                    collection_id_seasonal,
                    request_seasonal
                    | {"originating_centre": centre, "variable": variable}
                    | request_kwargs,
                    chunks=chunks,
                    transform_chunks=False,
                    transform_func=compute_tercile_occupation,
                    transform_func_kwargs={"region": region},
                    backend_kwargs={
                        "time_dims": (
                            "forecastMonth",
                            (
                                "indexing_time"
                                if centre in ["ukmo", "jma", "ncep"]
                                else "time"
                            ),
                        )
                    },
                )
            (da,) = ds.data_vars.values()
            dataarrays.append(da.rename(variable))
        ds = xr.merge(dataarrays)
        datasets.append(ds.expand_dims(centre=[centre], region=[region]).compute())
ds_seasonal = xr.merge(datasets)
del datasets

## Compute and plot hit-rate

In [None]:
hit_rate = (ds_seasonal == ds_reanalysis).sum("valid_year")
hit_rate = hit_rate.where(ds_seasonal.notnull().any("valid_year"))
for var, da in hit_rate.data_vars.items():
    da["valid_month"] = da["valid_month"].astype(str)
    da.attrs["long_name"] = "Hit-Rate"
    facet = da.plot(col="centre", row="region", x="valid_month")
    facet.fig.suptitle(f"{da.name.replace('_', ' ')}", y=1.01)
    plt.show()

## Debug Plots

In [None]:
def plot_mask(da, **kwargs):
    cmap = kwargs.pop("cmap", "RdYlBu_r")
    kwargs["cmap"] = plt.get_cmap(cmap, 3)
    cbar_kwargs = {"ticks": [-1, 0, 1]}

    da["valid_month"] = da["valid_month"].astype(str)
    plot_obj = da.plot(cbar_kwargs=cbar_kwargs, **kwargs)

    cbar = (
        plot_obj.cbar
        if isinstance(plot_obj, xr.plot.facetgrid.FacetGrid)
        else plot_obj.colorbar
    )
    cbar.ax.set_yticklabels(["Low", "Medium", "High"])
    return plot_obj

In [None]:
plot_mask(ds_reanalysis["2m_temperature"])

In [None]:
plot_mask(ds_seasonal["2m_temperature"], col="valid_year", col_wrap=5)