## Import packages

In [None]:
import functools
import warnings

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import skill_metrics
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot
from xarrayMannKendall import Mann_Kendall_test

warnings.filterwarnings("ignore")
plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["hatch.linewidth"] = 0.5

## Settings

In [None]:
# Time
year_start = 1959
year_stop = 2014

# Models
models = [
    "access_cm2",
    "awi_esm_1_1_lr",
    "bcc_esm1",
    "cnrm_cm6_1",
    "fgoals_g3",
    "hadgem3_gc31_ll",
]

# Chunks for download
chunks = {"year": 10}

# Define climatology years
clima_year_start = 1959
clima_year_stop = 1979

## Define request

In [None]:
common_request = {
    "year": [str(year) for year in range(year_start, year_stop + 1)],
    "month": [f"{month:02d}" for month in range(1, 12 + 1)],
}

request_era = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "mean_total_precipitation_rate",
        **common_request,
    },
)

request_sim = (
    "projections-cmip6",
    {
        "format": "zip",
        "temporal_resolution": "monthly",
        "experiment": "historical",
        "variable": "precipitation",
        **common_request,
    },
)

## Functions to cache

In [None]:
original_test = functools.partial(Mann_Kendall_test, alpha=0.05, method="theilslopes")


def get_annual_precipitation(ds, model):
    # Varname
    varname = "mtpr" if model == "ERA5" else "pr"

    # Annual weighted mean
    da = diagnostics.annual_weighted_mean(ds[varname])

    # Convert units
    with xr.set_options(keep_attrs=True):
        da = da * 3600 * 24
    da.attrs["units"] = "mm/day"

    return da


def compute_climatology(da, clima_year_start, clima_year_stop):
    with xr.set_options(keep_attrs=True):
        mask = (da["year"] >= clima_year_start) & (da["year"] <= clima_year_stop)
        clima = da.where(mask).mean("year")
    return clima


def compute_anomaly(da, clima_year_start, clima_year_stop):
    clima = compute_climatology(da, clima_year_start, clima_year_stop)
    anoma = (da - clima) * 100 / da
    anoma.attrs.update({"long_name": "precipitation anomaly", "units": "%"})
    return anoma


def spatial_weighted_trends(ds, model, clima_year_start, clima_year_stop):
    if "precipitation" in ds:
        da = ds["precipitation"]
    else:
        da = diagnostics.spatial_weighted_mean(get_annual_precipitation(ds, model))
    anoma = compute_anomaly(da, clima_year_start, clima_year_stop)

    # Compute anomaly trends
    ds = (
        original_test(
            anoma.expand_dims("x"),
            coords_name={"time": "year", "x": "x"},
        )
        .compute()
        .squeeze("x", drop=True)
    )

    # Add precipitation and anomaly
    ds["precipitation"] = da
    ds["precipitation_anomaly"] = anoma

    return ds.expand_dims(model=[model])


def regridded_trends(
    ds, model, clima_year_start, clima_year_stop, grid_out=None, **kwargs
):
    da = get_annual_precipitation(ds, model)

    # Compute anomaly trends
    coords_name = {"time": "year", "x": "longitude", "y": "latitude"}
    ds = original_test(
        da,
        coords_name=coords_name,
    ).compute()
    ds = ds.rename({k: v for k, v in coords_name.items() if k != "time"})

    # Add precipitation and climatology
    clima = compute_climatology(da, clima_year_start, clima_year_stop)
    ds["precipitation_climatology"] = clima

    if grid_out is not None:
        ds = diagnostics.regrid(ds, grid_out, **kwargs)

    return ds.expand_dims(model=[model])

## Compute spatial weighted trends

In [None]:
datasets = []
transform_func_kwargs = {
    "clima_year_start": clima_year_start,
    "clima_year_stop": clima_year_stop,
}
for model in models + ["ERA5"]:
    print(f"Downloading and processing {model}")
    if model == "ERA5":
        request_model = request_era
    else:
        request_model = request_sim
        request_model[1]["model"] = model
    datasets.append(
        download.download_and_transform(
            *request_model,
            chunks=chunks,
            transform_func=spatial_weighted_trends,
            transform_func_kwargs={"model": model, **transform_func_kwargs},
            transform_chunks=False,
        )
    )

# Combine and add ensemble
ds_spatial_weighted = xr.concat(datasets, "model")
with xr.set_options(keep_attrs=True):
    ds_ensemble = ds_spatial_weighted.drop_sel(model="ERA5").mean("model")
ds_spatial_weighted = ds_spatial_weighted.merge(
    spatial_weighted_trends(ds_ensemble, model="ensemble", **transform_func_kwargs)
)

## Plot spatial weighted precipitation and anomaly

In [None]:
for var in ["precipitation", "precipitation_anomaly"]:
    da = ds_spatial_weighted[var]
    da.drop_sel(model=["ensemble", "ERA5"]).plot(
        hue="model", linewidth=0.5, color="grey"
    )
    da.sel(model=["ensemble", "ERA5"]).plot(hue="model")
    plt.show()

## Spatial weighted trends boxplot

In [None]:
df_slope = (ds_spatial_weighted["trend"] * 10).to_dataframe()

ax = df_slope.boxplot()
ax.scatter(x=[1] * len(df_slope), y=df_slope, color="grey", marker=".")
for model in ["ensemble", "ERA5"]:
    ax.scatter(x=1, y=df_slope[df_slope.index == model], label=model, marker="o")
ax.set_ylabel("% / decade")
plt.legend()

## Compute trend maps

In [None]:
print("Downloading and processing ERA5")
transform_func_kwargs = {
    "clima_year_start": clima_year_start,
    "clima_year_stop": clima_year_stop,
    "method": "conservative",
}

ds_era = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_func=regridded_trends,
    transform_func_kwargs={"model": "ERA5", **transform_func_kwargs},
    transform_chunks=False,
)

datasets = []
for model in models:
    print(f"Downloading and processing {model}")
    request_model = request_sim
    request_model[1]["model"] = model
    ds = download.download_and_transform(
        *request_model,
        chunks=chunks,
        transform_func=regridded_trends,
        transform_func_kwargs={
            "model": model,
            "grid_out": ds_era[["longitude", "latitude"]],
            **transform_func_kwargs,
        },
        transform_chunks=False,
    )
    datasets.append(ds)

ds_sim_regr = xr.concat(datasets, "model")
ds_ens = ds_sim_regr.mean("model").expand_dims(model=["ensemble"])
ds_all_regr = xr.concat([ds_era, ds_ens, ds_sim_regr], "model")

da_trend = ds_all_regr["trend"] * 1.0e3 / ds_all_regr["precipitation_climatology"]
da_trend.name = ""
da_trend.attrs.update({"units": "%/decade"})

## Define plotting kwargs

In [None]:
shading_kwargs = {
    "projection": ccrs.Robinson(),
    "levels": np.linspace(-10, 10, 9),
    "cmap": "bwr_r",
}
hatches_kwargs = {
    "plot_func": "contourf",
    "show_stats": False,
    "cmap": "none",
    "add_colorbar": False,
}

## Plot ERA5 trends

In [None]:
model = "ERA5"
plot.projected_map(da_trend.sel(model=model), **shading_kwargs)
plot.projected_map(
    ds_all_regr["p"].sel(model=model),
    levels=[0, 0.05, 1],
    hatches=["", "/" * 4],
    **hatches_kwargs,
)
plt.suptitle(f"Precipitation trend ({year_start}-{year_stop})")

## Plote esemble trends

In [None]:
model = "ensemble"
is_signif = xr.where(
    ds_sim_regr["p"] <= 0.05,
    ds_sim_regr["trend"] / ds_ens["trend"].squeeze() > 0,
    False,
)
is_signif_ratio = is_signif.sum("model") / is_signif.sizes["model"]
is_signif_ratio = is_signif_ratio.expand_dims(model=[model])
plot.projected_map(da_trend.sel(model=model), **shading_kwargs)
plot.projected_map(
    is_signif_ratio.squeeze(),
    levels=[0, 0.8, 1],
    hatches=["/" * 5, ""],
    **hatches_kwargs,
)
plt.suptitle(f"Precipitation trend ({year_start}-{year_stop})")

## Plot precipitation trends for all models

In [None]:
facet = plot.projected_map(
    da_trend.drop_sel(model=["ERA5", "ensemble"]),
    col="model",
    col_wrap=3,
    cbar_kwargs={"orientation": "horizontal"},
    **shading_kwargs,
)
for ax, sel in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
    if not sel:
        continue
    da = ds_all_regr["p"].sel(**sel)
    plot.projected_map(
        da, ax=ax, levels=[0, 0.05, 1], hatches=["", "/" * 5], **hatches_kwargs
    )
plt.suptitle(f"Precipitation trend ({year_start}-{year_stop})")

## Plot precipitation trend bias for all models

In [None]:
da_bias = (
    (ds_all_regr["trend"].drop_sel(model="ERA5") - ds_era["trend"].squeeze())
    * 1_000
    / ds_era["precipitation_climatology"].squeeze()
)
da_bias.name = ""
da_bias.attrs.update({"units": "%/decade"})
plot.projected_map(
    da_bias.drop_sel(model="ensemble"),
    col="model",
    col_wrap=3,
    cbar_kwargs={"orientation": "horizontal"},
    **shading_kwargs,
)
plt.suptitle(f"Precipitation trend bias ({year_start}-{year_stop})")

## Plot precipitation trend bias for ensamble

In [None]:
model = "ensemble"
plot.projected_map(da_bias.sel(model=model), **shading_kwargs)
plt.suptitle(f"Precipitation trend bias ({year_start}-{year_stop})")

## Compute statistics

In [None]:
ds_trend = da_trend.to_dataset(name="trend")
ds_stats = diagnostics.spatial_weighted_statistics(ds_trend)
ds_error = diagnostics.spatial_weighted_errors(
    ds_trend.drop_sel(model="ERA5"), ds_trend.sel(model="ERA5")
)
df_stats_and_error = xr.merge([ds_stats, ds_error])["trend"].to_pandas()
df_stats_and_error

## Taylor Diagram

In [None]:
# Bug in
skill_metrics.taylor_diagram(
    df_stats_and_error.loc["std"].values,
    df_stats_and_error.loc["crmse"].values,
    df_stats_and_error.loc["corr"].values,
    alpha=0.0,
    axismax=4,
    colCOR="k",
    colOBS="k",
    colRMS="m",
    colSTD="b",
    markerLabel=list(df_stats_and_error.columns),
    markerLegend="on",
    markerSize=10,
    markerobs="o",
    styleCOR="--",
    styleOBS="--",
    styleRMS=":",
    styleSTD="-.",
    tickRMS=np.linspace(0, 4, 5),
    tickSTD=np.linspace(0, 4, 5),
    titleCOR="on",
    titleOBS="ERA5",
    titleRMS="on",
    titleRMSDangle=40.0,
    titleSTD="on",
    widthCOR=0.5,
    widthOBS=2,
    widthRMS=2,
    widthSTD=1.0,
)