# WP4: Climatology and Bias - Precipitation

## Import libraries

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
import skill_metrics
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time
year_start = 1985
year_stop = 1987

# Choose CORDEX or CMIP6
collection_id = "projections-cordex-domains-single-levels"  # "projections-cmip6"

# Define region for analysis
lon_slice = slice(-4, 20)
lat_slice = slice(35, 50)

# Define region for request
cordex_domain = "europe"

# Chunks for download
chunks = {"year": 1}
assert "month" not in chunks, "Do not use chunks smaller than 1y"

## Define models

In [None]:
models_cordex = [
    "clmcom_clm_cclm4_8_17",
    "clmcom_eth_cosmo_crclim",
    "cnrm_aladin63",
    "dmi_hirham5",
    "knmi_racmo22e",
    "mohc_hadrem3_ga7_05",
    "mpi_csc_remo2009",
    "smhi_rca4",
    "uhoh_wrf361h",
]

models_cmip6 = [
    "access_cm2",
    "awi_esm_1_1_lr",
    "bcc_esm1",
    "cesm2_fv2",
    "cnrm_cm6_1",
    "fgoals_g3",
]

## Define ERA5 request

In [None]:
request_era = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "mean_total_precipitation_rate",
        "year": [str(year) for year in range(year_start, year_stop + 1)],
        "month": [f"{month:02d}" for month in range(1, 12 + 1)],
    },
)

## Define model requests

In [None]:
request_cordex = {
    "format": "zip",
    "domain": cordex_domain,
    "experiment": "historical",
    "horizontal_resolution": "0_11_degree_x_0_11_degree",
    "temporal_resolution": "monthly_mean",
    "variable": "mean_precipitation_flux",
    "gcm_model": "mpi_m_mpi_esm_lr",
    "ensemble_member": "r1i1p1",
}

request_cmip6 = {
    "format": "zip",
    "temporal_resolution": "monthly",
    "experiment": "historical",
    "variable": "precipitation",
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": [f"{month:02d}" for month in range(1, 12 + 1)],
}


def get_cordex_years(
    year_start,
    year_stop,
    start_years=[1971, 1981, 1991, 2001],
    end_years=[1980, 1990, 2000, 2005],
):
    start_year = []
    end_year = []
    years = set(range(year_start, year_stop + 1))
    for start, end in zip(start_years, end_years):
        if years & set(range(start, end + 1)):
            start_year.append(start)
            end_year.append(end)
    return start_year, end_year


if collection_id == "projections-cordex-domains-single-levels":
    weights = False  # Do not weight spatial statistics/errors
    periodic = False
    models = models_cordex
    model_key = "rcm_model"
    request_sim = (
        collection_id,
        [
            {
                **request_cordex,
                "start_year": start_year,
                "end_year": end_year,
            }
            for start_year, end_year in zip(*get_cordex_years(year_start, year_stop))
        ],
    )
elif collection_id == "projections-cmip6":
    weights = True  # Weight spatial statistics/errors
    periodic = True
    models = models_cmip6
    model_key = "model"
    request_sim = (
        "projections-cmip6",
        download.split_request(request_cmip6, chunks=chunks),
    )
else:
    raise ValueError(f"{collection_id=} is not supported for this notebook.")

## Define transform function

In [None]:
def compute_annual_regridded_precipitation(
    ds, model, year_start, year_stop, grid_out=None, **kwargs
):
    varname = "mtpr" if model == "ERA5" else "pr"
    mask = ds["time"].dt.year.isin(range(year_start, year_stop + 1))
    ds = ds[[varname]].where(mask.compute(), drop=True)

    ds = diagnostics.annual_weighted_mean(ds)
    if grid_out is not None:
        ds = diagnostics.regrid(ds, grid_out, **kwargs)

    # Change units
    varname = "mtpr" if model == "ERA5" else "pr"
    with xr.set_options(keep_attrs=True):
        ds[varname] = ds[varname] * 3600 * 24
    ds[varname].attrs["units"] = "mm/day"

    return ds.rename({varname: "precipitation"}).expand_dims(model=[model])

## Download data

In [None]:
print("Downloading and processing ERA5")
ds_era = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_func=compute_annual_regridded_precipitation,
    transform_func_kwargs={
        "model": "ERA5",
        "year_start": year_start,
        "year_stop": year_stop,
    },
)

datasets = []
for model in models:
    print(f"Downloading and processing {model}")
    request_model = request_sim
    for request in request_model[1]:
        request[model_key] = model
    ds = download.download_and_transform(
        *request_model,
        transform_func=compute_annual_regridded_precipitation,
        transform_func_kwargs={
            "grid_out": ds_era[["longitude", "latitude"]],
            "model": model,
            "year_start": year_start,
            "year_stop": year_stop,
            "method": "bilinear",
            "periodic": periodic,
        },
    )
    datasets.append(ds)
ds_sim = xr.concat(datasets, "model")

## Create a single dataset and compute bias

In [None]:
ds_annual = xr.concat(
    [ds_sim, ds_sim.mean("model").expand_dims(model=["ensemble"]), ds_era], "model"
)
ds_annual = utils.regionalise(ds_annual, lon_slice=lon_slice, lat_slice=lat_slice)

with xr.set_options(keep_attrs=True):
    ds = ds_annual.mean("year")
    ds_bias = ds.drop_sel(model="ERA5") - ds.sel(model="ERA5")
for da in ds_bias.data_vars.values():
    da.attrs["long_name"] += " Bias"

## Plot Maps

In [None]:
# Choose projection
Projection = (
    ccrs.Robinson
    if abs(lon_slice.stop - lon_slice.start) == 360
    and abs(lat_slice.stop - lat_slice.start) == 180
    else ccrs.PlateCarree
)
projection = Projection(central_longitude=(lon_slice.stop + lon_slice.start) / 2)

da_to_plot = ds["precipitation"].sel(model=["ensemble", "ERA5"])
plot_kwargs = xr.plot.utils._determine_cmap_params(
    da_to_plot.values, robust=True, levels=10, cmap="Blues"
)
plot_kwargs["projection"] = projection
for _, da_to_plot in da_to_plot.groupby("model"):
    plot.projected_map(da_to_plot, **plot_kwargs)
    plt.show()

## Plot Bias

In [None]:
plot_kwargs = xr.plot.utils._determine_cmap_params(
    ds_bias["precipitation"].values, robust=True, levels=11, cmap="bwr_r"
)
plot_kwargs["projection"] = projection
plot.projected_map(ds_bias["precipitation"].sel(model=["ensemble"]), **plot_kwargs)
plt.show()
facet = plot.projected_map(
    ds_bias["precipitation"].drop_sel(model="ensemble"),
    col="model",
    col_wrap=min(3, ds_bias.sizes["model"] - 1),
    **plot_kwargs,
)

## Plot KDE and Statistics of Bias

In [None]:
# Create dataframe
da = ds_bias["precipitation"]
df_stats = diagnostics.spatial_weighted_statistics(da, weights=weights).to_pandas()

# Plot
fig, ax = plt.subplots(1, 1)
x = np.linspace(
    df_stats["ensemble"]["mean"] - 3 * df_stats["ensemble"]["std"],
    df_stats["ensemble"]["mean"] + 3 * df_stats["ensemble"]["std"],
    1_000,
)
for model, values in da.groupby("model"):
    values = values.stack(dim=values.dims).dropna("dim")
    y = scipy.stats.gaussian_kde(
        values,
        weights=np.abs(np.cos(np.deg2rad(values["latitude"]))) if weights else None,
    ).evaluate(x)
    plot_kwargs = {"color": "k", "ls": "--"} if model == "ensemble" else {}
    ax.plot(x, y, label=model, **plot_kwargs)
ax.grid()
ax.set_xlim(x[[0, -1]])
ax.set_xlabel(f"{da.attrs['long_name']} [{da.attrs['units']}]")
ax.legend()

# Add stats
table = plt.table(
    cellText=df_stats.round(5).T.values.tolist(),
    colLabels=df_stats.T.columns.values.tolist(),
    rowLabels=df_stats.T.index.values.tolist(),
    loc="top",
)

## Compute and Show Statistics

In [None]:
ds_stats = diagnostics.spatial_weighted_statistics(ds, weights=weights)
ds_error = diagnostics.spatial_weighted_errors(
    ds.drop_sel(model="ERA5"), ds.sel(model="ERA5"), weights=weights
)
df_stats_and_error = xr.merge([ds_stats, ds_error])["precipitation"].to_pandas()
df_stats_and_error

## Taylor Dyagram

In [None]:
skill_metrics.taylor_diagram(
    df_stats_and_error.loc["std"].values,
    df_stats_and_error.loc["crmse"].values,
    df_stats_and_error.loc["corr"].values,
    alpha=0.0,
    axismax=4,
    colCOR="k",
    colOBS="k",
    colRMS="m",
    colSTD="b",
    markerColor="r" if len(df_stats_and_error.columns) >= 9 else None,  # TODO
    markerLabel=list(df_stats_and_error.columns),
    markerLegend="on",
    markerSize=10,
    markerobs="o",
    styleCOR="--",
    styleOBS="--",
    styleRMS=":",
    styleSTD="-.",
    tickRMS=np.linspace(0, 4, 5),
    tickSTD=np.linspace(0, 4, 5),
    titleCOR="on",
    titleOBS="ERA5",
    titleRMS="on",
    titleRMSDangle=40.0,
    titleSTD="on",
    widthCOR=0.5,
    widthOBS=2,
    widthRMS=2,
    widthSTD=1.0,
)