## WIP: Extreme temperature indices

## Import packages

In [None]:
import tempfile

import cartopy.crs as ccrs
import icclim
import matplotlib.pyplot as plt
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot
from xarrayMannKendall import Mann_Kendall_test

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time period
year_start = 1971
year_stop = 1975

# Choose annual or seasonal timeseries
timeseries = "JJA"
assert timeseries in ("annual", "DJF", "MAM", "JJA", "SON")

# Choose CORDEX or CMIP6
collection_id = "CMIP6"
assert collection_id in ("CORDEX", "CMIP6")

# Define region for analysis
area = [72, -22, 27, 45]

# Define region for request
cordex_domain = "europe"

# Define index names
index_names = ("SU", "TX90p")

# Interpolation method
interpolation_method = "bilinear"

# Chunks for download
chunks = {"year": 1}

## Define models

In [None]:
models_cordex = [
    "clmcom_eth_cosmo_crclim",
    "dmi_hirham5",
    "knmi_racmo22e",
    "mpi_csc_remo2009",
    "uhoh_wrf361h",
]

models_cmip6 = [
    "EC-Earth3-CC",
    "mpi_esm1_2_lr",
    "access_cm2",
    "awi_esm_1_1_lr",
    "cnrm_cm6_1",
]

## Define ERA5 request

In [None]:
request_era = (
    "reanalysis-era5-single-levels",
    {
        "product_type": "reanalysis",
        "format": "netcdf",
        "time": [f"{hour:02d}:00" for hour in range(24)],
        "variable": "2m_temperature",
        "year": [
            str(year) for year in range(year_start - 1, year_stop + 1)
        ],  # Include D(year-1)
        "month": [f"{month:02d}" for month in range(1, 13)],
        "day": [f"{day:02d}" for day in range(1, 32)],
        "area": area,
    },
)


request_lsm = (
    request_era[0],
    request_era[1] | {"year": "1940", "month": "01", "variable": "land_sea_mask"},
)

## Define model requests

In [None]:
###################################################
# TODO
request_cordex = {
    "format": "zip",
    "domain": cordex_domain,
    "experiment": "historical",
    "horizontal_resolution": "0_11_degree_x_0_11_degree",
    "temporal_resolution": "monthly_mean",
    "variable": "daily_maximum_near_surface_air_temperature",
    "gcm_model": "mpi_m_mpi_esm_lr",
    "ensemble_member": "r1i1p1",
    "area": area,
}
###################################################

request_cmip6 = {
    "format": "zip",
    "temporal_resolution": "daily",
    "experiment": "historical",
    "variable": "daily_maximum_near_surface_air_temperature",
    "year": [
        str(year) for year in range(year_start - 1, year_stop + 1)
    ],  # Include D(year-1)
    "month": [f"{month:02d}" for month in range(1, 13)],
    "day": [f"{day:02d}" for day in range(1, 32)],
    "area": area,
}


def get_cordex_years(
    year_start,
    year_stop,
    start_years=[1971, 1981, 1991, 2001],
    end_years=[1980, 1990, 2000, 2005],
):
    start_year = []
    end_year = []
    years = set(range(year_start - 1, year_stop + 1))  # Include D(year-1)
    for start, end in zip(start_years, end_years):
        if years & set(range(start, end + 1)):
            start_year.append(start)
            end_year.append(end)
    return start_year, end_year


if collection_id == "CORDEX":
    raise NotImplementedError(f"{collection_id=}")
    models = models_cordex
    model_key = "rcm_model"
    request_sim = (
        "projections-cordex-domains-single-levels",
        [
            {
                **request_cordex,
                "start_year": start_year,
                "end_year": end_year,
            }
            for start_year, end_year in zip(*get_cordex_years(year_start, year_stop))
        ],
    )
elif collection_id == "CMIP6":
    models = models_cmip6
    model_key = "model"
    request_sim = (
        "projections-cmip6",
        download.split_request(request_cmip6, chunks=chunks),
    )
else:
    raise ValueError(f"{collection_id=}")

## Functions to cache

In [None]:
def select_timeseries(ds, timeseries, year_start, year_stop):
    if timeseries == "annual":
        return ds.sel(time=slice(str(year_start), str(year_stop)))
    ds = ds.sel(time=slice(f"{year_start-1}-12", f"{year_stop}-11"))
    return ds.where(ds["time"].dt.season == timeseries, drop=True)


def compute_indexes(ds, index_names, timeseries, dir):
    years, datasets = zip(*ds.groupby("time.year"))
    paths = [f"{dir}/{year}.nc" for year in years]
    datasets = [ds.chunk(-1) for ds in datasets]
    xr.save_mfdataset(datasets, paths)

    datasets = [
        icclim.index(
            index_name=index_name,
            in_files=paths,
            out_file=f"{dir}/{index_name}.nc",
            slice_mode="year" if timeseries == "annual" else timeseries,
        )
        for index_name in index_names
    ]
    return xr.merge(datasets).drop_dims("bounds")


def compute_trends(ds):
    datasets = []
    coords_name = {"time": "time", "y": "latitude", "x": "longitude"}
    for index, da in ds.data_vars.items():
        ds = Mann_Kendall_test(
            da - da.mean("time"),
            alpha=0.05,
            method="theilslopes",
            coords_name=coords_name,
        ).compute()
        ds = ds.rename({k: v for k, v in coords_name.items() if k in ds.dims})
        ds = ds.assign_coords({dim: da[dim] for dim in ds.dims})
        datasets.append(ds.expand_dims(index=[index]))
    return xr.concat(datasets, "index")


def compute_index_trends(
    ds, index_names, timeseries, year_start, year_stop, resample, **interp_kwargs
):
    ds = select_timeseries(ds, timeseries, year_start, year_stop)
    if resample:
        ds = ds.resample(time="1D").max(keep_attrs=True)
    if interp_kwargs:
        ds = diagnostics.regrid(ds, **interp_kwargs)
    tmpdir = tempfile.TemporaryDirectory()
    ds = compute_indexes(ds, index_names, timeseries, tmpdir.name).persist()
    ds = ds.merge(compute_trends(ds))
    return ds

## Download and transform ERA5

In [None]:
transform_func_kwargs = {
    "index_names": sorted(index_names),
    "timeseries": timeseries,
    "year_start": year_start,
    "year_stop": year_stop,
}
ds_era5 = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_chunks=False,
    transform_func=compute_index_trends,
    transform_func_kwargs=transform_func_kwargs | {"resample": True},
)

## Download and transform model

In [None]:
grid_out = ds_era5[["latitude", "longitude"]].reset_coords(drop=True)
grid_out.attrs = {}

datasets = []
for model in models:
    print(f"{model=}")
    collection_id = request_sim[0]
    requests = request_sim[1]
    ds = download.download_and_transform(
        collection_id,
        [request | {model_key: model} for request in requests],
        chunks=chunks,
        transform_chunks=False,
        transform_func=compute_index_trends,
        transform_func_kwargs=transform_func_kwargs
        | {
            "resample": False,
            "grid_out": grid_out,
            "method": interpolation_method,
        },
    )
    datasets.append(ds.expand_dims(model=[model]))
ds_models = xr.concat(datasets, "model")

In [None]:
for index in index_names:
    da_era5 = ds_era5[index].mean("time", keep_attrs=True)
    da_models = ds_models[index].mean("time", keep_attrs=True)
    median = da_models.median("model", keep_attrs=True)
    with xr.set_options(keep_attrs=True):
        bias = median - da_era5
    std = da_models.std("model", keep_attrs=True)
    title_da = {
        "ERA5": da_era5,
        "Ensemble Median": median,
        "Ensemble Median Bias": bias,
        "Ensemble Standard Deviation": std,
    }
    fig, axes = plt.subplots(
        2,
        len(title_da) // 2,
        subplot_kw={"projection": ccrs.PlateCarree()},
        figsize=(15, 15),
    )
    for ax, (title, da) in zip(axes.flatten(), title_da.items()):
        da = da.isel(
            latitude=slice(5, -5), longitude=slice(5, -5)
        )  # TODO: Hide boundaries
        plot.projected_map(da, show_stats=False, ax=ax)
        ax.set_title(title)
    fig.suptitle(index)