# EOBS seasonal means

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from warnings import warn

import cf_xarray as cfxr
import cordex as cx
import dask
import xarray as xr
import xesmf as xe
from dask.distributed import Client
from evaltools import obs
from evaltools.source import get_source_collection, open_and_sort
from evaltools.utils import short_iid

dask.config.set(scheduler="single-threaded")

In [None]:
client = Client(dashboard_address="localhost:8787", threads_per_worker=1)

In [None]:
add_fx = ["orog", "areacella", "sftlf"]


def add_bounds(ds):
    if "longitude" not in ds.cf.bounds and "latitude" not in ds.cf.bounds:
        ds = cx.transform_bounds(ds, trg_dims=("vertices_lon", "vertices_lat"))
    lon_bounds = ds.cf.get_bounds("longitude")
    lat_bounds = ds.cf.get_bounds("latitude")
    bounds_dim = [dim for dim in lon_bounds.dims if dim not in ds.indexes][0]
    # reshape bounds for xesmf
    ds = ds.assign_coords(
        lon_b=cfxr.bounds_to_vertices(
            lon_bounds, bounds_dim=bounds_dim, order="counterclockwise"
        ),
        lat_b=cfxr.bounds_to_vertices(
            lat_bounds, bounds_dim=bounds_dim, order="counterclockwise"
        ),
    )
    return ds


def mask_with_sftlf(ds, sftlf=None):
    if sftlf is None and "sftlf" in ds:
        sftlf = ds["sftlf"]
        for var in ds.data_vars:
            if var != "sftlf":
                ds[var] = ds[var].where(sftlf > 0)
        ds["mask"] = sftlf > 0
    else:
        warn(f"sftlf not found in dataset: {ds.source_id}")
    return ds


def open_datasets(
    variables,
    frequency="mon",
    mask=True,
    add_missing_bounds=True,
    rewrite_grid=True,
    **kwargs,
):
    cat = get_source_collection(variables, frequency, add_fx=add_fx, **kwargs)
    dsets = open_and_sort(cat, merge_fx=True, apply_fixes=True)
    if rewrite_grid is True:
        for dset_id, ds in dsets.items():
            dsets[dset_id] = rewrite_coords(ds)
    if mask is True:
        for ds in dsets.values():
            mask_with_sftlf(ds)
    if add_missing_bounds is True:
        for dset_id, ds in dsets.items():
            dsets[dset_id] = add_bounds(ds)
    return dsets


def create_cordex_grid(domain_id):
    grid = cx.domain(domain_id, bounds=True, mip_era="CMIP6")
    # grid["lon"].attrs = {}
    # grid["vertices_lat"].attrs = {}
    lon_b = cfxr.bounds_to_vertices(
        grid.vertices_lon, bounds_dim="vertices", order="counterclockwise"
    )
    lat_b = cfxr.bounds_to_vertices(
        grid.vertices_lat, bounds_dim="vertices", order="counterclockwise"
    )
    return grid.assign_coords(lon_b=lon_b, lat_b=lat_b)


def create_regridder(source, target, method="bilinear"):
    regridder = xe.Regridder(source, target, method=method)
    return regridder


def regrid(ds, regridder, mask_after_regrid=True):
    ds_regrid = regridder(ds)
    if mask_after_regrid:
        for var in ds.data_vars:
            if var not in add_fx:
                ds_regrid[var] = ds_regrid[var].where(ds_regrid["mask"] > 0.0)
    return ds_regrid


def regrid_dsets(dsets, target_grid, method="bilinear"):
    for dset_id, ds in dsets.items():
        mapping = ds.cf["grid_mapping"].grid_mapping_name
        if mapping != "rotated_latitude_longitude":
            print(f"regridding {dset_id} with grid_mapping: {mapping}")
            regridder = create_regridder(ds, target_grid, method=method)
            print(regridder)
            dsets[dset_id] = regrid(ds, regridder)
    return dsets


def mask_invalid(ds, vars=None, threshold=0.1):
    if isinstance(vars, str):
        vars = [vars]
    if vars is None:
        var = list(ds.data_vars)
    for var in vars:
        var_nan = ds[var].isnull().sum(dim="time") / ds.time.size
        ds[var] = ds[var].where(var_nan < threshold)
    return ds


def rewrite_coords(ds, coords="all"):
    if ds.cf["grid_mapping"].grid_mapping_name == "rotated_latitude_longitude":
        ds = ds.cx.rewrite_coords(coords=coords)
    return ds


def height_temperature_correction(model_elev, obs_elev):
    """
    Height correction for temperature
    """
    lapse_rate = 0.0065  # °C per meter
    # Apply correction (adjust model temp to obs elevation)
    return lapse_rate * (obs_elev - model_elev)

In [None]:
dsets = open_datasets(
    ["tas", "pr"],
    frequency="mon",
    mip_era="CMIP6",
    add_missing_bounds=True,
)

In [None]:
rotated_grid = create_cordex_grid("EUR-12")
dsets = regrid_dsets(dsets, rotated_grid, method="bilinear")

In [None]:
period = slice("1980", "2020")

eobs_tg = obs.eobs(variables=["tg"], add_mask=False).sel(time=period)
eobs_tg = mask_invalid(eobs_tg, vars="tg", threshold=0.1)
eobs_rr = obs.eobs(variables=["rr"], add_mask=False).sel(time=period)
eobs_rr = mask_invalid(eobs_rr, vars="rr", threshold=0.1)
eobs = xr.merge([eobs_tg, eobs_rr], join="override")

In [None]:
# unmapped_to_nan, see https://github.com/pangeo-data/xESMF/issues/56
regridder_bil = xe.Regridder(
    eobs, rotated_grid, method="bilinear", unmapped_to_nan=True
)
regridder_cons = xe.Regridder(
    eobs, rotated_grid, method="conservative_normed", unmapped_to_nan=True
)

In [None]:
eobs_on_rotated = xr.merge(
    [regridder_bil(eobs[["tg", "elevation"]]), regridder_cons(eobs[["rr"]])],
    join="override",
)

In [None]:
def seasonal_mean(da):
    """Optimized function to calculate seasonal averages from time series of monthly means

    based on: https://xarray.pydata.org/en/stable/examples/monthly-means.html
    """
    # Get number od days for each month
    month_length = da.time.dt.days_in_month
    # Calculate the weights by grouping by 'time.season'.
    weights = (
        month_length.groupby("time.season") / month_length.groupby("time.season").sum()
    )

    # Test that the sum of the weights for each season is 1.0
    # np.testing.assert_allclose(weights.groupby("time.season").sum().values, np.ones(4))

    # Calculate the weighted average
    return (
        (da * weights).groupby("time.season").sum(dim="time", skipna=True, min_count=1)
    )

In [None]:
%%time

eobs_seasmean = seasonal_mean(eobs_on_rotated[["tg", "rr"]].sel(time=period)).compute()

In [None]:
%%time

bias_tas = {
    dset_id: (
        seasonal_mean(ds["tas"].sel(time=period)).compute()
        + height_temperature_correction(ds.orog, eobs_on_rotated.elevation)
        - (eobs_seasmean.tg + 273.15)
    ).rename("tas")
    for dset_id, ds in dsets.items()
}

In [None]:
%%time

bias_pr = {
    dset_id: (
        100
        * (
            seasonal_mean(ds["pr"].sel(time=period)).compute() * 86400
            - eobs_seasmean.rr
        )
        / eobs_seasmean.rr
        # Assign NaN where observed precipitation is zero
    ).rename("pr")
    for dset_id, ds in dsets.items()
}

In [None]:
seasonal_bias_tas = xr.concat(
    list(bias_tas.values()),
    dim=xr.DataArray(
        list(
            map(
                lambda x: short_iid(x, ["institution_id", "source_id"], delimiter="-"),
                bias_tas.keys(),
            )
        ),
        dims="dset_id",
    ),
    compat="override",
    coords="minimal",
)

seasonal_bias_pr = xr.concat(
    list(bias_pr.values()),
    dim=xr.DataArray(
        list(
            map(
                lambda x: short_iid(x, ["institution_id", "source_id"], delimiter="-"),
                bias_pr.keys(),
            )
        ),
        dims="dset_id",
    ),
    compat="override",
    coords="minimal",
)

seasonal_bias = xr.merge(
    [seasonal_bias_tas, seasonal_bias_pr],
)

In [None]:
%%time
seasonal_bias_ = seasonal_bias.compute()

In [None]:
import numpy as np


def plot_all(ds, plot_var, levels, cmap):
    import matplotlib.pyplot as plt
    from cartopy import crs as ccrs
    from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter

    nrows = ds.dset_id.size
    ncols = ds.season.size

    aspect = ds.cf.dims["Y"] / ds.cf.dims["X"]
    print(f"aspect: {aspect}")
    data = ds[plot_var]

    pole_latitude = rotated_grid.cf["grid_mapping"].grid_north_pole_latitude
    pole_longitude = rotated_grid.cf["grid_mapping"].grid_north_pole_longitude
    transform = ccrs.RotatedPole(
        pole_latitude=pole_latitude, pole_longitude=pole_longitude
    )
    projection = transform

    # Define the figure and each axis for the 3 rows and 3 columns
    fig, axs = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        subplot_kw={"projection": projection},
        figsize=(18, 18 * 1.05 * nrows / ncols),
        sharex=True,
        sharey=True,
        gridspec_kw={"wspace": 0, "hspace": 0},
        #  aspect_ratio=0.97,
    )

    # axs is a 2 dimensional array of `GeoAxes`.  We will flatten it into a 1-D array
    axs = axs.flatten()

    # Loop over all of the models
    for i, season in enumerate(ds.season.values):
        for j, dset_id in enumerate(ds.dset_id.values):
            pos = i + j * ncols
            print(i, j, pos, season, dset_id)
            # Select the week 1 forecast from the specified model
            data = ds[plot_var].isel(season=i, dset_id=j)

            # Contour plot
            cs = axs[pos].contourf(
                ds.cf["X"],
                ds.cf["Y"],
                data,
                transform=transform,
                # Define the levels for contourf
                levels=levels,
                # cmap="coolwarm",
                extend="both",
                cmap=cmap,
            )

            axs[pos].set_aspect(round(aspect, 3))

            # Draw the coastines for each subplot
            axs[pos].coastlines(resolution="50m", color="black", linewidth=0.5)

            axs[pos].gridlines(
                draw_labels=False,
                linewidth=0.3,
                color="gray",
                xlocs=range(-180, 180, 10),
                ylocs=range(-90, 90, 10),
            )

            # Longitude labels
            # https://stackoverflow.com/questions/35479508/cartopy-set-xlabel-set-ylabel-not-ticklabels
            if i == 3:
                axs[pos].text(
                    1.1,
                    0.55,
                    dset_id,
                    va="bottom",
                    ha="center",
                    rotation="vertical",
                    rotation_mode="anchor",
                    transform=axs[pos].transAxes,
                )
            if j == 0:
                axs[pos].text(
                    0.55,
                    1.05,
                    season,
                    va="bottom",
                    ha="center",
                    rotation="horizontal",
                    rotation_mode="anchor",
                    transform=axs[pos].transAxes,
                )

            lon_formatter = LongitudeFormatter(zero_direction_label=True)
            lat_formatter = LatitudeFormatter()
            axs[pos].xaxis.set_major_formatter(lon_formatter)
            axs[pos].yaxis.set_major_formatter(lat_formatter)

    # Add a colorbar axis at the bottom of the graph
    cbar_ax = fig.add_axes([0.2, 0.05, 0.6, 0.02])

    # Draw the colorbar
    fig.colorbar(cs, cax=cbar_ax, orientation="horizontal")


plot_var = "pr"
levels = np.arange(-100, 110, 10)  #
cmap = "BrBG"
plot_all(seasonal_bias_[[plot_var]], plot_var, levels, cmap)

plot_var = "tas"
levels = np.arange(-8, 9, 1)
cmap = "RdBu_r"
plot_all(seasonal_bias_[[plot_var]], plot_var, levels, cmap)