# Intermodel comparison example
{download}`Download notebook<./example.ipynb>`

## Packages

In [None]:
import xarray as xr
import cftime
import os.path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cmocean
import seaborn as sns
import pathlib
import tqdm
import warnings
import cartopy.crs as ccrs
import matplotlib.ticker as mticker
import matplotlib.patches as mpatches
import copy

## (optional) remove gridlines from plots
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## Functions

In [None]:
def trim(data, lon_range, lat_range):
    """select part of data in given longitude/latitude range"""

    ## helper function to check if 'x' is in 'x_range'
    isin_range = lambda x, x_range: (x_range[0] <= x) & (x <= x_range[1])

    ## get mask for data in given lon/lat range
    in_lon_range = isin_range(data["lon"], lon_range)
    in_lat_range = isin_range(data["lat"], lat_range)
    in_lonlat_range = in_lon_range & in_lat_range

    ## load to memory
    in_lonlat_range.load()

    ## Retain all points with at least one valid grid cell
    x_idx = in_lonlat_range.any("nlat")
    y_idx = in_lonlat_range.any("nlon")

    ## select given points
    return data.isel({"nlon": x_idx, "nlat": y_idx})


def plot_setup(
    fig, projection, lon_range, lat_range, posn=(1, 1, 1), xticks=None, yticks=None
):
    """Add a subplot to the figure with the given map projection
    and lon/lat range. Returns an Axes object."""

    ## increase resolution for projection
    ## (otherwise lines plotted on surface won't follow curved trajectories)
    projection.threshold /= 1000

    ## Create subplot with given projection
    ax = fig.add_subplot(*posn, projection=projection)

    ## Subset to given region
    extent = [*lon_range, *lat_range]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    ## draw coastlines
    ax.coastlines(linewidths=0.5)

    ## add tick labels
    if xticks is not None:

        ## add lon/lat labels
        gl = ax.gridlines(
            draw_labels=True,
            linestyle="-",
            alpha=0.1,
            linewidth=0.5,
            color="k",
            zorder=1.05,
        )

        ## specify which axes to label
        gl.top_labels = False
        gl.right_labels = False

        ## specify ticks
        gl.ylocator = mticker.FixedLocator(yticks)
        gl.xlocator = mticker.FixedLocator(xticks)

    return ax


def plot_box_outline(ax, lon_range, lat_range, c="k"):
    """
    Plot box outlining the specifed lon/lat range on given
    ax object.
    """

    ## get width and height
    height = lat_range[1] - lat_range[0]
    width = lon_range[1] - lon_range[0]

    ## add rectangle to plot
    ax.add_patch(
        mpatches.Rectangle(
            xy=[lon_range[0], lat_range[0]],
            height=height,
            width=width,
            transform=ccrs.PlateCarree(),
            facecolor="none",
            edgecolor=c,
            linewidth=1,
        )
    )

    return ax


def plot_setup_atlantic(fig, posn=(1, 1, 1)):
    """Plot Atlantic region"""

    # ## adjust figure size
    # fig.set_size_inches(7.5, 3.75)

    ## specify map projection
    proj = ccrs.Orthographic(central_longitude=-50, central_latitude=40)

    ## get ax object
    ax = plot_setup(
        fig,
        proj,
        lon_range=[-90, -10],
        lat_range=[20, 60],
        xticks=[-80, -50, -20],
        yticks=[25, 45],
        posn=posn,
    )

    return ax


def plot_setup_woodshole(fig):
    """Plot zoomed-in view of Woods Hole"""

    ## adjust figure size
    fig.set_size_inches(5, 3)

    ## set map projection to orthographic
    proj = ccrs.Orthographic(central_longitude=-67.5, central_latitude=40)

    ## Get ax object based on generic plotting function
    ax = plot_setup(
        fig,
        proj,
        lon_range=[-80, -60],
        lat_range=[35, 45],
        xticks=[-80, -70, -60],
        yticks=[35, 40, 45],
    )

    return fig, ax


def make_cb_range(amp, delta):
    """Make colorbar_range for cmo.balance
    Args:
        - 'amp': amplitude of maximum value for colorbar
        - 'delta': increment for colorbar
    """
    return np.concatenate(
        [np.arange(-amp, 0, delta), np.arange(delta, amp + delta, delta)]
    )


def get_filepath(server_fp, model_center, model):
    """get filepath to given model output"""

    ## Path to cmip6 server
    cmip_fp = server_fp / pathlib.Path("cmip6/data/cmip6/CMIP")

    ## suffix (shared by all models)
    suffix = pathlib.Path("1pctCO2/r1i1p1f1/Omon/tos/gn/1")

    return pathlib.Path(cmip_fp, model_center, model, suffix)


def get_grid_mask(data, x_coord="lon", y_coord="lat"):
    """get mask for NaN values (useful for xesmf regridder)"""

    ## drop time as a coordinate
    if "time" in data.dims:
        data0 = data.isel(time=0).drop_vars("time")
    else:
        data0 = data

    return xr.where(~np.isnan(data0["tos"]), 1, 0).compute()


def get_grid(server_fp, lon_range, lat_range):
    """Get grid based on lon/lat grid for given CESM2"""

    ## get filepath
    fp = get_filepath(server_fp, model_center="NCAR", model="CESM2")
    filename = list(fp.glob("*.nc"))[0]

    ## load first timestep for cesm2 (ignore serialization warning)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=xr.SerializationWarning)
        data = xr.open_dataset(filename).isel(time=0)

    ## drop time as a coord
    data = data.drop_vars("time")

    ## trim in lon/lat space
    data = trim(data, lon_range, lat_range)

    ## get grid
    grid = data[["lon", "lat"]]

    ## mask NaN values
    grid["mask"] = get_grid_mask(data)

    return grid.compute()


def load_model_fromserver(server_fp, grid, model_center, model, **load_kwargs):
    """Load data for given model"""

    ## get filepath
    fp = get_filepath(server_fp, model_center=model_center, model=model)

    ## specify args for loading data
    load_kwargs = dict(
        decode_times=False,
        mask_and_scale=True,
        chunks=dict(time=3000),
        coords="minimal",
        data_vars="minimal",
        concat_dim="time",
        combine="nested",
        compat="override",
    )

    ## load data (ignore serialization warning)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=xr.SerializationWarning)
        data = xr.open_mfdataset(fp.glob("*nc"), **load_kwargs).compute()

    ## get mask of NaN values
    data["mask"] = get_grid_mask(data)

    ## regrid
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        regridder = xe.Regridder(data, grid, "bilinear", ignore_degenerate=True)
        data = regridder(data)

    ## reset time axis
    time = pd.date_range(start="1850-01", freq="MS", periods=len(data.time))
    data["time"] = time

    ## get annual average
    data = data.groupby("time.year").mean()

    return data["tos"].rename(f"{model_center}_{model}")


def load_model(server_fp, save_fp, grid, model_center, model):
    """load data for given model, re-gridded to specified grid"""

    ## get save filepath
    fp = pathlib.Path(save_fp, f"{model_center}_{model}.nc")

    if fp.is_file():

        data = xr.open_dataarray(fp)

    else:

        ## load data from file
        data = load_model_fromserver(
            server_fp,
            grid,
            model_center,
            model,
        )

        ## save to file
        data.to_netcdf(fp)

    return data


def sort_longitude(data):
    """shuffles data so that longitude is monotonically increasing"""

    ## Transpose data so that longitude is last dimension
    ## (we'll do all the sorting along this dimension)
    data = data.transpose(..., "nlon")

    ## Get indices needed to sort longitude to be monotonic increasing
    lon_sort_idx = np.argsort(data["lon"].values, axis=-1)

    ## sort the lon/lat coordindates
    sort = lambda X, idx: np.take_along_axis(X.values, indices=idx, axis=-1)
    data["lon"].values = sort(data["lon"], idx=lon_sort_idx)
    data["lat"].values = sort(data["lat"], idx=lon_sort_idx)

    #### sort the data

    # first, check to see if data has more than two dimensions
    if data.ndim > 2:
        extra_dims = [i for i in range(data.ndim - 2)]
        lon_sort_idx = np.expand_dims(lon_sort_idx, axis=extra_dims)

    ## now, do the actual sorting
    data.values = sort(data, idx=lon_sort_idx)

    return data


def swap_longitude_range(data):
    """swap longitude range of xr.DataArray from [0,360) to (-180, 180].
    Handles case with 2-dimension longitude coordinates ('lon')"""

    ## make copy of longitude coordinate to be modified
    lon_new = copy.deepcopy(data.lon.values)

    ## relabel values greater than 180
    exceeds_180 = lon_new > 180
    lon_new[exceeds_180] = -360 + lon_new[exceeds_180]

    ## Update the coordinate on the xarray object
    data["lon"].values = lon_new

    return data


def load_models(server_fp, save_fp, grid, model_names):
    """Load output from all models"""

    ## empty list to hold models
    data = []

    ## shared arguments
    kwargs = dict(server_fp=server_fp, save_fp=save_fp, grid=grid)

    ## loop thru models
    for model_name in tqdm.tqdm(model_names):

        ## unwrap model center/name
        model_center, model = model_name

        ## load data for the  model
        data.append(load_model(model_center=model_center, model=model, **kwargs))

    ## merge into single dataset
    data = xr.merge(data).isel(year=slice(None, 150))

    ## convert from dataset to dataarray
    data = data.to_dataarray(dim="model")

    return data

## Data loading

````{admonition} To-do
Update the ```SERVER_FP``` and ```SAVE_FP``` in the code cell below (intermediate results will be saved to ```SAVE_FP```).
````

In [None]:
## Path to file server
server_fp = pathlib.Path("/Volumes")
save_fp = pathlib.Path("./data")

### Do the dataloading

````{warning} 
1. **Loading the data from the CMIP server is slow** ($\sim 15$ minutes on a laptop). To avoid this step, download [the trimmed data from Google Drive](https://drive.google.com/drive/folders/1ghkCtBgynHOy18LZdvIsiVHcWiPKkuYq?usp=sharing) (then unzip the file and save it to the same directory as this notebook). If you choose to load the data from the server, the loading function (```load_model```) will save the pre-processed data to a file (and load it, if it already exists), so the cell below will run much faster the second time around.

2. **Loading the data from the CMIP server requires the ```xesmf``` package**. If not already installed, you can install it with:  
```mamba install -c conda-forge xesmf```. You don't need this package if using the pre-computed data from Google Drive.
````

In [None]:
## comment out this import if you are using data from Google Drive
## (and don't have xesmf installed)
import xesmf as xe

## list of (model center, model) tuples
model_names = [
    ("NCAR", "CESM2"),
    ("NCAR", "CESM2-WACCM"),
    ("CSIRO", "ACCESS-ESM1-5"),
    ("MIROC", "MIROC6"),
    ("MPI-M", "MPI-ESM1-2-LR"),
    ("IPSL", "IPSL-CM6A-LR"),
    ("NASA-GISS", "GISS-E2-2-G"),
]

## get target grid
grid = get_grid(server_fp, lon_range=[260, 360], lat_range=[10, 70])

## load data
kwargs = dict(server_fp=server_fp, save_fp=save_fp, grid=grid, model_names=model_names)
data = load_models(**kwargs)

## swap longitude range from [0,360) to (-180, 180]
data = swap_longitude_range(data)

## make sure longitude is in ascending order
data = sort_longitude(data)

## Analysis

Compute climatology and change over time

In [None]:
## Get change (last 30 yrs minus first 30)
first30 = data.isel(year=slice(None, 30)).mean("year")
last30 = data.isel(year=slice(-30, None)).mean("year")
change = (last30 - first30).mean("model")

### Compare climatologies for two of the models

Compute climatologies and differencee

In [None]:
## specify models to compare
model1 = "IPSL_IPSL-CM6A-LR"
model2 = "NASA-GISS_GISS-E2-2-G"

## compute climatologies and difference
clim1 = first30.sel(model=model1).drop_vars("model")
clim2 = first30.sel(model=model2).drop_vars("model")
diff = clim1 - clim2

In [None]:
def plot_clim(ax, clim):
    """plot climatology on given ax object"""

    ## plot data
    p = ax.contourf(
        clim.lon,
        clim.lat,
        clim,
        cmap="cmo.thermal",
        levels=np.arange(-3, 30, 3),
        transform=ccrs.PlateCarree(),
        extend="both",
    )

    ## plot outline of Woods Hole region
    plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44], c="magenta")

    ## plot outline of warming hole region
    plot_box_outline(ax, lon_range=[323, 343], lat_range=[45, 55])

    return p


def plot_diff(ax, diff):
    """plot difference on given ax object"""

    diff_plot = ax.contourf(
        diff.lon,
        diff.lat,
        diff,
        cmap="cmo.balance",
        levels=make_cb_range(5, 0.5),
        extend="both",
        transform=ccrs.PlateCarree(),
    )

    ## plot outline of Woods Hole region
    plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44], c="magenta")

    ## plot outline of warming hole region
    plot_box_outline(ax, lon_range=[323, 343], lat_range=[45, 55])

    return diff_plot

````{admonition} To-do
If you're not looking at Woods Hole, you may need to adapt the plotting function below (```plot_setup_atlantic```).
````

In [None]:
## setup plot
fig = plt.figure(figsize=(5, 8), layout="constrained")

## generate axes objects for plotting
ax1 = plot_setup_atlantic(fig, posn=(3, 1, 1))
ax2 = plot_setup_atlantic(fig, posn=(3, 1, 2))
ax3 = plot_setup_atlantic(fig, posn=(3, 1, 3))

## plot climatologies of models
for ax, clim, model in zip([ax1, ax2], [clim1, clim2], [model1, model2]):

    ## plot data and label
    p = plot_clim(ax, clim)
    ax.set_title(model)
    cb = fig.colorbar(p, ticks=[-3, 27], label=r"$^{\circ}$C")


## plot difference and label
diff_plot = plot_diff(ax3, diff)
ax3.set_title("Difference")
cb_diff = fig.colorbar(diff_plot, ticks=[-5, 5], label=r"$^{\circ}$C")


plt.show()

### Plot ensemble-mean warming pattern

In [None]:
fig = plt.figure(figsize=(5, 2.75), layout="constrained")

ax = plot_setup_atlantic(fig)

## plot the difference
plot_data = ax.pcolormesh(
    change.lon,
    change.lat,
    change,
    cmap="cmo.amp",
    vmax=4,
    vmin=1,
    transform=ccrs.PlateCarree(),
)

## plot the background state
ax.contour(
    first30.lon,
    first30.lat,
    first30.mean("model"),
    colors="w",
    levels=np.arange(-2, 34, 4),
    transform=ccrs.PlateCarree(),
    extend="both",
    linewidths=1,
    alpha=0.7,
)

## make a colorbar
cb = fig.colorbar(plot_data, ticks=[1, 4], label=r"$^{\circ}$C")

## plot outline of Woods Hole region
ax = plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44])

## plot outline of warming hole region
ax = plot_box_outline(ax, lon_range=[323, 343], lat_range=[45, 55])

## label
ax.set_title(r"$\Delta$(SST) b/n 1850-1880 and 1970-2000")

plt.show()

### Climate index

#### Compute index

````{admonition} To-do
If you'd like to look at a different index, modify the function below.
````

In [None]:
def compute_T_wh(x):
    """Compute Woods Hole temperature index"""

    ## compute spatial average
    return trim(x, lon_range=[-72.5, -66.5], lat_range=[39, 44]).mean(["nlon", "nlat"])

Do computation here

In [None]:
## do the computation here
idx = compute_T_wh(data)

## normalize by removing mean of first 30 years
idx_base = idx.isel(year=slice(None, 30)).mean("year")
idx_end = idx.isel(year=slice(-30, None)).mean("year")
idx_norm = idx - idx_base

## get change in index
idx_change = idx_end - idx_base

#### Change over time

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(7, 3), layout="constrained")

## first, plot the raw data; then, plot the normalized data
for i, idx_ in enumerate([idx, idx_norm]):

    ## loop through each model
    for model in idx_.model.values:

        label = model.split("_", 1)[1]
        axs[i].plot(idx_.year, idx_.sel(model=model), label=label, lw=1)

## plot ensemble mean
axs[1].plot(idx_.year, idx_.mean("model"), label="mean", lw=2, c="k")

## label and format
axs[0].set_title("Raw index")
axs[1].set_title("Change from first 30 yrs")
axs[1].legend(prop={"size": 7})
axs[0].set_yticks([11, 16, 21])
axs[1].set_yticks([0, 3, 6])
axs[1].yaxis.tick_right()
axs[1].yaxis.set_label_position("right")
for ax in axs:
    ax.set_xlabel("Year")
    ax.set_ylabel(r"$^{\circ}$C")

plt.show()

#### Change vs. climatology

In [None]:
## specify markers and colors to use in plot
markers = ["*", "+", "s", "<", "v", "o", "P"]
colors = sns.color_palette()[: len(markers)]

## plot
fig, ax = plt.subplots(figsize=(4.5, 3), layout="constrained")

## loop thru models
for model, m, c in zip(idx_base.model, markers, colors):

    ## get label
    label = model.values.item().split("_", 1)[1]

    ## plot datapoint
    ax.scatter(
        idx_base.sel(model=model),
        idx_change.sel(model=model),
        marker=m,
        color=c,
        s=100,
        label=label,
    )

## label/format plot
ax.legend(prop=dict(size=8), loc=(1.2, 0.2))
ax.set_xticks([12, 15, 18])
ax.set_yticks([2, 4, 6])
ax.set_xlabel(r"SST$_{1850}$ ($^{\circ}$C)")
ax.set_ylabel(r"$\Delta$ SST ($^{\circ}$C)")

plt.show()