# Uncertainty in seasonal forecasts for a single model system

## Import packages

In [None]:
import numpy as np
import plotly.figure_factory as ff
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, utils
from dateutil.relativedelta import relativedelta

## Define Parameters

In [None]:
# Model
centre = "ecmwf"
system = "51"

# Time
year_forecast = 2023
year_start_hindcast = 1993
year_stop_hindcast = 2016
month = 6

# Region
region_name = "Southern Norway"
lat_slice = slice(64, 58)
lon_slice = slice(4, 14)

# Download parameters
chunks = {"year": 1}
n_jobs = 1  # Number of concurrent requests for parallel download

## Define request

In [None]:
collection_id = "seasonal-monthly-single-levels"

request = {
    "format": "grib",
    "originating_centre": centre,
    "system": system,
    "variable": "2m_temperature",
    "product_type": "monthly_mean",
    "leadtime_month": list(map(str, range(1, 7))),
    "area": [89.5, -179.5, -89.5, 179.5],
    "grid": "1/1",
    "month": f"{month:02d}",
}

## Functions to cache

In [None]:
def regionalised_mean(ds, lon_slice, lat_slice, weights):
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    ds = diagnostics.spatial_weighted_mean(ds, weights=weights)
    with xr.set_options(keep_attrs=True):
        ds["t2m"] -= 273.15
    ds["t2m"].attrs["units"] = "°C"
    return ds

## Download and transform

In [None]:
datasets = {}
for model, years in {
    "hindcast": range(year_start_hindcast, year_stop_hindcast + 1),
    "forecast": [year_forecast],
}.items():
    ds = download.download_and_transform(
        collection_id,
        request | {"year": list(map(str, years))},
        chunks=chunks,
        n_jobs=n_jobs,
        transform_func=regionalised_mean,
        transform_func_kwargs={
            "lon_slice": lon_slice,
            "lat_slice": lat_slice,
            "weights": False,
        },
        backend_kwargs={
            "time_dims": (
                "forecastMonth",
                "indexing_time" if centre in ["ukmo", "jma", "ncep"] else "time",
            )
        },
        cached_open_mfdataset_kwargs={
            "combine": "nested",
            "concat_dim": "forecast_reference_time",
        },
    )
    datasets[model] = ds

## Density plot

In [None]:
def get_limits(data, xfactor, yfactor):
    ylim = [0, max([max(d.y) for d in data]) * yfactor]
    xlim = [func([func(d.x) for d in data]) for func in (min, max)]
    xshift = abs(xlim[1] - xlim[0]) * xfactor
    xlim = [x + xshift * sign for x, sign in zip(xlim, (-1, 1))]
    return xlim, ylim


# Density plot for each lead time
for leadtime_month, ds_forecast in datasets["forecast"].groupby("leadtime_month"):
    ds_hindcast = datasets["hindcast"].sel(leadtime_month=leadtime_month)

    colors = [(26, 150, 65), (100, 50, 150)]
    values = [
        ds_hindcast["t2m"].values.flatten(),
        ds_forecast["t2m"].values.flatten(),
    ]
    labels = [centre + " climatology", centre + " forecast"]
    fig = ff.create_distplot(
        values,
        labels,
        show_hist=False,
        show_rug=True,
        colors=[f"rgb{color}" for color in colors],
        curve_type="kde",
    )
    for color, data in zip(colors, fig.data):
        # Fill area under distline
        fig.add_scatter(
            x=data.x,
            y=data.y,
            fill="tozeroy",
            mode="none",
            fillcolor=f"rgba{color + (.4,)}",
            showlegend=False,
        )

    xlim, ylim = get_limits(fig.data[:2], xfactor=0.03, yfactor=1.4)
    quantiles = np.quantile(values[0], [1 / 3, 2 / 3]).tolist()
    scatter_dicts = {
        "lower tercile": {
            "color": (0, 180, 250, 0.1),
            "text": "COLD",
            "mask": values[1] <= quantiles[0],
            "xlim": [xlim[0], quantiles[0]],
        },
        "middle tercile": {
            "color": (230, 230, 0, 0.1),
            "text": "NEAR AVERAGE",
            "mask": (values[1] > quantiles[0]) & (values[1] <= quantiles[1]),
            "xlim": quantiles,
        },
        "upper tercile": {
            "color": (250, 50, 0, 0.1),
            "text": "WARM",
            "mask": values[1] > quantiles[1],
            "xlim": [quantiles[1], xlim[1]],
        },
    }
    for i, (name, scatter_dict) in enumerate(scatter_dicts.items()):
        # Add background color and text
        fig.add_scatter(
            x=scatter_dict["xlim"],
            y=[ylim[1]] * 2,
            fill="tozeroy",
            mode="none",
            fillcolor=f"rgba{scatter_dict['color']}",
            name=name,
        )
        percentage = 100 * scatter_dict["mask"].sum() / values[1].size
        text_color = tuple(c // 2 for c in scatter_dict["color"][:-1]) + (0.4,)
        fig.add_scatter(
            x=[sum(scatter_dict["xlim"]) / 2],
            y=[ylim[1] * 0.98],
            mode="text",
            name="",
            text=f"{scatter_dict['text']}<br>{round(percentage)}%",
            textfont=dict(size=18, color=f"rgba{text_color}"),
            textposition="bottom center",
            showlegend=False,
        )

    # Title and labels
    (forecast_reference_time,) = ds_forecast["forecast_reference_time"].dt.date.values
    valid_time = forecast_reference_time + relativedelta(months=leadtime_month - 1)
    title = (
        f"Density plot of {ds_forecast['t2m'].attrs['long_name']}"
        f" over {region_name} for {valid_time.strftime('%b %Y')},"
        f"<br>from {centre} with start time {forecast_reference_time}."
        f" Hindcast period from {year_start_hindcast} to {year_stop_hindcast}."
    )
    fig.update_layout(
        title=dict(text=title, font={"size": 22}),
        yaxis_range=ylim,
        xaxis_range=xlim,
        font_size=17,
        autosize=False,
        width=900,
        height=500,
    )
    fig.update_xaxes(
        title_text=(
            f"Mean {ds_forecast['t2m'].attrs['long_name']}"
            f" ({ds_forecast['t2m'].attrs['units']})"
        )
    )
    fig.show()