# Intermodel comparison example

## Packages

In [None]:
import xarray as xr
import cftime
import os.path
import pandas as pd
import numpy as np
import time
import matplotlib.pyplot as plt
import seaborn as sns
import pathlib
import tqdm
import warnings
import xesmf as xe

## (optional) remove gridlines from plots
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## Functions

In [None]:
def trim(data, lon_range, lat_range):
    """select part of data in given longitude/latitude range"""

    ## helper function to check if 'x' is in 'x_range'
    isin_range = lambda x, x_range: (x_range[0] <= x) & (x <= x_range[1])

    ## get mask for data in given lon/lat range
    in_lon_range = isin_range(data["lon"], lon_range)
    in_lat_range = isin_range(data["lat"], lat_range)
    in_lonlat_range = in_lon_range & in_lat_range

    ## load to memory
    in_lonlat_range.load()

    ## Retain all points with at least one valid grid cell
    x_idx = in_lonlat_range.any("nlat")
    y_idx = in_lonlat_range.any("nlon")

    ## select given points
    return data.isel({"nlon" : x_idx, "nlat" : y_idx})

def plot_setup(fig, projection, lon_range, lat_range, xticks=None, yticks=None):
    """Add a subplot to the figure with the given map projection
    and lon/lat range. Returns an Axes object."""

    ## increase resolution for projection
    ## (otherwise lines plotted on surface won't follow curved trajectories)
    projection.threshold /= 1000

    ## Create subplot with given projection
    ax = fig.add_subplot(projection=projection)

    ## Subset to given region
    extent = [*lon_range, *lat_range]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    ## draw coastlines
    ax.coastlines(linewidths=0.5)

    ## add tick labels
    if xticks is not None:

        ## add lon/lat labels
        gl = ax.gridlines(
            draw_labels=True,
            linestyle="-",
            alpha=0.1,
            linewidth=0.5,
            color="k",
            zorder=1.05,
        )

        ## specify which axes to label
        gl.top_labels = False
        gl.right_labels = False

        ## specify ticks
        gl.ylocator = mticker.FixedLocator(yticks)
        gl.xlocator = mticker.FixedLocator(xticks)

    return ax


def plot_box_outline(ax, lon_range, lat_range, c="k"):
    """
    Plot box outlining the specifed lon/lat range on given
    ax object.
    """

    ## get width and height
    height = lat_range[1] - lat_range[0]
    width = lon_range[1] - lon_range[0]

    ## add rectangle to plot
    ax.add_patch(
        mpatches.Rectangle(
            xy=[lon_range[0], lat_range[0]],
            height=height,
            width=width,
            transform=ccrs.PlateCarree(),
            facecolor="none",
            edgecolor=c,
            linewidth=1,
        )
    )

    return ax


def plot_setup_woodshole(fig):
    """Plot zoomed-in view of Woods Hole"""

    ## adjust figure size
    fig.set_size_inches(5, 3)

    ## set map projection to orthographic
    proj = ccrs.Orthographic(central_longitude=-67.5, central_latitude=40)

    ## Get ax object based on generic plotting function
    ax = plot_setup(
        fig,
        proj,
        lon_range=[-80, -60],
        lat_range=[35, 45],
        xticks=[-80, -70, -60],
        yticks=[35, 40, 45],
    )

    return fig, ax

## Data loading

### Filepaths

````{admonition} To-do
Update the ```SERVER_FP``` and ```SAVE_FP``` in the code cell below (intermediate results will be saved to ```SAVE_FP```).
````

In [None]:
## Path to file server
SERVER_FP = pathlib.Path("/Volumes")
SAVE_FP = pathlib.Path("./data")

## list of (model center, model) tuples
model_names = [
    ("NCAR", "CESM2"),
    ("CSIRO", "ACCESS-ESM1-5"),
    ("MIROC", "MIROC6"),
    ("MPI-M", "MPI-ESM1-2-LR"),
    ("IPSL", "IPSL-CM6A-LR"),
    ("NASA-GISS", "GISS-E2-2-G")
]

## specify load kwargs
time_decoder = xr.coding.times.CFDatetimeCoder(use_cftime=True)
load_kwargs = dict(
    decode_times=time_decoder, 
    mask_and_scale=True,
    chunks=dict(time=3000),
    coords="minimal",
    data_vars="minimal",
    concat_dim="time",
    combine="nested",
    compat="override",
)

#### Get path to data for each model

In [None]:
def get_filepath(model_center, model):
    """get filepath to given model output"""

    ## Path to cmip6 server
    cmip_fp = SERVER_FP / pathlib.Path("cmip6/data/cmip6/CMIP")

    ## suffix (shared by all models)
    suffix = pathlib.Path("1pctCO2/r1i1p1f1/Omon/tos/gn/1")
    
    return pathlib.Path(cmip_fp, model_center, model, suffix)

def get_grid_mask(data, x_coord="lon", y_coord="lat"):
    """get mask for NaN values (useful for xesmf regridder)"""

    ## drop time as a coordinate
    if "time" in data.dims:
        data0 = data.isel(time=0).drop_vars("time")
    else:
        data0 = data

    return xr.where(~np.isnan(data0["tos"]), 1, 0).compute()
    

def get_grid(lon_range=[280, 300], lat_range=[35, 45], **load_kwargs):
    """Get grid based on lon/lat grid for given CESM2"""

    ## get filepath
    fp = get_filepath(model_center="NCAR", model="CESM2")

    ## load first timestep for cesm2 (ignore serialization warning)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=xr.SerializationWarning)
        data = xr.open_mfdataset(fp.glob("*nc"), **load_kwargs).isel(time=0)

    ## drop time as a coord
    data = data.drop_vars("time")

    ## trim in lon/lat space
    data = trim(data, lon_range, lat_range)

    ## get grid
    grid = data[["lon","lat"]]

    ## mask NaN values
    grid["mask"] = get_grid_mask(data)

    return grid.compute()

def load_model_fromfile(grid, model_center, model):
    """Load data for given model"""

    ## get filepath
    fp = get_filepath(model_center=model_center, model=model)

    ## load data (ignore serialization warning)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=xr.SerializationWarning)
        data = xr.open_mfdataset(fp.glob("*nc"), **load_kwargs).compute()

    ## get mask of NaN values
    data["mask"] = get_grid_mask(data)

    ## regrid
    regridder = xe.Regridder(data, grid, "bilinear")
    data = regridder(data)
    
    return data["tos"].rename(f"{model_center}_{model}")

def load_model(grid, model_center, model):
    """load data for given model, re-gridded to specified grid"""

    ## get save filepath
    fp = pathlib.Path(SAVE_FP, f"{model_center}_{model}.nc")

    if fp.is_file():
        data = xr.open_dataarray(fp)

    else:

        ## load data from file
        data = load_model_fromfile(grid, model_center, model)

        ## save to file
        data.to_netcdf(fp)

    return data

In [None]:
## get target grid
grid = get_grid()

## empty list to hold models
data = []

## loop thru models
for model_name in tqdm.tqdm(model_names):
    data.append(load_model(grid, *model_name))

### Function to compute climate index

````{admonition} To-do
If you'd like to look at a different index, modify the function below.
````

In [None]:
def compute_T_wh(x):
    """Compute Woods Hole temperature index"""

    ## compute spatial average
    return trim_and_avg(x, lon_range=[287.5, 293.5], lat_range=[39, 44])

### Loop through models

In [None]:
## empty list to hold result
T2m_idx = []

## loop through each (model, filepath) pair
for model, fp in tqdm.tqdm(fp_dict.items()):

    ## get file pattern
    file_pattern = os.path.join(fp, "*.nc")

    ## open dataset and append to list
    T2m_idx.append(
        xr.open_mfdataset(
            os.path.join(fp, "*.nc"),
            decode_times=False,
            mask_and_scale=False,
            preprocess=compute_T_wh,
            chunks=dict(time=1800),
            coords="minimal",
            data_vars="minimal",
            concat_dim="time",
        )["tas"]
    )

In [None]:
T2m_idx_ = xr.concat(T2m_idx, dim=pd.Index(models, name="model"), coords="minimal")

In [None]:
help(xr.open_mfdataset)

In [None]:
T2m_idx[4]

In [None]:
T2m_idx_

### Combine in single data array

In [None]:
def reset_year(T2m):
    """Function to reset year to start at 1850"""

    ## get new time index
    updated_year = np.arange(1850, 1850 + len(T2m.year))

    ## add to array
    T2m["year"] = updated_year

    return T2m


## concatenate in dataset
T2m_idx = [reset_year(x) for x in T2m_idx]
T2m_idx = xr.concat(T2m_idx, dim=pd.Index(models, name="model"), coords="minimal")

## Drop unnecessary coordinates
T2m_idx = T2m_idx.drop_vars(["height", "lon", "lat"])

## Load into memory
start = time.time()
T2m_idx.load()
end = time.time()
print(end - start)

### Normalized version of data

In [None]:
## normalize by removing mean of first 30 years
T2m_idx_baseline = T2m_idx.isel(year=slice(None, 30)).mean("year")
T2m_idx_norm = T2m_idx - T2m_idx_baseline

## Plot results

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 3))

## first, plot the raw data; then, plot the normalized data
for i, data in enumerate([T2m_idx, T2m_idx_norm]):

    ## loop through each model
    for model in data.model:

        axs[i].plot(data.year, data.sel(model=model), label=model.item(), lw=1)

## plot ensemble mean
axs[1].plot(data.year, data.mean("model"), label="mean", lw=2, c="k")
axs[1].set_xlim([None, 1999])

## add legend
axs[1].legend(prop={"size": 8})

plt.show()