## Import libraries

In [None]:
import warnings

import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot

warnings.filterwarnings("ignore")

## Define time period and models

In [None]:
year_start = 1985
year_stop = 1987

models = [
    "access_cm2",
    "awi_esm_1_1_lr",
    "bcc_esm1",
    "cesm2_fv2",
    "cnrm_cm6_1",
    "fgoals_g3",
]

## Define requests

In [None]:
common_request = {
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": [f"{month:02d}" for month in range(1, 12 + 1)],
}

request_era = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "mean_total_precipitation_rate",
        **common_request,
    },
)

request_sim = (
    "projections-cmip6",
    {
        "format": "zip",
        "temporal_resolution": "monthly",
        "experiment": "historical",
        "variable": "precipitation",
        **common_request,
    },
)

In [None]:
def resample_and_regrid_and_rescale(ds, model, grid_out=None, **kwargs):
    ds = diagnostics.annual_weighted_mean(ds)
    if grid_out:
        ds = diagnostics.regrid(ds, grid_out, **kwargs)

    # Change unit
    varname = "mtpr" if model == "ERA5" else "pr"
    with xr.set_options(keep_attrs=True):
        ds[varname] *= 3600 * 24
    ds[varname].attrs["units"] = "mm/day"

    return ds.rename({"varname": "precipitation"}).expand_dims(model=[model])

## Download data

In [None]:
chunks = {"year": 1}


ds_era = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_func=resample_and_regrid_and_rescale,
    transform_chunks=False,
)

In [None]:
def resample_and_regrid(ds, grid_out, model, **kwargs):
    ds = diagnostics.annual_weighted_mean(ds)
    ds = diagnostics.regrid(ds, grid_out, **kwargs)
    return ds.expand_dims(model=[model])


datasets = []
for model in models:
    request_model = request_sim
    request_model[1]["model"] = model
    ds = download.download_and_transform(
        *request_model,
        chunks=chunks,
        transform_func=resample_and_regrid,
        transform_func_kwargs={
            "grid_out": ds_era[["longitude", "latitude"]],
            "model": model,
            "method": "bilinear",
            "periodic": True,
        },
        transform_chunks=False,
    )
    datasets.append(ds)
ds_sim = xr.merge(datasets)

# Add ensamble
ds_sim = ds_sim.merge(ds_sim.mean("model").expand_dims(model=["ensamble"]))

In [None]:
plot_kwargs = {"levels": range(0, 10, 1), "cmap": "Blues"}

for da in (ds_sim["pr"].sel(model=["ensamble"]), ds_era["mtpr"]):
    plot.global_map(da * 3600 * 24, **plot_kwargs)
    plt.show()