## Extreme temperature indices

## Import packages

In [None]:
import math
import tempfile

import cartopy.crs as ccrs
import icclim
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot
from xarrayMannKendall import Mann_Kendall_test

plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["hatch.linewidth"] = 0.5

## Define Parameters

In [None]:
# Time period
historical_slice = slice(1971, 2000)
future_slice = slice(2015, 2099)
assert future_slice.start > historical_slice.stop

# Choose annual or seasonal timeseries
timeseries = "JJA"
assert timeseries in ("annual", "DJF", "MAM", "JJA", "SON")

# Choose CORDEX or CMIP6
collection_id = "CMIP6"
assert collection_id in ("CORDEX", "CMIP6")

# Define region for analysis
area = [72, -22, 27, 45]

# Define region for request
cordex_domain = "europe"

# Define index names
index_names = ("SU", "TX90p")

# Interpolation method
interpolation_method = "bilinear"

# Chunks for download
chunks = {"year": 1}

## Define models

In [None]:
models_cordex = [
    "clmcom_clm_cclm4_8_17",
    "clmcom_eth_cosmo_crclim",
    "cnrm_aladin63",
    "dmi_hirham5",
    "knmi_racmo22e",
    "mohc_hadrem3_ga7_05",
    "mpi_csc_remo2009",
    "smhi_rca4",
    "uhoh_wrf361h",
]

models_cmip6 = [
    "access_cm2",
    "awi_cm_1_1_mr",
    "cmcc_esm2",
    "cnrm_cm6_1_hr",
    "ec_earth3_cc",
    "gfdl_esm4",
    "inm_cm5_0",
    "miroc6",
    "mpi_esm1_2_lr",
]

model_regrid = "gfdl_esm4" if collection_id == "CMIP6" else "clmcom_eth_cosmo_crclim"

## Define lsm request

In [None]:
request_lsm = (
    "reanalysis-era5-single-levels",
    {
        "product_type": "reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "land_sea_mask",
        "year": "1940",
        "month": "01",
        "day": "01",
        "area": area,
    },
)

## Define model requests

In [None]:
request_cordex = {
    "format": "zip",
    "domain": cordex_domain,
    "horizontal_resolution": "0_11_degree_x_0_11_degree",
    "temporal_resolution": "daily_mean",
    "variable": "maximum_2m_temperature_in_the_last_24_hours",
    "gcm_model": "mpi_m_mpi_esm_lr",
    "ensemble_member": "r1i1p1",
    "area": area,
}

request_cmip6 = {
    "format": "zip",
    "temporal_resolution": "daily",
    "variable": "daily_maximum_near_surface_air_temperature",
    "month": [f"{month:02d}" for month in range(1, 13)],
    "day": [f"{day:02d}" for day in range(1, 32)],
    "area": area,
}


def get_cordex_years(
    year_slice,
    timeseries,
    start_years=list(range(1951, 2097, 5)),
    end_years=list(range(1955, 2101, 5)),
):
    start_year = []
    end_year = []
    years = set(
        range(year_slice.start - int(timeseries == "DJF"), year_slice.stop + 1)
    )  # Include D(year-1)
    for start, end in zip(start_years, end_years):
        if years & set(range(start, end + 1)):
            start_year.append(start)
            end_year.append(end)
    return start_year, end_year


def get_cmip6_years(year_slice):
    return [
        str(year)
        for year in range(
            year_slice.start - int(timeseries == "DJF"),  # Include D(year-1)
            year_slice.stop + 1,
        )
    ]


if collection_id == "CORDEX":
    models = models_cordex
    model_key = "rcm_model"
    request_sim = (
        "projections-cordex-domains-single-levels",
        [
            {
                **request_cordex,
                "experiment": "historical",
                "start_year": start_year,
                "end_year": end_year,
            }
            for start_year, end_year in zip(
                *get_cordex_years(historical_slice, timeseries)
            )
        ]
        + [
            {
                **request_cordex,
                "experiment": "rcp_8_5",
                "start_year": start_year,
                "end_year": end_year,
            }
            for start_year, end_year in zip(*get_cordex_years(future_slice, timeseries))
        ],
    )
elif collection_id == "CMIP6":
    models = models_cmip6
    model_key = "model"

    requests_historical = download.split_request(
        request_cmip6
        | {"year": get_cmip6_years(historical_slice), "experiment": "historical"},
        chunks=chunks,
    )
    requests_future = download.split_request(
        request_cmip6
        | {"year": get_cmip6_years(future_slice), "experiment": "ssp5_8_5"},
        chunks=chunks,
    )
    request_sim = (
        "projections-cmip6",
        requests_historical + requests_future,
    )
else:
    raise ValueError(f"{collection_id=}")


request_grid_out = (
    request_sim[0],
    request_sim[1][0] | {model_key: model_regrid},
)

## Functions to cache

In [None]:
def select_timeseries(ds, timeseries, year_slice):
    if timeseries == "annual":
        return ds.sel(time=slice(str(year_slice.start), str(year_slice.stop)))
    ds = ds.sel(time=slice(f"{year_slice.start-1}-12", f"{year_slice.stop}-11"))
    return ds.where(ds["time"].dt.season == timeseries, drop=True)


def compute_indices(
    ds,
    index_names,
    timeseries,
    tmpdir,
    historical_slice,
    future_slice,
):
    labels, datasets = zip(*ds.groupby("time.year"))
    paths = [f"{tmpdir}/{label}.nc" for label in labels]
    datasets = [ds.chunk(-1) for ds in datasets]
    xr.save_mfdataset(datasets, paths)

    ds = xr.open_mfdataset(paths)
    in_files = f"{tmpdir}/rechunked.zarr"
    chunks = {dim: -1 if dim == "time" else "auto" for dim in ds.dims}
    ds.chunk(chunks).to_zarr(in_files)

    time_range = (f"{future_slice.start}-01-01", f"{future_slice.stop}-12-31")
    base_range = (f"{historical_slice.start}-01-01", f"{historical_slice.stop}-12-31")

    datasets = [
        icclim.index(
            index_name=index_name,
            in_files=in_files,
            out_file=f"{tmpdir}/{index_name}.nc",
            slice_mode="year" if timeseries == "annual" else timeseries,
            time_range=time_range,
            base_period_time_range=base_range if index_name == "TX90p" else None,
        )
        for index_name in index_names
    ]

    return xr.merge(datasets).drop_dims("bounds")


def compute_trends(ds):
    datasets = []
    (lat,) = set(ds.dims) & set(ds.cf.axes["Y"])
    (lon,) = set(ds.dims) & set(ds.cf.axes["X"])
    coords_name = {
        "time": "time",
        "y": lat,
        "x": lon,
    }
    for index, da in ds.data_vars.items():
        ds = Mann_Kendall_test(
            da - da.mean("time"),
            alpha=0.05,
            method="theilslopes",
            coords_name=coords_name,
        ).compute()
        ds = ds.rename({k: v for k, v in coords_name.items() if k in ds.dims})
        ds = ds.assign_coords({dim: da[dim] for dim in ds.dims})
        datasets.append(ds.expand_dims(index=[index]))
    ds = xr.concat(datasets, "index")
    return ds


def add_bounds(ds):
    for coord in {"latitude", "longitude"} - set(ds.cf.bounds):
        ds = ds.cf.add_bounds(coord)
    return ds


def get_grid_out(request_grid_out, method):
    ds_regrid = download.download_and_transform(*request_grid_out)
    coords = ["latitude", "longitude"]
    if method == "conservative":
        ds_regrid = add_bounds(ds_regrid)
        for coord in list(coords):
            coords.extend(ds_regrid.cf.bounds[coord])
    grid_out = ds_regrid[coords]
    coords_to_drop = set(grid_out.coords) - set(coords) - set(grid_out.dims)
    grid_out = ds_regrid[coords].reset_coords(coords_to_drop, drop=True)
    grid_out.attrs = {}
    return grid_out


def compute_indices_and_trends_future(
    ds,
    index_names,
    timeseries,
    historical_slice,
    future_slice,
    resample,
    request_grid_out=None,
    **regrid_kwargs,
):
    assert (request_grid_out and regrid_kwargs) or not (
        request_grid_out or regrid_kwargs
    )
    ds = ds.drop_vars([var for var, da in ds.data_vars.items() if len(da.dims) != 3])
    ds = ds[list(ds.data_vars)]

    # Original bounds for conservative interpolation
    if regrid_kwargs.get("method") == "conservative":
        ds = add_bounds(ds)
        bounds = [
            ds.cf.get_bounds(coord).reset_coords(drop=True)
            for coord in ("latitude", "longitude")
        ]
    else:
        bounds = []

    ds_historical = select_timeseries(ds, timeseries, historical_slice)
    ds_future = select_timeseries(ds, timeseries, future_slice)
    ds = xr.concat([ds_historical, ds_future], "time")
    if resample:
        ds = ds.resample(time="1D").max(keep_attrs=True)
    with tempfile.TemporaryDirectory() as tmpdir:
        ds_indices = compute_indices(
            ds, index_names, timeseries, tmpdir, historical_slice, future_slice
        ).compute()
        ds_trends = compute_trends(ds_indices)
        ds = ds_indices.mean("time", keep_attrs=True)
        ds = ds.merge(ds_trends)
        if request_grid_out:
            ds = diagnostics.regrid(
                ds.merge({da.name: da for da in bounds}),
                grid_out=get_grid_out(request_grid_out, regrid_kwargs["method"]),
                **regrid_kwargs,
            )
        return ds

## Download and transform regrid model

In [None]:
kwargs = {
    "collection_id": request_sim[0],
    "chunks": chunks if collection_id == "CMIP6" else None,
    "transform_chunks": False,
    "transform_func": compute_indices_and_trends_future,
}
transform_func_kwargs = {
    "index_names": sorted(index_names),
    "timeseries": timeseries,
    "historical_slice": historical_slice,
    "future_slice": future_slice,
    "resample": False,
}
ds_regrid = download.download_and_transform(
    requests=[request | {model_key: model_regrid} for request in request_sim[1]],
    **kwargs,
    transform_func_kwargs=transform_func_kwargs,
)

## Download and transform models

In [None]:
interpolated_datasets = []
model_datasets = {}
for model in models:
    print(f"{model=}")
    # Original model
    ds = download.download_and_transform(
        requests=[request | {model_key: model} for request in request_sim[1]],
        **kwargs,
        transform_func_kwargs=transform_func_kwargs,
    )
    model_datasets[model] = ds

    if model != model_regrid:
        # Interpolated model
        ds = download.download_and_transform(
            requests=[request | {model_key: model} for request in request_sim[1]],
            **kwargs,
            transform_func_kwargs=transform_func_kwargs
            | {
                "request_grid_out": request_grid_out,
                "method": interpolation_method,
                "skipna": True,
            },
        )
    interpolated_datasets.append(ds.expand_dims(model=[model]))

ds_interpolated = xr.concat(interpolated_datasets, "model")

## Mask land and change attrs

In [None]:
# Mask
lsm = download.download_and_transform(*request_lsm)["lsm"].squeeze(drop=True)
ds_interpolated = ds_interpolated.where(
    diagnostics.regrid(lsm, ds_interpolated, method="bilinear")
)
model_datasets = {
    model: ds.where(diagnostics.regrid(lsm, ds, method="bilinear"))
    for model, ds in model_datasets.items()
}

# Edit attributes
for ds in (ds_interpolated, *model_datasets.values()):
    ds["trend"] *= 10
    ds["trend"].attrs = {"long_name": "trend", "units": "days / decade"}
    for index in index_names:
        ds[index].attrs = {"long_name": "", "units": "days"}

## Plotting functions

In [None]:
def hatch_p_value(da, ax, **kwargs):
    default_kwargs = {
        "plot_func": "contourf",
        "show_stats": False,
        "cmap": "none",
        "add_colorbar": False,
        "levels": [0, 0.05, 1],
        "hatches": ["", "/" * 3],
    }
    kwargs = default_kwargs | kwargs

    title = ax.get_title()
    plot_obj = plot.projected_map(da, ax=ax, **kwargs)
    ax.set_title(title)
    return plot_obj


def hatch_p_value_ensemble(trend, p_value, ax):
    n_models = trend.sizes["model"]
    robust_ratio = (p_value <= 0.05).sum("model") / n_models
    robust_ratio = robust_ratio.where(p_value.notnull().any("model"))
    signs = xr.concat([(trend > 0).sum("model"), (trend < 0).sum("model")], "sign")
    sign_ratio = signs.max("sign") / n_models
    robust_threshold = 0.66
    sign_ratio = sign_ratio.where(robust_ratio > robust_threshold)
    for da, threshold, character in zip(
        [robust_ratio, sign_ratio], [robust_threshold, 0.8], ["/", "\\"]
    ):
        hatch_p_value(da, ax=ax, levels=[0, threshold, 1], hatches=[character * 3, ""])


def set_extent(da, axs, area):
    extent = [area[i] for i in (1, 3, 2, 0)]
    for i, coord in enumerate(extent):
        extent[i] += -1 if i % 2 else +1
    for ax in axs:
        ax.set_extent(extent)


def plot_models(
    data,
    da_for_kwargs=None,
    p_values=None,
    col_wrap=3,
    subplot_kw={"projection": ccrs.PlateCarree()},
    figsize=None,
    layout="constrained",
    area=area,
    **kwargs,
):
    if isinstance(data, dict):
        assert da_for_kwargs is not None
        model_dataarrays = data
    else:
        da_for_kwargs = da_for_kwargs or data
        model_dataarrays = dict(data.groupby("model"))

    if p_values is not None:
        model_p_dataarrays = (
            p_values if isinstance(p_values, dict) else dict(p_values.groupby("model"))
        )
    else:
        model_p_dataarrays = None

    # Get kwargs
    default_kwargs = {"robust": True, "extend": "both"}
    kwargs = default_kwargs | kwargs
    kwargs = xr.plot.utils._determine_cmap_params(da_for_kwargs.values, **kwargs)

    fig, axs = plt.subplots(
        *(col_wrap, math.ceil(len(model_dataarrays) / col_wrap)),
        subplot_kw=subplot_kw,
        figsize=figsize,
        layout=layout,
    )
    axs = axs.flatten()
    for (model, da), ax in zip(model_dataarrays.items(), axs):
        pcm = plot.projected_map(
            da, ax=ax, show_stats=False, add_colorbar=False, **kwargs
        )
        ax.set_title(model)
        if model_p_dataarrays is not None:
            hatch_p_value(model_p_dataarrays[model], ax)
    set_extent(da_for_kwargs, axs, area)
    fig.colorbar(
        pcm,
        ax=axs.flatten(),
        extend=kwargs["extend"],
        location="right",
        label=f"{da_for_kwargs.attrs.get('long_name', '')} [{da_for_kwargs.attrs.get('units', '')}]",
    )
    return fig


def plot_ensemble(
    da_models,
    da_era5=None,
    p_value_era5=None,
    p_value_models=None,
    subplot_kw={"projection": ccrs.PlateCarree()},
    figsize=None,
    layout="constrained",
    cbar_kwargs=None,
    area=area,
    **kwargs,
):
    # Get kwargs
    default_kwargs = {"robust": True, "extend": "both"}
    kwargs = default_kwargs | kwargs
    kwargs = xr.plot.utils._determine_cmap_params(
        da_models.values if da_era5 is None else da_era5.values, **kwargs
    )
    if da_era5 is None and cbar_kwargs is None:
        cbar_kwargs = {"orientation": "horizontal"}

    # Figure
    fig, axs = plt.subplots(
        *(1 if da_era5 is None else 2, 2),
        subplot_kw=subplot_kw,
        figsize=figsize,
        layout=layout,
    )
    axs = axs.flatten()
    axs_iter = iter(axs)

    # ERA5
    if da_era5 is not None:
        ax = next(axs_iter)
        plot.projected_map(
            da_era5, ax=ax, show_stats=False, cbar_kwargs=cbar_kwargs, **kwargs
        )
        if p_value_era5 is not None:
            hatch_p_value(p_value_era5, ax=ax)
        ax.set_title("ERA5")

    # Median
    ax = next(axs_iter)
    median = da_models.median("model", keep_attrs=True)
    plot.projected_map(
        median, ax=ax, show_stats=False, cbar_kwargs=cbar_kwargs, **kwargs
    )
    if p_value_models is not None:
        hatch_p_value_ensemble(trend=da_models, p_value=p_value_models, ax=ax)
    ax.set_title("Ensemble Median")

    # Bias
    if da_era5 is not None:
        ax = next(axs_iter)
        with xr.set_options(keep_attrs=True):
            bias = median - da_era5
        plot.projected_map(
            bias,
            ax=ax,
            show_stats=False,
            center=0,
            cbar_kwargs=cbar_kwargs,
            **default_kwargs,
        )
        ax.set_title("Ensemble Median Bias")

    # Std
    ax = next(axs_iter)
    std = da_models.std("model", keep_attrs=True)
    plot.projected_map(
        std, ax=ax, show_stats=False, cbar_kwargs=cbar_kwargs, **default_kwargs
    )
    ax.set_title("Ensemble Standard Deviation")

    set_extent(da_models, axs, area)
    return fig


common_title = (
    f"{future_slice.start=} {future_slice.stop=} {collection_id=} {timeseries=}"
)

## Plot ensembles

In [None]:
for index in index_names:
    # Index
    fig = plot_ensemble(da_models=ds_interpolated[index])
    fig.suptitle(f"{index}\n{common_title}", y=0.8)
    plt.show()

    # Trend
    fig = plot_ensemble(
        da_models=ds_interpolated["trend"].sel(index=index),
        p_value_models=ds_interpolated["p"].sel(index=index),
        center=0,
    )
    fig.suptitle(f"Trend of {index}\n{common_title}", y=0.8)
    plt.show()

## Plot models

In [None]:
for index in index_names:
    # Index
    fig = plot_models(
        data={model: ds[index] for model, ds in model_datasets.items()},
        da_for_kwargs=ds_interpolated[index],
    )
    fig.suptitle(f"{index}\n{common_title}")
    plt.show()

    # Trend
    fig = plot_models(
        data={
            model: ds["trend"].sel(index=index) for model, ds in model_datasets.items()
        },
        da_for_kwargs=ds_interpolated["trend"].sel(index=index),
        p_values={
            model: ds["p"].sel(index=index) for model, ds in model_datasets.items()
        },
        center=0,
    )
    fig.suptitle(f"Trend of {index}\n{common_title}")
    plt.show()

## Boxplot

In [None]:
weights = collection_id == "CMIP6"
mean_datasets = [
    diagnostics.spatial_weighted_mean(ds.expand_dims(model=[model]), weights=weights)
    for model, ds in model_datasets.items()
]
mean_ds = xr.concat(mean_datasets, "model")
for index, da in mean_ds["trend"].groupby("index"):
    df_slope = da.to_dataframe()[["trend"]]
    ax = df_slope.boxplot()
    ax.scatter(
        x=[1] * len(df_slope),
        y=df_slope,
        color="grey",
        marker=".",
        label="models",
    )

    # Ensemble mean
    ax.scatter(
        x=1,
        y=da.mean("model"),
        marker="o",
        label=f"{collection_id} Ensemble Mean",
    )

    labels = [f"{collection_id} Ensemble"]
    ax.set_xticks(range(1, len(labels) + 1), labels)
    ax.set_ylabel(da.attrs["units"])
    plt.suptitle(f"Trend of {index}")
    plt.legend()
    plt.show()