# WP4: Climatology and Bias - Precipitation

## Import libraries

In [None]:
import warnings

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skill_metrics
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

warnings.filterwarnings("ignore")

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
year_start = 1985
year_stop = 1987

# Choose CORDEX or cmip6
collection_id = "projections-cordex-domains-single-levels"  # "projections-cmip6"

# Define region
lon_slice = slice(-4, 20)
lat_slice = slice(35, 50)
era5_area = [73, -45, 20, 65]

# Chunks for download
chunks = {"year": 1}

## Define models

In [None]:
models_cordex = [
    "smhi_rca4",
    "clmcom_clm_cclm4_8_17",
    "clmcom_eth_cosmo_crclim",
    "mpi_csc_remo2009",
    "knmi_racmo22e",
    "dmi_hirham5",
    "uhoh_wrf361h",
    "cnrm_aladin63",
    "mohc_hadrem3_ga7_05",
]

models_cmip6 = [
    "access_cm2",
    "awi_esm_1_1_lr",
    "bcc_esm1",
    "cesm2_fv2",
    "cnrm_cm6_1",
    "fgoals_g3",
]

## Define ERA5 request

In [None]:
request_era = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "mean_total_precipitation_rate",
        "area": era5_area,
        "year": [str(year) for year in range(year_start, year_stop + 1)],
        "month": [f"{month:02d}" for month in range(1, 12 + 1)],
    },
)

## Define model requests

In [None]:
request_cordex = {
    "format": "zip",
    "domain": "europe",
    "experiment": "historical",
    "horizontal_resolution": "0_11_degree_x_0_11_degree",
    "temporal_resolution": "monthly_mean",
    "variable": "mean_precipitation_flux",
    "gcm_model": "mpi_m_mpi_esm_lr",
    "ensemble_member": "r1i1p1",
}

request_cmip6 = {
    "format": "zip",
    "temporal_resolution": "monthly",
    "experiment": "historical",
    "variable": "precipitation",
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": [f"{month:02d}" for month in range(1, 12 + 1)],
}


def get_sim_years(
    year_start,
    year_stop,
    start_years=[1971, 1981, 1991, 2001],
    end_years=[1980, 1990, 2000, 2005],
):
    start_year = []
    end_year = []
    years = set(range(year_start, year_stop + 1))
    for start, end in zip(start_years, end_years):
        if years & set(range(start, end + 1)):
            start_year.append(start)
            end_year.append(end)
    return start_year, end_year


if collection_id == "projections-cordex-domains-single-levels":
    weights = False  # Do not weight spatial statistics/errors
    models = models_cordex
    model_key = "rcm_model"
    request_sim = (
        collection_id,
        [
            {
                **request_cordex,
                "start_year": start_year,
                "end_year": end_year,
            }
            for start_year, end_year in zip(*get_sim_years(year_start, year_stop))
        ],
    )
elif collection_id == "projections-cmip6":
    weights = True  # Weight spatial statistics/errors
    models = models_cmip6
    model_key = "model"
    request_sim = (
        "projections-cmip6",
        download.split_request(request_cmip6, chunks=chunks),
    )
else:
    raise ValueError(f"{collection_id=} is not supported for this notebook.")

## Define transform function

In [None]:
def resample_and_regrid_and_rescale(
    ds, model, year_start, year_stop, grid_out=None, **kwargs
):
    varname = "mtpr" if model == "ERA5" else "pr"
    ds = ds[[varname]].sel(
        time=ds["time"].dt.year.isin(range(year_start, year_stop + 1))
    )

    ds = diagnostics.time_weighted_mean(ds)
    if grid_out is not None:
        ds = diagnostics.regrid(ds, grid_out, **kwargs)

    # Change units
    varname = "mtpr" if model == "ERA5" else "pr"
    with xr.set_options(keep_attrs=True):
        ds[varname] = ds[varname] * 3600 * 24
    ds[varname].attrs["units"] = "mm/day"

    return ds.rename({varname: "precipitation"}).expand_dims(model=[model])

## Download data

In [None]:
print("Downloading and processing ERA5")
ds_era = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_func=resample_and_regrid_and_rescale,
    transform_func_kwargs={
        "model": "ERA5",
        "year_start": year_start,
        "year_stop": year_stop,
    },
    transform_chunks=False,
)

datasets = []
for model in models:
    print(f"Downloading and processing {model}")
    request_model = request_sim
    for request in request_model[1]:
        request[model_key] = model
    ds = download.download_and_transform(
        *request_model,
        transform_func=resample_and_regrid_and_rescale,
        transform_func_kwargs={
            "grid_out": ds_era[["longitude", "latitude"]],
            "model": model,
            "year_start": year_start,
            "year_stop": year_stop,
            "method": "bilinear",
        },
        transform_chunks=False,
    )
    datasets.append(ds)
ds_sim = xr.concat(datasets, "model")

## Create a single dataset and compute bias

In [None]:
ds = xr.concat(
    [ds_sim, ds_sim.mean("model").expand_dims(model=["ensemble"]), ds_era], "model"
)

ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)

with xr.set_options(keep_attrs=True):
    ds_bias = ds.drop_sel(model="ERA5") - ds.sel(model="ERA5")
for da in ds_bias.data_vars.values():
    da.attrs["long_name"] += " Bias"

## Plot Maps

In [None]:
# Choose projection
Projection = (
    ccrs.Robinson
    if lon_slice.stop - lon_slice.start == 360
    and lat_slice.stop - lat_slice.start == 180
    else ccrs.PlateCarree
)
projection = Projection(central_longitude=(lon_slice.stop + lon_slice.start) / 2)

plot_kwargs = {"vmin": 0, "robust": True, "cmap": "Blues", "projection": projection}
for model in ["ensemble", "ERA5"]:
    plot.projected_map(ds["precipitation"].sel(model=[model]), **plot_kwargs)
    plt.show()

## Plot Bias

In [None]:
plot_kwargs = {
    "center": True,
    "robust": True,
    "cmap": "RdBu_r",
    "projection": projection,
}
plot.projected_map(ds_bias["precipitation"].sel(model=["ensemble"]), **plot_kwargs)
plt.show()
facet = plot.projected_map(
    ds_bias["precipitation"].drop_sel(model="ensemble"),
    col="model",
    col_wrap=3,
    **plot_kwargs,
)
for ax in facet.axs.flat:
    ax.set_extent(
        ds_bias["longitude"][[0, -1]].values.tolist()
        + ds_bias["latitude"][[0, -1]].values.tolist()
    )

## Plot KDE and Statistics of Bias

TODO: The plot is unweighted, the statistics are weighted.

In [None]:
# Create dataframes
da = ds_bias["precipitation"]
da_dict = {k: v.values.flatten() for k, v in da.groupby("model")}
df_ensemble = pd.DataFrame({"ensemble": da_dict.pop("ensemble")})
df_models = pd.DataFrame(da_dict)
df_stats = diagnostics.spatial_weighted_statistics(da).to_pandas()

# Plot
plot_kwargs = {
    "xlim": (
        df_stats["ensemble"]["mean"] - 3 * df_stats["ensemble"]["std"],
        df_stats["ensemble"]["mean"] + 3 * df_stats["ensemble"]["std"],
    ),
    "grid": True,
}
ax = df_models.plot.kde(**plot_kwargs)
ax = df_ensemble.plot.kde(color="k", ls="--", ax=ax, **plot_kwargs)
ax.set_xlabel(f"{da.attrs['long_name']} [{da.attrs['units']}]")

# Add stats
table = plt.table(
    cellText=df_stats.round(5).T.values.tolist(),
    colLabels=df_stats.T.columns.values.tolist(),
    rowLabels=df_stats.T.index.values.tolist(),
    loc="top",
)

## Compute and Show Statistics

In [None]:
ds_stats = diagnostics.spatial_weighted_statistics(ds, weights=weights)
ds_error = diagnostics.spatial_weighted_errors(
    ds.drop_sel(model="ERA5"), ds.sel(model="ERA5"), weights=weights
)
df_stats_and_error = xr.merge([ds_stats, ds_error])["precipitation"].to_pandas()
df_stats_and_error

## Taylor Dyagram

In [None]:
skill_metrics.taylor_diagram(
    df_stats_and_error.loc["std"].values,
    df_stats_and_error.loc["crmse"].values,
    df_stats_and_error.loc["corr"].values,
    alpha=0.0,
    axismax=4,
    colCOR="k",
    colOBS="k",
    colRMS="m",
    colSTD="b",
    markerColor="r",
    markerLabel=list(df_stats_and_error.columns),
    markerLegend="on",
    markerSize=10,
    markerobs="o",
    styleCOR="--",
    styleOBS="--",
    styleRMS=":",
    styleSTD="-.",
    tickRMS=np.linspace(0, 4, 5),
    tickSTD=np.linspace(0, 4, 5),
    titleCOR="on",
    titleOBS="ERA5",
    titleRMS="on",
    titleRMSDangle=40.0,
    titleSTD="on",
    widthCOR=0.5,
    widthOBS=2,
    widthRMS=2,
    widthSTD=1.0,
)