# Precipitation and temperature bias

## Import packages

In [None]:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import scipy
import skill_metrics
import xarray as xr
from c3s_eqc_automatic_quality_control import diagnostics, download, plot, utils

plt.style.use("seaborn-v0_8-notebook")

## Define Parameters

In [None]:
# Time period
year_start = 1985
year_stop = 1987

# Choose annual or seasonal timeseries
timeseries = "annual"
assert timeseries in ("annual", "DJF", "MAM", "JJA", "SON")

# Variable
variable = "temperature"
assert variable in ("precipitation", "temperature")

# Choose CORDEX or CMIP6
collection_id = "cordex"
assert collection_id in ("cordex", "cmip6")

# Define region for analysis
lon_slice = slice(-4, 20)
lat_slice = slice(35, 50)

# Define region for request
cordex_domain = "europe"

# Chunks for download
chunks = {"year": 1}
assert "month" not in chunks, "Do not use chunks smaller than 1y"

## Define models

In [None]:
models_cordex = [
    "clmcom_clm_cclm4_8_17",
    "clmcom_eth_cosmo_crclim",
    "cnrm_aladin63",
    "dmi_hirham5",
    "knmi_racmo22e",
    "mohc_hadrem3_ga7_05",
    "mpi_csc_remo2009",
    "smhi_rca4",
    "uhoh_wrf361h",
]

models_cmip6 = [
    "access_cm2",
    "awi_esm_1_1_lr",
    "bcc_esm1",
    "cesm2_fv2",
    "cnrm_cm6_1",
    "fgoals_g3",
]

## Define ERA5 request

In [None]:
era5_variables = {
    "precipitation": "mean_total_precipitation_rate",
    "temperature": "2m_temperature",
}

request_era = (
    "reanalysis-era5-single-levels-monthly-means",
    {
        "product_type": "monthly_averaged_reanalysis",
        "format": "netcdf",
        "time": "00:00",
        "variable": era5_variables[variable],
        "year": [
            str(year) for year in range(year_start - 1, year_stop + 1)
        ],  # Include D(year-1)
        "month": [f"{month:02d}" for month in range(1, 12 + 1)],
    },
)

## Define model requests

In [None]:
cordex_variables = {
    "precipitation": "mean_precipitation_flux",
    "temperature": "2m_air_temperature",
}

request_cordex = {
    "format": "zip",
    "domain": cordex_domain,
    "experiment": "historical",
    "horizontal_resolution": "0_11_degree_x_0_11_degree",
    "temporal_resolution": "monthly_mean",
    "variable": cordex_variables[variable],
    "gcm_model": "mpi_m_mpi_esm_lr",
    "ensemble_member": "r1i1p1",
}

cmip6_variables = {
    "precipitation": "precipitation",
    "temperature": "near_surface_air_temperature",
}
request_cmip6 = {
    "format": "zip",
    "temporal_resolution": "monthly",
    "experiment": "historical",
    "variable": cmip6_variables[variable],
    "year": request_era[1]["year"],
    "month": request_era[1]["month"],
}


def get_cordex_years(
    year_start,
    year_stop,
    start_years=[1971, 1981, 1991, 2001],
    end_years=[1980, 1990, 2000, 2005],
):
    start_year = []
    end_year = []
    years = set(range(year_start - 1, year_stop + 1))  # Include D(year-1)
    for start, end in zip(start_years, end_years):
        if years & set(range(start, end + 1)):
            start_year.append(start)
            end_year.append(end)
    return start_year, end_year


if collection_id == "cordex":
    weights = False  # Do not weight spatial statistics/errors
    periodic = False
    models = models_cordex
    model_key = "rcm_model"
    request_sim = (
        "projections-cordex-domains-single-levels",
        [
            {
                **request_cordex,
                "start_year": start_year,
                "end_year": end_year,
            }
            for start_year, end_year in zip(*get_cordex_years(year_start, year_stop))
        ],
    )
elif collection_id == "cmip6":
    weights = True  # Weight spatial statistics/errors
    periodic = True
    models = models_cmip6
    model_key = "model"
    request_sim = (
        "projections-cmip6",
        download.split_request(request_cmip6, chunks=chunks),
    )
else:
    raise ValueError

## Function to cache

In [None]:
def compute_regridded_timeseries(
    ds, year_start, year_stop, annual, grid_out=None, **kwargs
):
    # Drop useless data
    if annual:
        mask = (ds["time"].dt.year >= year_start) & (ds["time"].dt.year <= year_stop)
    else:
        # Select years (shift -1 to get D(year-1)J(year)F(year))
        ds = ds.assign_coords(year=ds["time"].dt.year.shift(time=-1))
        mask = (ds["year"] >= year_start) & (ds["year"] <= year_stop)
    ds = ds.where(mask.compute(), drop=True)
    if not ds.sizes["time"]:
        # Return empty dataset. Previous year needed for DJF only.
        return xr.Dataset()

    # Select variable
    if var_names := set(ds.data_vars) & {"mtpr", "pr"}:
        field = "precipitation"
    elif var_names := set(ds.data_vars) & {"tas", "t2m"}:
        field = "temperature"
    else:
        raise ValueError("Unable to find precipitation or temperature variable.")
    (var_name,) = var_names
    da = ds[var_name].rename(field)

    # Create timeseries
    if annual:
        da = diagnostics.annual_weighted_mean(da)
    else:
        da["year"] = da["year"].astype(int)
        da = da.groupby("year").map(diagnostics.seasonal_weighted_mean)

    # Regid
    if grid_out is not None:
        da = diagnostics.regrid(da, grid_out, **kwargs)

    # Convert units
    with xr.set_options(keep_attrs=True):
        if da.name == "precipitation":
            da *= 3600 * 24
            da.attrs["units"] = "mm/day"
        elif da.name == "temperature":
            da -= 273.15
            da.attrs["units"] = "°C"
            da = da.assign_coords(height=ds["height"] if "height" in ds else None)
        else:
            raise ValueError
    return da.to_dataset()

## Download data

In [None]:
annual = timeseries == "annual"
transform_chunks = annual
cached_open_mfdataset_kwargs = {"chunks": chunks}
print("Downloading and processing ERA5")
ds_era = download.download_and_transform(
    *request_era,
    chunks=chunks,
    transform_chunks=transform_chunks,
    transform_func=compute_regridded_timeseries,
    transform_func_kwargs={
        "year_start": year_start,
        "year_stop": year_stop,
        "annual": annual,
    },
    cached_open_mfdataset_kwargs={"chunks": chunks},
).expand_dims(model=["ERA5"])

datasets = []
for model in models:
    print(f"Downloading and processing {model}")
    request_model = request_sim
    for request in request_model[1]:
        request[model_key] = model
    ds = download.download_and_transform(
        *request_model,
        transform_chunks=transform_chunks,
        transform_func=compute_regridded_timeseries,
        transform_func_kwargs={
            "grid_out": ds_era[["longitude", "latitude"]],
            "year_start": year_start,
            "year_stop": year_stop,
            "annual": annual,
            "method": "bilinear",
            "periodic": periodic,
        },
        cached_open_mfdataset_kwargs={"chunks": chunks},
    )
    datasets.append(ds.expand_dims(model=[model]))
ds_sim = xr.concat(datasets, "model")

## Create a single dataset and compute bias

In [None]:
ds_timeseries = xr.concat(
    [
        ds_sim.drop_vars("height", errors="ignore"),
        ds_sim.mean("model").expand_dims(model=["ensemble"]),
        ds_era.drop_vars("height", errors="ignore"),
    ],
    "model",
)
if timeseries != "annual":
    ds_timeseries = ds_timeseries.sel(season=timeseries)
ds_timeseries = utils.regionalise(
    ds_timeseries, lon_slice=lon_slice, lat_slice=lat_slice
)

with xr.set_options(keep_attrs=True):
    ds = ds_timeseries.mean("year").compute()
    ds_bias = ds.drop_sel(model="ERA5") - ds.sel(model="ERA5")
for da in ds_bias.data_vars.values():
    da.attrs["long_name"] += " Bias"

## Plot Maps

In [None]:
# Choose projection
Projection = (
    ccrs.Robinson
    if abs(lon_slice.stop - lon_slice.start) >= 360
    and abs(lat_slice.stop - lat_slice.start) >= 180
    else ccrs.PlateCarree
)
projection = Projection(central_longitude=(lon_slice.stop + lon_slice.start) / 2)

da_to_plot = ds[variable].sel(model=["ensemble", "ERA5"])
plot_kwargs = xr.plot.utils._determine_cmap_params(
    da_to_plot.values,
    robust=True,
    levels=10,
    cmap="Blues" if variable == "precipitation" else "Reds",
)
plot_kwargs["projection"] = projection
for _, da_to_plot in da_to_plot.groupby("model"):
    plot.projected_map(
        da_to_plot,
        stats_weights=True if da_to_plot["model"] == "ERA5" else weights,
        **plot_kwargs,
    )
    plt.show()

## Plot Bias

In [None]:
plot_kwargs = xr.plot.utils._determine_cmap_params(
    ds_bias[variable].values,
    robust=True,
    levels=11,
    cmap="bwr_r" if variable == "precipitation" else "bwr",
)
plot_kwargs["projection"] = projection
plot.projected_map(
    ds_bias[variable].sel(model=["ensemble"]), stats_weights=weights, **plot_kwargs
)
plt.show()
facet = plot.projected_map(
    ds_bias[variable].drop_sel(model="ensemble"),
    col="model",
    col_wrap=min(3, ds_bias.sizes["model"] - 1),
    **plot_kwargs,
)
for ax in facet.axs.flatten():
    ax.set_extent((lon_slice.start, lon_slice.stop, lat_slice.start, lat_slice.stop))

## Plot KDE and Statistics of Bias

In [None]:
# Create dataframe
da = ds_bias[variable]
df_stats = diagnostics.spatial_weighted_statistics(da, weights=weights).to_pandas()

# Plot
fig, ax = plt.subplots(1, 1)
x = np.linspace(
    df_stats["ensemble"]["mean"] - 3 * df_stats["ensemble"]["std"],
    df_stats["ensemble"]["mean"] + 3 * df_stats["ensemble"]["std"],
    1_000,
)
for model, values in da.groupby("model"):
    values = values.stack(dim=values.dims).dropna("dim")
    y = scipy.stats.gaussian_kde(
        values,
        weights=np.abs(np.cos(np.deg2rad(values["latitude"]))) if weights else None,
    ).evaluate(x)
    plot_kwargs = {"color": "k", "ls": "--"} if model == "ensemble" else {}
    ax.plot(x, y, label=model, **plot_kwargs)
ax.grid()
ax.set_xlim(x[[0, -1]])
ax.set_xlabel(f"{da.attrs['long_name']} [{da.attrs['units']}]")
ax.legend()

# Add stats
table = plt.table(
    cellText=df_stats.round(5).T.values.tolist(),
    colLabels=df_stats.T.columns.values.tolist(),
    rowLabels=df_stats.T.index.values.tolist(),
    loc="top",
)

## Compute and Show Statistics

In [None]:
ds_stats = xr.concat(
    [
        diagnostics.spatial_weighted_statistics(
            ds.drop_sel(model="ERA5"), weights=weights
        ),
        diagnostics.spatial_weighted_statistics(ds.sel(model="ERA5"), weights=True),
    ],
    "model",
)
ds_error = diagnostics.spatial_weighted_errors(
    ds.drop_sel(model="ERA5"), ds.sel(model="ERA5"), weights=weights
)
df_stats_and_error = xr.merge([ds_stats, ds_error])[variable].to_pandas()
df_stats_and_error

## Taylor Dyagram

In [None]:
tickRMS = np.linspace(0, df_stats_and_error.loc["crmse"].max(), 5, dtype=int)
tickSTD = np.linspace(0, df_stats_and_error.loc["std"].max(), 5, dtype=int)
skill_metrics.taylor_diagram(
    df_stats_and_error.loc["std"].values,
    df_stats_and_error.loc["crmse"].values,
    df_stats_and_error.loc["corr"].values,
    alpha=0.0,
    colCOR="k",
    colOBS="k",
    colRMS="m",
    colSTD="b",
    markerColor="r" if len(df_stats_and_error.columns) >= 9 else None,  # TODO
    markerLabel=list(df_stats_and_error.columns),
    markerLegend="on",
    markerSize=10,
    markerobs="o",
    styleCOR="--",
    styleOBS="--",
    styleRMS=":",
    styleSTD="-.",
    tickRMS=tickRMS,
    tickSTD=tickSTD,
    titleCOR="on",
    titleOBS="ERA5",
    titleRMS="on",
    titleRMSDangle=40.0,
    titleSTD="on",
    widthCOR=0.5,
    widthOBS=2,
    widthRMS=2,
    widthSTD=1.0,
)