# Energy-consumption-related indices from CMIP6 Global Climate Models

## Import packages

In [None]:
import tempfile

import icclim
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download
from xarrayMannKendall import Mann_Kendall_test

plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["hatch.linewidth"] = 0.5

## Define parameters

In [None]:
# Time period
year_start = 1971
year_stop = 2000
assert year_start >= 1971

# Choose annual or seasonal timeseries
index_timeseries = {
    "HDD15.5": "DJF",
    "CDD22": "JJA",
}
if "annual" in index_timeseries.values():
    assert set(index_timeseries.values()) == {"annual"}

# Interpolation method
interpolation_method = "bilinear"

# Area to show
area = [72, -22, 27, 45]

# Chunks for download
chunks = {"year": 1}

# Define models
models_cmip6 = (
    "access_cm2",
    "awi_cm_1_1_mr",
    "cmcc_esm2",
    "cnrm_cm6_1_hr",
    "ec_earth3_cc",
    "gfdl_esm4",
    "inm_cm5_0",
    "miroc6",
    "mpi_esm1_2_lr",
)

## Define ERA5 request

In [None]:
request_era = (
    "reanalysis-era5-single-levels",
    {
        "product_type": "reanalysis",
        "format": "netcdf",
        "time": [f"{hour:02d}:00" for hour in range(24)],
        "variable": "2m_temperature",
        "year": [
            str(year) for year in range(year_start - 1, year_stop + 1)
        ],  # Include D(year-1)
        "month": [f"{month:02d}" for month in range(1, 13)],
        "day": [f"{day:02d}" for day in range(1, 32)],
        "area": area,
    },
)

request_lsm = (
    request_era[0],
    request_era[1]
    | {
        "year": "1940",
        "month": "01",
        "day": "01",
        "time": "00:00",
        "variable": "land_sea_mask",
    },
)

## Define CMIP6 request

In [None]:
request_cmip6 = {
    "format": "zip",
    "temporal_resolution": "daily",
    "experiment": "historical",
    "variable": "near_surface_air_temperature",
    "year": [
        str(year) for year in range(year_start - 1, year_stop + 1)
    ],  # Include D(year-1)
    "month": [f"{month:02d}" for month in range(1, 13)],
    "day": [f"{day:02d}" for day in range(1, 32)],
    "area": area,
}

model_requests = {}
for model in models_cmip6:
    model_requests[model] = (
        "projections-cmip6",
        download.split_request(request_cmip6 | {"model": model}, chunks=chunks),
    )

## Functions to cache

In [None]:
def select_timeseries(ds, index_timeseries, year_start, year_stop):
    timeseries = set(index_timeseries.values())
    if timeseries == {"annual"}:
        return ds.sel(time=slice(str(year_start), str(year_stop)))
    assert "annual" not in timeseries
    return ds.sel(time=slice(f"{year_start-1}-12", f"{year_stop}-11"))


def compute_indices(ds, index_timeseries, tmpdir):
    labels, datasets = zip(*ds.groupby("time.year"))
    paths = [f"{tmpdir}/{label}.nc" for label in labels]
    datasets = [ds.chunk(-1) for ds in datasets]
    xr.save_mfdataset(datasets, paths)

    ds = xr.open_mfdataset(paths)
    in_files = f"{tmpdir}/rechunked.zarr"
    chunks = {dim: -1 if dim == "time" else "auto" for dim in ds.dims}
    ds.chunk(chunks).to_zarr(in_files)

    datasets = []
    for index_name, timeseries in index_timeseries.items():
        kwargs = {
            "in_files": in_files,
            "out_file": f"{tmpdir}/{index_name}.nc",
            "slice_mode": "year" if timeseries == "annual" else timeseries,
        }
        if index_name == "HDD15.5":
            ds_index = icclim.index(
                **kwargs,
                index_name="deficit",
                threshold=icclim.build_threshold("15.5 degC"),
            )
        elif index_name == "CDD22":
            ds_index = icclim.excess(
                **kwargs,
                threshold=icclim.build_threshold("22 degC"),
            )
        else:
            raise NotImplementedError(f"{index_name=}")

        ds_index = ds_index.drop_dims("bounds")
        num_days = {"DJF": 90, "MAM": 92, "JJA": 92, "SON": 91}
        with xr.set_options(keep_attrs=True):
            ds_index /= (
                num_days[timeseries]
                if timeseries != "annual"
                else sum(num_days.values())
            )
        datasets.append(ds_index)

    ds = xr.merge(datasets)
    for da in ds.data_vars.values():
        da.attrs["units"] = da.attrs["units"].replace(" d", "")
    return ds


def compute_trends(ds):
    datasets = []
    (lat,) = set(ds.dims) & set(ds.cf.axes["Y"])
    (lon,) = set(ds.dims) & set(ds.cf.axes["X"])
    coords_name = {
        "time": "time",
        "y": lat,
        "x": lon,
    }
    for index, da in ds.data_vars.items():
        ds = Mann_Kendall_test(
            da - da.mean("time"),
            alpha=0.05,
            method="theilslopes",
            coords_name=coords_name,
        ).compute()
        ds = ds.rename({k: v for k, v in coords_name.items() if k in ds.dims})
        ds = ds.assign_coords({dim: da[dim] for dim in ds.dims})
        datasets.append(ds.expand_dims(index=[index]))
    ds = xr.concat(datasets, "index")
    return ds


def add_bounds(ds):
    for coord in {"latitude", "longitude"} - set(ds.cf.bounds):
        ds = ds.cf.add_bounds(coord)
    return ds


def get_grid_out(request_grid_out, method):
    ds_regrid = download.download_and_transform(*request_grid_out)
    coords = ["latitude", "longitude"]
    if method == "conservative":
        ds_regrid = add_bounds(ds_regrid)
        for coord in list(coords):
            coords.extend(ds_regrid.cf.bounds[coord])
    grid_out = ds_regrid[coords]
    coords_to_drop = set(grid_out.coords) - set(coords) - set(grid_out.dims)
    grid_out = ds_regrid[coords].reset_coords(coords_to_drop, drop=True)
    grid_out.attrs = {}
    return grid_out


def compute_indices_and_trends(
    ds,
    index_timeseries,
    year_start,
    year_stop,
    resample_reduction=None,
    request_grid_out=None,
    **regrid_kwargs,
):
    assert (request_grid_out and regrid_kwargs) or not (
        request_grid_out or regrid_kwargs
    )
    ds = ds.drop_vars([var for var, da in ds.data_vars.items() if len(da.dims) != 3])
    ds = ds[list(ds.data_vars)]

    # Original bounds for conservative interpolation
    if regrid_kwargs.get("method") == "conservative":
        ds = add_bounds(ds)
        bounds = [
            ds.cf.get_bounds(coord).reset_coords(drop=True)
            for coord in ("latitude", "longitude")
        ]
    else:
        bounds = []

    ds = select_timeseries(ds, index_timeseries, year_start, year_stop)
    if resample_reduction:
        resampled = ds.resample(time="1D")
        ds = getattr(resampled, resample_reduction)(keep_attrs=True)
        if resample_reduction == "sum":
            for da in ds.data_vars.values():
                da.attrs["units"] = f"{da.attrs['units']} / day"
    with tempfile.TemporaryDirectory() as tmpdir:
        ds_indices = compute_indices(ds, index_timeseries, tmpdir).compute()
        ds_trends = compute_trends(ds_indices)
        ds = ds_indices.mean("time", keep_attrs=True)
        ds = ds.merge(ds_trends)
        if request_grid_out:
            ds = diagnostics.regrid(
                ds.merge({da.name: da for da in bounds}),
                grid_out=get_grid_out(request_grid_out, regrid_kwargs["method"]),
                **regrid_kwargs,
            )
        return ds

## Download and transform ERA5

In [None]:
transform_func_kwargs = {
    "index_timeseries": dict(sorted(index_timeseries.items())),
    "year_start": year_start,
    "year_stop": year_stop,
}
ds_era5 = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_chunks=False,
    transform_func=compute_indices_and_trends,
    transform_func_kwargs=transform_func_kwargs | {"resample_reduction": "mean"},
)

## Download and transform models

In [None]:
interpolated_datasets = []
model_datasets = {}
for model, requests in model_requests.items():
    print(f"{model=}")
    model_kwargs = {
        "chunks": chunks,
        "transform_chunks": False,
        "transform_func": compute_indices_and_trends,
    }
    # Original model
    model_datasets[model] = download.download_and_transform(
        *requests,
        **model_kwargs,
        transform_func_kwargs=transform_func_kwargs,
    )

    # Interpolated model
    ds = download.download_and_transform(
        *requests,
        **model_kwargs,
        transform_func_kwargs=transform_func_kwargs
        | {
            "request_grid_out": request_lsm,
            "method": interpolation_method,
            "skipna": True,
        },
    )
    interpolated_datasets.append(ds.expand_dims(model=[model]))

ds_interpolated = xr.concat(interpolated_datasets, "model")