# Intermodel comparison example
{download}`Download notebook<./example.ipynb>`

## Packages

In [None]:
import xarray as xr
import cftime
import time
import os.path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cmocean
import seaborn as sns
import pathlib
import tqdm
import warnings
import cartopy.crs as ccrs
import matplotlib.ticker as mticker
import matplotlib.patches as mpatches
import copy
import xesmf as xe

## only required for loading from cloud
import gcsfs
import zarr

## (optional) remove gridlines from plots
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## Functions

### Plotting / misc.

In [None]:
def plot_setup(
    fig, projection, lon_range, lat_range, posn=(1, 1, 1), xticks=None, yticks=None
):
    """Add a subplot to the figure with the given map projection
    and lon/lat range. Returns an Axes object."""

    ## increase resolution for projection
    ## (otherwise lines plotted on surface won't follow curved trajectories)
    projection.threshold /= 1000

    ## Create subplot with given projection
    ax = fig.add_subplot(*posn, projection=projection)

    ## Subset to given region
    extent = [*lon_range, *lat_range]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    ## draw coastlines
    ax.coastlines(linewidths=0.5)

    ## add tick labels
    if xticks is not None:

        ## add lon/lat labels
        gl = ax.gridlines(
            draw_labels=True,
            linestyle="-",
            alpha=0.1,
            linewidth=0.5,
            color="k",
            zorder=1.05,
        )

        ## specify which axes to label
        gl.top_labels = False
        gl.right_labels = False

        ## specify ticks
        gl.ylocator = mticker.FixedLocator(yticks)
        gl.xlocator = mticker.FixedLocator(xticks)

    return ax


def plot_box_outline(ax, lon_range, lat_range, c="k"):
    """
    Plot box outlining the specifed lon/lat range on given
    ax object.
    """

    ## get width and height
    height = lat_range[1] - lat_range[0]
    width = lon_range[1] - lon_range[0]

    ## add rectangle to plot
    ax.add_patch(
        mpatches.Rectangle(
            xy=[lon_range[0], lat_range[0]],
            height=height,
            width=width,
            transform=ccrs.PlateCarree(),
            facecolor="none",
            edgecolor=c,
            linewidth=1,
        )
    )

    return ax


def plot_setup_atlantic(fig, posn=(1, 1, 1)):
    """Plot Atlantic region"""

    # ## adjust figure size
    # fig.set_size_inches(7.5, 3.75)

    ## specify map projection
    proj = ccrs.Orthographic(central_longitude=-50, central_latitude=40)

    ## get ax object
    ax = plot_setup(
        fig,
        proj,
        lon_range=[-90, -10],
        lat_range=[20, 60],
        xticks=[-80, -50, -20],
        yticks=[25, 45],
        posn=posn,
    )

    return ax


def make_cb_range(amp, delta):
    """Make colorbar_range for cmo.balance
    Args:
        - 'amp': amplitude of maximum value for colorbar
        - 'delta': increment for colorbar
    """
    return np.concatenate(
        [np.arange(-amp, 0, delta), np.arange(delta, amp + delta, delta)]
    )

### Server loading functions

In [None]:
def load_grid(lon_range, lat_range):
    """Create mask from OISST data on cloud"""

    ## load sst data
    sst = xr.open_dataset(
        r"http://psl.noaa.gov/thredds/dodsC/Datasets/noaa.oisst.v2/new/sst.oisst.mon.ltm.1991-2020.nc",
        decode_times=False,
    )
    sst = sst["sst"].isel(time=0).drop_vars("time")

    ## convert to lsm (fill ones over ocean)
    lsm = sst.where(np.isnan(sst), other=1.0)

    ## sel lon/lat range
    lsm = lsm.sel(lon=slice(*lon_range), lat=slice(*lat_range))

    # ## add binary mask for regridding
    lsm["mask"] = ~np.isnan(lsm)

    return lsm


def get_filepath(cmip_fp, model_center, model):
    """get filepath to given model output"""

    ## suffix (shared by all models)
    suffix = pathlib.Path("1pctCO2/r1i1p1f1/Omon/tos/gn/1")

    return pathlib.Path(cmip_fp, model_center, model, suffix)


def load_model_fromserver(cmip_fp, grid, model_center, model, **load_kwargs):
    """Load data for given model"""

    ## get filepath
    fp = get_filepath(cmip_fp, model_center=model_center, model=model)

    ## specify args for loading data
    load_kwargs = dict(
        decode_times=False,
        mask_and_scale=True,
        chunks=dict(time=3000),
        coords="minimal",
        data_vars="minimal",
        concat_dim="time",
        combine="nested",
        compat="override",
    )

    ## load data (ignore serialization warning)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=xr.SerializationWarning)
        data = xr.open_mfdataset(fp.glob("*nc"), **load_kwargs).compute()

    ## regrid
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        regridder = xe.Regridder(data, grid, "bilinear", ignore_degenerate=True)
        data = regridder(data)

    ## reset time axis
    time = pd.date_range(start="1850-01", freq="MS", periods=len(data.time))
    data["time"] = time

    ## get annual average
    data = data.groupby("time.year").mean()

    return data["tos"].rename(f"{model}")


def load_model(cmip_fp, save_fp, grid, model_center, model):
    """load data for given model, re-gridded to specified grid"""

    ## get save filepath
    fp = pathlib.Path(save_fp, f"{model_center}_{model}.nc")

    if fp.is_file():

        data = xr.open_dataarray(fp)

    else:

        ## load data from file
        data = load_model_fromserver(
            cmip_fp=cmip_fp,
            grid=grid,
            model_center=model_center,
            model=model,
        )

        ## save to file
        data.to_netcdf(fp)

    return data


def load_models(cmip_fp, save_fp, grid, model_names):
    """Load output from all models"""

    ## empty list to hold models
    data = []

    ## shared arguments
    kwargs = dict(cmip_fp=cmip_fp, save_fp=save_fp, grid=grid)

    ## loop thru models
    for model_name in tqdm.tqdm(model_names):

        ## unwrap model center/name
        model_center, model = model_name

        ## load data for the  model
        data.append(load_model(model_center=model_center, model=model, **kwargs))

    ## merge into single dataset
    data = xr.merge(data).isel(year=slice(None, 150))

    ## convert from dataset to dataarray
    data = data.to_dataarray(dim="model")

    return data

### Cloud loading functions

In [None]:
def load_models_from_cloud(save_fp, grid):
    """load models from google API"""

    ## Get path to models. For example of how to get this list (and
    ## search for other models), see example at this link:
    # https://nbviewer.org/github/pangeo-data/pangeo-cmip6-examples/blob/master/basic_search_and_load.ipynb
    model_dict = {
        "IPSL-CM6A-LR": "gs://cmip6/CMIP6/CMIP/IPSL/IPSL-CM6A-LR/1pctCO2/r1i1p1f1/Omon/tos/gn/v20180727/",
        "MIROC6": "gs://cmip6/CMIP6/CMIP/MIROC/MIROC6/1pctCO2/r1i1p1f1/Omon/tos/gn/v20181212/",
        "CESM2": "gs://cmip6/CMIP6/CMIP/NCAR/CESM2/1pctCO2/r1i1p1f1/Omon/tos/gn/v20190425/",
        "CESM2-WACCM": "gs://cmip6/CMIP6/CMIP/NCAR/CESM2-WACCM/1pctCO2/r1i1p1f1/Omon/tos/gn/v20190425/",
        "MPI-ESM1-2-LR": "gs://cmip6/CMIP6/CMIP/MPI-M/MPI-ESM1-2-LR/1pctCO2/r1i1p1f1/Omon/tos/gn/v20190710/",
        "ACCESS-ESM1-5": "gs://cmip6/CMIP6/CMIP/CSIRO/ACCESS-ESM1-5/1pctCO2/r1i1p1f1/Omon/tos/gn/v20191115/",
        "GISS-E2-2-G": "gs://cmip6/CMIP6/CMIP/NASA-GISS/GISS-E2-2-G/1pctCO2/r1i1p1f1/Omon/tos/gn/v20191120/",
    }

    ## empty list to hold models
    data = []

    ## loop through dicionary items
    for model_name, load_fp in tqdm.tqdm(model_dict.items()):

        ## save filename for model
        save_filename = pathlib.Path(save_fp, f"{model_name}.nc")

        ## check if file exists
        if save_filename.is_file():

            data.append(xr.open_dataarray(save_filename))

        else:

            ## load data for model
            kwargs = dict(grid=grid, model_name=model_name, load_fp=load_fp)
            data_ = load_model_from_cloud(**kwargs)

            ## save it to file
            data_.to_netcdf(save_filename)

            ## append it to list
            data.append(load_model_from_cloud(**kwargs))

    ## convert list into dataarray
    data = xr.concat(data, dim=pd.Index(model_dict.keys(), name="model"))

    ## trim to period 1850-1999
    data = data.sel(year=slice(None, 1999))

    return data


def load_model_from_cloud(grid, model_name, load_fp):
    """Load single model from the cloud"""

    ## this only needs to be created once
    gcs = gcsfs.GCSFileSystem(token="anon")

    ## create a mutable-mapping-style interface to the store
    mapper = gcs.get_mapper(load_fp)

    ## open it using xarray and zarr
    data = xr.open_zarr(
        mapper,
        consolidated=True,
        chunks=dict(time=2400),
        decode_times=False,
    )["tos"]

    ## regrid
    regridder = xe.Regridder(data, grid, "bilinear", ignore_degenerate=True)
    data = regridder(data)

    ## reset time axis
    time = pd.date_range(start="1850-01", freq="MS", periods=len(data.time))
    data["time"] = time

    ## get annual average
    data = data.groupby("time.year").mean()

    return data.compute()

## Data loading

````{admonition} To-do: set pre-processing specs
Update the constants below (see code cell for description).
````

In [None]:
## Should we load from cloud?
LOAD_FROM_CLOUD = False

## Lon/lat range for trimming data
LON_RANGE = [260, 360]
LAT_RANGE = [10, 70]

## where to save the pre-processed data
SAVE_FP = pathlib.Path("./data/server")

## File paths (uncomment second line for windows)
CMIP_FP = pathlib.Path("/Volumes/cmip6/data/cmip6/CMIP")
# CMIP_FP = pathlib.Path("Z:/data/cmip6/CMIP")

````{warning} 
**Preprocessing the data in the code cell below is slow** ($\sim 5-10$ minutes on a laptop). To skip this step, download [the preprocessed data from Google Drive](https://drive.google.com/drive/folders/1ghkCtBgynHOy18LZdvIsiVHcWiPKkuYq?usp=sharing). After downloading, save the files in ```SAVE_FP``` (specify ```SAVE_FP``` below).

In [None]:
## load target grid (we'll regrid to this)
grid = load_grid(lon_range=LON_RANGE, lat_range=LAT_RANGE)

## list of (model center, model) tuples
model_names = [
    ("NCAR", "CESM2"),
    ("NCAR", "CESM2-WACCM"),
    ("CSIRO", "ACCESS-ESM1-5"),
    ("MIROC", "MIROC6"),
    ("MPI-M", "MPI-ESM1-2-LR"),
    ("IPSL", "IPSL-CM6A-LR"),
    ("NASA-GISS", "GISS-E2-2-G"),
]

## specify shared keyword arguments (used for both cloud/server loading)
shared_kwargs = dict(grid=grid, save_fp=SAVE_FP)

## load the data
if LOAD_FROM_CLOUD:
    data = load_models_from_cloud(**shared_kwargs)

else:
    data = load_models(cmip_fp=CMIP_FP, model_names=model_names, **shared_kwargs)

## Analysis

Compute climatology and change over time

In [None]:
## Get change (last 30 yrs minus first 30)
first30 = data.isel(year=slice(None, 30)).mean("year")
last30 = data.isel(year=slice(-30, None)).mean("year")
change = (last30 - first30).mean("model")

### Compare climatologies for two of the models

Compute climatologies and differencee

In [None]:
## specify models to compare
model1 = "IPSL-CM6A-LR"
model2 = "GISS-E2-2-G"

## compute climatologies and difference
clim1 = first30.sel(model=model1).drop_vars("model")
clim2 = first30.sel(model=model2).drop_vars("model")
diff = clim1 - clim2

#### Functions to help with plotting

In [None]:
def plot_clim(ax, clim):
    """plot climatology on given ax object"""

    ## plot data
    p = ax.contourf(
        clim.lon,
        clim.lat,
        clim,
        cmap="cmo.thermal",
        levels=np.arange(-3, 30, 3),
        transform=ccrs.PlateCarree(),
        extend="both",
    )

    ## plot outline of Woods Hole region
    plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44], c="magenta")

    ## plot outline of warming hole region
    plot_box_outline(ax, lon_range=[323, 343], lat_range=[45, 55])

    return p


def plot_diff(ax, diff):
    """plot difference on given ax object"""

    diff_plot = ax.contourf(
        diff.lon,
        diff.lat,
        diff,
        cmap="cmo.balance",
        levels=make_cb_range(5, 0.5),
        extend="both",
        transform=ccrs.PlateCarree(),
    )

    ## plot outline of Woods Hole region
    plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44], c="magenta")

    ## plot outline of warming hole region
    plot_box_outline(ax, lon_range=[323, 343], lat_range=[45, 55])

    return diff_plot

#### Make the plot

In [None]:
## setup plot
fig = plt.figure(figsize=(5, 8), layout="constrained")

## generate axes objects for plotting
ax1 = plot_setup_atlantic(fig, posn=(3, 1, 1))
ax2 = plot_setup_atlantic(fig, posn=(3, 1, 2))
ax3 = plot_setup_atlantic(fig, posn=(3, 1, 3))

## plot climatologies of models
for ax, clim, model in zip([ax1, ax2], [clim1, clim2], [model1, model2]):

    ## plot data and label
    p = plot_clim(ax, clim)
    ax.set_title(model)
    cb = fig.colorbar(p, ticks=[-3, 27], label=r"$^{\circ}$C")


## plot difference and label
diff_plot = plot_diff(ax3, diff)
ax3.set_title("Difference")
cb_diff = fig.colorbar(diff_plot, ticks=[-5, 5], label=r"$^{\circ}$C")


plt.show()

### Plot ensemble-mean warming pattern

In [None]:
fig = plt.figure(figsize=(5, 2.75), layout="constrained")

ax = plot_setup_atlantic(fig)

## plot the difference
plot_data = ax.pcolormesh(
    change.lon,
    change.lat,
    change,
    cmap="cmo.amp",
    vmax=4,
    vmin=1,
    transform=ccrs.PlateCarree(),
)

## plot the background state
ax.contour(
    first30.lon,
    first30.lat,
    first30.mean("model"),
    colors="w",
    levels=np.arange(-2, 34, 4),
    transform=ccrs.PlateCarree(),
    extend="both",
    linewidths=1,
    alpha=0.7,
)

## make a colorbar
cb = fig.colorbar(plot_data, ticks=[1, 4], label=r"$^{\circ}$C")

## plot outline of Woods Hole region
ax = plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44])

## plot outline of warming hole region
ax = plot_box_outline(ax, lon_range=[323, 343], lat_range=[45, 55])

## label
ax.set_title(r"$\Delta$(SST) b/n 1850-1880 and 1970-2000")

plt.show()

### Climate index

#### Compute index

````{admonition} To-do
If you'd like to look at a different index, modify the function below.
````

In [None]:
def compute_T_wh(x):
    """Compute Woods Hole temperature index"""

    ## compute spatial average
    return x.sel(lon=slice(287.5, 293.5), lat=slice(39, 44)).mean(["lon", "lat"])

Do computation here

In [None]:
## do the computation here
idx = compute_T_wh(data)

## normalize by removing mean of first 30 years
idx_base = idx.isel(year=slice(None, 30)).mean("year")
idx_end = idx.isel(year=slice(-30, None)).mean("year")
idx_norm = idx - idx_base

## get change in index
idx_change = idx_end - idx_base

#### Change over time

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(7, 3), layout="constrained")

## first, plot the raw data; then, plot the normalized data
for i, idx_ in enumerate([idx, idx_norm]):

    ## loop through each model
    for model in idx_.model.values:

        # label = model.split("_", 1)[1]
        axs[i].plot(idx_.year, idx_.sel(model=model), label=model, lw=1)

## plot ensemble mean
axs[1].plot(idx_.year, idx_.mean("model"), label="mean", lw=2, c="k")

## label and format
axs[0].set_title("Raw index")
axs[1].set_title("Change from first 30 yrs")
axs[1].legend(prop={"size": 7})
axs[0].set_yticks([11, 16, 21])
axs[1].set_yticks([0, 3, 6])
axs[1].yaxis.tick_right()
axs[1].yaxis.set_label_position("right")
for ax in axs:
    ax.set_xlabel("Year")
    ax.set_ylabel(r"$^{\circ}$C")

plt.show()

#### Change vs. climatology

In [None]:
## specify markers and colors to use in plot
markers = ["*", "+", "s", "<", "v", "o", "P"]
colors = sns.color_palette()[: len(markers)]

## plot
fig, ax = plt.subplots(figsize=(4.5, 3), layout="constrained")

## loop thru models
for model, m, c in zip(idx_base.model.values, markers, colors):

    ## plot datapoint
    ax.scatter(
        idx_base.sel(model=model),
        idx_change.sel(model=model),
        marker=m,
        color=c,
        s=100,
        label=model,
    )

## label/format plot
ax.legend(prop=dict(size=8), loc=(1.2, 0.2))
ax.set_xticks([12, 15, 18])
ax.set_yticks([0, 3, 6])
ax.set_ylim([-0.5, None])
ax.set_xlabel(r"SST$_{1850}$ ($^{\circ}$C)")
ax.set_ylabel(r"$\Delta$ SST ($^{\circ}$C)")
ax.axhline(0, c="k", lw=1)

plt.show()