## WIP: Extreme temperature indices

## Import packages

In [None]:
import tempfile

import cartopy.crs as ccrs
import icclim
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot
from xarrayMannKendall import Mann_Kendall_test

plt.style.use("seaborn-v0_8-notebook")
plt.rcParams["hatch.linewidth"] = 0.5

## Define Parameters

In [None]:
# Time period
year_start = 1971
year_stop = 2005

# Choose annual or seasonal timeseries
timeseries = "JJA"
assert timeseries in ("annual", "DJF", "MAM", "JJA", "SON")

# Choose CORDEX or CMIP6
collection_id = "CMIP6"
assert collection_id in ("CORDEX", "CMIP6")

# Define region for analysis
area = [72, -22, 27, 45]

# Define region for request
cordex_domain = "europe"

# Define index names
index_names = ("SU", "TX90p")

# Interpolation method
interpolation_method = "bilinear"

# Chunks for download
chunks = {"year": 1}

## Define models

In [None]:
models_cordex = [
    "clmcom_eth_cosmo_crclim",
    "dmi_hirham5",
    "knmi_racmo22e",
    "mpi_csc_remo2009",
    "uhoh_wrf361h",
]

models_cmip6 = [
    "EC-Earth3-CC",
    "mpi_esm1_2_lr",
    "access_cm2",
    "awi_esm_1_1_lr",
    "cnrm_cm6_1",
]

## Define ERA5 request

In [None]:
request_era = (
    "reanalysis-era5-single-levels",
    {
        "product_type": "reanalysis",
        "format": "netcdf",
        "time": [f"{hour:02d}:00" for hour in range(24)],
        "variable": "2m_temperature",
        "year": [
            str(year) for year in range(year_start - 1, year_stop + 1)
        ],  # Include D(year-1)
        "month": [f"{month:02d}" for month in range(1, 13)],
        "day": [f"{day:02d}" for day in range(1, 32)],
        "area": area,
    },
)

request_lsm = (
    request_era[0],
    request_era[1]
    | {
        "year": "1940",
        "month": "01",
        "day": "01",
        "time": "00:00",
        "variable": "land_sea_mask",
    },
)

## Define model requests

In [None]:
###################################################
# TODO
request_cordex = {
    "format": "zip",
    "domain": cordex_domain,
    "experiment": "historical",
    "horizontal_resolution": "0_11_degree_x_0_11_degree",
    "temporal_resolution": "monthly_mean",
    "variable": "daily_maximum_near_surface_air_temperature",
    "gcm_model": "mpi_m_mpi_esm_lr",
    "ensemble_member": "r1i1p1",
    "area": area,
}
###################################################

request_cmip6 = {
    "format": "zip",
    "temporal_resolution": "daily",
    "experiment": "historical",
    "variable": "daily_maximum_near_surface_air_temperature",
    "year": [
        str(year) for year in range(year_start - 1, year_stop + 1)
    ],  # Include D(year-1)
    "month": [f"{month:02d}" for month in range(1, 13)],
    "day": [f"{day:02d}" for day in range(1, 32)],
    "area": area,
}


def get_cordex_years(
    year_start,
    year_stop,
    start_years=[1971, 1981, 1991, 2001],
    end_years=[1980, 1990, 2000, 2005],
):
    start_year = []
    end_year = []
    years = set(range(year_start - 1, year_stop + 1))  # Include D(year-1)
    for start, end in zip(start_years, end_years):
        if years & set(range(start, end + 1)):
            start_year.append(start)
            end_year.append(end)
    return start_year, end_year


if collection_id == "CORDEX":
    raise NotImplementedError(f"{collection_id=}")
    models = models_cordex
    model_key = "rcm_model"
    request_sim = (
        "projections-cordex-domains-single-levels",
        [
            {
                **request_cordex,
                "start_year": start_year,
                "end_year": end_year,
            }
            for start_year, end_year in zip(*get_cordex_years(year_start, year_stop))
        ],
    )
elif collection_id == "CMIP6":
    models = models_cmip6
    model_key = "model"
    request_sim = (
        "projections-cmip6",
        download.split_request(request_cmip6, chunks=chunks),
    )
else:
    raise ValueError(f"{collection_id=}")

## Functions to cache

In [None]:
def select_timeseries(ds, timeseries, year_start, year_stop):
    if timeseries == "annual":
        return ds.sel(time=slice(str(year_start), str(year_stop)))
    ds = ds.sel(time=slice(f"{year_start-1}-12", f"{year_stop}-11"))
    return ds.where(ds["time"].dt.season == timeseries, drop=True)


def compute_indices(ds, index_names, timeseries, tmpdir):
    years, datasets = zip(*ds.groupby("time.year"))
    paths = [f"{tmpdir}/{year}.nc" for year in years]
    datasets = [ds.chunk(-1) for ds in datasets]
    xr.save_mfdataset(datasets, paths)

    datasets = [
        icclim.index(
            index_name=index_name,
            in_files=paths,
            out_file=f"{tmpdir}/{index_name}.nc",
            slice_mode="year" if timeseries == "annual" else timeseries,
        )
        for index_name in index_names
    ]
    return xr.merge(datasets).drop_dims("bounds")


def compute_trends(ds):
    datasets = []
    coords_name = {"time": "time", "y": "latitude", "x": "longitude"}
    for index, da in ds.data_vars.items():
        ds = Mann_Kendall_test(
            da - da.mean("time"),
            alpha=0.05,
            method="theilslopes",
            coords_name=coords_name,
        ).compute()
        ds = ds.rename({k: v for k, v in coords_name.items() if k in ds.dims})
        ds = ds.assign_coords({dim: da[dim] for dim in ds.dims})
        datasets.append(ds.expand_dims(index=[index]))
    ds = xr.concat(datasets, "index")
    return ds


def compute_indices_and_trends(
    ds,
    index_names,
    timeseries,
    year_start,
    year_stop,
    resample,
    **regrid_kwargs,
):
    if regrid_kwargs.get("method") == "conservative":
        bounds = [
            ds.cf.get_bounds(coord).reset_coords(drop=True)
            for coord in ("latitude", "longitude")
            if coord in ds.cf.bounds
        ]
    else:
        bounds = []

    ds = select_timeseries(ds, timeseries, year_start, year_stop)
    if resample:
        ds = ds.resample(time="1D").max(keep_attrs=True)
    with tempfile.TemporaryDirectory() as tmpdir:
        ds_indices = compute_indices(ds, index_names, timeseries, tmpdir).persist()
        ds_trends = compute_trends(ds_indices)
        ds = ds_indices.mean("time", keep_attrs=True)
        ds = ds.merge(ds_trends)
        if regrid_kwargs:
            ds = diagnostics.regrid(
                ds.merge({da.name: da for da in bounds}),
                **regrid_kwargs,
            )
        return ds

## Download and transform ERA5

In [None]:
transform_func_kwargs = {
    "index_names": sorted(index_names),
    "timeseries": timeseries,
    "year_start": year_start,
    "year_stop": year_stop,
}
ds_era5 = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_chunks=False,
    transform_func=compute_indices_and_trends,
    transform_func_kwargs=transform_func_kwargs | {"resample": True},
)

## Download and transform model

In [None]:
grid_out = ds_era5[["latitude", "longitude"]].reset_coords(drop=True)
grid_out.attrs = {}

model_datasets = {}
interpolated_datasets = []
for model in models:
    print(f"{model=}")
    requests = request_sim[1]
    # Original grid
    model_datasets[model] = download.download_and_transform(
        request_sim[0],
        [request | {model_key: model} for request in requests],
        chunks=chunks,
        transform_chunks=False,
        transform_func=compute_indices_and_trends,
        transform_func_kwargs=transform_func_kwargs | {"resample": False},
    )
    # Interpolated
    ds = download.download_and_transform(
        request_sim[0],
        [request | {model_key: model} for request in requests],
        chunks=chunks,
        transform_chunks=False,
        transform_func=compute_indices_and_trends,
        transform_func_kwargs=transform_func_kwargs
        | {"resample": False, "grid_out": grid_out, "method": interpolation_method},
    )
    interpolated_datasets.append(ds.expand_dims(model=[model]))

ds_models = xr.concat(interpolated_datasets, "model")

## Define plotting function

In [None]:
def plot_maps(ds_era5, ds_models, index, trend, model, model_datasets={}, **kwargs):
    is_ensemble = model.lower() == "ensemble"

    # Hide interpolated edges
    isel_dict = {coord: slice(5, -5) for coord in ("latitude", "longitude")}
    ds_era5 = ds_era5.isel(isel_dict).sel(index=index)
    ds_models = ds_models.isel(isel_dict).sel(index=index)

    if is_ensemble:
        median = ds_models.median("model", keep_attrs=True)
        with xr.set_options(keep_attrs=True):
            bias = median - ds_era5
        std = ds_models.std("model", keep_attrs=True)
        datasets = {
            "ERA5": ds_era5,
            "Ensemble Median": median,
            "Ensemble Median Bias": bias,
            "Ensemble Standard Deviation": std,
        }
    else:
        ds_model = model_datasets[model].sel(index=index)
        with xr.set_options(keep_attrs=True):
            bias = ds_models.sel(model=model) - ds_era5
        datasets = {
            "ERA5": ds_era5,
            model: ds_model,
            f"{model} Bias": bias,
        }

    # Initialize figure
    fig, axes = plt.subplots(
        *(2, 2),
        subplot_kw={"projection": ccrs.PlateCarree()},
        figsize=(14, 7),
    )
    for i, (ax, ds) in enumerate(zip(axes.flatten(), datasets.values())):
        da = ds["trend" if trend else index]
        if trend:
            da *= 10
            da.attrs["units"] = "days / decade"
        plot_kwargs = kwargs if i <= 1 else {"robust": True}
        plot.projected_map(da, show_stats=False, ax=ax, **plot_kwargs)

    if trend:
        hatches_kwargs = {
            "plot_func": "contourf",
            "show_stats": False,
            "cmap": "none",
            "add_colorbar": False,
        }
        plot.projected_map(
            ds_era5["p"],
            ax=axes[0, 0],
            levels=[0, 0.05, 1],
            hatches=["", "/" * 3],
            **hatches_kwargs,
        )
        if is_ensemble:
            n_models = ds_models.sizes["model"]
            robust_ratio = (ds_models["p"] <= 0.05).sum("model") / n_models
            robust_ratio = robust_ratio.where(ds_models["p"].notnull().any("model"))
            sign_ratio = (
                xr.concat(
                    [
                        (ds_models["trend"] > 0).sum("model"),
                        (ds_models["trend"] < 0).sum("model"),
                    ],
                    "sign",
                ).max("sign")
                / n_models
            )
            robust_threshold = 0.66
            sign_ratio = sign_ratio.where(robust_ratio > robust_threshold)
            for da, threshold, character in zip(
                [robust_ratio, sign_ratio], [robust_threshold, 0.8], ["\\", "/"]
            ):
                plot.projected_map(
                    da,
                    ax=axes[0, 1],
                    levels=[0, threshold, 1],
                    hatches=[character * 3, ""],
                    **hatches_kwargs,
                )
        else:
            plot.projected_map(
                ds_model["p"],
                ax=axes[0, 1],
                levels=[0, 0.05, 1],
                hatches=["", "/" * 3],
                **hatches_kwargs,
            )
    for ax, title in zip(axes.flatten(), datasets.keys()):
        ax.set_title(title)
    fig.suptitle(f"Trend of {index}" if trend else index)
    plt.axis("off")
    return fig, axes

## Mask land

In [None]:
# Mask land
lsm = download.download_and_transform(*request_lsm)["lsm"].squeeze(drop=True)
ds_era5 = ds_era5.where(lsm)
ds_models = ds_models.where(lsm)
model_datasets = {
    model: ds.where(diagnostics.regrid(lsm, ds, method="bilinear"))
    for model, ds in model_datasets.items()
}

## Plot maps

In [None]:
# Define cbar limits
index_kwargs = {
    "SU": {"vmin": 10, "vmax": 80},
    "TX90p": {"vmin": 8, "vmax": 9},
}
trend_kwargs = {
    "SU": {"vmin": -8, "center": 0},
    "TX90p": {"vmin": -6, "center": 0},
}
for index in index_names:
    for trend in (False, True):
        kwargs = (trend_kwargs if trend else index_kwargs)[index]
        kwargs = kwargs | {"extend": "both"}
        fig, axes = plot_maps(
            ds_era5,
            ds_models,
            index=index,
            trend=trend,
            model="ensemble",
            **kwargs,
        )
        plt.show()
        for model in model_datasets:
            fig, axes = plot_maps(
                ds_era5,
                ds_models,
                index=index,
                trend=trend,
                model=model,
                model_datasets=model_datasets,
                **kwargs,
            )
            plt.show()

## Boxplot

In [None]:
for index in index_names:
    # Models
    da = ds_models["trend"].sel(index=index)
    da = diagnostics.spatial_weighted_mean(da) * 10
    df_slope = da.to_dataframe()[["trend"]]
    ax = df_slope.boxplot()
    ax.scatter(
        x=[1] * len(df_slope),
        y=df_slope,
        color="grey",
        marker=".",
        label="models",
    )

    # ERA5
    da = ds_era5["trend"].sel(index=index)
    da = diagnostics.spatial_weighted_mean(da) * 10
    ax.scatter(
        x=2,
        y=da.values,
        color="orange",
        marker="o",
        label="ERA5",
    )

    # Figure settings
    ax.set_xticks([1, 2], [f"{collection_id} Ensemble", "ERA5"])
    ax.set_ylabel("days / decade")
    plt.suptitle(f"Trend of {index}")
    plt.legend()
    plt.show()