# Example
In this example, we'll use CESM1-LE to diagnose the forced response to external forcing near Woods Hole.  
{download}`Download notebook<./woods-hole_example.ipynb>`

## Imports

In [None]:
import pathlib
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import datetime
import seaborn as sns
import cmocean
import matplotlib.patches as mpatches
import matplotlib.ticker as mticker
import copy
import pandas as pd
import time
import intake
import xesmf as xe

## (optional) remove gridlines from plots
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## should we save figs?
SAVE_FIGS = False

## Misc. functions

In [None]:
def plot_setup(fig, projection, lon_range, lat_range, xticks=None, yticks=None):
    """Add a subplot to the figure with the given map projection
    and lon/lat range. Returns an Axes object."""

    ## increase resolution for projection
    ## (otherwise lines plotted on surface won't follow curved trajectories)
    projection.threshold /= 1000

    ## Create subplot with given projection
    ax = fig.add_subplot(projection=projection)

    ## Subset to given region
    extent = [*lon_range, *lat_range]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    ## draw coastlines
    ax.coastlines(linewidths=0.5)

    ## add tick labels
    if xticks is not None:

        ## add lon/lat labels
        gl = ax.gridlines(
            draw_labels=True,
            linestyle="-",
            alpha=0.1,
            linewidth=0.5,
            color="k",
            zorder=1.05,
        )

        ## specify which axes to label
        gl.top_labels = False
        gl.right_labels = False

        ## specify ticks
        gl.ylocator = mticker.FixedLocator(yticks)
        gl.xlocator = mticker.FixedLocator(xticks)

    return ax


def plot_box_outline(ax, lon_range, lat_range, c="k"):
    """
    Plot box outlining the specifed lon/lat range on given
    ax object.
    """

    ## get width and height
    height = lat_range[1] - lat_range[0]
    width = lon_range[1] - lon_range[0]

    ## add rectangle to plot
    ax.add_patch(
        mpatches.Rectangle(
            xy=[lon_range[0], lat_range[0]],
            height=height,
            width=width,
            transform=ccrs.PlateCarree(),
            facecolor="none",
            edgecolor=c,
            linewidth=1,
        )
    )

    return ax


def plot_setup_woodshole(fig):
    """Plot zoomed-in view of Woods Hole"""

    ## adjust figure size
    fig.set_size_inches(5, 3)

    ## set map projection to orthographic
    proj = ccrs.Orthographic(central_longitude=-67.5, central_latitude=40)

    ## Get ax object based on generic plotting function
    ax = plot_setup(
        fig,
        proj,
        lon_range=[-80, -60],
        lat_range=[35, 45],
        xticks=[-80, -70, -60],
        yticks=[35, 40, 45],
    )

    return fig, ax


def make_cb_range(amp, delta):
    """Make colorbar_range for cmo.balance
    Args:
        - 'amp': amplitude of maximum value for colorbar
        - 'delta': increment for colorbar
    """
    return np.concatenate(
        [np.arange(-amp, 0, delta), np.arange(delta, amp + delta, delta)]
    )


def get_empirical_pdf(x, bin_edges=None):
    """
    Estimate the "empirical" probability distribution function for the data x.
    In this case the result is a normalized histogram,
    Normalized means that integrating over the histogram yields 1.
    Returns the PDF (normalized histogram) and edges of the histogram bins
    """

    ## compute histogram
    if bin_edges is None:
        hist, bin_edges = np.histogram(x)

    else:
        hist, _ = np.histogram(x, bins=bin_edges)

    ## normalize to a probability distribution (PDF)
    bin_width = bin_edges[1:] - bin_edges[:-1]
    pdf = hist / (hist * bin_width).sum()

    return pdf, bin_edges


def plot_quantiles(ax, x, label=None, month=None, time_dim="time", **kwargs):
    """plot .1, .5, and .9 quantiles on given ax object"""

    ## filter for month if specified
    if month is not None:
        x = x.sel(time=(x.time.dt.month == month))

    ## compute quantiles
    x_quantiles = x.quantile([0.1, 0.5, 0.9], dim="member_id")

    ## plot median
    ax.plot(
        x_quantiles[time_dim],
        x_quantiles.sel(quantile=0.5),
        lw=2,
        label=label,
        **kwargs,
    )

    ## plot upper/lower bounds
    for q in [0.1, 0.9]:
        ax.plot(
            x_quantiles[time_dim],
            x_quantiles.sel(quantile=q),
            lw=1,
            alpha=0.5,
            **kwargs,
        )

    return

## Load data

### Data-loading functions

In [None]:
def load_from_server_fn(
    lens_fp, simulation_type, lon_range, lat_range, n_members=2, varname="SST"
):
    """
    Load ensemble data for CESM-LE from server.
    Args:
        - lens_fp: filepath to LENS data on CMIP server
        - simulation type: either "hist" or "rcp85"
        - n_members: number of ensemble members to load
        - varname: one of "SST" or "PSL"
    """

    ## get filename pattern
    if simulation_type == "historical":
        pattern = f"*20TRC*.nc"

    elif simulation_type == "future":
        pattern = f"*RCP85*2006*.nc"

    else:
        print("not a valid simulation type")
        return

    ## get (sorted) list of files
    files = pathlib.Path(lens_fp, varname).glob(pattern)
    files = sorted(list(files))

    ## get subset of files to load
    files = files[:n_members]

    ## open the data (but don't load to memory)
    data = xr.open_mfdataset(
        files,
        concat_dim="member_id",
        decode_timedelta=True,
        combine="nested",
        chunks=dict({"time": 1872}),
        parallel=True,
    )

    ## trim in time (gets around NaN values)
    data = data.sel(time=slice("1921", "2080"))

    ## drop vertical coord if it exists
    if "z_t" in data.dims:
        data = data.drop_vars("z_t").squeeze()

    ## rename TLONG/TLAT coords
    if "TLONG" in data.coords:
        data = data.rename({"TLONG": "lon", "TLAT": "lat"})

    return data[varname]


def load_from_cloud_fn(
    simulation_type,
    lon_range,
    lat_range,
    varname="TREFHT",
    n_members=2,
):
    """Load CESM data from cloud. Args:
    - simulation type: either "historical" or "future"
    - preprocess_func: preprocessing function
    - varname: variable to load ("TREFHT" is 2m-temperature)
    - n_members: number of ensemble members to load
    """

    ## get catalog of available data
    catalog = intake.open_esm_datastore(
        "https://raw.githubusercontent.com/NCAR/cesm2-le-aws/main/intake-catalogs/aws-cesm2-le.json"
    )

    ## subset for temperature data
    ## to look at available data, use: catalog.df
    catalog_subset = catalog.search(variable=varname, frequency="monthly")

    ## kwargs for opening data
    kwargs = dict(
        aggregate=True,
        xarray_open_kwargs=dict(
            engine="zarr",
            decode_timedelta=True,
        ),
        zarr_kwargs={"consolidated": True},
        storage_options={"anon": True},
    )

    ## open data (but don't load to memory)
    dsets = catalog_subset.to_dataset_dict(**kwargs)

    ## load future or historical
    if simulation_type == "historical":
        data = dsets["atm.historical.monthly.cmip6"]

    elif simulation_type == "future":
        data = dsets["atm.ssp370.monthly.cmip6"]

    else:
        print("Not a valid simulation type")

    ## subset for ensemble members
    data = data.isel(member_id=slice(None, n_members))

    return data[varname]


def load_datasets(
    varname, lon_range, lat_range, lens_fp=None, load_from_cloud=True, n_members=2
):
    """
    Load historical and future datasets for specified variable.
    Args:
        - varname: name of variable to load
        - lon_range, lat_range: 2-elements lists specifying bounds of region to load
        - lens_fp: filepath to LENS directory on CMIP server (only required if using server)
        - load_from_cloud: boolean specifying whether to load from cloud
        - n_members: number of ensemble members to load
    """

    ## dictionaries with arguments
    kwargs = dict(
        varname=varname, n_members=n_members, lon_range=lon_range, lat_range=lat_range
    )

    ## handle loading from cloud/server
    if load_from_cloud:
        load_fn = load_from_cloud_fn

    else:
        load_fn = load_from_server_fn
        kwargs["lens_fp"] = lens_fp

    ## open the data
    data_hist = load_fn(simulation_type="historical", **kwargs)
    data_fut = load_fn(simulation_type="future", **kwargs)

    ## Load LSM (for regridding)
    lsm = load_lsm(lon_range=lon_range, lat_range=lat_range)
    regridder = xe.Regridder(data_hist, lsm, "bilinear", ignore_degenerate=False)

    ## do the regridding
    data_hist = regridder(data_hist)
    data_fut = regridder(data_fut)

    return data_hist, data_fut


def load_lsm(lon_range, lat_range):
    """Create mask from OISST data on cloud"""

    ## load sst data
    sst = xr.open_dataset(
        r"http://psl.noaa.gov/thredds/dodsC/Datasets/noaa.oisst.v2/new/sst.oisst.mon.ltm.1991-2020.nc",
        decode_times=False,
    )
    sst = sst["sst"].isel(time=0).drop_vars("time")

    ## convert to lsm (fill ones over ocean)
    lsm = sst.where(np.isnan(sst), other=1.0)

    ## sel lon/lat range
    lsm = lsm.sel(lon=slice(*lon_range), lat=slice(*lat_range))

    # ## add binary mask for regridding
    lsm["mask"] = ~np.isnan(lsm)

    return lsm

### Initialize dask cluster (optional)

````{note} Dask cluster
**If you'd like to use a Dask cluster, uncomment the lines in the cell below** (the ```n_workers``` argument specifies how many processes/CPUs to use). You can copy and paste the Dashboard url into a separate browser tab to monitor the cluster.

Using a cluster may speed up the data preprocessing via parallelization. E.g., rather than loading all ensemble members, then trimming them in lon/lat space, we assign a separate "worker" (e.g., CPU) to trim data for each ensemble member separately.

In our experience, using a cluster leads to a larger speed-up when using Poseidon than when remotely connected to the CMIP server (possibly because the network/download is the bottleneck, rather than compute resources). The cluster may also exacerbate the much-feared ```NetCDF:HDF``` error and trigger other hard-to-decipher errors.
````

In [None]:
# from dask.distributed import LocalCluster, Client
# cluster = LocalCluster(n_workers=4)
# client = Client(cluster)
# client

### Set pre-processing specs

````{admonition} To-do
Update the constants below (see code cell for description).
````

````{warning} 
**Preprocessing the data in the code cell below is slow** ($\sim 20$ minutes on a laptop). To speed up the process, here are a few options (pick one):
1. Download [pre-processed data from Google Drive](https://drive.google.com/drive/folders/1sBa-Z1-b6iKaBHo_UDCdSL5d4zz9GVJE?usp=sharing). Save the files in ```SAVE_FP``` (specify ```SAVE_FP``` below).
2. Load less ensemble members (e.g., reduce ```n_members``` from 35 to 9 in the code cell below).
3. Use Poseidon (takes less than a minute to run).

In [None]:
## specify "save" filepath: directory where to save the data
## (so we don't have to load from server/cloud every time)
SAVE_FP = pathlib.Path("./data/server")

## should we load from the cloud?
LOAD_FROM_CLOUD = False

## path to LENS on CMIP server (uncomment second line for windows)
LENS_FP = pathlib.Path("/Volumes/cmip6/data/cmip6/CMIP/NCAR/LENS")
# LENS_FP = pathlib.Path("Z:/data/cmip6/CMIP/NCAR/LENS")

## variable name
## suggested: "SST" if on server and "TREFHT" if on cloud (2m atmos. temp)
VARNAME = "SST"  # use TREFHT for cloud

## number of ensemble members to load
N_MEMBERS = 35

## lon/lat range
LON_RANGE = [280, 300]
LAT_RANGE = [35, 45]

### Open data (but don't load to memory)

In [None]:
## check if data exists:
files_exist = pathlib.Path(SAVE_FP, "data_hist.nc").is_file()

## try to load pre-computed data
if files_exist:
    data_hist = xr.open_dataarray(pathlib.Path(SAVE_FP, "data_hist.nc"))
    data_fut = xr.open_dataarray(pathlib.Path(SAVE_FP, "data_fut.nc"))

else:
    ## load the data
    data_hist, data_fut = load_datasets(
        varname=VARNAME,
        n_members=N_MEMBERS,
        lon_range=LON_RANGE,
        lat_range=LAT_RANGE,
        lens_fp=LENS_FP,
        load_from_cloud=LOAD_FROM_CLOUD,
    )

### Now, load to memory (this is the slow part)

In [None]:
print("Loading historical data")
t0 = time.time()
data_hist.load()
t1 = time.time()
print(f"Elapsed time: {(t1-t0)/60:.1f} minutes\n")

print("Loading future data")
t0 = time.time()
data_fut.load()
t1 = time.time()
print(f"Elapsed time: {(t1-t0)/60:.1f} minutes\n")

## save to file
if not files_exist:
    data_hist.to_netcdf(pathlib.Path(SAVE_FP, "data_hist.nc"))
    data_fut.to_netcdf(pathlib.Path(SAVE_FP, "data_fut.nc"))

## Concatenate in time
data = xr.concat([data_hist, data_fut], dim="time")

## Plot a sample

In [None]:
## blank canvas
fig = plt.figure()

## plot background
fig, ax = plot_setup_woodshole(fig)


## plot the data
plot_data = ax.pcolormesh(
    data.lon,
    data.lat,
    data.isel(member_id=0, time=0),
    cmap="cmo.thermal",
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(plot_data, fraction=0.015, pad=0.05)

## plot outline of region
ax = plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44])

plt.show()

## Analysis

### Compute index

````{admonition} To-do
If you'd like to look at a different index, modify the function below.
````

In [None]:
def compute_T_wh(x):
    """Compute Woods Hole temperature index"""

    ## get subset of data inside the box
    data_subset = data.sel(lon=slice(287.5, 293.5), lat=slice(39, 44))

    ## compute spatial average
    return data_subset.mean(["lon", "lat"])


## do the computation here
idx = compute_T_wh(data)

### Separate forced response from internal variability

Estimate forced response and internval variability

In [None]:
## define forced response as ensemble mean
idx_forced = idx.mean("member_id")

## internal variability is the residual
idx_iv = idx - idx_forced

## compute annual means
get_ann_mean = lambda x: x.groupby("time.year").mean()
idx_forced_ann = get_ann_mean(idx_forced)
idx_iv_ann = get_ann_mean(idx_iv)
idx_ann = get_ann_mean(idx)

Lets plot the annual mean

In [None]:
## set up plot
fig, axs = plt.subplots(1, 2, figsize=(7, 2.75), layout="constrained")

## plot ensemble median and 10%/90% percentiles
plot_quantiles(axs[0], idx_ann, c="k", time_dim="year", label="Ensemble median")

## plot forced response (ensemble mean)
axs[0].plot(idx_forced_ann.year, idx_forced_ann, c="r", lw=2, label="Ensemble mean")

## plot internval variability
kwargs = dict(lw=0.5, alpha=0.5, c="gray")
for m in idx_iv_ann.member_id[:10]:
    axs[1].plot(idx_iv_ann.year, idx_iv_ann.sel(member_id=m), **kwargs)


## label plots
for ax in axs:
    ax.set_xticks([1920, 2000, 2080])
    ax.set_xlabel("Year")
    ax.set_ylabel(r"$^{\circ}$C")

# axs[0].set_yticks([18, 20, 22])
axs[1].set_yticks([-1, 0, 1])
axs[0].set_title("Forced response")
axs[1].set_title("Internal variability (10/35 members)")
axs[1].axhline(0, ls="--", c="k", zorder=0.5)
axs[1].yaxis.tick_right()
axs[1].yaxis.set_label_position("right")
axs[0].legend(prop=dict(size=8))

## save fig
if SAVE_FIGS:
    fig.savefig("figs/forced-response.svg")

plt.show()

### Compare warming rates in Mar and Sep

#### Get smoothed timeseries for Mar/Sep

Function to do the smoothing

In [None]:
def preprocess_idx(idx, month):
    """
    Preprocess index for plotting:
    1. select specified month (integer from 1 to 12)
    2. normalize by subtracting mean over first 10 years
    3. smooth the index with a 9-year rolling mean
    """

    ## 1. select month (and rename time index from "month" to "year"
    is_month = idx.time.dt.month == month
    idx_ = idx.isel(time=is_month)

    ## update time coordinate
    year = idx_.time.dt.year.values
    idx_ = idx_.rename({"time": "year"}).assign_coords({"year": year})

    ## 2. Normalize data (subtract mean over first 30 years)
    baseline = idx_.isel(year=slice(None, 10)).mean(["member_id", "year"])
    idx_ = idx_ - baseline

    ## 3. Smooth timeseries with 9-year rolling mean
    idx_ = idx_.rolling({"year": 9}, center=True).mean()
    idx_ = idx_.isel(year=slice(4, -4))

    return idx_

Do the computation

In [None]:
## compute values for Mar/Sep
idx_mar_norm = preprocess_idx(idx, month=3)
idx_sep_norm = preprocess_idx(idx, month=9)

#### Plot timeseries of forced response for Mar/Sep

In [None]:
## get colors for plot
colors = sns.color_palette()

fig, ax = plt.subplots(figsize=(3.5, 2.75), layout="constrained")

## plot data
plot_quantiles(ax, idx_mar_norm, c=colors[0], label="Mar", time_dim="year")
plot_quantiles(ax, idx_sep_norm, c=colors[1], label="Sep", time_dim="year")

## label
ax.legend(prop=dict(size=8))
ax.set_ylabel(r"$\Delta T$ ($^{\circ}$C)")
# ax.set_xticks(["1920", "2000", "2080"])
ax.set_xticks([1920, 2000, 2080])
ax.set_xlabel("Year")

## save to file
if SAVE_FIGS:
    fig.savefig("figs/forced-response_by-seasonal.svg")

plt.show()

#### Compare histograms for early/late period

Function to compute histograms

In [None]:
def get_delta_T_pdf(T, t0, t1, bin_edges):
    """function to compute PDFs of temperature difference between two periods"""

    ## get delta T
    delta_T = T.sel(year=t1).squeeze() - T.sel(year=t0).squeeze()

    ## create PDF
    pdf, _ = get_empirical_pdf(delta_T, bin_edges)

    return pdf

Do the computation

In [None]:
## specify params for PDF
kwargs0 = dict(t0=1925, t1=2000, bin_edges=np.arange(-0.3, 1.9, 0.15))
kwargs1 = dict(t0=2000, t1=2075, bin_edges=np.arange(2, 4.5, 0.2))

## compute PDFs for each period
pdf_mar_0 = get_delta_T_pdf(idx_mar_norm, **kwargs0)
pdf_sep_0 = get_delta_T_pdf(idx_sep_norm, **kwargs0)
pdf_mar_1 = get_delta_T_pdf(idx_mar_norm, **kwargs1)
pdf_sep_1 = get_delta_T_pdf(idx_sep_norm, **kwargs1)

Plot result

In [None]:
## Set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.75), layout="constrained")

## plot style for march
mar_kwargs = dict(fill=True, alpha=0.3, label="Mar")
sep_kwargs = dict(lw=1.5, label="Sep")

## plot temperature change for first period
axs[0].stairs(pdf_mar_0, edges=kwargs0["bin_edges"], **mar_kwargs)
axs[0].stairs(pdf_sep_0, edges=kwargs0["bin_edges"], **sep_kwargs)

## plot for second period
axs[1].stairs(pdf_mar_1, edges=kwargs1["bin_edges"], **mar_kwargs)
axs[1].stairs(pdf_sep_1, edges=kwargs1["bin_edges"], **sep_kwargs)

## set axis limits
axs[0].set_xlim([-0.75, 2.25])
axs[1].set_xlim([1.5, 4.5])
for ax in axs:
    ax.set_ylim([0, 3])

## label
axs[0].legend(prop=dict(size=8))
axs[0].set_title(f"{kwargs0['t0']}-{kwargs0['t1']}")
axs[1].set_title(f"{kwargs1['t0']}-{kwargs1['t1']}")
axs[0].set_xticks([-0.5, 0.5, 1.5])
axs[1].set_xticks([2, 3, 4])
axs[0].set_yticks([0, 1.5, 3])
axs[1].set_yticks([])
axs[0].set_ylabel("Prob. density")
for ax in axs:
    ax.set_xlabel(r"$\Delta T$ ($^{\circ}$C)")

## save to file
if SAVE_FIGS:
    fig.savefig("figs/histograms.svg")

plt.show()

### Look at change in standard deviation (spatial pattern)

#### Compute change

In [None]:
## get internval variability signal
data_iv = data - data.mean("member_id")

## compute standard dev over first/last 30 years of simulation
std_init = data_iv.isel(time=slice(None, 360)).std(["time", "member_id"], skipna=False)
std_end = data_iv.isel(time=slice(-360, None)).std(["time", "member_id"], skipna=False)

## get percentage change
std_pct_change = 100 * (std_end - std_init) / std_init

#### Plot initial standard deviation 

In [None]:
## blank canvas
fig = plt.figure(layout="constrained")

## plot background
fig, ax = plot_setup_woodshole(fig)


## plot the data
plot_data = ax.pcolormesh(
    std_init.lon,
    std_init.lat,
    std_init,
    cmap="cmo.amp",
    vmax=1.0,
    vmin=0.3,
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(
    plot_data,
    fraction=0.04,
    label=r"Std. dev. ($^{\circ}$C)",
)

## plot outline of region
ax = plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44])

## Label
ax.set_title("Standard dev. of SST (1920-1950)")

plt.show()

#### Plot % change

In [None]:
## blank canvas
fig = plt.figure(layout="constrained")

## plot background
fig, ax = plot_setup_woodshole(fig)


## plot the data
plot_data = ax.pcolormesh(
    std_init.lon,
    std_init.lat,
    std_pct_change,
    cmap="cmo.balance",
    vmax=30,
    vmin=-30,
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(
    plot_data, fraction=0.015, pad=0.05, ticks=[-30, 0, 30], label="% change"
)

## plot outline of region
ax = plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44])

## label
ax.set_title("Change in standard dev. of SST (1935-2065)")

## save to  file
if SAVE_FIGS:
    fig.savefig("figs/sigma-change.svg")

plt.show()