# Temporal Taylor Diagrams

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import dask
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import regionmask
import xarray as xr
import xesmf as xe
from dask.distributed import Client
from evaltools import obs
from evaltools.obs import eobs_mapping
from evaltools.utils import short_iid
from tools import (
    TaylorDiagram,
    check_equal_period,
    create_cordex_grid,
    fix_360_longitudes,
    height_temperature_correction,
    load_obs,
    mask_invalid,
    open_datasets,
    regional_mean,
    regional_means,
    regrid_dsets,
    select_season,
    standardize_unit,
    var_dic,
    variable_mapping,
)

dask.config.set(scheduler="single-threaded")

<dask.config.set at 0x7d0daf101390>

In [None]:
client = Client(dashboard_address="localhost:8000", threads_per_worker=1)
client

In [4]:
def ensure_uniform_cftime(ds):
    """Asegura que las fechas en el Dataset tengan un tipo uniforme de cftime."""
    time_vals = ds.time.values
    types = {type(t) for t in time_vals}

    if len(types) > 1:
        # Usamos el tipo del primer elemento para reconstruir todos
        base_type = type(time_vals[0])
        uniform_time = [
            base_type(t.year, t.month, t.day, t.hour, t.minute, t.second)
            for t in time_vals
        ]
        ds["time"] = ("time", uniform_time)

    return ds


def compute_tcoiav(model_ds, reference_ds):
    """
    Compute the Temporal Correlation of Interannual Variability (TCOIAV) between model and reference data.

    The Temporal Correlation of Interannual Variability (TCOIAV) measures the correlation between the
    interannual variability of spatially averaged annual or seasonal mean values of the model and reference
    datasets for a selected subregion. It quantifies how well the model captures the year-to-year variations
    observed in the reference data.

    A higher positive TCOIAV value indicates that the model accurately captures the interannual variability
    observed in the reference data, while a lower or negative TCOIAV value indicates that the model's
    interannual variability deviates significantly from that of the reference data.

    Parameters:
    model_ds (xarray.DataArray): The model data with spatial dimensions.
    reference_ds (xarray.DataArray): The reference data with spatial dimensions.

    Returns:
    float: The TCOIAV value.
    """
    # Compute the annual or seasonal mean values
    try:
        model_mean = model_ds.groupby("time.year").mean("time")
        reference_mean = reference_ds.groupby("time.year").mean("time")
    except (TypeError, AttributeError, ValueError):
        # Homogeneiza los tipos de tiempo si hay mezcla de cftime
        model_ds = ensure_uniform_cftime(model_ds)
        reference_ds = ensure_uniform_cftime(reference_ds)

        model_mean = model_ds.groupby("time.year").mean("time")
        reference_mean = reference_ds.groupby("time.year").mean("time")

    if "lon" in reference_mean.coords:
        # Spatially average these mean values over the subregion
        model_mean = model_mean.mean(dim=["lat", "lon"])
        reference_mean = reference_mean.mean(dim=["lat", "lon"])

    # Compute the temporal correlation of interannual variability
    tcoiav = xr.corr(model_mean, reference_mean, dim="year")

    return tcoiav


# RIAV: ratio (model over reference) of temporal standard
# deviations of interannual time series of spatially aver-
# aged annual or seasonal mean values of a selected sub-
# region.


def compute_riav(model_ds, reference_ds):
    """
    Compute the Ratio of Interannual Variability (RIAV) between model and reference data.

    The Ratio of Interannual Variability (RIAV) is the ratio of the temporal standard deviations
    of interannual time series of spatially averaged annual or seasonal mean values between the model
    and reference datasets for a selected subregion. It quantifies the relative temporal variability
    in the model data compared to the reference data.

    An RIAV value greater than 1 indicates that the model data has higher temporal variability
    than the reference data, while an RIAV value less than 1 indicates lower temporal variability
    in the model data compared to the reference data.

    Parameters:
    model_ds (xarray.DataArray): The model data with spatial dimensions.
    reference_ds (xarray.DataArray): The reference data with spatial dimensions.

    Returns:
    float: The RIAV value.
    """
    # Compute the annual or seasonal mean values
    try:
        model_mean = model_ds.groupby("time.year").mean("time")
        reference_mean = reference_ds.groupby("time.year").mean("time")
    except (TypeError, AttributeError, ValueError):
        # Homogeneiza los tipos de tiempo si hay mezcla de cftime
        model_ds = ensure_uniform_cftime(model_ds)
        reference_ds = ensure_uniform_cftime(reference_ds)

        model_mean = model_ds.groupby("time.year").mean("time")
        reference_mean = reference_ds.groupby("time.year").mean("time")

    if "lon" in reference_mean.coords:
        # Spatially average these mean values over the subregion
        model_mean = model_mean.mean(dim=["lat", "lon"])
        reference_mean = reference_mean.mean(dim=["lat", "lon"])

    # Compute the temporal standard deviations of the interannual time series
    model_std = model_mean.std(dim="year")
    reference_std = reference_mean.std(dim="year")

    # Compute the ratio of these standard deviations (RIAV)
    riav = model_std / reference_std

    return riav

In [5]:
# Parameter papermill
index = "pr"
frequency = "mon"
domain = "EUR-11"
regridding = "bilinear"
period_star = "1991"
period_stop = "2020"
reference_regions = "PRUDENCE"
parent = True

In [6]:
period = slice(period_star, period_stop)

In [7]:
save_results_path = os.path.abspath(
    os.path.join(os.getcwd(), "..", "intermediate-results")
)
save_figure_path = os.path.abspath(os.path.join(os.getcwd(), "..", "plots"))

In [8]:
variable = var_dic[index]["variable"]

In [9]:
eur_colors = pd.read_csv("eurocordex_models.csv")

In [10]:
# prudence
regions = regionmask.defined_regions.prudence

In [11]:
rotated_grid = create_cordex_grid("EUR-11")  # No matter CMIP5 or CMIP6

## E-OBS is used as the reference dataset for all the analysis
It is used to calculate bias not only respect to CORDEX, but also in comparison wit other reanalyses and observational dataset, to assess the uncertaintly of the observational dataset

In [None]:
# load, regrid and calculate seasonal means
eobs_var = [key for key, value in eobs_mapping.items() if value == variable][0]
eobs = obs.eobs(variables=eobs_var, add_mask=False).sel(time=period)
eobs = mask_invalid(eobs, vars=eobs_var, threshold=0.1)
eobs = standardize_unit(eobs, variable)
# eobs = load_eobs(add_mask=False, to_cf=False, variable = variable)
# unmapped_to_nan, see https://github.com/pangeo-data/xESMF/issues/56
regridder = xe.Regridder(eobs, rotated_grid, method=regridding, unmapped_to_nan=True)
ref_on_rotated = regridder(eobs)
if not check_equal_period(ref_on_rotated, period):
    print(f"Temporal coverage of dataset does not match with {period}")
ref_regions = regional_mean(
    ref_on_rotated[eobs_var], regions, aggr=var_dic[index]["aggr"]
)
ref_regions_seasons = select_season(ref_regions).compute()

## CERRA and ERA5

In [None]:
dsets = {}
for dset in var_dic[variable]["datasets"]:
    ds = load_obs(variable, dset, add_fx=True, mask=True)
    ds = ds.sel(time=period).compute()
    ds = fix_360_longitudes(ds, lonname="longitude")
    if not variable_mapping[dset][variable] == variable:
        ds = ds.rename_vars({variable_mapping[dset][variable]: variable})
    ds = standardize_unit(ds, variable)
    dsets[dset] = ds

In [14]:
for dset in dsets.keys():
    if not check_equal_period(dsets[dset], period):
        print(f"Temporal coverage of {dset} does not match with {period}")

In [15]:
for dset, ds in dsets.items():
    regridder = xe.Regridder(ds, rotated_grid, method=regridding, unmapped_to_nan=True)
    dsets[dset] = regridder(ds)

In [16]:
if variable == "tas":
    for dset in dsets:
        h_c = height_temperature_correction(dsets[dset].orog, ref_on_rotated.elevation)
        dsets[dset]["tas"] = dsets[dset].tas - h_c.fillna(0)

In [17]:
obs_regions = regional_means(dsets, regions, aggr=var_dic[index]["aggr"]).compute()

In [18]:
obs_regions_seasons = select_season(obs_regions).compute()

In [19]:
model_ds = obs_regions_seasons.copy()
reference_ds = ref_regions_seasons.copy()

In [20]:
diffs = {}
for dset_id in obs_regions_seasons.iid:
    dset_id = str(dset_id.values)
    model_id = obs_regions_seasons[variable].sel(iid=dset_id)
    diffs[dset_id] = compute_tcoiav(model_id, ref_regions_seasons).compute()

obs_tcoiav = xr.concat(
    list(diffs.values()),
    dim=xr.DataArray(
        list(map(lambda x: x, diffs.keys())),
        dims="dset_id",
    ),
    compat="override",
    coords="minimal",
)

diffs = {}
for dset_id in obs_regions_seasons.iid:
    dset_id = str(dset_id.values)
    model_id = obs_regions_seasons[variable].sel(iid=dset_id)
    diffs[dset_id] = compute_riav(model_id, ref_regions_seasons).compute()

obs_riav = xr.concat(
    list(diffs.values()),
    dim=xr.DataArray(
        list(map(lambda x: x, diffs.keys())),
        dims="dset_id",
    ),
    compat="override",
    coords="minimal",
)

## CMIP6

In [21]:
mip_era = "CMIP6"
driving_source_id = "ERA5"
# Define how to merge the files in xarray

In [None]:
dsets = open_datasets(
    [variable],
    frequency=frequency,
    driving_source_id=driving_source_id,
    mask=True,
    add_missing_bounds=False,
)

In [23]:
for dset in dsets.keys():
    dsets[dset] = dsets[dset].sel(time=period).compute()

In [24]:
for dset in dsets.keys():
    if not check_equal_period(dsets[dset], period):
        print(f"Temporal coverage of {dset} does not match with {period}")

Temporal coverage of CORDEX-CMIP6.EUR-12.ICTP.ERA5.evaluation.r1i1p1f1.RegCM5-0.v1-r1.mon.v20250415 does not match with slice('1991', '2020', None)


In [25]:
for dset in dsets.keys():
    dsets[dset] = standardize_unit(dsets[dset], variable)

Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.
Convert precipitation from kg/m/s² to mm/day.


In [26]:
rotated_grid = create_cordex_grid(domain)
dsets = regrid_dsets(dsets, rotated_grid, method=regridding)

regridding CORDEX-CMIP6.EUR-12.RMIB-UGent.ERA5.evaluation.r1i1p1f1.ALARO1-SFX.v1-r1.mon.v20241009 with grid_mapping: lambert_conformal_conic
xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_483x483_412x424.nc 
Reuse pre-computed weights? False 
Input grid shape:           (483, 483) 
Output grid shape:          (412, 424) 
Periodic in longitude?      False
regridding CORDEX-CMIP6.EUR-12.HCLIMcom-SMHI.ERA5.evaluation.r1i1p1f1.HCLIM43-ALADIN.v1-r1.mon.v20241205 with grid_mapping: lambert_conformal_conic
xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_453x453_412x424.nc 
Reuse pre-computed weights? False 
Input grid shape:           (453, 453) 
Output grid shape:          (412, 424) 
Periodic in longitude?      False
regridding CORDEX-CMIP6.EUR-12.CESAM-UA.ERA5.evaluation.r1i1p1f1.WRF451Q.v1-r2.mon.v20250630 with grid_mapping: rotated_latitude_longitude
xESMF Regridder 
Regridding algorithm:       bilin

In [27]:
if variable == "tas":
    for dset in dsets:
        h_c = height_temperature_correction(dsets[dset].orog, ref_on_rotated.elevation)
        dsets[dset]["tas"] = dsets[dset].tas - h_c.fillna(0)

In [28]:
dset_id_regions = regional_means(dsets, regions, aggr=var_dic[index]["aggr"]).compute()

In [29]:
dset_id_regions_seasons = select_season(dset_id_regions)

In [None]:
diffs = {}
for dset_id in dset_id_regions_seasons.iid:
    print(str(dset_id.values))
    dset_id = str(dset_id.values)
    model_id = dset_id_regions_seasons[variable].sel(iid=dset_id)
    diffs[dset_id] = compute_tcoiav(model_id, ref_regions_seasons).compute()

dset_id_tcoiav = xr.concat(
    list(diffs.values()),
    dim=xr.DataArray(
        list(map(lambda x: short_iid(x, ["source_id", "version_realization"], delimiter="_"), diffs.keys())),
        dims="dset_id",
    ),
    compat="override",
    coords="minimal",
)

In [None]:
diffs = {}
for dset_id in dset_id_regions_seasons.iid:
    dset_id = str(dset_id.values)
    print(dset_id)
    model_id = dset_id_regions_seasons[variable].sel(iid=dset_id)
    diffs[dset_id] = compute_riav(model_id, ref_regions_seasons).compute()

dset_id_riav = xr.concat(
    list(diffs.values()),
    dim=xr.DataArray(
        list(map(lambda x: short_iid(x, ["source_id", "version_realization"], delimiter="_"), diffs.keys())),
        dims="dset_id",
    ),
    compat="override",
    coords="minimal",
)

In [32]:
dset_id_tcoiav.to_netcdf(
    f"{save_results_path}/{variable}_CMIP6_tcoiav_{period.start}-{period.stop}.nc"
)
dset_id_riav.to_netcdf(
    f"{save_results_path}/{variable}_CMIP6_riav_{period.start}-{period.stop}.nc"
)

## CMIP5

In [33]:
mip_era = "CMIP5"
driving_source_id = "ERAINT"

In [None]:
dsets = open_datasets(
    [variable],
    frequency=frequency,
    driving_source_id=driving_source_id,
    mask=True,
    add_missing_bounds=False,
)

In [None]:
dsets_day = open_datasets(
    [variable],
    frequency='day',
    source_id = ['WRF381P', 'HIRHAM5', 'RegCM4-6'],
    driving_source_id=driving_source_id,
    mask=True,
    add_missing_bounds=False,
)
static_vars = ["orog", "sftlf", "areacella", "mask"]
## resample dset_day
dsets_mon = {}
for dset in dsets_day.keys():
    dsets_copy = dsets_day[dset].copy()
    ds_var_mon = dsets_copy[[variable]].resample(time="ME").mean()
    for var in static_vars:
        if var in dsets_copy.variables:
            ds_var_mon[var] = dsets_copy[var]
    dsets_mon[dset.replace('.day.', '.mon.')] = ds_var_mon
dsets = dsets | dsets_mon

In [36]:
for dset in dsets.keys():
    dsets[dset] = dsets[dset].sel(time=period).compute()

In [None]:
for dset in dsets.keys():
    if not check_equal_period(dsets[dset], period):
        print(f"Temporal coverage of {dset} does not match with {period}")

In [None]:
for dset in dsets.keys():
    dsets[dset] = standardize_unit(dsets[dset], variable)

In [None]:
rotated_grid = create_cordex_grid(domain)
dsets = regrid_dsets(dsets, rotated_grid, method=regridding)

In [40]:
if variable == "tas":
    for dset in dsets:
        h_c = height_temperature_correction(dsets[dset].orog, ref_on_rotated.elevation)
        dsets[dset]["tas"] = dsets[dset].tas - h_c.fillna(0)

In [41]:
dset_id_regions = regional_means(dsets, regions, aggr=var_dic[index]["aggr"]).compute()

In [42]:
dset_id_regions_seasons = select_season(dset_id_regions).compute()

In [None]:
diffs = {}
for dset_id in dset_id_regions_seasons.iid:
    dset_id = str(dset_id.values)
    print(dset_id)
    model_id = dset_id_regions_seasons[variable].sel(iid=dset_id)
    diffs[dset_id] = compute_tcoiav(model_id, ref_regions_seasons).compute()

dset_id_tcoiav = xr.concat(
    list(diffs.values()),
    dim=xr.DataArray(
        list(map(lambda x: short_iid(x, ["source_id", "version_realization"], delimiter="_"), diffs.keys())),
        dims="dset_id",
    ),
    compat="override",
    coords="minimal",
)

In [None]:
diffs = {}
for dset_id in dset_id_regions_seasons.iid:
    dset_id = str(dset_id.values)
    print(dset_id)
    model_id = dset_id_regions_seasons[variable].sel(iid=dset_id)
    diffs[dset_id] = compute_riav(model_id, ref_regions_seasons).compute()

dset_id_riav = xr.concat(
    list(diffs.values()),
    dim=xr.DataArray(
        list(map(lambda x: short_iid(x, ["source_id", "version_realization"], delimiter="_"), diffs.keys())),
        dims="dset_id",
    ),
    compat="override",
    coords="minimal",
)

In [45]:
dset_id_tcoiav.to_netcdf(
    f"{save_results_path}/{variable}_CMIP5_tcoiav_{period.start}-{period.stop}.nc"
)
dset_id_riav.to_netcdf(
    f"{save_results_path}/{variable}_CMIP5_riav_{period.start}-{period.stop}.nc"
)

#### Load results for both CMIP5 and CMIP6 simulations

In [53]:
dset_id_tcoiav_CMIP6 = xr.open_dataset(
    f"{save_results_path}/{variable}_CMIP6_tcoiav_{period.start}-{period.stop}.nc"
)
dset_id_riav_CMIP6 = xr.open_dataset(
    f"{save_results_path}/{variable}_CMIP6_riav_{period.start}-{period.stop}.nc"
)

In [54]:
CMIP6_coord = xr.DataArray(
    np.full(len(dset_id_tcoiav_CMIP6["dset_id"]), "CMIP6"),
    dims="dset_id",
    coords={"dset_id": dset_id_tcoiav_CMIP6["dset_id"]},
    name="mip_era",
)
dset_id_tcoiav_CMIP6 = dset_id_tcoiav_CMIP6.assign_coords(mip_era=CMIP6_coord)
dset_id_riav_CMIP6 = dset_id_riav_CMIP6.assign_coords(mip_era=CMIP6_coord)

In [55]:
dset_id_tcoiav_CMIP5 = xr.open_dataset(
    f"{save_results_path}/{variable}_CMIP5_tcoiav_{period.start}-{period.stop}.nc"
)
dset_id_riav_CMIP5 = xr.open_dataset(
    f"{save_results_path}/{variable}_CMIP5_riav_{period.start}-{period.stop}.nc"
)

In [56]:
CMIP5_coord = xr.DataArray(
    np.full(len(dset_id_tcoiav_CMIP5["dset_id"]), "CMIP5"),
    dims="dset_id",
    coords={"dset_id": dset_id_tcoiav_CMIP5["dset_id"]},
    name="mip_era",
)
dset_id_tcoiav_CMIP5 = dset_id_tcoiav_CMIP5.assign_coords(mip_era=CMIP5_coord)
dset_id_riav_CMIP5 = dset_id_riav_CMIP5.assign_coords(mip_era=CMIP5_coord)

In [57]:
seasons_marker = {"winter": "o", "summer": "^"}

In [58]:
parent_str = "no-parent"

In [59]:
list_model_version = np.array((eur_colors["model"].astype(str) + "_" + eur_colors["model_version"].astype(str)).tolist())

In [None]:
from PIL import Image

# Reference std
stdref = 1

regions = ["EA", "IP", "ME", "SC"]

for n_r, region in enumerate(regions):
    fig = plt.figure()

    dia = TaylorDiagram(stdref, fig=fig, label="Reference")
    # dia.samplePoints[0].set_color('r')  # Mark reference point as a red star

    for season, mark in seasons_marker.items():

        if parent is True:
            parent_str = "parent"

            # cmip5
            rho = dset_id_tcoiav_CMIP5.isel(
                region=np.where(dset_id_tcoiav_CMIP5.abbrevs == region)[0],
                season=np.where(dset_id_tcoiav_CMIP5.season == season)[0],
            ).squeeze()
            std = dset_id_riav_CMIP5.isel(
                region=np.where(dset_id_riav_CMIP5.abbrevs == region)[0],
                season=np.where(dset_id_riav_CMIP5.season == season)[0],
            ).squeeze()
            # Add models to Taylor diagram
            for i, model in enumerate(dset_id_tcoiav_CMIP5.dset_id.data):
                mip_era = eur_colors["mip_era"][list_model_version == model].values
                color = eur_colors["color"][list_model_version == model].values[0]
                dia.add_sample(
                    std.sel(dset_id=model)[list(std.data_vars)[0]].item(),
                    rho.sel(dset_id=model)[list(rho.data_vars)[0]].item(),
                    marker=mark,
                    ms=5,
                    ls="",
                    mfc="none",
                    mec=color,
                    label=f"{model}_{season}",
                )

        # cmip6
        rho = dset_id_tcoiav_CMIP6.isel(
            region=np.where(dset_id_tcoiav_CMIP6.abbrevs == region)[0],
            season=np.where(dset_id_tcoiav_CMIP6.season == season)[0],
        ).squeeze()
        std = dset_id_riav_CMIP6.isel(
            region=np.where(dset_id_riav_CMIP6.abbrevs == region)[0],
            season=np.where(dset_id_riav_CMIP6.season == season)[0],
        ).squeeze()
        # Add models to Taylor diagram
        for i, model in enumerate(dset_id_tcoiav_CMIP6.dset_id.data):
            mip_era = eur_colors["mip_era"][list_model_version == model].values
            color = eur_colors["color"][list_model_version == model].values[0]
            dia.add_sample(
                std.sel(dset_id=model)[list(std.data_vars)[0]].item(),
                rho.sel(dset_id=model)[list(rho.data_vars)[0]].item(),
                marker=mark,
                ms=5,
                ls="",
                mfc=mcolors.to_rgba(color, 0.5),
                mec=color,
                label=f"{model}_{season}",
            )

        # obs
        rho = obs_tcoiav.isel(
            region=np.where(obs_tcoiav.abbrevs == region)[0],
            season=np.where(obs_tcoiav.season == season)[0],
        ).squeeze()
        std = obs_riav.isel(
            region=np.where(obs_riav.abbrevs == region)[0],
            season=np.where(obs_riav.season == season)[0],
        ).squeeze()
        # Add models to Taylor diagram
        for i, model in enumerate(obs_tcoiav.dset_id.data):
            if "era5" in model:
                color = "magenta"
            else:
                color = "black"

            dia.add_sample(
                std.sel(dset_id=model),
                rho.sel(dset_id=model),
                marker=mark,
                ms=8,
                ls="",
                mfc="none",
                mec=color,
                label=f"{model}_{season}",
            )

    # Add correlation lines
    dia.add_correlation_lines()

    # Add RMS contours, and label them
    contours = dia.add_contours(levels=2, colors="0.5")  # 5 levels in grey
    plt.clabel(contours, inline=1, fontsize=10, fmt="%.1f")

    if n_r == 1:
        # Add a figure legend and title
        fig.legend(
            dia.samplePoints,
            [p.get_label() for p in dia.samplePoints],
            numpoints=1,
            prop=dict(size=5),
            loc="upper right",
        )

    fig.text(0.25, 0.83, region, fontsize=12, fontweight="bold", va="top", ha="left")

    fig.savefig(
        f"taylor_{parent_str}_{region}_{period.start}-{period.stop}.png", dpi=300
    )
    plt.close(fig)

imgs = [
    Image.open(f"taylor_{parent_str}_{r}_{period.start}-{period.stop}.png")
    for r in regions
]
imgs = [img.crop(img.getbbox()) for img in imgs]

w, h = imgs[0].size

final_img = Image.new("RGB", (2 * w, 2 * h), "white")

final_img.paste(imgs[0], (0, 0))
final_img.paste(imgs[1], (w, 0))
final_img.paste(imgs[2], (0, h))
final_img.paste(imgs[3], (w, h))

final_img.save(
    f"{save_figure_path}/PRUDENCE_A_taylor_temporal_{parent_str}_{variable}_{period.start}-{period.stop}.png"
)

In [None]:
from PIL import Image

# Reference std
stdref = 1

regions = ["AL", "BI", "FR", "MD"]

for n_r, region in enumerate(regions):
    fig = plt.figure()

    dia = TaylorDiagram(stdref, fig=fig, label="Reference")
    # dia.samplePoints[0].set_color('r')  # Mark reference point as a red star

    for season, mark in seasons_marker.items():

        if parent is True:
            parent_str = "parent"

            # cmip5
            rho = dset_id_tcoiav_CMIP5.isel(
                region=np.where(dset_id_tcoiav_CMIP5.abbrevs == region)[0],
                season=np.where(dset_id_tcoiav_CMIP5.season == season)[0],
            ).squeeze()
            std = dset_id_riav_CMIP5.isel(
                region=np.where(dset_id_riav_CMIP5.abbrevs == region)[0],
                season=np.where(dset_id_riav_CMIP5.season == season)[0],
            ).squeeze()
            # Add models to Taylor diagram
            for i, model in enumerate(dset_id_tcoiav_CMIP5.dset_id.data):
                mip_era = eur_colors["mip_era"][list_model_version == model].values
                color = eur_colors["color"][list_model_version == model].values[0]
                dia.add_sample(
                    std.sel(dset_id=model)[list(std.data_vars)[0]].item(),
                    rho.sel(dset_id=model)[list(rho.data_vars)[0]].item(),
                    marker=mark,
                    ms=5,
                    ls="",
                    mfc="none",
                    mec=color,
                    label=f"{model}_{season}",
                )

        # cmip6
        rho = dset_id_tcoiav_CMIP6.isel(
            region=np.where(dset_id_tcoiav_CMIP6.abbrevs == region)[0],
            season=np.where(dset_id_tcoiav_CMIP6.season == season)[0],
        ).squeeze()
        std = dset_id_riav_CMIP6.isel(
            region=np.where(dset_id_riav_CMIP6.abbrevs == region)[0],
            season=np.where(dset_id_riav_CMIP6.season == season)[0],
        ).squeeze()
        # Add models to Taylor diagram
        for i, model in enumerate(dset_id_tcoiav_CMIP6.dset_id.data):
            mip_era = eur_colors["mip_era"][list_model_version == model].values
            color = eur_colors["color"][list_model_version == model].values[0]
            dia.add_sample(
                std.sel(dset_id=model)[list(std.data_vars)[0]].item(),
                rho.sel(dset_id=model)[list(rho.data_vars)[0]].item(),
                marker=mark,
                ms=5,
                ls="",
                mfc=mcolors.to_rgba(color, 0.5),
                mec=color,
                label=f"{model}_{season}",
            )

        # obs
        rho = obs_tcoiav.isel(
            region=np.where(obs_tcoiav.abbrevs == region)[0],
            season=np.where(obs_tcoiav.season == season)[0],
        ).squeeze()
        std = obs_riav.isel(
            region=np.where(obs_riav.abbrevs == region)[0],
            season=np.where(obs_riav.season == season)[0],
        ).squeeze()
        # Add models to Taylor diagram
        for i, model in enumerate(obs_tcoiav.dset_id.data):
            if "era5" in model:
                color = "magenta"
            else:
                color = "black"

            dia.add_sample(
                std.sel(dset_id=model),
                rho.sel(dset_id=model),
                marker=mark,
                ms=8,
                ls="",
                mfc="none",
                mec=color,
                label=f"{model}_{season}",
            )

    # Add correlation lines
    dia.add_correlation_lines()

    # Add RMS contours, and label them
    contours = dia.add_contours(levels=2, colors="0.5")  # 5 levels in grey
    plt.clabel(contours, inline=1, fontsize=10, fmt="%.1f")

    if n_r == 1:
        # Add a figure legend and title
        fig.legend(
            dia.samplePoints,
            [p.get_label() for p in dia.samplePoints],
            numpoints=1,
            prop=dict(size=3.5),
            loc="upper right",
        )

    fig.text(0.25, 0.83, region, fontsize=12, fontweight="bold", va="top", ha="left")

    fig.savefig(
        f"taylor_{parent_str}_{region}_{period.start}-{period.stop}.png", dpi=300
    )
    plt.close(fig)

imgs = [
    Image.open(f"taylor_{parent_str}_{r}_{period.start}-{period.stop}.png")
    for r in regions
]
imgs = [img.crop(img.getbbox()) for img in imgs]

w, h = imgs[0].size

final_img = Image.new("RGB", (2 * w, 2 * h), "white")

final_img.paste(imgs[0], (0, 0))
final_img.paste(imgs[1], (w, 0))
final_img.paste(imgs[2], (0, h))
final_img.paste(imgs[3], (w, h))

final_img.save(
    f"{save_figure_path}/PRUDENCE_B_taylor_temporal_{parent_str}_{variable}_{period.start}-{period.stop}.png"
)