# Energy-consumption-related indices from CMIP6 Global Climate Models

## Import packages

In [None]:
import math
import tempfile

import cartopy.crs as ccrs
import icclim
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils
from xarrayMannKendall import Mann_Kendall_test

plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["hatch.linewidth"] = 0.5

## Define parameters

In [None]:
# Time period
year_start = 1971
year_stop = 2000

# Choose annual or seasonal timeseries
index_timeseries = {
    "HDD15.5": "DJF",
    "CDD22": "JJA",
}
if "annual" in index_timeseries.values():
    assert set(index_timeseries.values()) == {"annual"}

# Interpolation method
interpolation_method = "bilinear"

# Area to show
area = [72, -22, 27, 45]

# Chunks for download
chunks = {"year": 1}

# Define models
models_cmip6 = (
    "access_cm2",
    "awi_cm_1_1_mr",
    "cmcc_esm2",
    "cnrm_cm6_1_hr",
    "ec_earth3_cc",
    "gfdl_esm4",
    "inm_cm5_0",
    "miroc6",
    "mpi_esm1_2_lr",
)

# Colormaps
cmaps = {"HDD15.5": "Blues", "CDD22": "Reds"}
cmaps_trend = cmaps_bias = {"HDD15.5": "RdBu", "CDD22": "RdBu_r"}

## Define ERA5 request

In [None]:
request_era = (
    "reanalysis-era5-single-levels",
    {
        "product_type": "reanalysis",
        "format": "netcdf",
        "time": [f"{hour:02d}:00" for hour in range(24)],
        "variable": "2m_temperature",
        "year": [
            str(year) for year in range(year_start - 1, year_stop + 1)
        ],  # Include D(year-1)
        "month": [f"{month:02d}" for month in range(1, 13)],
        "day": [f"{day:02d}" for day in range(1, 32)],
        "area": area,
    },
)

request_lsm = (
    request_era[0],
    request_era[1]
    | {
        "year": "1940",
        "month": "01",
        "day": "01",
        "time": "00:00",
        "variable": "land_sea_mask",
    },
)

## Define CMIP6 request

In [None]:
request_cmip6 = {
    "format": "zip",
    "temporal_resolution": "daily",
    "experiment": "historical",
    "variable": "near_surface_air_temperature",
    "year": [
        str(year) for year in range(year_start - 1, year_stop + 1)
    ],  # Include D(year-1)
    "month": [f"{month:02d}" for month in range(1, 13)],
    "day": [f"{day:02d}" for day in range(1, 32)],
    "area": area,
}

model_requests = {}
for model in models_cmip6:
    model_requests[model] = (
        "projections-cmip6",
        download.split_request(request_cmip6 | {"model": model}, chunks=chunks),
    )

## Functions to cache

In [None]:
def select_timeseries(ds, index_timeseries, year_start, year_stop):
    timeseries = set(index_timeseries.values())
    if timeseries == {"annual"}:
        return ds.sel(time=slice(str(year_start), str(year_stop)))
    assert "annual" not in timeseries
    return ds.sel(time=slice(f"{year_start-1}-12", f"{year_stop}-11"))


def compute_indices(ds, index_timeseries, tmpdir):
    labels, datasets = zip(*ds.groupby("time.year"))
    paths = [f"{tmpdir}/{label}.nc" for label in labels]
    datasets = [ds.chunk(-1) for ds in datasets]
    xr.save_mfdataset(datasets, paths)

    ds = xr.open_mfdataset(paths)
    in_files = f"{tmpdir}/rechunked.zarr"
    chunks = {dim: -1 if dim == "time" else "auto" for dim in ds.dims}
    ds.chunk(chunks).to_zarr(in_files)

    dataarrays = []
    for index_name, timeseries in index_timeseries.items():
        kwargs = {
            "in_files": in_files,
            "out_file": f"{tmpdir}/{index_name}.nc",
            "slice_mode": "year" if timeseries == "annual" else timeseries,
        }
        if index_name == "HDD15.5":
            ds_index = icclim.index(
                **kwargs,
                index_name="deficit",
                threshold=icclim.build_threshold("15.5 degC"),
            )
        elif index_name == "CDD22":
            ds_index = icclim.excess(
                **kwargs,
                threshold=icclim.build_threshold("22 degC"),
            )
        else:
            raise NotImplementedError(f"{index_name=}")

        (da,) = ds_index.drop_dims("bounds").data_vars.values()
        num_days = {"DJF": 90, "MAM": 92, "JJA": 92, "SON": 91}
        with xr.set_options(keep_attrs=True):
            da /= (
                num_days[timeseries]
                if timeseries != "annual"
                else sum(num_days.values())
            )
        da.attrs["units"] = da.attrs["units"].replace(" d", "")
        dataarrays.append(da.rename(index_name))
    return xr.merge(dataarrays)


def compute_trends(ds):
    datasets = []
    (lat,) = set(ds.dims) & set(ds.cf.axes["Y"])
    (lon,) = set(ds.dims) & set(ds.cf.axes["X"])
    coords_name = {
        "time": "time",
        "y": lat,
        "x": lon,
    }
    for index, da in ds.data_vars.items():
        ds = Mann_Kendall_test(
            da - da.mean("time"),
            alpha=0.05,
            method="theilslopes",
            coords_name=coords_name,
        ).compute()
        ds = ds.rename({k: v for k, v in coords_name.items() if k in ds.dims})
        ds = ds.assign_coords({dim: da[dim] for dim in ds.dims})
        datasets.append(ds.expand_dims(index=[index]))
    ds = xr.concat(datasets, "index")
    return ds


def add_bounds(ds):
    for coord in {"latitude", "longitude"} - set(ds.cf.bounds):
        ds = ds.cf.add_bounds(coord)
    return ds


def get_grid_out(request_grid_out, method):
    ds_regrid = download.download_and_transform(*request_grid_out)
    coords = ["latitude", "longitude"]
    if method == "conservative":
        ds_regrid = add_bounds(ds_regrid)
        for coord in list(coords):
            coords.extend(ds_regrid.cf.bounds[coord])
    grid_out = ds_regrid[coords]
    coords_to_drop = set(grid_out.coords) - set(coords) - set(grid_out.dims)
    grid_out = ds_regrid[coords].reset_coords(coords_to_drop, drop=True)
    grid_out.attrs = {}
    return grid_out


def compute_indices_and_trends(
    ds,
    index_timeseries,
    year_start,
    year_stop,
    resample_reduction=None,
    request_grid_out=None,
    **regrid_kwargs,
):
    assert (request_grid_out and regrid_kwargs) or not (
        request_grid_out or regrid_kwargs
    )
    ds = ds.drop_vars([var for var, da in ds.data_vars.items() if len(da.dims) != 3])
    ds = ds[list(ds.data_vars)]

    # Original bounds for conservative interpolation
    if regrid_kwargs.get("method") == "conservative":
        ds = add_bounds(ds)
        bounds = [
            ds.cf.get_bounds(coord).reset_coords(drop=True)
            for coord in ("latitude", "longitude")
        ]
    else:
        bounds = []

    ds = select_timeseries(ds, index_timeseries, year_start, year_stop)
    if resample_reduction:
        resampled = ds.resample(time="1D")
        ds = getattr(resampled, resample_reduction)(keep_attrs=True)
        if resample_reduction == "sum":
            for da in ds.data_vars.values():
                da.attrs["units"] = f"{da.attrs['units']} / day"
    with tempfile.TemporaryDirectory() as tmpdir:
        ds_indices = compute_indices(ds, index_timeseries, tmpdir).compute()
        ds_trends = compute_trends(ds_indices)
        ds = ds_indices.mean("time", keep_attrs=True)
        ds = ds.merge(ds_trends)
        if request_grid_out:
            ds = diagnostics.regrid(
                ds.merge({da.name: da for da in bounds}),
                grid_out=get_grid_out(request_grid_out, regrid_kwargs["method"]),
                **regrid_kwargs,
            )
        return ds

## Download and transform ERA5

In [None]:
transform_func_kwargs = {
    "index_timeseries": dict(sorted(index_timeseries.items())),
    "year_start": year_start,
    "year_stop": year_stop,
}
ds_era5 = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_chunks=False,
    transform_func=compute_indices_and_trends,
    transform_func_kwargs=transform_func_kwargs | {"resample_reduction": "mean"},
)

## Download and transform models

In [None]:
interpolated_datasets = []
model_datasets = {}
for model, requests in model_requests.items():
    print(f"{model=}")
    model_kwargs = {
        "chunks": chunks,
        "transform_chunks": False,
        "transform_func": compute_indices_and_trends,
    }
    # Original model
    model_datasets[model] = download.download_and_transform(
        *requests,
        **model_kwargs,
        transform_func_kwargs=transform_func_kwargs,
    )

    # Interpolated model
    ds = download.download_and_transform(
        *requests,
        **model_kwargs,
        transform_func_kwargs=transform_func_kwargs
        | {
            "request_grid_out": request_lsm,
            "method": interpolation_method,
            "skipna": True,
        },
    )
    interpolated_datasets.append(ds.expand_dims(model=[model]))

ds_interpolated = xr.concat(interpolated_datasets, "model")

## Mask land and change attributes

In [None]:
lsm = download.download_and_transform(*request_lsm)["lsm"].squeeze(drop=True)

# Cutout
regionalise_kwargs = {
    "lon_slice": slice(area[1], area[3]),
    "lat_slice": slice(area[0], area[2]),
}
lsm = utils.regionalise(lsm, **regionalise_kwargs)
ds_interpolated = utils.regionalise(ds_interpolated, **regionalise_kwargs)
model_datasets = {
    model: utils.regionalise(ds, **regionalise_kwargs)
    for model, ds in model_datasets.items()
}

# Mask
ds_era5 = ds_era5.where(lsm)
ds_interpolated = ds_interpolated.where(lsm)
model_datasets = {
    model: ds.where(diagnostics.regrid(lsm, ds, method="bilinear"))
    for model, ds in model_datasets.items()
}

# Edit attributes
for ds in (ds_era5, ds_interpolated, *model_datasets.values()):
    ds["trend"] *= 10
    ds["trend"].attrs = {"long_name": "trend"}
    for index in index_timeseries:
        ds[index].attrs = {"long_name": "", "units": ds[index].attrs["units"]}

## Plotting functions

In [None]:
def hatch_p_value(da, ax, **kwargs):
    default_kwargs = {
        "plot_func": "contourf",
        "show_stats": False,
        "cmap": "none",
        "add_colorbar": False,
        "levels": [0, 0.05, 1],
        "hatches": ["", "/" * 3],
    }
    kwargs = default_kwargs | kwargs

    title = ax.get_title()
    plot_obj = plot.projected_map(da, ax=ax, **kwargs)
    ax.set_title(title)
    return plot_obj


def hatch_p_value_ensemble(trend, p_value, ax):
    n_models = trend.sizes["model"]
    robust_ratio = (p_value <= 0.05).sum("model") / n_models
    robust_ratio = robust_ratio.where(p_value.notnull().any("model"))
    signs = xr.concat([(trend > 0).sum("model"), (trend < 0).sum("model")], "sign")
    sign_ratio = signs.max("sign") / n_models
    robust_threshold = 0.66
    sign_ratio = sign_ratio.where(robust_ratio > robust_threshold)
    for da, threshold, character in zip(
        [robust_ratio, sign_ratio], [robust_threshold, 0.8], ["/", "\\"]
    ):
        hatch_p_value(da, ax=ax, levels=[0, threshold, 1], hatches=[character * 3, ""])


def set_extent(da, axs, area):
    extent = [area[i] for i in (1, 3, 2, 0)]
    for i, coord in enumerate(extent):
        extent[i] += -1 if i % 2 else +1
    for ax in axs:
        ax.set_extent(extent)


def plot_models(
    data,
    da_for_kwargs=None,
    p_values=None,
    col_wrap=3,
    subplot_kw={"projection": ccrs.PlateCarree()},
    figsize=None,
    layout="constrained",
    area=area,
    **kwargs,
):
    if isinstance(data, dict):
        assert da_for_kwargs is not None
        model_dataarrays = data
    else:
        da_for_kwargs = da_for_kwargs or data
        model_dataarrays = dict(data.groupby("model"))

    if p_values is not None:
        model_p_dataarrays = (
            p_values if isinstance(p_values, dict) else dict(p_values.groupby("model"))
        )
    else:
        model_p_dataarrays = None

    # Get kwargs
    default_kwargs = {"robust": True, "extend": "both"}
    kwargs = default_kwargs | kwargs
    kwargs = xr.plot.utils._determine_cmap_params(da_for_kwargs.values, **kwargs)

    fig, axs = plt.subplots(
        *(col_wrap, math.ceil(len(model_dataarrays) / col_wrap)),
        subplot_kw=subplot_kw,
        figsize=figsize,
        layout=layout,
    )
    axs = axs.flatten()
    for (model, da), ax in zip(model_dataarrays.items(), axs):
        pcm = plot.projected_map(
            da, ax=ax, show_stats=False, add_colorbar=False, **kwargs
        )
        ax.set_title(model)
        if model_p_dataarrays is not None:
            hatch_p_value(model_p_dataarrays[model], ax)
    set_extent(da_for_kwargs, axs, area)
    fig.colorbar(
        pcm,
        ax=axs.flatten(),
        extend=kwargs["extend"],
        location="right",
        label=f"{da_for_kwargs.attrs.get('long_name', '')} [{da_for_kwargs.attrs.get('units', '')}]",
    )
    return fig


def plot_ensemble(
    da_models,
    da_era5=None,
    p_value_era5=None,
    p_value_models=None,
    subplot_kw={"projection": ccrs.PlateCarree()},
    figsize=None,
    layout="constrained",
    cbar_kwargs=None,
    area=area,
    cmap_bias=None,
    cmap_std=None,
    **kwargs,
):
    # Get kwargs
    default_kwargs = {"robust": True, "extend": "both"}
    kwargs = default_kwargs | kwargs
    kwargs = xr.plot.utils._determine_cmap_params(
        da_models.values if da_era5 is None else da_era5.values, **kwargs
    )
    if da_era5 is None and cbar_kwargs is None:
        cbar_kwargs = {"orientation": "horizontal"}

    # Figure
    fig, axs = plt.subplots(
        *(1 if da_era5 is None else 2, 2),
        subplot_kw=subplot_kw,
        figsize=figsize,
        layout=layout,
    )
    axs = axs.flatten()
    axs_iter = iter(axs)

    # ERA5
    if da_era5 is not None:
        ax = next(axs_iter)
        plot.projected_map(
            da_era5, ax=ax, show_stats=False, cbar_kwargs=cbar_kwargs, **kwargs
        )
        if p_value_era5 is not None:
            hatch_p_value(p_value_era5, ax=ax)
        ax.set_title("ERA5")

    # Median
    ax = next(axs_iter)
    median = da_models.median("model", keep_attrs=True)
    plot.projected_map(
        median, ax=ax, show_stats=False, cbar_kwargs=cbar_kwargs, **kwargs
    )
    if p_value_models is not None:
        hatch_p_value_ensemble(trend=da_models, p_value=p_value_models, ax=ax)
    ax.set_title("Ensemble Median")

    # Bias
    if da_era5 is not None:
        ax = next(axs_iter)
        with xr.set_options(keep_attrs=True):
            bias = median - da_era5
        plot.projected_map(
            bias,
            ax=ax,
            show_stats=False,
            center=0,
            cbar_kwargs=cbar_kwargs,
            **(default_kwargs | {"cmap": cmap_bias}),
        )
        ax.set_title("Ensemble Median Bias")

    # Std
    ax = next(axs_iter)
    std = da_models.std("model", keep_attrs=True)
    plot.projected_map(
        std,
        ax=ax,
        show_stats=False,
        cbar_kwargs=cbar_kwargs,
        **(default_kwargs | {"cmap": cmap_std}),
    )
    ax.set_title("Ensemble Standard Deviation")

    set_extent(da_models, axs, area)
    return fig


common_title = f"{year_start=} {year_stop=}"

## Plot ensembles

In [None]:
for index in index_timeseries:
    # Index
    da = ds_interpolated[index]
    fig = plot_ensemble(
        da_models=da,
        da_era5=ds_era5[index],
        cmap=cmaps.get(index),
        cmap_bias=cmaps_bias.get(index),
    )
    fig.suptitle(f"{index}\n{common_title}")
    plt.show()

    # Trend
    da_era5_trend = ds_era5["trend"].sel(index=index)
    da_era5_trend.attrs["units"] = f"{da.attrs['units']} / decade"
    da_trend = ds_interpolated["trend"].sel(index=index)
    da_trend.attrs["units"] = f"{da.attrs['units']} / decade"
    fig = plot_ensemble(
        da_models=da_trend,
        da_era5=da_era5_trend,
        p_value_era5=ds_era5["p"].sel(index=index),
        p_value_models=ds_interpolated["p"].sel(index=index),
        center=0,
        cmap=cmaps_trend.get(index),
        cmap_bias=cmaps_bias.get(index),
    )
    fig.suptitle(f"Trend of {index}\n{common_title}")
    plt.show()

## Plot models

In [None]:
for index in index_timeseries:
    # Index
    da_for_kwargs = ds_era5[index]
    fig = plot_models(
        data={model: ds[index] for model, ds in model_datasets.items()},
        da_for_kwargs=da_for_kwargs,
        cmap=cmaps.get(index),
    )
    fig.suptitle(f"{index}\n{common_title}")
    plt.show()

    # Trend
    da_for_kwargs_trends = ds_era5["trend"].sel(index=index)
    da_for_kwargs_trends.attrs["units"] = f"{da_for_kwargs.attrs['units']} / decade"
    fig = plot_models(
        data={
            model: ds["trend"].sel(index=index) for model, ds in model_datasets.items()
        },
        da_for_kwargs=da_for_kwargs_trends,
        p_values={
            model: ds["p"].sel(index=index) for model, ds in model_datasets.items()
        },
        center=0,
        cmap=cmaps_trend.get(index),
    )
    fig.suptitle(f"Trend of {index}\n{common_title}")
    plt.show()

## Plot bias

In [None]:
with xr.set_options(keep_attrs=True):
    bias = ds_interpolated - ds_era5

for index in index_timeseries:
    # Index bias
    da = bias[index]
    fig = plot_models(data=da, center=0, cmap=cmaps_bias.get(index))
    fig.suptitle(f"Bias of {index}\n{common_title}")
    plt.show()

    # Trend bias
    da_trend = bias["trend"].sel(index=index)
    da_trend.attrs["units"] = f"{da.attrs['units']} / decade"
    fig = plot_models(data=da_trend, center=0, cmap=cmaps_bias.get(index))
    fig.suptitle(f"Trend bias of {index}\n{common_title}")
    plt.show()

## Boxplot

In [None]:
weights = True
mean_datasets = [
    diagnostics.spatial_weighted_mean(ds.expand_dims(model=[model]), weights=weights)
    for model, ds in model_datasets.items()
]
mean_ds = xr.concat(mean_datasets, "model")
mean_bias_ds = diagnostics.spatial_weighted_mean(bias, weights=weights)
for is_bias, ds in zip((False, True), (mean_ds, mean_bias_ds)):
    for index, da in ds["trend"].groupby("index"):
        df_slope = da.to_dataframe()[["trend"]]
        ax = df_slope.boxplot()
        ax.scatter(
            x=[1] * len(df_slope),
            y=df_slope,
            color="grey",
            marker=".",
            label="models",
        )

        # Ensemble mean
        ax.scatter(
            x=1,
            y=da.mean("model"),
            marker="o",
            label="CMIP6 Ensemble Mean",
        )

        # ERA5
        labels = ["CMIP6 Ensemble"]
        if not is_bias:
            da = ds_era5["trend"].sel(index=index)
            da = diagnostics.spatial_weighted_mean(da)
            ax.scatter(
                x=2,
                y=da.values,
                marker="o",
                label="ERA5",
            )
            labels.append("ERA5")

        ax.set_xticks(range(1, len(labels) + 1), labels)
        ax.set_ylabel(f"{ds[index].attrs['units']} / decade")
        plt.suptitle(f"Trend{' bias ' if is_bias else ' '}of {index}")
        plt.legend()
        plt.show()