## Extreme temperature indices: Historical VS Future

## Import packages

In [None]:
import math
import tempfile

import cartopy.crs as ccrs
import icclim
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot

plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["hatch.linewidth"] = 0.5

## Define Parameters

In [None]:
# Time period historical
year_start_historical = 1971
year_stop_historical = 2000

# Models time post
models = {
    "access_cm2": slice(2023, 2052),
    "awi_cm_1_1_mr": slice(2022, 2051),
    "canesm5": slice(2015, 2044),
    "cmcc_esm2": slice(2024, 2053),
    "cnrm_cm6_1_hr": slice(2016, 2045),
    "ec_earth3_cc": slice(2020, 2049),
    "gfdl_esm4": slice(2038, 2067),
    "miroc6": slice(2039, 2068),
    "mpi_esm1_2_lr": slice(2034, 2063),
}
assert all(year_slice.start > year_stop_historical for year_slice in models.values())

# Choose model for regridding
model_regrid = "gfdl_esm4"

# Choose annual or seasonal timeseries
timeseries = "JJA"
assert timeseries in ("annual", "DJF", "MAM", "JJA", "SON")

# Define region for analysis
area = [72, -22, 27, 45]

# Define region for request
cordex_domain = "europe"

# Define index names
index_names = ("SU", "TX90p")

# Interpolation method
interpolation_method = "bilinear"

# Chunks for download
chunks = {"year": 1}

## Define lsm request

In [None]:
request_lsm = (
    "reanalysis-era5-single-levels",
    {
        "product_type": "reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": "land_sea_mask",
        "year": "1940",
        "month": "01",
        "day": "01",
        "area": area,
    },
)

## Define model requests

In [None]:
def get_years(year_start, year_stop, timeseries):
    return [
        str(year)
        for year in range(year_start - int(timeseries == "DJF"), year_stop + 1)
    ]


collection_id = "projections-cmip6"

request = {
    "format": "zip",
    "temporal_resolution": "daily",
    "variable": "daily_maximum_near_surface_air_temperature",
    "month": [f"{month:02d}" for month in range(1, 13)],
    "day": [f"{day:02d}" for day in range(1, 32)],
    "area": area,
}

request_historical = request | {
    "year": get_years(year_start_historical, year_stop_historical, timeseries),
    "experiment": "historical",
}
request_future = request | {
    "experiment": "ssp5_8_5",
}

model_requests = {}
for model, year_slice in models.items():
    years_future = get_years(year_slice.start, year_slice.stop, timeseries)
    model_request_historical = download.split_request(
        request_historical | {"model": model}, chunks=chunks
    )
    model_request_future = download.split_request(
        request_future | {"model": model, "year": years_future}, chunks=chunks
    )
    model_requests[model] = model_request_historical + model_request_future

request_grid_out = model_requests[model_regrid]

## Functions to cache

In [None]:
def separate_historical_and_future(ds, year_stop_historical):
    return [ds for _, ds in ds.groupby(ds["time"].dt.year <= year_stop_historical)]


def select_timeseries(ds, timeseries, year_stop_historical):
    datasets = []
    for ds in separate_historical_and_future(ds, year_stop_historical):
        if timeseries == "DJF":
            year_start = ds["time"].dt.year.min().values
            year_stop = ds["time"].dt.year.max().values
            ds = ds.sel(time=slice(f"{year_start}-12", f"{year_stop}-11"))
        datasets.append(ds)
    ds = xr.concat(datasets, "time")
    if timeseries == "annual":
        return ds
    return ds.where(ds["time"].dt.season == timeseries, drop=True)


def compute_indices(ds, index_names, timeseries, tmpdir, year_stop_historical):
    labels, datasets = zip(*ds.groupby("time.year"))
    paths = [f"{tmpdir}/{label}.nc" for label in labels]
    datasets = [ds.chunk(-1) for ds in datasets]
    xr.save_mfdataset(datasets, paths)

    ds = xr.open_mfdataset(paths)
    in_files = f"{tmpdir}/rechunked.zarr"
    chunks = {dim: -1 if dim == "time" else "auto" for dim in ds.dims}
    ds.chunk(chunks).to_zarr(in_files)

    start = str(ds["time"].min().values)
    stop = str(ds["time"].max().values)
    historical_range = (start, f"{year_stop_historical}-12-31")
    future_range = (f"{year_stop_historical+1}-01-01", stop)

    datasets = []
    for index_name in index_names:
        kwargs = {
            "index_name": index_name,
            "in_files": in_files,
            "slice_mode": "year" if timeseries == "annual" else timeseries,
        }
        if index_name == "TX90p":
            datasets.append(
                icclim.index(
                    out_file=f"{tmpdir}/{index_name}.nc",
                    time_range=future_range,
                    base_period_time_range=historical_range,
                    **kwargs,
                )
            )
        else:
            ds_historical = icclim.index(
                out_file=f"{tmpdir}/{index_name}_historical.nc",
                time_range=historical_range,
                **kwargs,
            )
            ds_future = icclim.index(
                out_file=f"{tmpdir}/{index_name}_future.nc",
                time_range=future_range,
                **kwargs,
            )
            with xr.set_options(keep_attrs=True):
                datasets.append(ds_future.mean("time") - ds_historical.mean("time"))
    return xr.merge(datasets).drop_dims("bounds")


def add_bounds(ds):
    for coord in {"latitude", "longitude"} - set(ds.cf.bounds):
        ds = ds.cf.add_bounds(coord)
    return ds


def get_grid_out(request_grid_out, method):
    ds_regrid = download.download_and_transform(*request_grid_out)
    coords = ["latitude", "longitude"]
    if method == "conservative":
        ds_regrid = add_bounds(ds_regrid)
        for coord in list(coords):
            coords.extend(ds_regrid.cf.bounds[coord])
    grid_out = ds_regrid[coords]
    coords_to_drop = set(grid_out.coords) - set(coords) - set(grid_out.dims)
    grid_out = ds_regrid[coords].reset_coords(coords_to_drop, drop=True)
    grid_out.attrs = {}
    return grid_out


def compute_indices_and_trends_historical_vs_future(
    ds,
    index_names,
    timeseries,
    year_stop_historical,
    resample,
    request_grid_out=None,
    **regrid_kwargs,
):
    assert (request_grid_out and regrid_kwargs) or not (
        request_grid_out or regrid_kwargs
    )
    ds = ds.drop_vars([var for var, da in ds.data_vars.items() if len(da.dims) != 3])
    ds = ds[list(ds.data_vars)]

    # Original bounds for conservative interpolation
    if regrid_kwargs.get("method") == "conservative":
        ds = add_bounds(ds)
        bounds = [
            ds.cf.get_bounds(coord).reset_coords(drop=True)
            for coord in ("latitude", "longitude")
        ]
    else:
        bounds = []

    ds = select_timeseries(ds, timeseries, year_stop_historical)
    if resample:
        ds = ds.resample(time="1D").max(keep_attrs=True)
    with tempfile.TemporaryDirectory() as tmpdir:
        ds_indices = compute_indices(
            ds, index_names, timeseries, tmpdir, year_stop_historical
        ).compute()
        ds = ds_indices.mean("time", keep_attrs=True)
        if request_grid_out:
            ds = diagnostics.regrid(
                ds.merge({da.name: da for da in bounds}),
                grid_out=get_grid_out(request_grid_out, regrid_kwargs["method"]),
                **regrid_kwargs,
            )
        return ds

## Download and transform regrid model

In [None]:
kwargs = {
    "collection_id": collection_id,
    "chunks": chunks,
    "transform_chunks": False,
    "transform_func": compute_indices_and_trends_historical_vs_future,
}
transform_func_kwargs = {
    "index_names": sorted(index_names),
    "timeseries": timeseries,
    "year_stop_historical": year_stop_historical,
    "resample": False,
}
ds_regrid = download.download_and_transform(
    requests=request_grid_out,
    **kwargs,
    transform_func_kwargs=transform_func_kwargs,
)

## Download and transform models

In [None]:
interpolated_datasets = []
model_datasets = {}
for model, requests in model_requests.items():
    print(f"{model=}")
    # Original model
    ds = download.download_and_transform(
        requests=requests,
        **kwargs,
        transform_func_kwargs=transform_func_kwargs,
    )
    model_datasets[model] = ds

    if model != model_regrid:
        # Interpolated model
        ds = download.download_and_transform(
            requests=requests,
            **kwargs,
            transform_func_kwargs=transform_func_kwargs
            | {
                "request_grid_out": (collection_id, request_grid_out),
                "method": interpolation_method,
                "skipna": True,
            },
        )
    interpolated_datasets.append(ds.expand_dims(model=[model]))

ds_interpolated = xr.concat(interpolated_datasets, "model")

## Mask land and change attrs

In [None]:
# Mask
lsm = download.download_and_transform(*request_lsm)["lsm"].squeeze(drop=True)
ds_interpolated = ds_interpolated.where(
    diagnostics.regrid(lsm, ds_interpolated, method="bilinear")
)
model_datasets = {
    model: ds.where(diagnostics.regrid(lsm, ds, method="bilinear"))
    for model, ds in model_datasets.items()
}

# Edit attributes
for ds in (ds_interpolated, *model_datasets.values()):
    for index in index_names:
        ds[index].attrs = {"long_name": "", "units": "days"}

## Plotting functions

In [None]:
def hatch_p_value(da, ax, **kwargs):
    default_kwargs = {
        "plot_func": "contourf",
        "show_stats": False,
        "cmap": "none",
        "add_colorbar": False,
        "levels": [0, 0.05, 1],
        "hatches": ["", "/" * 3],
    }
    kwargs = default_kwargs | kwargs

    title = ax.get_title()
    plot_obj = plot.projected_map(da, ax=ax, **kwargs)
    ax.set_title(title)
    return plot_obj


def hatch_p_value_ensemble(trend, p_value, ax):
    n_models = trend.sizes["model"]
    robust_ratio = (p_value <= 0.05).sum("model") / n_models
    robust_ratio = robust_ratio.where(p_value.notnull().any("model"))
    signs = xr.concat([(trend > 0).sum("model"), (trend < 0).sum("model")], "sign")
    sign_ratio = signs.max("sign") / n_models
    robust_threshold = 0.66
    sign_ratio = sign_ratio.where(robust_ratio > robust_threshold)
    for da, threshold, character in zip(
        [robust_ratio, sign_ratio], [robust_threshold, 0.8], ["/", "\\"]
    ):
        hatch_p_value(da, ax=ax, levels=[0, threshold, 1], hatches=[character * 3, ""])


def set_extent(da, axs, area):
    extent = [area[i] for i in (1, 3, 2, 0)]
    for i, coord in enumerate(extent):
        extent[i] += -1 if i % 2 else +1
    for ax in axs:
        ax.set_extent(extent)


def plot_models(
    data,
    da_for_kwargs=None,
    p_values=None,
    col_wrap=3,
    subplot_kw={"projection": ccrs.PlateCarree()},
    figsize=None,
    layout="constrained",
    area=area,
    **kwargs,
):
    if isinstance(data, dict):
        assert da_for_kwargs is not None
        model_dataarrays = data
    else:
        da_for_kwargs = da_for_kwargs or data
        model_dataarrays = dict(data.groupby("model"))

    if p_values is not None:
        model_p_dataarrays = (
            p_values if isinstance(p_values, dict) else dict(p_values.groupby("model"))
        )
    else:
        model_p_dataarrays = None

    # Get kwargs
    default_kwargs = {"robust": True, "extend": "both"}
    kwargs = default_kwargs | kwargs
    kwargs = xr.plot.utils._determine_cmap_params(da_for_kwargs.values, **kwargs)

    fig, axs = plt.subplots(
        *(col_wrap, math.ceil(len(model_dataarrays) / col_wrap)),
        subplot_kw=subplot_kw,
        figsize=figsize,
        layout=layout,
    )
    axs = axs.flatten()
    for (model, da), ax in zip(model_dataarrays.items(), axs):
        pcm = plot.projected_map(
            da, ax=ax, show_stats=False, add_colorbar=False, **kwargs
        )
        ax.set_title(model)
        if model_p_dataarrays is not None:
            hatch_p_value(model_p_dataarrays[model], ax)
    set_extent(da_for_kwargs, axs, area)
    fig.colorbar(
        pcm,
        ax=axs.flatten(),
        extend=kwargs["extend"],
        location="right",
        label=f"{da_for_kwargs.attrs.get('long_name', '')} [{da_for_kwargs.attrs.get('units', '')}]",
    )
    return fig


def plot_ensemble(
    da_models,
    da_era5=None,
    p_value_era5=None,
    p_value_models=None,
    subplot_kw={"projection": ccrs.PlateCarree()},
    figsize=None,
    layout="constrained",
    cbar_kwargs=None,
    area=area,
    **kwargs,
):
    # Get kwargs
    default_kwargs = {"robust": True, "extend": "both"}
    kwargs = default_kwargs | kwargs
    kwargs = xr.plot.utils._determine_cmap_params(
        da_models.values if da_era5 is None else da_era5.values, **kwargs
    )
    if da_era5 is None and cbar_kwargs is None:
        cbar_kwargs = {"orientation": "horizontal"}

    # Figure
    fig, axs = plt.subplots(
        *(1 if da_era5 is None else 2, 2),
        subplot_kw=subplot_kw,
        figsize=figsize,
        layout=layout,
    )
    axs = axs.flatten()
    axs_iter = iter(axs)

    # ERA5
    if da_era5 is not None:
        ax = next(axs_iter)
        plot.projected_map(
            da_era5, ax=ax, show_stats=False, cbar_kwargs=cbar_kwargs, **kwargs
        )
        if p_value_era5 is not None:
            hatch_p_value(p_value_era5, ax=ax)
        ax.set_title("ERA5")

    # Median
    ax = next(axs_iter)
    median = da_models.median("model", keep_attrs=True)
    plot.projected_map(
        median, ax=ax, show_stats=False, cbar_kwargs=cbar_kwargs, **kwargs
    )
    if p_value_models is not None:
        hatch_p_value_ensemble(trend=da_models, p_value=p_value_models, ax=ax)
    ax.set_title("Ensemble Median")

    # Bias
    if da_era5 is not None:
        ax = next(axs_iter)
        with xr.set_options(keep_attrs=True):
            bias = median - da_era5
        plot.projected_map(
            bias,
            ax=ax,
            show_stats=False,
            center=0,
            cbar_kwargs=cbar_kwargs,
            **default_kwargs,
        )
        ax.set_title("Ensemble Median Bias")

    # Std
    ax = next(axs_iter)
    std = da_models.std("model", keep_attrs=True)
    plot.projected_map(
        std, ax=ax, show_stats=False, cbar_kwargs=cbar_kwargs, **default_kwargs
    )
    ax.set_title("Ensemble Standard Deviation")

    set_extent(da_models, axs, area)
    return fig


common_title = f"{year_start_historical=} {year_stop_historical=} {timeseries=}"

## Plot ensembles

In [None]:
for index in index_names:
    # Index
    fig = plot_ensemble(da_models=ds_interpolated[index])
    fig.suptitle(f"{index}\n{common_title}", y=0.8)
    plt.show()

## Plot models

In [None]:
for index in index_names:
    # Index
    fig = plot_models(
        data={model: ds[index] for model, ds in model_datasets.items()},
        da_for_kwargs=ds_interpolated[index],
    )
    fig.suptitle(f"{index}\n{common_title}")
    plt.show()

## Boxplot

In [None]:
mean_datasets = [
    diagnostics.spatial_weighted_mean(ds.expand_dims(model=[model]), weights=True)
    for model, ds in model_datasets.items()
]
mean_ds = xr.concat(mean_datasets, "model")
for index, da in mean_ds.data_vars.items():
    df_slope = da.to_dataframe()[[index]]
    ax = df_slope.boxplot()
    ax.scatter(
        x=[1] * len(df_slope),
        y=df_slope,
        color="grey",
        marker=".",
        label="models",
    )

    # Ensemble mean
    ax.scatter(
        x=1,
        y=da.mean("model"),
        marker="o",
        label="CMIP6 Ensemble Mean",
    )

    labels = ["CMIP6 Ensemble"]
    ax.set_xticks(range(1, len(labels) + 1), labels)
    ax.set_ylabel(da.attrs["units"])
    plt.suptitle(f"Trend of {index}")
    plt.legend()
    plt.show()