## Import libraries

In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot

warnings.filterwarnings("ignore")

plt.style.use("seaborn-v0_8-notebook")

## Define time period and models

In [None]:
year_start = 1985
year_stop = 1987

models = [
    "access_cm2",
    "awi_esm_1_1_lr",
    "bcc_esm1",
    "cesm2_fv2",
    "cnrm_cm6_1",
    "fgoals_g3",
]

## Define requests

In [None]:
common_request = {
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": [f"{month:02d}" for month in range(1, 12 + 1)],
}

request_era = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "mean_total_precipitation_rate",
        **common_request,
    },
)

request_sim = (
    "projections-cmip6",
    {
        "format": "zip",
        "temporal_resolution": "monthly",
        "experiment": "historical",
        "variable": "precipitation",
        **common_request,
    },
)

In [None]:
def resample_and_regrid_and_rescale(ds, model, grid_out=None, **kwargs):
    ds = diagnostics.annual_weighted_mean(ds)
    if grid_out is not None:
        ds = diagnostics.regrid(ds, grid_out, **kwargs)

    # Change unit
    varname = "mtpr" if model == "ERA5" else "pr"
    with xr.set_options(keep_attrs=True):
        ds[varname] = ds[varname] * 3600 * 24
    ds[varname].attrs["units"] = "mm/day"

    return ds.rename({varname: "precipitation"}).expand_dims(model=[model])

## Download data

In [None]:
chunks = {"year": 1}

print("Downloading and processing ERA5")
ds_era = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_func=resample_and_regrid_and_rescale,
    transform_func_kwargs={"model": "ERA5"},
    transform_chunks=False,
)

datasets = []
for model in models:
    print(f"Downloading and processing {model}")
    request_model = request_sim
    request_model[1]["model"] = model
    ds = download.download_and_transform(
        *request_model,
        chunks=chunks,
        transform_func=resample_and_regrid_and_rescale,
        transform_func_kwargs={
            "grid_out": ds_era[["longitude", "latitude"]],
            "model": model,
            "method": "bilinear",
            "periodic": True,
        },
        transform_chunks=False,
    )
    datasets.append(ds)
ds_sim = xr.merge(datasets)

## Create a single dataset and compute bias

In [None]:
ds = xr.merge([ds_sim, ds_sim.mean("model").expand_dims(model=["ensemble"]), ds_era])

# TODO: speed up development
# Substitute with regionalization option
ds = ds.coarsen(longitude=10, latitude=10, boundary="trim").mean()

with xr.set_options(keep_attrs=True):
    ds_bias = ds.drop_sel(model="ERA5") - ds.sel(model="ERA5")
for da in ds_bias.data_vars.values():
    da.attrs["long_name"] += " Bias"

## Plots

In [None]:
plot_kwargs = {"levels": range(0, 10, 1), "cmap": "Blues"}

for model in ["ensemble", "ERA5"]:
    plot.global_map(ds["precipitation"].sel(model=[model]), **plot_kwargs)
    plt.show()

In [None]:
plot_kwargs = {"levels": np.linspace(-2, 2, 11), "cmap": "RdBu_r"}
plot.global_map(ds_bias["precipitation"].sel(model=["ensemble"]), **plot_kwargs)
plt.show()
plot.global_map(
    ds_bias["precipitation"].drop_sel(model="ensemble"),
    col="model",
    col_wrap=3,
    **plot_kwargs,
)

## Plot KDE and statistics

TODO: The plot is unweighted, the statistics are weighted.

In [None]:
# Create dataframes
da = ds_bias["precipitation"]
da_dict = {k: v.values.flatten() for k, v in da.groupby("model")}
df_ensemble = pd.DataFrame({"ensemble": da_dict.pop("ensemble")})
df_models = pd.DataFrame(da_dict)
df_stats = diagnostics.spatial_weighted_statistics(da).to_pandas()

# Plot
plot_kwargs = {
    "xlim": (
        df_stats["ensemble"]["mean"] - 3 * df_stats["ensemble"]["std"],
        df_stats["ensemble"]["mean"] + 3 * df_stats["ensemble"]["std"],
    ),
    "grid": True,
}
ax = df_models.plot.kde(**plot_kwargs)
ax = df_ensemble.plot.kde(color="k", ls="--", **plot_kwargs, ax=ax)
ax.set_xlabel(f"{da.attrs['long_name']} [{da.attrs['units']}]")

# Add stats
table = plt.table(
    cellText=df_stats.round(5).T.values.tolist(),
    colLabels=df_stats.T.columns.values.tolist(),
    rowLabels=df_stats.T.index.values.tolist(),
    loc="top",
)