# Example
In this example, we'll compare ERA5 reanalysis outputs with a single climate model.

In [None]:
import pathlib
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import seaborn as sns
import time
import cmocean
import copy
import os
import xesmf as xe
import intake
import scipy.stats

## (optional) remove gridlines from plots
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## Data loading

### Functions

#### for loading from cloud

In [None]:
def prep_lsm(lsm, lon_range, lat_range):
    """Prepare land-sea mask by renaming coords,
    downsampling to 1x1 grid, and selecting specified lon/lat range"""

    ## rename coords
    lsm = lsm.rename({"latitude": "lat", "longitude": "lon"})

    ## get coords for interpolation (downscale to 1x1 grid)
    lon = np.arange(lon_range[0], lon_range[1] + 1, 1)
    lat = np.arange(lat_range[0], lat_range[1] + 1, 1)
    new_coords = dict(lon=lon, lat=lat)

    ## interpolate to grid
    lsm = lsm.interp(new_coords)

    ## add binary mask for regridding
    lsm["mask"] = (lsm < 0.5).astype(int)

    return lsm.transpose("lat", "lon")


def load_lsm_from_cloud(lon_range, lat_range):
    """Load ERA5 land-sea-mask from Google server"""

    ## use ERA5 land-sea mask
    data = xr.open_zarr(
        "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr",
    )

    ## load lsm into memory
    lsm = data["land_sea_mask"].compute()

    return prep_lsm(lsm, lon_range=lon_range, lat_range=lat_range)


def load_era5_from_cloud(lon_range, lat_range, apply_ocean_mask=True):
    """Load ERA5 from Google server with consistent processing"""

    ## open data and get 2m temperature
    data = xr.open_zarr(
        "gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-240x121_equiangular_with_poles_conservative.zarr",
        chunks=dict(time=1024),
    )["2m_temperature"]

    ## subset for lon/lat range
    lonlat_idx = dict(longitude=slice(*lon_range), latitude=slice(*lat_range))
    data = data.sel(**lonlat_idx)

    ## load into memory
    data.load()

    ## convert kelvin to celsius
    data_celsius = data.copy() - 273.15

    ## resample from 6-hourly to monthly
    data = data_celsius.resample({"time": "MS"}).mean()

    ## transpose data (consistent with data on server)
    data = data.transpose("time", "latitude", "longitude")

    ## Apply ocean mask if requested
    if apply_ocean_mask:
        lsm = load_lsm_from_cloud(lon_range=lon_range, lat_range=lat_range)
        ocean_mask = lsm["mask"]
        ocean_mask = ocean_mask.rename({"lat": "latitude", "lon": "longitude"})

        # Interpolate ocean mask to ERA5 grid
        ocean_mask_interp = ocean_mask.interp(
            latitude=data.latitude, longitude=data.longitude
        )

        # Apply mask (keep only ocean points)
        data = data.where(ocean_mask_interp == 1)

    return data


def load_cesm_from_cloud(
    lon_range,
    lat_range,
    varname="TREFHT",
    load_ssp370=False,
    member_id=1,
    apply_ocean_mask=True,
):
    """Load CESM data from cloud with consistent processing"""

    ## get catalog of available data
    catalog = intake.open_esm_datastore(
        "https://raw.githubusercontent.com/NCAR/cesm2-le-aws/main/intake-catalogs/aws-cesm2-le.json"
    )

    ## subset for temperature data
    catalog_subset = catalog.search(variable=varname, frequency="monthly")

    ## kwargs for opening data
    kwargs = dict(
        aggregate=True,
        xarray_open_kwargs=dict(engine="zarr", decode_timedelta=True),
        zarr_kwargs={"consolidated": True},
        storage_options={"anon": True},
    )

    ## open data (but don't load to memory)
    dsets = catalog_subset.to_dataset_dict(**kwargs)
    data = dsets["atm.historical.monthly.cmip6"]

    ## optionally load ssp data as well
    if load_ssp370:
        data = xr.concat([data, dsets["atm.ssp370.monthly.cmip6"]], dim="time")

    ## trim data (select ensemble members and lon/lat space)
    lonlat_idx = dict(lon=slice(*lon_range), lat=slice(*lat_range))
    data = data.sel(lonlat_idx).isel(member_id=member_id)

    ## convert kelvin to celsius
    data = data[varname] - 273.15

    ## swap longitude range
    data = swap_longitude_range(data)

    ## load to memory
    data = data.compute()

    ## Regrid to regular grid and optionally apply ocean mask
    lsm = load_lsm_from_cloud(lon_range=lon_range, lat_range=lat_range)
    regridder = xe.Regridder(data, lsm, "bilinear", ignore_degenerate=False)
    data = regridder(data)

    ## Apply ocean mask if requested
    if apply_ocean_mask:
        ocean_mask = lsm["mask"]
        data = data.where(ocean_mask == 1)

    return data

#### for loading from server

In [None]:
def load_simulation(
    server_fp, varname, member_id, simulation_type, preprocess_func=None
):
    """
    Load dataset for single simulation, for single variable.
    Arguments:
        - varname: name of variable to load, one of {"SST","PSL"}
        - member_id: ID of ensemble member to load, an integer in the range [1,10]
        - simulation_type: one of {"hist", "rcp85"}
    Returns:
        - xarray dataarray with given data
    """

    ## Filepath to the CESM LENS dataset
    lens_fp = pathlib.Path("cmip6/data/cmip6/CMIP/NCAR/LENS")

    #### 1. get filepath to data
    data_fp = server_fp / lens_fp / pathlib.Path(varname)

    #### 2. get naming pattern for files to open
    if simulation_type == "hist":
        file_pattern = f"*20TRC*.{member_id:03d}.*.nc"

    elif simulation_type == "rcp85":
        file_pattern = f"*RCP85*.{member_id:03d}.*.nc"

    else:
        print("Not a valid simulation type")

    #### 3. open the relevant datasets, applying preprocessing function

    ## filepath to data
    fp = list(data_fp.glob(file_pattern))[0]

    ## load data
    data = xr.open_dataset(fp, chunks=None, decode_timedelta=True)

    ## apply (optional) preprocessing
    if preprocess_func is not None:
        data = preprocess_func(data)

    return data[varname].squeeze(drop=True)


def load_lsm_from_server(server_fp, lon_range=[0, 360], lat_range=[-90, 90]):
    """Load ERA5 land-sea-mask from CMIP6 server"""

    ## get server path to ERA5 land sea mask
    lsm_fp = server_fp / pathlib.Path(
        "cmip6/data/era5/reanalysis/single-levels/monthly-means/land_sea_mask/2020_land_sea_mask.nc"
    )

    ## open lsm
    lsm = xr.open_dataarray(lsm_fp).isel(time=0).drop_vars("time")

    return prep_lsm(lsm, lon_range=lon_range, lat_range=lat_range)


def load_cesm_from_server(lon_range, lat_range, server_fp, varname, member_id=10):
    """Load CESM SST data with minimal preprocessing - SST datasets are typically already ocean-masked"""
    ## shared arguments for loading data
    load_kwargs = dict(server_fp=server_fp, varname=varname, member_id=member_id)

    ## Load data
    data_hist = load_simulation(simulation_type="hist", **load_kwargs).compute()
    data_rcp = load_simulation(simulation_type="rcp85", **load_kwargs).compute()

    ## concatenate in time
    data = xr.concat([data_hist, data_rcp], dim="time")

    ## rename coordinates for convenience
    data = data.rename({"TLONG": "lon", "TLAT": "lat"})

    ## subset data by longitude and latitude
    data = trim(data, lon_range=lon_range, lat_range=lat_range)

    ## swap longitude range from [0,360) to (-180, 180] if needed
    data = swap_longitude_range(data)

    ## make sure longitude is in ascending order
    data = sort_longitude(data)

    # Only apply basic masking to remove obvious invalid values (like 0 or very negative values)
    # SST datasets should already be ocean-masked
    data = data.where(data > -10, other=np.nan)  # Remove obviously invalid SST values

    return data


def load_era5_from_server(server_fp, lon_range, lat_range):
    """Load ERA5 data from CMIP6 server"""

    ## Filepath to the ERA5 reanalysis
    era5_fp = pathlib.Path("cmip6/data/era5/reanalysis/single-levels/monthly-means")

    ## sea surface temperature (SST) filepaths
    era5_fp_sst = server_fp / era5_fp / pathlib.Path("sea_surface_temperature")

    ## open the data
    data = xr.open_mfdataset(era5_fp_sst.glob("*.nc"))["sst"]

    ## convert kelvin to celsius
    data_celsius = data.copy() - 273.15

    ## select lon/lat range
    lonlat_idx = dict(longitude=slice(*lon_range), latitude=slice(*lat_range[::-1]))
    data = data_celsius.sel(lonlat_idx).compute()

    ## put latitudes in ascending order
    data = data.reindex({"latitude": data["latitude"].values[::-1]})

    return data

#### utilities

In [None]:
def trim(data, lon_range, lat_range):
    """select part of data in given longitude/latitude range"""

    ## helper function to check if 'x' is in 'x_range'
    isin_range = lambda x, x_range: (x_range[0] <= x) & (x <= x_range[1])

    ## get mask for data in given lon/lat range
    in_lon_range = isin_range(data["lon"], lon_range)
    in_lat_range = isin_range(data["lat"], lat_range)
    in_lonlat_range = in_lon_range & in_lat_range

    ## load to memory
    in_lonlat_range.load()

    if "nlon" in data.dims:

        ## Retain all points with at least one valid grid cell
        x_idx = in_lonlat_range.any("nlat")
        y_idx = in_lonlat_range.any("nlon")

        return data.isel(nlon=x_idx, nlat=y_idx)

    else:

        return data.isel(lon=in_lon_range, lat=in_lat_range)


def sort_longitude(data):
    """shuffles data so that longitude is monotonically increasing"""

    ## Transpose data so that longitude is last dimension
    ## (we'll do all the sorting along this dimension)
    data = data.transpose(..., "nlon")

    ## Get indices needed to sort longitude to be monotonic increasing
    lon_sort_idx = np.argsort(data["lon"].values, axis=-1)

    ## sort the lon/lat coordindates
    sort = lambda X, idx: np.take_along_axis(X.values, indices=idx, axis=-1)
    data["lon"].values = sort(data["lon"], idx=lon_sort_idx)
    data["lat"].values = sort(data["lat"], idx=lon_sort_idx)

    #### sort the data

    # first, check to see if data has more than two dimensions
    if data.ndim > 2:
        extra_dims = [i for i in range(data.ndim - 2)]
        lon_sort_idx = np.expand_dims(lon_sort_idx, axis=extra_dims)

    ## now, do the actual sorting
    data.values = sort(data, idx=lon_sort_idx)

    return data


def swap_longitude_range(data):
    """swap longitude range of xr.DataArray from [0,360) to (-180, 180]."""

    ## make copy of longitude coordinate to be modified
    lon_new = copy.deepcopy(data.lon.values)

    ## relabel values greater than 180
    exceeds_180 = lon_new > 180
    lon_new[exceeds_180] = -360 + lon_new[exceeds_180]

    ## Update the coordinate on the xarray object
    if "lon" in data.dims:
        data = data.assign_coords({"lon": lon_new})

    else:
        data["lon"].values = lon_new

    return data

#### Data-handling

In [None]:
def ensure_consistent_grids(era5_data, cesm_data, target_grid=None):
    """Ensure both datasets are on the same spatial grid"""

    if target_grid is None:
        # Use ERA5 grid as target by default
        target_grid = era5_data

    # Check if CESM needs regridding to match ERA5
    if not (
        np.array_equal(cesm_data.lat.values, target_grid.latitude.values)
        and np.array_equal(cesm_data.long.values, target_grid.longitude.values)
    ):

        print("Regridding CESM data to match ERA5 grid...")
        regridder = xe.Regridder(
            cesm_data, target_grid, "bilinear", ignore_degenerate=False
        )
        cesm_regridded = regridder(cesm_data)
        return era5_data, cesm_regridded

    return era5_data, cesm_data


def load_comparable_datasets(
    server_fp=None,
    lon_range=[260, 360],
    lat_range=[10, 70],
    member_id=10,
    load_from_cloud=True,
    load_ssp370=True,
):
    """
    Load ERA5 and CESM datasets with consistent processing for comparison

    Args:
        server_fp: Path to server data (only needed if load_from_cloud=False)
        lon_range: Longitude range [min, max]
        lat_range: Latitude range [min, max]
        member_id: CESM ensemble member ID
        load_from_cloud: Whether to load from cloud (True) or server (False)
        load_ssp370: Whether to include SSP370 scenario for CESM

    Returns:
        tuple: (era5_data, cesm_data) - aligned and comparable datasets
    """

    kwargs = dict(lon_range=lon_range, lat_range=lat_range)

    if load_from_cloud:
        print("Loading from cloud...")
        print("- ERA5: 2m temperature with ocean mask")
        print("- CESM: 2m temperature with ocean mask")

        # Load atmospheric 2m temp with ocean mask for both
        era5_data = load_era5_from_cloud(apply_ocean_mask=True, **kwargs)
        cesm_data = load_cesm_from_cloud(
            varname="TREFHT",
            load_ssp370=load_ssp370,
            member_id=member_id,
            apply_ocean_mask=True,
            **kwargs,
        )

    else:
        print("Loading from server...")
        print("- ERA5: SST (ocean only)")
        print("- CESM: SST (ocean only)")

        # Load ocean-only SST data
        era5_data = load_era5_from_server(server_fp, **kwargs)
        cesm_data = load_cesm_from_server(
            varname="SST", server_fp=server_fp, member_id=member_id, **kwargs
        )

    # Align temporal bounds
    # era5_data, cesm_data = align_temporal_bounds(era5_data, cesm_data)

    # Ensure consistent spatial grids
    era5_data, cesm_data = ensure_consistent_grids(era5_data, cesm_data)

    # Print summary statistics
    print(f"\nDataset summary:")
    print(f"ERA5 shape: {era5_data.shape}")
    print(f"CESM shape: {cesm_data.shape}")

    return era5_data, cesm_data

### Do the data-loading

In [None]:
## Execute data loading
SERVER_FP = pathlib.Path("/Volumes")
LON_RANGE = [260, 359.9]
LAT_RANGE = [0, 70]
LOAD_FROM_CLOUD = False

## try to suppress file locking
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

## Load comparable datasets
t0 = time.time()
era5_data, cesm_data = load_comparable_datasets(
    server_fp=SERVER_FP,
    lon_range=LON_RANGE,
    lat_range=LAT_RANGE,
    member_id=10,
    load_from_cloud=LOAD_FROM_CLOUD,
    load_ssp370=True,
)

print(f"\nTotal loading time: {time.time() - t0:.1f} seconds")

## Analysis

Lets do a similar timeseries analysis as we did in tutorial 1 but now compare the two datasets:

### Functions

In [None]:
def spatial_avg(data):
    """function to compute spatial average of data on grid with constant
    longitude/latitude spacing."""

    ## determine coordinate names based on what's available
    if "latitude" in data.coords:
        lat_name = "latitude"
        lon_name = "longitude"
    elif "lat" in data.coords:
        lat_name = "lat"
        lon_name = "lon"
    else:
        raise ValueError(
            "Data must have either 'latitude'/'longitude' or 'lat'/'lon' coordinates"
        )

    ## first, compute cosine of latitude (after converting degrees to radians)
    latitude_radians = np.deg2rad(data[lat_name])
    cos_lat = np.cos(latitude_radians)

    ## get weighted average using xarray
    avg = data.weighted(weights=cos_lat).mean([lon_name, lat_name])

    return avg


def compute_T_wh(x, ERA=False):
    """Compute Woods Hole temperature index"""
    # ERA = True for era5, False for cesm
    ERA = bool(ERA)

    if ERA == True:
        lonlat_idx = dict(longitude=slice(287.5, 293.5), latitude=slice(39, 44))
    elif ERA == False:
        lonlat_idx = dict(lon=slice(287.5, 293.5), lat=slice(39, 44))
    else:
        raise ValueError("ERA must be True or False")

    ## get subset of data inside the box
    data_subset = x.sel(lonlat_idx)

    return spatial_avg(data_subset)


def compute_T_wh(x):
    """Compute Woods Hole temperature index"""

    lonlat_idx = dict(longitude=slice(287.5, 293.5), latitude=slice(39, 44))

    ## get subset of data inside the box
    data_subset = x.sel(lonlat_idx)

    return spatial_avg(data_subset)


def get_empirical_pdf(x, bin_edges=None):
    """
    Estimate the "empirical" probability distribution function for the data x.
    In this case the result is a normalized histogram,
    Normalized means that integrating over the histogram yields 1.
    Returns the PDF (normalized histogram) and edges of the histogram bins
    """

    ## compute histogram
    if bin_edges is None:
        hist, bin_edges = np.histogram(x)

    else:
        hist, _ = np.histogram(x, bins=bin_edges)

    ## normalize to a probability distribution (PDF)
    bin_width = bin_edges[1:] - bin_edges[:-1]
    pdf = hist / (hist * bin_width).sum()

    return pdf, bin_edges


def get_gaussian_best_fit(x):
    """Get gaussian best fit to data, and evaluate
    probabilities over the range of the data."""

    ## get normal distribution best fit
    gaussian = scipy.stats.norm(loc=x.mean(), scale=x.std())

    ## evaluate over range of data
    # Fix: x is already a numpy array, so use x directly instead of x.values
    amp = np.max(np.abs(x))
    x_eval = np.linspace(-amp, amp)
    pdf_eval = gaussian.pdf(x_eval)

    return pdf_eval, x_eval


# Compute JAS (July-August-September) seasonal averages
def get_jas_averages(data):
    """Extract JAS seasonal averages from monthly data"""
    # Get month numbers
    months = data.time.dt.month

    # Create mask for JAS months (7, 8, 9)
    jas_mask = (months >= 7) & (months <= 9)

    # Select JAS months
    jas_data = data.isel(time=jas_mask)

    # Group by year and compute mean for each JAS season
    jas_yearly = jas_data.groupby("time.year").mean("time")

    return jas_yearly


def make_cb_range(amp, delta):
    """Make colorbar_range for cmo.balance
    Args:
        - 'amp': amplitude of maximum value for colorbar
        - 'delta': increment for colorbar
    """
    return np.concatenate(
        [np.arange(-amp, 0, delta), np.arange(delta, amp + delta, delta)]
    )


def get_trend(data, dim="time", deg=1):
    """
    Get trend for an xr.dataarray along specified dimension,
    by fitting polynomial of degree 'deg'.
    """

    ## Get coefficients for best fit
    polyfit_coefs = data.polyfit(dim=dim, deg=deg)["polyfit_coefficients"]

    ## Get best fit line (linear trend in this case)
    trend = xr.polyval(data[dim], polyfit_coefs)

    return trend


def detrend(data, dim="time", deg=1):
    """
    Remove trend of degree 'deg' from data, along dimension 'dim'.
    """

    return data - get_trend(data, dim=dim, deg=deg)


def detrend_by_month(data, deg=1):
    """
    Remove trend of degree "deg" from data along "dim", for each month separately
    """

    return data.groupby("time.month").map(detrend, dim="time", deg=deg)

### 1. Timeseries of $T_{wh}$

In [None]:
# Compute T_wh for both datasets
print("Computing Woods Hole temperature index...")
t_wh_era5 = compute_T_wh(era5_data)
t_wh_cesm = compute_T_wh(cesm_data)

# Select time period 1979-2006 first
t_wh_era5_1979_2006 = t_wh_era5.sel(time=slice("1979-01-01", "2006-12-31"))
t_wh_cesm_1979_2006 = t_wh_cesm.sel(time=slice("1979-01-01", "2006-12-31"))

# Compute JAS averages for both datasets
era5_jas = get_jas_averages(t_wh_era5_1979_2006)
cesm_jas = get_jas_averages(t_wh_cesm_1979_2006)

# Create timeseries plot
plt.figure(figsize=(5, 3), layout="constrained")
plt.plot(
    era5_jas.year,
    era5_jas.values,
    "b-o",
    linewidth=2,
    markersize=6,
    alpha=0.8,
    label="ERA5",
)
plt.plot(
    cesm_jas.year,
    cesm_jas.values,
    "r-s",
    linewidth=2,
    markersize=6,
    alpha=0.8,
    label="CESM",
)

plt.xlabel("Year")
plt.ylabel("JAS Temperature (°C)")
plt.title("Woods Hole Temperature Index (JAS season; 1979-2006)")
plt.legend()
plt.grid(True, alpha=0.3)

# ## save to file
# plt.savefig("figs/timeseries.svg")

plt.show()

# Print some statistics
print(f"\nERA5 JAS statistics (1979-2006):")
print(f"  Mean: {era5_jas.mean().values:.2f}°C")
print(f"  Std:  {era5_jas.std().values:.2f}°C")
print(f"  Min:  {era5_jas.min().values:.2f}°C")
print(f"  Max:  {era5_jas.max().values:.2f}°C")

print(f"\nCESM JAS statistics (1979-2006):")
print(f"  Mean: {cesm_jas.mean().values:.2f}°C")
print(f"  Std:  {cesm_jas.std().values:.2f}°C")
print(f"  Min:  {cesm_jas.min().values:.2f}°C")
print(f"  Max:  {cesm_jas.max().values:.2f}°C")

### 2. Histogram of $T_{wh}$

In [None]:
## Preprocess: remove linear trend
cesm_detrend = detrend_by_month(t_wh_cesm_1979_2006)
era5_detrend = detrend_by_month(t_wh_era5_1979_2006)

## Compute histograms
if LOAD_FROM_CLOUD:
    edges = np.arange(-4.5, 4.5, 0.5) + 0.25
else:
    edges = np.arange(-2.4, 2.4, 0.3) + 0.15

pdf_era5, _ = get_empirical_pdf(era5_detrend, bin_edges=edges)
pdf_cesm, _ = get_empirical_pdf(cesm_detrend, bin_edges=edges)

#### Plot result
fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")

## plot histogram
ax.stairs(values=pdf_era5, edges=edges, label="ERA5", fill=True, alpha=0.3)
ax.stairs(values=pdf_cesm, edges=edges, label="CESM")


## plot zero line
ax.axvline(0, ls="--", c="k", lw=0.8)

## label
ax.set_xlabel(r"$K$ anomaly")
ax.set_ylabel("Prob. density")

ax.legend(prop=dict(size=8))

# ## save to file
# fig.savefig("figs/histogram.svg")

plt.show()

### 3. Spatial bias

In [None]:
# Align datasets to common time period (1979-2006)
era5_common = era5_data.sel(time=slice("1979-01-01", "2006-12-31"))
cesm_common = cesm_data.sel(time=slice("1979-01-01", "2006-12-31"))

# Compute JAS averages for both datasets
era5_jas_mean = get_jas_averages(era5_common).mean("year")
cesm_jas_mean = get_jas_averages(cesm_common).mean("year")

# Compute error (ERA5 - CESM)
error = cesm_jas_mean - era5_jas_mean

# Create spatial error plot with shared colorbar
fig = plt.figure(figsize=(20, 6))

# Create GridSpec for better control of subplot layout
gs = fig.add_gridspec(1, 4, width_ratios=[1, 1, 1, 0.1], wspace=0.3)

# ERA5 mean
ax1 = fig.add_subplot(gs[0], projection=ccrs.PlateCarree())
im1 = ax1.contourf(
    era5_jas_mean.longitude,
    era5_jas_mean.latitude,
    era5_jas_mean.values,
    levels=np.arange(-3, 33, 3),
    cmap="cmo.thermal",
    transform=ccrs.PlateCarree(),
)
ax1.set_title("ERA5 Mean Temperature (1979-2006)")
ax1.coastlines()

# CESM mean
ax2 = fig.add_subplot(gs[1], projection=ccrs.PlateCarree())
im2 = ax2.contourf(
    cesm_jas_mean.longitude,
    cesm_jas_mean.latitude,
    cesm_jas_mean.values,
    levels=np.arange(-3, 33, 3),
    cmap="cmo.thermal",
    transform=ccrs.PlateCarree(),
)
ax2.set_title("CESM Mean Temperature (1979-2006)")
ax2.coastlines()

# Error (CESM - ERA5)
ax3 = fig.add_subplot(gs[2], projection=ccrs.PlateCarree())
im3 = ax3.contourf(
    error.longitude,
    error.latitude,
    error.values,
    levels=make_cb_range(5, 0.5),
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
)
ax3.set_title("Error: CESM - ERA5")
ax3.coastlines()

# Shared colorbar for ERA5 and CESM (left side)
cbar_ax1 = fig.add_axes([0.05, 0.25, 0.02, 0.5])  # [left, bottom, width, height]
cbar1 = plt.colorbar(im1, cax=cbar_ax1, orientation="vertical")
cbar1.set_label("Temperature (°C)")

# Separate colorbar for error (right side)
cbar_ax2 = fig.add_axes([0.85, 0.25, 0.02, 0.5])  # [left, bottom, width, height]
cbar2 = plt.colorbar(im3, cax=cbar_ax2, orientation="vertical")
cbar2.set_label("Temperature Difference (°C)")

# ## save to file
# fig.savefig("figs/spatial-bias.svg")

plt.show()

The mean temperature plots look fairly similar; however, the error plot reveals some differences between the datasets.

***Note:*** The errors differ for the server/cloud (sst/2m atmospheric temp) data sets. If you are able it is worth reading in both to have a look at the differeces. 

***Potential extension***: How does this error behave for another region or variable?

### 4. Histogram of gridcell-level bias

In [None]:
# Define Cape Cod region (similar to Woods Hole but slightly larger)
# Cape Cod roughly: lon 287-295, lat 40-43
wh_lon = slice(287.5, 293.5)
wh_lat = slice(39, 44)

# Extract Cape Cod subset
error_wh = error.sel(longitude=wh_lon, latitude=wh_lat)

# Flatten both datasets
error_flat = error.values.flatten()
error_wh_flat = error_wh.values.flatten()

# Remove NaN values
error_flat = error_flat[~np.isnan(error_flat)]
error_wh_flat = error_wh_flat[~np.isnan(error_wh_flat)]

# Create bin edges for both histograms (use same range for comparison)
all_errors = np.concatenate([error_flat, error_wh_flat])
bin_edges = np.linspace(all_errors.min(), all_errors.max(), 50)

# Compute PDFs using your get_empirical_pdf function
error_pdf, error_edges = get_empirical_pdf(error_flat, bin_edges=bin_edges)
error_wh_pdf, error_wh_edges = get_empirical_pdf(error_wh_flat, bin_edges=bin_edges)

# Compute Gaussian fits
error_gauss, error_gauss_pts = get_gaussian_best_fit(error_flat)
error_cape_cod_gauss, error_cape_cod_gauss_pts = get_gaussian_best_fit(error_wh_flat)

# Create subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Plot 1: Whole Region
ax1.stairs(
    values=error_pdf,
    edges=error_edges,
    color="blue",
    alpha=0.7,
    label="Whole Region",
    linewidth=2,
)
ax1.plot(
    error_gauss_pts, error_gauss, "b--", linewidth=1.5, alpha=0.8, label="Gaussian Fit"
)
ax1.set_xlabel("Temperature Error (°C)")
ax1.set_ylabel("Probability Density")
ax1.set_title("Error Distribution: Whole Region (CESM - ERA5)")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Cape Cod Region
ax2.stairs(
    values=error_wh_pdf,
    edges=error_wh_edges,
    color="red",
    alpha=0.7,
    label="Cape Cod Region",
    linewidth=2,
)
ax2.plot(
    error_cape_cod_gauss_pts,
    error_cape_cod_gauss,
    "r--",
    linewidth=1.5,
    alpha=0.8,
    label="Gaussian Fit",
)
ax2.set_xlabel("Temperature Error (°C)")
ax2.set_ylabel("Probability Density")
ax2.set_title("Error Distribution: Cape Cod Region (CESM - ERA5)")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

What do you think are reasons for the bias? Try plotting the data set without applying the contours (ie make a more simple plot than we did above–you can use xarray's built in plotting command ds.plot() )