# Regional Bias

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import warnings
from warnings import warn

import cf_xarray as cfxr
import cordex as cx
import dask
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import regionmask
import xarray as xr
import xesmf as xe
from dask.distributed import Client
from evaltools import obs
from evaltools.eval import regional_means
from evaltools.obs import eobs_mapping
from evaltools.source import get_source_collection, open_and_sort
from evaltools.utils import short_iid

dask.config.set(scheduler="single-threaded")

In [None]:
client = Client(dashboard_address="localhost:8000", threads_per_worker=1)
client

In [None]:
vertices = {
    "CMIP6": ("vertices_lon", "vertices_lat"),
    "CMIP5": ("lon_vertices", "lat_vertices"),
}


def add_bounds(ds, mip_era="CMIP6"):
    if "longitude" not in ds.cf.bounds and "latitude" not in ds.cf.bounds:
        ds = cx.transform_bounds(ds, trg_dims=vertices[mip_era])
        ds = ds.assign_coords(
            lon_b=cfxr.bounds_to_vertices(
                ds[vertices[mip_era][0]],
                bounds_dim="vertices",
                order="counterclockwise",
            ),
            lat_b=cfxr.bounds_to_vertices(
                ds[vertices[mip_era][1]],
                bounds_dim="vertices",
                order="counterclockwise",
            ),
        )
    return ds


def mask_with_sftlf(ds, sftlf=None):
    if sftlf is None and "sftlf" in ds:
        sftlf = ds["sftlf"]
        for var in ds.data_vars:
            if var != "sftlf":
                ds[var] = ds[var].where(sftlf > 0)
        ds["mask"] = sftlf > 0
    else:
        source = [
            ds.attrs[attr]
            for attr in ["source_id", "model_id", "source"]
            if attr in ds.attrs
        ]
        warn(f"sftlf not found in dataset: {source[0]}")
    return ds


def open_datasets(
    variables,
    frequency="mon",
    driving_source_id="ERA5",
    mask=True,
    add_missing_bounds=False,
    **kargs,
):
    catalog = get_source_collection(
        variables, frequency, driving_source_id, add_fx=["areacella", "sftlf"]
    )
    dsets = open_and_sort(
        catalog, merge=merge, concat=False, time_range=kargs.get("time_range", None)
    )
    if mask is True:
        for ds in dsets.values():
            mask_with_sftlf(ds)
    if add_missing_bounds is True:
        for dset_id, ds in dsets.items():
            if driving_source_id == "ERA5":
                dsets[dset_id] = add_bounds(ds, mip_era="CMIP6")
            elif driving_source_id == "ECMWF-ERAINT":
                dsets[dset_id] = add_bounds(ds, mip_era="CMIP5")
    return dsets


def create_cordex_grid(domain_id, mip_era="CMIP6"):
    grid = cx.domain(domain_id, bounds=True, mip_era=mip_era)
    lon_b = cfxr.bounds_to_vertices(
        grid[vertices[mip_era][0]], bounds_dim="vertices", order="counterclockwise"
    )
    lat_b = cfxr.bounds_to_vertices(
        grid[vertices[mip_era][1]], bounds_dim="vertices", order="counterclockwise"
    )
    return grid.assign_coords(lon_b=lon_b, lat_b=lat_b)


def create_regridder(source, target, method="bilinear"):
    regridder = xe.Regridder(source, target, method=method)
    return regridder


def regrid(ds, regridder):
    ds_regrid = regridder(ds)
    for var in ds.data_vars:
        if var not in ["mask", "sftlf"]:
            ds_regrid[var] = ds_regrid[var].where(ds_regrid["mask"] > 0.0)
    return ds_regrid


def regrid_dsets(dsets, target_grid, method="bilinear"):
    for dset_id, ds in dsets.items():
        try:
            mapping = ds.cf["grid_mapping"].grid_mapping_name
        except Exception:
            print("problmes with grid_mapping definition")
            continue
        if mapping == "rotated_latitude_longitude":
            dsets[dset_id] = ds.cx.rewrite_coords(coords="all")
        else:
            print(f"regridding {dset_id} with grid_mapping: {mapping}")
            regridder = create_regridder(ds, target_grid, method=method)
            print(regridder)
            dsets[dset_id] = regrid(ds, regridder)
    return dsets


def mask_invalid(ds, vars=None, threshold=0.1):
    if isinstance(vars, str):
        vars = [vars]
    if vars is None:
        var = list(ds.data_vars)
    for var in vars:
        var_nan = ds[var].isnull().sum(dim="time") / ds.time.size
        ds[var] = ds[var].where(var_nan < threshold)
    return ds

In [None]:
def convert_celsius_to_kelvin(ds, threshold=200):
    """
    Converts all temperature variables in an xarray Dataset from degrees Celsius to Kelvin
    based on the 'units' attribute, value magnitude, or 'standard_name' attribute.

    Parameters:
        ds (xarray.Dataset): The input dataset.
        threshold (float): A heuristic threshold (default=200) to assume temperatures
                           below this value might be in Celsius.

    Returns:
        xarray.Dataset: A new dataset with converted temperature values.
    """
    ds = ds.copy()  # Avoid modifying the original dataset

    for var in ds.data_vars:
        units = ds[var].attrs.get("units", "").lower()
        standard_name = ds[var].attrs.get("standard_name", "").lower()

        # Check if units explicitly indicate Celsius
        if units in ["c", "°c", "celsius", "degc"]:
            ds[var] = ds[var] + 273.15
            ds[var].attrs["units"] = "K"
            print("Convert celsius to kelvin")

        # If no unit attribute exists, check standard_name for temperature-related terms
        elif standard_name in [
            "air_temperature",
            "sea_surface_temperature",
            "surface_temperature",
        ]:
            data_vals = ds[var].values
            if np.nanmax(data_vals) < threshold:  # Likely in °C
                ds[var] = ds[var] + 273.15
                ds[var].attrs["units"] = "K"
                print("Convert celsius to kelvin")

    return ds

In [None]:
def seasonal_mean(da):
    """Optimized function to calculate seasonal averages from time series of monthly means

    based on: https://xarray.pydata.org/en/stable/examples/monthly-means.html
    """
    # Get number od days for each month
    month_length = da.time.dt.days_in_month
    # Calculate the weights by grouping by 'time.season'.
    weights = (
        month_length.groupby("time.season") / month_length.groupby("time.season").sum()
    )

    # Test that the sum of the weights for each season is 1.0
    # np.testing.assert_allclose(weights.groupby("time.season").sum().values, np.ones(4))

    # Calculate the weighted average
    return (
        (da * weights).groupby("time.season").sum(dim="time", skipna=True, min_count=1)
    )

In [None]:
def check_equal_period(ds, period):
    years_in_ds = np.unique(ds.time.dt.year.values)
    expected_years = np.arange(int(period.start), int(period.stop) + 1)
    return np.array_equal(years_in_ds, expected_years)

In [None]:
default_attrs_ = [
    "project_id",
    "domain_id",
    "institution_id",
    "driving_source_id",
    "driving_experiment_id",
    "driving_variant_label",
    "source_id",
    "version_realization",
    "frequency",
    "variable_id",
    "version",
]

In [None]:
var_dic = {
    "tas": {
        "variable": "tas",
        "name": "Temperature BIAS [K]",
        "diff": "abs",
        "range": [-4, 4],
        "aggr": "mean",
    },
    "pr": {
        "variable": "pr",
        "name": "Precipitation BIAS [%]",
        "diff": "rel",
        "range": [-60, 180],
        "aggr": "mean",
    },
    "tas95": {
        "variable": "tas",
        "name": "Temperature 95%-P [K]",
        "diff": "abs",
        "range": [-2, 10],
        "aggr": "P95",
    },
    "pr95": {
        "variable": "pr",
        "name": "Precipitation 95%-P [%]",
        "diff": "rel",
        "range": [0, 400],
        "aggr": "P95",
    },
}

In [None]:
# Parameter papermill
index = "pr95"
frequency = "mon"
domain = "EUR-11"
regridding = "bilinear"
period = slice("1989", "2008")

In [None]:
save_figure_path = os.path.abspath(os.path.join(os.getcwd(), "..", "plots"))

In [None]:
variable = var_dic[index]["variable"]

In [None]:
eur_colors = pd.read_csv("eurocordex_models.csv")

In [None]:
# prudence
regions = regionmask.defined_regions.prudence

In [None]:
rotated_grid = create_cordex_grid("EUR-11", mip_era="CMIP5")  # No matter CMIP5 or CMIP6

## eobs

In [None]:
eobs = obs.eobs(variable, add_mask=False).sel(time=period)
eobs_var = [key for key, value in eobs_mapping.items() if value == variable][0]
eobs = mask_invalid(eobs, vars=eobs_var, threshold=0.1)

In [None]:
regridder = xe.Regridder(eobs, rotated_grid, method=regridding, unmapped_to_nan=True)
eobs_on_rotated = regridder(eobs)

In [None]:
if not check_equal_period(eobs_on_rotated, period):
    print(f"Temporal coverage of dataset does not match with {period}")

In [None]:
eobs_seasmean = seasonal_mean(eobs_on_rotated[eobs_var].sel(time=period)).compute()

## CMIP6

In [None]:
mip_era = "CMIP6"
driving_source_id = "ERA5"
# Define how to merge the files in xarray
merge = ["variable_id", "frequency"]
default_attrs = [d for d in default_attrs_ if d not in merge]

In [None]:
dsets = open_datasets(
    [variable],
    frequency=frequency,
    driving_source_id=driving_source_id,
    mask=True,
    add_missing_bounds=False,
    **{"merge": merge, "time_range": period},
)

In [None]:
for dset in dsets.keys():
    if not check_equal_period(dsets[dset], period):
        print(f"Temporal coverage of {dset} does not match with {period}")

In [None]:
for dset in dsets.keys():
    dsets[dset] = convert_celsius_to_kelvin(dsets[dset])

In [None]:
dsets = regrid_dsets(dsets, rotated_grid, method=regridding)

In [None]:
if var_dic[index]["diff"] == "abs":
    diffs = {
        dset_id: seasonal_mean(ds[[variable]].sel(time=period)).compute()
        - (eobs_seasmean + 273.15)
        for dset_id, ds in dsets.items()
        if variable in ds.variables
    }
elif var_dic[index]["diff"] == "rel":
    diffs = {
        dset_id: 100
        * (
            seasonal_mean(ds[[variable]].sel(time=period)).compute() * 86400
            - (eobs_seasmean)
        )
        / (eobs_seasmean)
        for dset_id, ds in dsets.items()
        if variable in ds.variables
    }

seasonal_bias = xr.concat(
    list(diffs.values()),
    dim=xr.DataArray(
        list(
            map(
                lambda x: short_iid(x, ["source_id"], default_attrs=default_attrs),
                diffs.keys(),
            )
        ),
        dims="dset_id",
    ),
    compat="override",
    coords="minimal",
)

In [None]:
dset_id_regions = regional_means(seasonal_bias, regions, aggr=var_dic[index]["aggr"])
dset_id_regions.to_netcdf(
    f"{mip_era}-CORDEX_{index}_region_means_{period.start}-{period.stop}.nc"
)

## CMIP5 

In [None]:
mip_era = "CMIP5"
driving_source_id = "ECMWF-ERAINT"
# Define how to merge the files in xarray
merge = ["variable_id", "frequency", "driving_variant_label", "version"]
default_attrs = [d for d in default_attrs_ if d not in merge]

In [None]:
dsets = open_datasets(
    [variable],
    frequency=frequency,
    driving_source_id=driving_source_id,
    mask=True,
    add_missing_bounds=False,
    **{"merge": merge, "time_range": period},
)

In [None]:
for dset in dsets.keys():
    if not check_equal_period(dsets[dset], period):
        print(f"Temporal coverage of {dset} does not match with {period}")

In [None]:
for dset in dsets.keys():
    dsets[dset] = convert_celsius_to_kelvin(dsets[dset])

In [None]:
dsets = regrid_dsets(dsets, rotated_grid, method=regridding)

In [None]:
if var_dic[index]["diff"] == "abs":
    diffs = {
        dset_id: seasonal_mean(ds[[variable]].sel(time=period)).compute()
        - (eobs_seasmean + 273.15)
        for dset_id, ds in dsets.items()
        if variable in ds.variables
    }
elif var_dic[index]["diff"] == "rel":
    diffs = {
        dset_id: 100
        * (
            seasonal_mean(ds[[variable]].sel(time=period)).compute() * 86400
            - (eobs_seasmean)
        )
        / (eobs_seasmean)
        for dset_id, ds in dsets.items()
        if variable in ds.variables
    }

seasonal_bias = xr.concat(
    list(diffs.values()),
    dim=xr.DataArray(
        list(
            map(
                lambda x: short_iid(x, ["source_id"], default_attrs=default_attrs),
                diffs.keys(),
            )
        ),
        dims="dset_id",
    ),
    compat="override",
    coords="minimal",
)

In [None]:
ds = seasonal_bias.isel(dset_id=1)
weights = xr.ones_like(ds.lon)
mask = regions.mask_3D(ds.lon, ds.lat, drop=False)
result = ds.cf.weighted(mask * weights).mean(dim=("X", "Y"))

In [None]:
dset_id_regions = regional_means(seasonal_bias, regions, aggr=var_dic[index]["aggr"])
dset_id_regions.to_netcdf(
    f"{mip_era}-CORDEX_{index}_region_means_{period.start}-{period.stop}.nc"
)

#### Load results for both CMIP5 and CMIP6 simulations

In [None]:
seasons = ["DJF", "MAM", "JJA", "SON"]

In [None]:
dset_id_regions_CMIP6 = xr.open_dataset(
    f"CMIP6-CORDEX_{index}_region_means_{period.start}-{period.stop}.nc"
)
dset_id_regions_CMIP5 = xr.open_dataset(
    f"CMIP5-CORDEX_{index}_region_means_{period.start}-{period.stop}.nc"
)

In [None]:
df_CMIP6 = dset_id_regions_CMIP6.to_dataframe().reset_index()
df_CMIP5 = dset_id_regions_CMIP5.to_dataframe().reset_index()

In [None]:
print(np.min(df_CMIP5[variable]), np.max(df_CMIP5[variable]))
print(np.min(df_CMIP6[variable]), np.max(df_CMIP6[variable]))

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

regs = ["EA", "IP", "ME", "SC"]

fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True, sharey=True)
axes = axes.flatten()

handles = []
labels = []

for i, region in enumerate(regs):
    ax = axes[i]

    df_CMIP6_region = df_CMIP6[df_CMIP6["abbrevs"] == region]
    df_CMIP5_region = df_CMIP5[df_CMIP5["abbrevs"] == region]

    df_CMIP6_region["season_num"] = df_CMIP6_region["season"].apply(
        lambda x: seasons.index(x)
    )
    df_CMIP5_region["season_num"] = df_CMIP5_region["season"].apply(
        lambda x: seasons.index(x)
    )

    df_CMIP6_region["season_shifted"] = df_CMIP6_region["season_num"] + 0.1
    df_CMIP5_region["season_shifted"] = df_CMIP5_region["season_num"] - 0.1

    # Create lists to store the bias values for calculating median
    cmip6_biases = {season: [] for season in seasons}
    cmip5_biases = {season: [] for season in seasons}

    for idx, row in df_CMIP6_region.iterrows():
        dset_id = row["dset_id"]
        color = eur_colors["color"][eur_colors["model"] == dset_id].values[0]
        scatter = ax.scatter(
            row["season_shifted"],
            row[variable],
            color=color,
            edgecolors=color,
            marker="o",
            s=80,
        )

        # Collect bias values for median calculation
        cmip6_biases[row["season"]].append(abs(row[variable]))

        if dset_id not in labels:
            handles.append(scatter)
            labels.append(dset_id)

        parent = eur_colors["parent"][eur_colors["model"] == dset_id].values[0]
        if not pd.isnull(parent):
            row_cmip5 = df_CMIP5_region[df_CMIP5_region["dset_id"] == parent]
            if not row_cmip5.empty:
                row_cmip5 = row_cmip5[row_cmip5["season"] == row.season].iloc[0]
                ax.plot(
                    [row_cmip5["season_shifted"], row["season_shifted"]],
                    [row_cmip5[variable], row[variable]],
                    color=color,
                    linestyle="-",
                    zorder=0,
                )

    for idx, row in df_CMIP5_region.iterrows():
        dset_id = row["dset_id"]
        color = eur_colors["color"][eur_colors["model"] == dset_id].values[0]
        scatter = ax.scatter(
            row["season_shifted"],
            row[variable],
            color=color,
            edgecolors=color,
            facecolor="none",
            marker="o",
            s=80,
        )

        # Collect bias values for median calculation
        cmip5_biases[row["season"]].append(abs(row[variable]))

        if dset_id not in labels:
            handles.append(scatter)
            labels.append(dset_id)

    axes[0].set_ylabel(var_dic[index]["name"])
    axes[2].set_ylabel(var_dic[index]["name"])
    # Add region label in the top-left corner of each subplot
    ax.text(
        0.05,
        0.95,
        region,
        transform=ax.transAxes,
        fontsize=12,
        verticalalignment="top",
        horizontalalignment="left",
        color="black",
        weight="bold",
    )

    ax.set_xticks([0, 1, 2, 3])  # Adjust tick positions according to the shift
    ax.set_xticklabels(seasons)  # Set the names of the seasons as labels

    ax.grid(True)
    ax.axhline(0, color="black", linestyle="--")

    if variable == "pr":
        ax.fill_between([-0.5, 3.5], 0, 25, color="#cceeff", alpha=0.5)

    # Calculate and display the absolute median bias for each season for both CMIP5 and CMIP6
    for j, season in enumerate(seasons):
        cmip6_median = (
            np.nanmedian(cmip6_biases[season]) if cmip6_biases[season] else np.nan
        )
        cmip5_median = (
            np.nanmedian(cmip5_biases[season]) if cmip5_biases[season] else np.nan
        )

        # Add the absolute median bias text below the season labels
        # ax.text(j - 0.1, -3.2, f'CMIP6: {cmip6_median:.2f}\nCMIP5: {cmip5_median:.2f}',
        #        fontsize=10, verticalalignment='top', horizontalalignment='center', color='black')
        ax.text(
            j,
            var_dic[index]["range"][0] + 0.5,
            f"{cmip5_median:.1f}  {cmip6_median:.1f}",
            fontsize=10,
            verticalalignment="top",
            horizontalalignment="center",
            color="black",
        )

fig.legend(
    handles,
    labels,
    loc="upper center",
    bbox_to_anchor=(0.5, -0.05),
    ncol=5,
    fontsize=10,
)

# plt.ylim([var_dic[index]['range'][0],
#          var_dic[index]['range'][1]])
plt.tight_layout()
plt.show()
fig.savefig(
    f"{save_figure_path}/CMIP6-CMIP5_regionsA_bias_{index}_{period.start}-{period.stop}.png",
    bbox_inches="tight",
    transparent=True,
    pad_inches=0,
)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

regs = ["AL", "BI", "FR", "MD"]

fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True, sharey=True)
axes = axes.flatten()

handles = []
labels = []

for i, region in enumerate(regs):
    ax = axes[i]

    df_CMIP6_region = df_CMIP6[df_CMIP6["abbrevs"] == region]
    df_CMIP5_region = df_CMIP5[df_CMIP5["abbrevs"] == region]

    df_CMIP6_region["season_num"] = df_CMIP6_region["season"].apply(
        lambda x: seasons.index(x)
    )
    df_CMIP5_region["season_num"] = df_CMIP5_region["season"].apply(
        lambda x: seasons.index(x)
    )

    df_CMIP6_region["season_shifted"] = df_CMIP6_region["season_num"] + 0.1
    df_CMIP5_region["season_shifted"] = df_CMIP5_region["season_num"] - 0.1

    # Create lists to store the bias values for calculating median
    cmip6_biases = {season: [] for season in seasons}
    cmip5_biases = {season: [] for season in seasons}

    for idx, row in df_CMIP6_region.iterrows():
        dset_id = row["dset_id"]
        color = eur_colors["color"][eur_colors["model"] == dset_id].values[0]
        scatter = ax.scatter(
            row["season_shifted"],
            row[variable],
            color=color,
            edgecolors=color,
            marker="o",
            s=80,
        )

        # Collect bias values for median calculation
        cmip6_biases[row["season"]].append(abs(row[variable]))

        if dset_id not in labels:
            handles.append(scatter)
            labels.append(dset_id)

        parent = eur_colors["parent"][eur_colors["model"] == dset_id].values[0]
        if not pd.isnull(parent):
            row_cmip5 = df_CMIP5_region[df_CMIP5_region["dset_id"] == parent]
            if not row_cmip5.empty:
                row_cmip5 = row_cmip5[row_cmip5["season"] == row.season].iloc[0]
                ax.plot(
                    [row_cmip5["season_shifted"], row["season_shifted"]],
                    [row_cmip5[variable], row[variable]],
                    color=color,
                    linestyle="-",
                    zorder=0,
                )

    for idx, row in df_CMIP5_region.iterrows():
        dset_id = row["dset_id"]
        color = eur_colors["color"][eur_colors["model"] == dset_id].values[0]
        scatter = ax.scatter(
            row["season_shifted"],
            row[variable],
            color=color,
            edgecolors=color,
            facecolor="none",
            marker="o",
            s=80,
        )

        # Collect bias values for median calculation
        cmip5_biases[row["season"]].append(abs(row[variable]))

        if dset_id not in labels:
            handles.append(scatter)
            labels.append(dset_id)

    axes[0].set_ylabel(var_dic[index]["name"])
    axes[2].set_ylabel(var_dic[index]["name"])
    # Add region label in the top-left corner of each subplot
    ax.text(
        0.05,
        0.95,
        region,
        transform=ax.transAxes,
        fontsize=12,
        verticalalignment="top",
        horizontalalignment="left",
        color="black",
        weight="bold",
    )

    ax.set_xticks([0, 1, 2, 3])  # Adjust tick positions according to the shift
    ax.set_xticklabels(seasons)  # Set the names of the seasons as labels

    ax.grid(True)
    ax.axhline(0, color="black", linestyle="--")

    if variable == "pr":
        ax.fill_between([-0.5, 3.5], 0, 25, color="#cceeff", alpha=0.5)

    # Calculate and display the absolute median bias for each season for both CMIP5 and CMIP6
    for j, season in enumerate(seasons):
        cmip6_median = (
            np.nanmedian(cmip6_biases[season]) if cmip6_biases[season] else np.nan
        )
        cmip5_median = (
            np.nanmedian(cmip5_biases[season]) if cmip5_biases[season] else np.nan
        )

        # Add the absolute median bias text below the season labels
        # ax.text(j - 0.1, -3.2, f'CMIP6: {cmip6_median:.2f}\nCMIP5: {cmip5_median:.2f}',
        #        fontsize=10, verticalalignment='top', horizontalalignment='center', color='black')
        ax.text(
            j,
            var_dic[index]["range"][0] + 0.5,
            f"{cmip5_median:.1f}  {cmip6_median:.1f}",
            fontsize=10,
            verticalalignment="top",
            horizontalalignment="center",
            color="black",
        )

fig.legend(
    handles,
    labels,
    loc="upper center",
    bbox_to_anchor=(0.5, -0.05),
    ncol=5,
    fontsize=10,
)

# plt.ylim([var_dic[index]['range'][0],
#          var_dic[index]['range'][1]])
plt.tight_layout()
plt.show()
fig.savefig(
    f"{save_figure_path}/CMIP6-CMIP5_regionsB_bias_{index}_{period.start}-{period.stop}.png",
    bbox_inches="tight",
    transparent=True,
    pad_inches=0,
)