## Import packages

In [None]:
import functools

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import skill_metrics
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils
from xarrayMannKendall import Mann_Kendall_test

plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["hatch.linewidth"] = 0.5

## Define Parameters

In [None]:
# Time period
year_start = 1985
year_stop = 1987

# Climatology period
clima_year_start = 1985
clima_year_stop = 1986

# Choose annual or seasonal timeseries
timeseries = "MAM"
assert timeseries in ("annual", "DJF", "MAM", "JJA", "SON")

# Variable
variable = "temperature"
assert variable in ("precipitation", "temperature")

# Choose CORDEX or CMIP6
collection_id = "cordex"
assert collection_id in ("cordex", "cmip6")

# Define region for analysis
lon_slice = slice(-4, 20)
lat_slice = slice(35, 50)

# Define region for request
cordex_domain = "europe"

# Chunks for download
chunks = {"year": 1}
assert "month" not in chunks, "Do not use chunks smaller than 1y"

## Define models

In [None]:
models_cordex = [
    "clmcom_clm_cclm4_8_17",
    "clmcom_eth_cosmo_crclim",
    "cnrm_aladin63",
    "dmi_hirham5",
    "knmi_racmo22e",
    "mohc_hadrem3_ga7_05",
    "mpi_csc_remo2009",
    "smhi_rca4",
    "uhoh_wrf361h",
]

models_cmip6 = [
    "access_cm2",
    "awi_esm_1_1_lr",
    "bcc_esm1",
    "cesm2_fv2",
    "cnrm_cm6_1",
    "fgoals_g3",
]

## Define ERA5 request

In [None]:
era5_variables = {
    "precipitation": "mean_total_precipitation_rate",
    "temperature": "2m_temperature",
}

request_era = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": era5_variables[variable],
        "year": [
            str(year)
            for year in range(
                min(year_start, clima_year_start) - 1,  # Include D(year-1)
                max(year_stop, clima_year_stop) + 1,
            )
        ],
        "month": [f"{month:02d}" for month in range(1, 12 + 1)],
    },
)

## Define model requests

In [None]:
cordex_variables = {
    "precipitation": "mean_precipitation_flux",
    "temperature": "2m_air_temperature",
}

request_cordex = {
    "format": "zip",
    "domain": cordex_domain,
    "experiment": "historical",
    "horizontal_resolution": "0_11_degree_x_0_11_degree",
    "temporal_resolution": "monthly_mean",
    "variable": cordex_variables[variable],
    "gcm_model": "mpi_m_mpi_esm_lr",
    "ensemble_member": "r1i1p1",
}

cmip6_variables = {
    "precipitation": "precipitation",
    "temperature": "near_surface_air_temperature",
}
request_cmip6 = {
    "format": "zip",
    "temporal_resolution": "monthly",
    "experiment": "historical",
    "variable": cmip6_variables[variable],
    "year": request_era[1]["year"],
    "month": request_era[1]["month"],
}


def get_cordex_years(
    year_start,
    year_stop,
    start_years=[1971, 1981, 1991, 2001],
    end_years=[1980, 1990, 2000, 2005],
):
    start_year = []
    end_year = []
    years = set(range(year_start - 1, year_stop + 1))  # Include D(year-1)
    for start, end in zip(start_years, end_years):
        if years & set(range(start, end + 1)):
            start_year.append(start)
            end_year.append(end)
    return start_year, end_year


if collection_id == "cordex":
    weights = False  # Do not weight spatial statistics/errors
    periodic = False
    models = models_cordex
    model_key = "rcm_model"
    request_sim = (
        "projections-cordex-domains-single-levels",
        [
            {
                **request_cordex,
                "start_year": start_year,
                "end_year": end_year,
            }
            for start_year, end_year in zip(
                *get_cordex_years(
                    min(year_start, clima_year_start), max(year_stop, clima_year_stop)
                )
            )
        ],
    )
elif collection_id == "cmip6":
    weights = True  # Weight spatial statistics/errors
    periodic = True
    models = models_cmip6
    model_key = "model"
    request_sim = (
        "projections-cmip6",
        download.split_request(request_cmip6, chunks=chunks),
    )
else:
    raise ValueError

## Functions to cache

In [None]:
original_test = functools.partial(Mann_Kendall_test, alpha=0.05, method="theilslopes")


def get_timeseries(ds, year_start, year_stop, timeseries):
    # Drop useless data
    if timeseries == "annual":
        mask = (ds["time"].dt.year >= year_start) & (ds["time"].dt.year <= year_stop)
    else:
        # Select years (shift -1 to get D(year-1)J(year)F(year))
        ds = ds.assign_coords(year=ds["time"].dt.year.shift(time=-1).astype(int))
        mask = (
            (ds["year"] >= year_start)
            & (ds["year"] <= year_stop)
            & (ds["time"].dt.season == timeseries)
        )
    ds = ds.where(mask.compute(), drop=True)

    if var_names := set(ds.data_vars) & {"mtpr", "pr"}:
        field = "precipitation"
    elif var_names := set(ds.data_vars) & {"tas", "t2m"}:
        field = "temperature"
    else:
        raise ValueError("Unable to find precipitation or temperature variable.")
    (var_name,) = var_names
    da = ds[var_name].rename(field)

    # Create timeseries
    if timeseries == "annual":
        da = diagnostics.annual_weighted_mean(da)
    else:
        da = (
            da.groupby("year")
            .map(diagnostics.seasonal_weighted_mean)
            .sel(season=timeseries)
        )

    # Convert units
    with xr.set_options(keep_attrs=True):
        if da.name == "precipitation":
            da *= 3600 * 24
            da.attrs["units"] = "mm/day"
        elif da.name == "temperature":
            da -= 273.15
            da.attrs["units"] = "°C"
            da = da.assign_coords(height=ds["height"] if "height" in ds else None)
        else:
            raise ValueError
    return da


def compute_climatology(da, clima_year_start, clima_year_stop):
    return da.sel(year=slice(clima_year_start, clima_year_stop)).mean(
        "year", keep_attrs=True
    )


def compute_anomaly(da, clima_year_start, clima_year_stop):
    clima = compute_climatology(da, clima_year_start, clima_year_stop)
    anoma = da - clima
    units = da.attrs["units"]
    if da.name == "precipitation":
        anoma *= 100 / da
        units = "%"
    anoma.attrs.update({"long_name": f"{da.name} anomaly", "units": units})
    return anoma


def spatial_weighted_trends(
    obj,
    year_start,
    year_stop,
    clima_year_start,
    clima_year_stop,
    timeseries,
    weights,
    lon_slice,
    lat_slice,
):
    if isinstance(obj, xr.DataArray):
        da = obj
    else:
        ds = utils.regionalise(obj, lon_slice=lon_slice, lat_slice=lat_slice)
        da = diagnostics.spatial_weighted_mean(
            get_timeseries(ds, year_start, year_stop, timeseries), weights=weights
        )
    anoma = compute_anomaly(da, clima_year_start, clima_year_stop)

    # Compute anomaly trends
    ds = (
        original_test(
            anoma.expand_dims("x"),
            coords_name={"time": "year", "x": "x"},
        )
        .compute()
        .squeeze("x", drop=True)
    )

    # Add variable and anomaly
    return ds.merge({da.name: da, f"{da.name}_anomaly": anoma})


def regridded_trends(
    ds,
    year_start,
    year_stop,
    clima_year_start,
    clima_year_stop,
    timeseries,
    grid_out=None,
    **kwargs,
):
    da = get_timeseries(ds, year_start, year_stop, timeseries)

    # Compute anomaly trends
    coords_name = {"time": "year"} | {
        k: v for k, v in zip(("x", "y"), ds[["longitude", "latitude"]].dims)
    }
    ds_trend = original_test(da, coords_name=coords_name).compute()
    ds_trend = ds_trend.rename({k: v for k, v in coords_name.items() if k != "time"})
    ds_trend = ds_trend.assign_coords(ds.drop_dims("time").coords)

    # Add variable and climatology
    clima = compute_climatology(da, clima_year_start, clima_year_stop)
    ds_trend[f"{da.name}_climatology"] = clima
    if grid_out is not None:
        ds_trend = ds_trend.cf.add_bounds(
            (
                coord
                for coord in ("longitude", "latitude")
                if coord not in ds_trend.cf.bounds
            )
        )
        ds_trend = diagnostics.regrid(ds_trend, grid_out, **kwargs)
    return ds_trend

## Compute spatial weighted trends

In [None]:
datasets = []
transform_func_kwargs = {
    "year_start": year_start,
    "year_stop": year_stop,
    "clima_year_start": clima_year_start,
    "clima_year_stop": clima_year_stop,
    "lon_slice": lon_slice,
    "lat_slice": lat_slice,
    "timeseries": timeseries,
}
for model in models + ["ERA5"]:
    print(f"Downloading and processing {model}")
    if model == "ERA5":
        request_model = request_era
    else:
        request_model = request_sim
        for request in request_model[1]:
            request[model_key] = model
    ds = download.download_and_transform(
        *request_model,
        chunks=chunks if model == "ERA5" else {},
        transform_func=spatial_weighted_trends,
        transform_func_kwargs={
            "weights": True if model == "ERA5" else weights,
            **transform_func_kwargs,
        },
        transform_chunks=False,
    )
    datasets.append(ds.expand_dims(model=[model]))

# Combine and add ensemble
ds_mean_trend = xr.concat(datasets, "model")
da_ensemble = (
    ds_mean_trend[variable].drop_sel(model="ERA5").mean("model", keep_attrs=True)
)
ds_ensemble_trend = spatial_weighted_trends(
    da_ensemble, weights=weights, **transform_func_kwargs
)
ds_mean_trend = ds_mean_trend.merge(ds_ensemble_trend.expand_dims(model=["ensemble"]))

## Plot field and anomaly

In [None]:
for var in [variable, f"{variable}_anomaly"]:
    fig, ax = plt.subplots(1, 1)
    da = ds_mean_trend[var]
    da.drop_sel(model=["ensemble", "ERA5"]).plot(
        hue="model", linewidth=0.5, color="grey", ax=ax
    )
    da.sel(model=["ensemble", "ERA5"]).plot(hue="model", ax=ax)
    ax.set_title(timeseries.upper())
    ax.grid()

# Trends boxplot

In [None]:
df_slope = (ds_mean_trend["trend"] * 10).to_dataframe()[["trend"]]
ax = df_slope.boxplot()
ax.scatter(x=[1] * len(df_slope), y=df_slope, color="grey", marker=".")
for model in ["ensemble", "ERA5"]:
    ax.scatter(x=1, y=df_slope[df_slope.index == model], label=model, marker="o")
units = ds_mean_trend[f"{variable}_anomaly"].attrs["units"] + " / decade"
ax.set_ylabel(units)
ax.set_title(timeseries.upper())
_ = plt.legend(title="model")

## Compute trend maps

In [None]:
print("Downloading and processing ERA5")
transform_func_kwargs = {
    "year_start": year_start,
    "year_stop": year_stop,
    "clima_year_start": clima_year_start,
    "clima_year_stop": clima_year_stop,
    "timeseries": timeseries,
    "method": "conservative",
}

ds_era = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_func=regridded_trends,
    transform_func_kwargs={"model": "ERA5", **transform_func_kwargs},
    transform_chunks=False,
)

datasets = []
for model in models:
    print(f"Downloading and processing {model}")
    request_model = request_sim
    for request in request_model[1]:
        request[model_key] = model
    ds = download.download_and_transform(
        *request_model,
        transform_func=regridded_trends,
        transform_func_kwargs={
            "grid_out": ds_era[["longitude", "latitude"]],
            "periodic": periodic,
            **transform_func_kwargs,
        },
        transform_chunks=False,
    )
    # Cache global trends, then regionalise
    ds = utils.regionalise(ds, lon_slice=lon_slice, lat_slice=lat_slice)
    datasets.append(ds.expand_dims(model=[model]))
# Regionalise ERA5
ds_era = utils.regionalise(
    ds_era, lon_slice=lon_slice, lat_slice=lat_slice
).expand_dims(model=["ERA5"])

# Concatenate
ds_sim_regr = xr.concat(datasets, "model")
ds_ens = ds_sim_regr.mean("model").expand_dims(model=["ensemble"])
ds_all_regr = xr.concat([ds_era, ds_ens, ds_sim_regr], "model")

# Compute anomaly trends
da_trend = ds_all_regr["trend"] * 10
units = ds_all_regr[f"{variable}_climatology"].attrs["units"] + "/decade"
if variable == "precipitation":
    da_trend *= 100 / ds_all_regr[f"{variable}_climatology"]
    units = "%/decade"
da_trend.name = ""
da_trend.attrs.update({"units": units})

## Define plotting kwargs

In [None]:
Projection = (
    ccrs.Robinson
    if abs(lon_slice.stop - lon_slice.start) >= 360
    and abs(lat_slice.stop - lat_slice.start) >= 180
    else ccrs.PlateCarree
)
shading_kwargs = xr.plot.utils._determine_cmap_params(
    da_trend.values,
    levels=11,
    robust=True,
    cmap="bwr_r" if variable == "precipitation" else "bwr",
    extend="both",
)
shading_kwargs["projection"] = Projection(
    central_longitude=(lon_slice.stop + lon_slice.start) / 2
)
cbar_ax = [0.05, -0.04, 0.95, 0.04]

hatches = ["", "/" * 5]
hatches_kwargs = {
    "plot_func": "contourf",
    "show_stats": False,
    "cmap": "none",
    "add_colorbar": False,
}
p_hatches_kwargs = hatches_kwargs | {
    "levels": [0, 0.05, 1],
    "hatches": ["/", "" * 5][:: 1 if variable == "precipitation" else -1],
}
is_signif_ratio_hatches_kwargs = hatches_kwargs | {
    "levels": [1, 0.8, 0],
    "hatches": ["", "/" * 5][:: 1 if variable == "precipitation" else -1],
}

## Plot ERA5 trends

In [None]:
model = "ERA5"
plot.projected_map(da_trend.sel(model=model), **shading_kwargs)
plot.projected_map(
    ds_all_regr["p"].sel(model=model).drop_vars(["height", "season"], errors="ignore"),
    **p_hatches_kwargs,
)
_ = plt.suptitle(
    f"{variable.title()} trend ({year_start}-{year_stop}) - {timeseries.upper()}"
)

## Plote esemble trends

In [None]:
model = "ensemble"
is_signif = xr.where(
    ds_sim_regr["p"] <= 0.05,
    ds_sim_regr["trend"] / ds_ens["trend"].squeeze() > 0,
    False,
)
is_signif_ratio = is_signif.sum("model") / is_signif.sizes["model"]
is_signif_ratio = is_signif_ratio.expand_dims(model=[model])
plot.projected_map(
    da_trend.sel(model=model).drop_vars(["height", "season"], errors="ignore"),
    stats_weights=weights,
    **shading_kwargs,
)
plot.projected_map(
    is_signif_ratio.drop_vars(["height", "season"], errors="ignore").squeeze(),
    **is_signif_ratio_hatches_kwargs,
)
_ = plt.suptitle(
    f"{variable.title()} trend ({year_start}-{year_stop}) - {timeseries.upper()}"
)

## Plot trends for all models

In [None]:
facet = plot.projected_map(
    da_trend.drop_sel(model=["ERA5", "ensemble"]),
    col="model",
    col_wrap=3,
    add_colorbar=False,
    **shading_kwargs,
)
for ax, sel in zip(facet.axs.flatten(), facet.name_dicts.flatten()):
    if not sel:
        continue
    plot.projected_map(
        ds_all_regr["p"].sel(**sel).drop_vars(["height", "season"], errors="ignore"),
        ax=ax,
        **p_hatches_kwargs,
    )
plt.suptitle(
    f"{variable.title()} trend ({year_start}-{year_stop}) - {timeseries.upper()}"
)
cax = plt.axes(cbar_ax)
_ = plt.colorbar(
    facet.axs[0][0].collections[0],
    cax=cax,
    orientation="horizontal",
    label=xr.plot.utils.label_from_attrs(da_trend),
)

## Plot trend bias for all models

In [None]:
da_bias = (ds_all_regr["trend"].drop_sel(model="ERA5") - ds_era["trend"].squeeze()) * 10
units = ds_all_regr[f"{variable}_climatology"].attrs["units"] + "/decade"
if variable == "precipitation":
    da_bias *= 100 / ds_era[f"{variable}_climatology"].squeeze()
    units = "%/decade"
da_bias.name = ""
da_bias.attrs.update({"units": units})
facet = plot.projected_map(
    da_bias.drop_sel(model="ensemble").drop_vars(["height", "season"], errors="ignore"),
    col="model",
    col_wrap=3,
    add_colorbar=False,
    **shading_kwargs,
)
plt.suptitle(
    f"{variable.title()} trend bias ({year_start}-{year_stop}) - {timeseries.upper()}"
)
cax = plt.axes(cbar_ax)
_ = plt.colorbar(
    facet.axs[0][0].collections[0],
    cax=cax,
    orientation="horizontal",
    label=xr.plot.utils.label_from_attrs(da_bias),
)

## Plot trend bias for ensamble

In [None]:
model = "ensemble"
plot.projected_map(
    da_bias.sel(model=model).drop_vars(["height", "season"], errors="ignore"),
    stats_weights=weights,
    **shading_kwargs,
)
_ = plt.suptitle(
    f"{variable.title()} trend bias ({year_start}-{year_stop}) - {timeseries.upper()}"
)

## Compute statistics

In [None]:
ds_trend = da_trend.to_dataset(name="trend")
ds_stats = xr.concat(
    [
        diagnostics.spatial_weighted_statistics(
            ds_trend.drop_sel(model="ERA5"), weights=weights
        ),
        diagnostics.spatial_weighted_statistics(
            ds_trend.sel(model="ERA5"), weights=True
        ),
    ],
    "model",
)

ds_error = diagnostics.spatial_weighted_errors(
    ds_trend.drop_sel(model="ERA5"), ds_trend.sel(model="ERA5"), weights=weights
)
df_stats_and_error = xr.merge([ds_stats, ds_error])["trend"].to_pandas()
df_stats_and_error

## Taylor Diagram

In [None]:
tickRMS = np.linspace(0, df_stats_and_error.loc["crmse"].max(), 5, dtype=int)
tickSTD = np.linspace(0, df_stats_and_error.loc["std"].max(), 5, dtype=int)
skill_metrics.taylor_diagram(
    df_stats_and_error.loc["std"].values,
    df_stats_and_error.loc["crmse"].values,
    df_stats_and_error.loc["corr"].values,
    alpha=0.0,
    colCOR="k",
    colOBS="k",
    colRMS="m",
    colSTD="b",
    markerColor="r" if len(df_stats_and_error.columns) >= 9 else None,  # TODO
    markerLabel=list(df_stats_and_error.columns),
    markerLegend="on",
    markerSize=10,
    markerobs="o",
    styleCOR="--",
    styleOBS="--",
    styleRMS=":",
    styleSTD="-.",
    tickRMS=tickRMS,
    tickSTD=tickSTD,
    titleCOR="on",
    titleOBS="ERA5",
    titleRMS="on",
    titleRMSDangle=40.0,
    titleSTD="on",
    widthCOR=0.5,
    widthOBS=2,
    widthRMS=2,
    widthSTD=1.0,
)