# EOBS seasonal means

In [None]:
import dask
import intake
import xesmf as xe
from dask.distributed import Client
from evaltools import obs

dask.config.set(scheduler="single-threaded")

In [None]:
client = Client(dashboard_address=None, threads_per_worker=1)
client

In [None]:
catalog = intake.open_esm_datastore(
    "https://raw.githubusercontent.com/euro-cordex/joint-evaluation/refs/heads/main/CORDEX-CMIP6.json"
)
catalog.keys()

In [None]:
dataset_dict = catalog.search(variable_id=["tas", "orog", "sftlf"]).to_dataset_dict()

In [None]:
eobs = obs.eobs(add_mask=True)
eobs

In [None]:
import xarray as xr

tas = dataset_dict[
    "CORDEX.EUR-12.GERICS.ERA5.evaluation.r1i1p1f1.REMO2020.v1.mon.tas.v20241120"
]
orog = dataset_dict[
    "CORDEX.EUR-12.GERICS.ERA5.evaluation.r1i1p1f1.REMO2020.v1.fx.orog.v20241120"
]
sftlf = dataset_dict[
    "CORDEX.EUR-12.GERICS.ERA5.evaluation.r1i1p1f1.REMO2020.v1.fx.sftlf.v20241120"
]

tas["mask"] = xr.where(sftlf.sftlf > 0, 1, 0)
tas["tas"] = tas.tas.where(tas.mask == 1)
tas.tas.isel(time=0).plot()

In [None]:
%%time
regridder = xe.Regridder(eobs, tas, "bilinear")
regridder

In [None]:
eobs_regridded = regridder(eobs)
eobs_regridded.tg.isel(time=0).plot()

In [None]:
def seasonal_mean(da):
    """Optimized function to calculate seasonal averages from time series of monthly means

    based on: https://xarray.pydata.org/en/stable/examples/monthly-means.html
    """

    # Get number od days for each month
    month_length = da.time.dt.days_in_month
    # Calculate the weights by grouping by 'time.season'.
    weights = (
        month_length.groupby("time.season") / month_length.groupby("time.season").sum()
    )

    # Test that the sum of the weights for each season is 1.0
    # np.testing.assert_allclose(weights.groupby("time.season").sum().values, np.ones(4))

    # Calculate the weighted average
    return (da * weights).groupby("time.season").sum(dim="time")

In [None]:
%%time

period = slice("1980", "2020")

eobs_seasmean = seasonal_mean(eobs.tg.sel(time=period)).compute()
diffs = (regridder(eobs_seasmean) - seasonal_mean(tas.tas - 273.5)).compute()

In [None]:
import matplotlib.pyplot as plt
from cartopy import crs as ccrs

# Subplots are organized in a Rows x Cols Grid
# Tot and Cols are known
Tot = 4
Cols = 2
# Compute Rows required
Rows = Tot // Cols
Rows += Tot % Cols
# Create a Position index
Position = range(1, Tot + 1)

In [None]:
nrows = 2
ncols = 2

ds = diffs.where(~diffs.isnull(), drop=True)
pole = tas.cf["grid_mapping"]
transform = ccrs.RotatedPole(
    pole_latitude=pole.grid_north_pole_latitude,
    pole_longitude=pole.grid_north_pole_longitude,
)
projection = transform

# Define the figure and each axis for the 3 rows and 3 columns
fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    subplot_kw={"projection": projection},
    figsize=(12, 12),
    sharex=True,
    sharey=True,
)

# axs is a 2 dimensional array of `GeoAxes`.  We will flatten it into a 1-D array
axs = axs.flatten()

# Loop over all of the models
for i, season in enumerate(diffs.season.values):

    # Select the week 1 forecast from the specified model
    data = ds.isel(season=i)

    # Add the cyclic point
    # data,lons=add_cyclic_point(data,coord=ds['lon'])

    # Contour plot
    cs = axs[i].contourf(
        ds.cf["X"],
        ds.cf["Y"],
        data,
        transform=transform,
        levels=18,
        cmap="coolwarm",
        extend="both",
    )

    # Title each subplot with the name of the model
    axs[i].set_title(season)

    # Draw the coastines for each subplot
    axs[i].coastlines(resolution="50m", color="black", linewidth=1)

    axs[i].gridlines(
        draw_labels=False,
        linewidth=0.5,
        color="gray",
        xlocs=range(-180, 180, 10),
        ylocs=range(-90, 90, 10),
    )
    # axs[i].xaxis.set_tick_position('bottom')
    # axs[i].yaxis.set_tick_position('left')

    # Longitude labels
    # axs[i].set_xticks(range(-180, 180, 10), crs=ccrs.PlateCarree())
    # lon_formatter = cticker.LongitudeFormatter()
    # axs[i].xaxis.set_major_formatter(lon_formatter)

    # # Latitude labels
    # axs[i].set_yticks(range(-90, 90, 10), crs=ccrs.PlateCarree())
    # lat_formatter = cticker.LatitudeFormatter()
    # axs[i].yaxis.set_major_formatter(lat_formatter)


# Adjust the location of the subplots on the page to make room for the colorbar
fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.9, wspace=0.08, hspace=0.08)

# Add a colorbar axis at the bottom of the graph
cbar_ax = fig.add_axes([0.2, 0.05, 0.6, 0.02])

# Draw the colorbar
cbar = fig.colorbar(cs, cax=cbar_ax, orientation="horizontal")

# Add a big title at the top
# plt.suptitle('SubX Week 1 2m Temperature Anomalies ($^\circ$C): Apr 16, 2020 Initialized Forecasts')