# Overview
__In this notebook, we'll:__
- Walk through an example of model validation

# Preliminaries

#### Specify filepath to data

In [None]:
## filepath for the mount
cmip6_fp = "/Volumes/cmip6/data"

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import cmocean
import seaborn as sns
import glob
import cftime
import src.utils
import cartopy.crs as ccrs
import matplotlib.ticker as mticker
import matplotlib.dates as mdates
import xesmf as xe

## set plotting style
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## initialize random number generator
rng = np.random.default_rng()

#### Load data

In [None]:
## function to subset longitude/latitude
def trim_to_north_atl(x):
    """trims data to the North Pacific region"""

    ## lon/lat boundaries for region to subset
    lon_range = [258.5, 318.5]
    lat_range = [60, 20]
    # latitude=41.5, longitude=288.5

    ## trim the data
    x_trimmed = x.sel(longitude=slice(*lon_range), latitude=slice(*lat_range))

    return x_trimmed

Load in ERA5 data

In [None]:
## specify file path to ERA5 reanalysis product.
## We'll look at surface temperature
era5_path = f"{cmip6_fp}/era5/reanalysis/single-levels/monthly-means/2m_temperature"

## List the first few files in the folder:
file_list = glob.glob(f"{era5_path}/*.nc")

## Load in the data
T2m_era = xr.open_mfdataset(file_list, preprocess=trim_to_north_atl)["t2m"]
T2m_era = T2m_era.load();

Load in CESM2 data

In [None]:
cesm_path = f"{cmip6_fp}/cmip6/CMIP/NCAR/CESM2/historical/r1i1p1f1/Amon/tas/gn/1"
cesm_fname = "tas_Amon_CESM2_historical_r1i1p1f1_gn_185001-201412.nc"
T2m_cesm = xr.open_dataset(f"{cesm_path}/{cesm_fname}")["tas"]

## rename lon/lat to match ERA5
T2m_cesm = T2m_cesm.rename({"lon": "longitude", "lat": "latitude"})

## reverse latitude coordinate, so that it matches ERA5
latitude_reversed = T2m_cesm.latitude.values[::-1]
T2m_cesm = T2m_cesm.reindex({"latitude": latitude_reversed})

## trim to N. Atlantic
T2m_cesm = trim_to_north_atl(T2m_cesm)

## load to memory
T2m_cesm.load();

Trim both datasets to overlapping period (1979-2014)

In [None]:
T2m_era = T2m_era.sel(time=slice(None, "2014"))
T2m_cesm = T2m_cesm.sel(time=slice("1979", None))

## for convenience, set T2m_cesm's time to match T2m_era
T2m_cesm["time"] = T2m_era.time

#### Plot climatology

In [None]:
## First, a generic plot setup function
def plot_setup(ax, lon_range, lat_range, xticks, yticks, scale):
    """
    Create map background for plotting spatial data.
    Arguments:
        - ax: Matplotlib object containing everything in the plot.
            (I think of it as the plot "canvas")
        - lon_range/lat_range: 2-element arrays, representing plot boundaries
        - xticks/yticks: location for lon/lat labels
        - scale: number which controls linewidth and fontsize

    Returns a modified 'ax' object.
    """

    ## specify transparency/linewidths
    grid_alpha = 0.1 * scale
    grid_linewidth = 0.5 * scale
    coastline_linewidth = 0.3 * scale
    label_size = 8 * scale

    ## crop map and plot coastlines
    ax.set_extent([*lon_range, *lat_range], crs=ccrs.PlateCarree())
    ax.coastlines(linewidth=coastline_linewidth)

    ## plot grid
    gl = ax.gridlines(
        draw_labels=True,
        linestyle="--",
        alpha=grid_alpha,
        linewidth=grid_linewidth,
        color="k",
        zorder=1.05,
    )

    ## add tick labels
    gl.bottom_labels = False
    gl.right_labels = False
    gl.xlabel_style = {"size": label_size}
    gl.ylabel_style = {"size": label_size}
    gl.ylocator = mticker.FixedLocator(yticks)
    gl.xlocator = mticker.FixedLocator(xticks)

    return ax, gl


## Next, a function to plot the North Atlantic
def plot_setup_north_atl(ax, scale=1):
    """Create map background for plotting spatial data.
    Returns modified 'ax' object."""

    ## specify range and ticklabels for plot
    lon_range = [-102.5, -42.5]
    lat_range = [60, 20]
    xticks = [-90, -75, -60]
    yticks = [25, 40, 55]

    ax, gl = plot_setup(ax, lon_range, lat_range, xticks, yticks, scale)

    return ax, gl

In [None]:
## specify colorbar levels
levels = np.arange(266, 306, 4)

## make plot
fig = plt.figure(figsize=(8, 4))

for i, (data, label) in enumerate(zip([T2m_era, T2m_cesm], ["ERA5", "CESM2"]), start=1):

    ## ERA
    ax = fig.add_subplot(1, 2, i, projection=ccrs.PlateCarree())
    ax, gl = plot_setup_north_atl(ax)
    ax.set_title(f"{label}")

    ## plot data
    plot = ax.contourf(
        data.longitude,
        data.latitude,
        data.mean("time"),
        cmap="cmo.thermal",
        levels=levels,
        extend="both",
    )

plt.show()

#### Look at bias: difference between means
First, regrid

In [None]:
## Regrid ERA5 and look at difference
T2m_era_regrid = T2m_era.interp(
    {"latitude": T2m_cesm.latitude, "longitude": T2m_cesm.longitude}
)

## get difference between means
bias = T2m_cesm.mean("time") - T2m_era_regrid.mean("time")

## get difference between std. devs.
bias_std = T2m_cesm.std("time") - T2m_era_regrid.std("time")

Plot mean bias

In [None]:
## make plot
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(projection=ccrs.PlateCarree())
ax, gl = plot_setup_north_atl(ax)
ax.set_title(r"$\mu$ bias (ERA5 – CESM2)")

## plot data
plot = ax.contourf(
    bias.longitude,
    bias.latitude,
    bias,
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(5, 0.5),
    extend="both",
)

cb = fig.colorbar(plot, ticks=[-5, 0, 5], label=r"$T_{2m}$")
plt.show()

Plot standard deviation bias

In [None]:
## make plot
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(projection=ccrs.PlateCarree())
ax, gl = plot_setup_north_atl(ax)
ax.set_title(r"$\sigma$ bias (ERA5 – CESM2)")

## plot data
plot = ax.contourf(
    bias_std.longitude,
    bias_std.latitude,
    bias_std,
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(5, 0.5),
    extend="both",
)

cb = fig.colorbar(plot, ticks=[-5, 0, 5], label=r"$T_{2m}$")
plt.show()

## Get temperature near Woods Hole

In [None]:
T2m_era_wh = T2m_era.interp(latitude=41.5, longitude=288.5)
T2m_cesm_wh = T2m_cesm.interp(latitude=41.5, longitude=288.5)

Look at seasonal cycle and standard dev.

In [None]:
def plot_clim(ax, x, label):
    """plot seasonal cycle and standard deviation for variable"""

    ## compute stats
    mean = x.groupby("time.month").mean()
    std = x.groupby("time.month").std()

    ## plot them
    mean_plot = ax.plot(mean.month, mean, label=label)

    ## specify style of bounds:
    ax.fill_between(
        mean.month, mean + std, mean - std, color=mean_plot[0].get_color(), alpha=0.2
    )

    return ax


fig, ax = plt.subplots(figsize=(4, 3))
ax = plot_clim(ax, T2m_era_wh, label="ERA5")
ax = plot_clim(ax, T2m_cesm_wh, label="CESM2")
ax.legend()
ax.set_xticks([3, 6, 9, 12], labels=["Mar", "Jun", "Sep", "Dec"])
plt.show()

#### Detrend, and look at histograms

In [None]:
def get_trend(data, dim="time"):
    """Get linear trend for an xr.dataarray along specified dimension"""

    ## Get coefficients for best fit
    polyfit_coefs = data.polyfit(dim=dim, deg=1)["polyfit_coefficients"]

    ## Get best fit line (linear trend in this case)
    trend = xr.polyval(data[dim], polyfit_coefs)

    return trend


def detrend(data, dim="time"):
    """remove linear trend along specified dimension"""

    return data - get_trend(data, dim="time")


def detrend_by_month(data):
    """function detrends data for each month separately"""
    return data.groupby("time.month").map(detrend)


## detrend each time series by month
T2m_era_wh_detrend = detrend_by_month(T2m_era_wh)
T2m_cesm_wh_detrend = detrend_by_month(T2m_cesm_wh)

Make histogram

In [None]:
def plot_histogram_comparison(ax, samples0, samples1, label0=None, label1=None):
    """
    Compute two histograms, one each for samples0 and samples1.
    Plot the results on the specified ax object, and label histograms
    'label0' and 'label1', respectively.
    """

    ## First, make the histograms.
    # specify histogram bins
    bin_width = 0.5
    bin_edges = np.arange(-5.25, 5.25 + bin_width, bin_width)

    # compute histograms
    hist0, _ = np.histogram(samples0, bins=bin_edges)
    hist1, _ = np.histogram(samples1, bins=bin_edges)

    ## plot histograms
    ax.stairs(values=hist0, edges=bin_edges, color="k", label=label0)
    ax.stairs(
        values=hist1,
        edges=bin_edges,
        color="k",
        label=label1,
        fill=True,
        alpha=0.3,
    )

    ## label plot
    ax.set_ylabel("Count")
    ax.set_xlabel(r"$T_{2m}$ anomaly ($^{\circ}C$)")

    return ax

In [None]:
fig = plt.figure(figsize=(4, 3))
ax = fig.add_subplot()
ax = plot_histogram_comparison(
    ax, T2m_era_wh_detrend, T2m_cesm_wh_detrend, label0="ERA5", label1="CESM2"
)
ax.axvline(0, ls=":", c="w")
ax.legend()
plt.show()

#### Get JJA average

In [None]:
def get_jja_avg(data):
    """Get june-july-august average"""

    ## get month for each timestep in data
    month = data.time.dt.month

    ## check if each timestep is in JJA
    isin_jja = (month >= 6) & (month <= 8)

    ## subset for JJA
    data_jja = data.sel(time=isin_jja)

    ## average over jja
    data_jja_avg = data_jja.groupby("time.year").mean()

    return data_jja_avg


T2m_era_wh_jja = get_jja_avg(T2m_era_wh)
T2m_cesm_wh_jja = get_jja_avg(T2m_cesm_wh)