# Example
In this example, we'll use a single ensemble member from CESM1-LE to diagnose the forced response to external forcing near Woods Hole.

## Imports

In [None]:
import pathlib
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import datetime
import seaborn as sns
import cmocean
import matplotlib.patches as mpatches
import matplotlib.ticker as mticker
import copy
import pandas as pd
import tqdm
import time

## (optional) remove gridlines from plots
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## should we save figs?
SAVE_FIGS = False

## Functions

In [None]:
def trim(data, lon_range, lat_range):
    """select part of data in given longitude/latitude range"""

    ## helper function to check if 'x' is in 'x_range'
    isin_range = lambda x, x_range: (x_range[0] <= x) & (x <= x_range[1])

    ## get mask for data in given lon/lat range
    in_lon_range = isin_range(data["TLONG"], lon_range)
    in_lat_range = isin_range(data["TLAT"], lat_range)
    in_lonlat_range = in_lon_range & in_lat_range

    ## load to memory
    in_lonlat_range.load()

    ## Retain all points with at least one valid grid cell
    x_idx = in_lonlat_range.any("nlat")
    y_idx = in_lonlat_range.any("nlon")

    ## select given points
    return data.isel(nlon=x_idx, nlat=y_idx)


def sort_longitude(data):
    """shuffles data so that longitude is monotonically increasing"""

    ## Transpose data so that longitude is last dimension
    ## (we'll do all the sorting along this dimension)
    data = data.transpose(..., "nlon")

    ## Get indices needed to sort longitude to be monotonic increasing
    TLONG_sort_idx = np.argsort(data["TLONG"].values, axis=-1)

    ## sort the lon/lat coordindates
    sort = lambda X, idx: np.take_along_axis(X.values, indices=idx, axis=-1)
    data["TLONG"].values = sort(data["TLONG"], idx=TLONG_sort_idx)
    data["TLAT"].values = sort(data["TLAT"], idx=TLONG_sort_idx)

    #### sort the data

    # first, check to see if data has more than two dimensions
    if data.ndim > 2:
        extra_dims = [i for i in range(data.ndim - 2)]
        TLONG_sort_idx = np.expand_dims(TLONG_sort_idx, axis=extra_dims)

    ## now, do the actual sorting
    data.values = sort(data, idx=TLONG_sort_idx)

    return data


def swap_longitude_range(data):
    """swap longitude range of xr.DataArray from [0,360) to (-180, 180].
    Handles case with 2-dimension longitude coordinates ('TLONG')"""

    ## make copy of longitude coordinate to be modified
    TLONG_new = copy.deepcopy(data.TLONG.values)

    ## relabel values greater than 180
    exceeds_180 = TLONG_new > 180
    TLONG_new[exceeds_180] = -360 + TLONG_new[exceeds_180]

    ## Update the coordinate on the xarray object
    data["TLONG"].values = TLONG_new

    return data


def plot_setup(fig, projection, lon_range, lat_range, xticks=None, yticks=None):
    """Add a subplot to the figure with the given map projection
    and lon/lat range. Returns an Axes object."""

    ## increase resolution for projection
    ## (otherwise lines plotted on surface won't follow curved trajectories)
    projection.threshold /= 1000

    ## Create subplot with given projection
    ax = fig.add_subplot(projection=projection)

    ## Subset to given region
    extent = [*lon_range, *lat_range]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    ## draw coastlines
    ax.coastlines(linewidths=0.5)

    ## add tick labels
    if xticks is not None:

        ## add lon/lat labels
        gl = ax.gridlines(
            draw_labels=True,
            linestyle="-",
            alpha=0.1,
            linewidth=0.5,
            color="k",
            zorder=1.05,
        )

        ## specify which axes to label
        gl.top_labels = False
        gl.right_labels = False

        ## specify ticks
        gl.ylocator = mticker.FixedLocator(yticks)
        gl.xlocator = mticker.FixedLocator(xticks)

    return ax


def plot_box_outline(ax, lon_range, lat_range, c="k"):
    """
    Plot box outlining the specifed lon/lat range on given
    ax object.
    """

    ## get width and height
    height = lat_range[1] - lat_range[0]
    width = lon_range[1] - lon_range[0]

    ## add rectangle to plot
    ax.add_patch(
        mpatches.Rectangle(
            xy=[lon_range[0], lat_range[0]],
            height=height,
            width=width,
            transform=ccrs.PlateCarree(),
            facecolor="none",
            edgecolor=c,
            linewidth=1,
        )
    )

    return ax


def plot_setup_woodshole(fig):
    """Plot zoomed-in view of Woods Hole"""

    ## adjust figure size
    fig.set_size_inches(5, 3)

    ## set map projection to orthographic
    proj = ccrs.Orthographic(central_longitude=-67.5, central_latitude=40)

    ## Get ax object based on generic plotting function
    ax = plot_setup(
        fig,
        proj,
        lon_range=[-80, -60],
        lat_range=[35, 45],
        xticks=[-80, -70, -60],
        yticks=[25, 40],
    )

    return fig, ax


def plot_setup_atlantic(fig):
    """Plot Atlantic region"""

    ## adjust figure size
    fig.set_size_inches(7.5, 3.75)

    ## specify map projection
    proj = ccrs.Orthographic(central_longitude=-50, central_latitude=40)

    ## get ax object
    ax = plot_setup(
        fig,
        proj,
        lon_range=[-90, -10],
        lat_range=[20, 60],
        xticks=[-80, -50, -20],
        yticks=[25, 45],
    )

    return fig, ax


def make_cb_range(amp, delta):
    """Make colorbar_range for cmo.balance
    Args:
        - 'amp': amplitude of maximum value for colorbar
        - 'delta': increment for colorbar
    """
    return np.concatenate(
        [np.arange(-amp, 0, delta), np.arange(delta, amp + delta, delta)]
    )


def get_empirical_pdf(x, bin_edges=None):
    """
    Estimate the "empirical" probability distribution function for the data x.
    In this case the result is a normalized histogram,
    Normalized means that integrating over the histogram yields 1.
    Returns the PDF (normalized histogram) and edges of the histogram bins
    """

    ## compute histogram
    if bin_edges is None:
        hist, bin_edges = np.histogram(x)

    else:
        hist, _ = np.histogram(x, bins=bin_edges)

    ## normalize to a probability distribution (PDF)
    bin_width = bin_edges[1:] - bin_edges[:-1]
    pdf = hist / (hist * bin_width).sum()

    return pdf, bin_edges


def get_autocorr(x, lags, month=None):
    """Get autocorrelation for data for multiple lags"""

    ## put autocorrelation for each lag in array
    autocorr = [get_autocorr_helper(x, lag, month) for lag in lags]

    ## convert to xr.DataArray
    autocorr = xr.concat(autocorr, dim=pd.Index(lags, name="lag"))
    # return xr.DataArray(autocorr, coords={"lag": lags})
    return autocorr


def get_autocorr_helper(x, lag, month=None):
    """Get autocorrelation of data for single lag"""

    ## return 1 for a lag of 0
    if lag == 0:
        return xr.ones_like(x.isel(time=0))

    ## get lagged version of x
    elif lag > 0:
        x_lagged = x.isel(time=slice(lag, None))
        x_ = x.isel(time=slice(None, -lag))

    else:
        x_lagged = x.isel(time=slice(None, lag))
        x_ = x.isel(time=slice(-lag, None))

    ## re-label time axis so arrays match
    x_lagged["time"] = x_.time

    ## subset for data from given month
    if month is not None:
        is_month = x_.time.dt.month == month
        x_ = x_.isel(time=is_month)
        x_lagged = x_lagged.isel(time=is_month)

    return xr.corr(x_, x_lagged, dim="time")


def load_simulation(varname, member_id, simulation_type, preprocess_func=None):
    """
    Load dataset for single simulation, for single variable.
    Arguments:
        - varname: name of variable to load, one of {"SST","PSL"}
        - member_id: ID of ensemble member to load, an integer in the range [1,10]
        - simulation_type: one of {"hist", "rcp85"}
    Returns:
        - xarray dataarray with given data
    """

    ## Filepath to the CESM LENS dataset
    lens_fp = pathlib.Path("cmip6/data/cmip6/CMIP/NCAR/LENS")

    #### 1. get filepath to data
    data_fp = SERVER_FP / lens_fp / pathlib.Path(varname)

    #### 2. get naming pattern for files to open
    if simulation_type == "hist":
        file_pattern = f"*20TRC*.{member_id:03d}.*.nc"

    elif simulation_type == "rcp85":
        file_pattern = f"*RCP85*.{member_id:03d}.*.nc"

    else:
        print("Not a valid simulation type")

    #### 3. open the relevant datasets, applying preprocessing function

    ## filepath to data
    fp = list(data_fp.glob(file_pattern))[0]

    ## load data
    data = xr.open_dataset(fp, chunks=None, decode_timedelta=True)

    ## apply (optional) preprocessing
    if preprocess_func is not None:
        data = preprocess_func(data)

    return data[varname].squeeze(drop=True)


def get_trend(data, dim="time", deg=1):
    """
    Get trend for an xr.dataarray along specified dimension,
    by fitting polynomial of degree 'deg'.
    """

    ## Get coefficients for best fit
    polyfit_coefs = data.polyfit(dim=dim, deg=deg)["polyfit_coefficients"]

    ## Get best fit line (linear trend in this case)
    trend = xr.polyval(data[dim], polyfit_coefs)

    return trend


def detrend(data, dim="time", deg=1):
    """
    Remove trend of degree 'deg' from data, along dimension 'dim'.
    """

    return data - get_trend(data, dim=dim, deg=deg)

## Load data

### Set filepaths

````{admonition} To-do
Update the filepaths ```SERVER_FP``` and ```save_fp``` in the code cell below.
````

In [None]:
## Path to file server
# SERVER_FP = pathlib.Path("/vortexfs1/share")
SERVER_FP = pathlib.Path("/Volumes")

## Specify folder location for saving trimmed data ("./" means current directory)
save_fp = pathlib.Path("./")

### Load the data

In [None]:
## keep track of loading time
t0 = time.time()

## shared arguments for loading data
load_kwargs = dict(varname="SST", member_id=10, preprocess_func=None)

## Load data
data_hist = load_simulation(simulation_type="hist", **load_kwargs).compute()
data_rcp = load_simulation(simulation_type="rcp85", **load_kwargs).compute()

## concatenate in time
data = xr.concat([data_hist, data_rcp], dim="time")

## print out loading time
t1 = time.time()
print(f"Elapsed time: {t1-t0:.1f} seconds")

### Trim in lon/lat space

````{admonition} To-do
If you'd like to look at a different region, change ```lon_range``` and ```lat_range``` in the preprocessing function below.
````

In [None]:
## subset data by longitude
data = trim(data, lon_range=[260, 360], lat_range=[0, 70])

## swap longitude range from [0,360) to (-180, 180]
data = swap_longitude_range(data)

## make sure longitude is in ascending order
data = sort_longitude(data)

### Downsample from monthly to seasonal averages

In [None]:
## get seasonal averages (omit first/last timesteps)
## (resample from monthly to quarterly starting with December, "QS-DEC")
## averages will be: "DJF", "MAM", "JJA", "SON"
data = data.resample({"time": "QS-DEC"}).mean()

## omit first/last timesteps (don't have enough months for full seasonal avg)
data = data.isel(time=slice(1, -1))

### Plot a sample

````{admonition} To-do
If you're not looking at Woods Hole, you may need to adapt the plotting function below (```plot_setup_woodshole```).
````

In [None]:
## blank canvas
fig = plt.figure()

## plot background
fig, ax = plot_setup_atlantic(fig)


## plot the data
plot_data = ax.contourf(
    data.TLONG,
    data.TLAT,
    data.isel(time=0),
    cmap="cmo.thermal",
    levels=np.arange(-3, 30, 3),
    transform=ccrs.PlateCarree(),
    extend="min",
)

## make a colorbar
cb = fig.colorbar(plot_data, ticks=[-3, 12, 27])

## plot outline of region
ax = plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44])

plt.show()

## Statistics

### Compute index

````{admonition} To-do
If you'd like to look at a different index, modify the function below.
````

In [None]:
def compute_T_wh(x):
    """Compute Woods Hole temperature index"""

    ## get subset of data inside the box
    data_subset = trim(x, lon_range=[-72.5, -66.5], lat_range=[39, 44])

    ## compute spatial average
    return data_subset.mean(["nlon", "nlat"])

Do the computational, and get seasonal averages

In [None]:
## do the computation here
idx = compute_T_wh(data)

## compute SON average
is_son = idx.time.dt.month == 9
idx_son = idx.isel(time=is_son)
idx_son["time"] = idx_son.time.dt.year

### SON trend and anomalies

In [None]:
## compute trends (linear and quadratic)
idx_son_trend1 = get_trend(idx_son, dim="time", deg=1)
idx_son_trend2 = get_trend(idx_son, dim="time", deg=2)

## estimate anomalies (using quadratic)
idx_son_anom = idx_son - idx_son_trend2

#### Plotting function

In [None]:
def setup_trend_plot():
    """create fig and axs for plotting trends"""

    ## blank canvas
    fig, axs = plt.subplots(1, 2, figsize=(7, 2.75), layout="constrained")

    #### label plots
    for ax in axs:
        ax.set_xticks([1920, 2006, 2080])
        ax.set_xlabel("Year")
        ax.set_ylabel(r"$^{\circ}$C")

    axs[0].set_yticks([22, 24.5, 27])
    axs[1].set_yticks([-1, 0, 1])
    axs[0].set_title("Total")
    axs[1].set_title("Anomaly")
    axs[1].axhline(0, c="k", zorder=0.5, lw=0.5)
    axs[1].yaxis.tick_right()
    axs[1].yaxis.set_label_position("right")

    for ax in axs:
        ## plot boundary between HIST and RCP
        ax.axvline(2006, ls="--", lw=0.5, c="k")

    return fig, axs

#### Make the plot

In [None]:
fig, axs = setup_trend_plot()

## plot raw index
axs[0].plot(idx_son.time, idx_son, c="gray", label="Raw")

# get list for (i) trends and (ii) corresponding linestyles
trends = [idx_son_trend1, idx_son_trend2]
linestyles = ["--", "-"]

# plot each trend with corresponding linestyle
for i, (trend, ls) in enumerate(zip(trends, linestyles), start=1):

    kwargs = dict(c="r", ls=ls, label=f"Trend (degree {i})")
    axs[0].plot(trend.time, trend, **kwargs)

## on RHS sub-panel, plot anomalies
axs[1].plot(idx_son_anom.time, idx_son_anom, c="gray")

## add legend
axs[0].legend(prop=dict(size=8))

plt.show()

## Spatial pattern of warming

#### Function to compute climatology for given period and month

In [None]:
def get_clim(data, yr_range, QS_month=9):
    """function to get SON climatology for specified range"""

    ## find samples in given quarter
    in_quarter = data.time.dt.month == QS_month

    ## subset for samples in quarter
    data_subset = data.sel(time=in_quarter)

    ## subset for samples in year range
    data_subset = data_subset.sel(time=slice(*yr_range))

    ## average in time
    return data_subset.mean("time")

#### Do computation

In [None]:
## compare SON averages in two 30-yr periods
yr_range0 = ["1980", "2010"]
yr_range1 = ["2050", "2080"]

## get SON climatology for each period
clim0 = get_clim(data, yr_range=yr_range0, QS_month=9)
clim1 = get_clim(data, yr_range=yr_range1, QS_month=9)

## get difference
delta_clim = clim1 - clim0

#### Plot difference

In [None]:
## blank canvas
fig = plt.figure()

## plot background
fig, ax = plot_setup_atlantic(fig)

## plot the difference
plot_data = ax.pcolormesh(
    data.TLONG,
    data.TLAT,
    delta_clim,
    cmap="cmo.amp",
    vmax=4,
    vmin=1,
    transform=ccrs.PlateCarree(),
)

## plot the background state
ax.contour(
    data.TLONG,
    data.TLAT,
    clim0,
    colors="w",
    levels=np.arange(-2, 34, 4),
    transform=ccrs.PlateCarree(),
    extend="both",
    linewidths=1,
    alpha=0.7,
)

## make a colorbar
cb = fig.colorbar(plot_data, ticks=[1, 4], label=r"$^{\circ}$C")

## plot outline of region
ax = plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44])

## label
ax.set_title(r"$\Delta$(SST) b/n 1980-2010 and 2050-2080")

plt.show()

## Scratch

Plot difference

Lets plot the annual mean

Compute temperature change over two pairs of periods

In [None]:
def get_delta_T_pdf(T, t0, t1, bin_edges):
    """function to compute PDFs of temperature difference between two periods"""

    ## get delta T
    delta_T = T.sel(time=t1).squeeze() - T.sel(time=t0).squeeze()

    ## create PDF
    pdf, _ = get_empirical_pdf(delta_T, bin_edges)

    return pdf


## specify params for PDF
kwargs0 = dict(t0="1924", t1="2000", bin_edges=np.arange(-0.3, 1.9, 0.15))
kwargs1 = dict(t0="2000", t1="2076", bin_edges=np.arange(2, 4.5, 0.2))

## compute PDFs for each period
pdf_mar_0 = get_delta_T_pdf(idx_mar_norm, **kwargs0)
pdf_sep_0 = get_delta_T_pdf(idx_sep_norm, **kwargs0)
pdf_mar_1 = get_delta_T_pdf(idx_mar_norm, **kwargs1)
pdf_sep_1 = get_delta_T_pdf(idx_sep_norm, **kwargs1)

Plot result

In [None]:
## Set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.75), layout="constrained")

## plot style for march
mar_kwargs = dict(fill=True, alpha=0.3, label="Mar")
sep_kwargs = dict(lw=1.5, label="Sep")

## plot temperature change for first period
axs[0].stairs(pdf_mar_0, edges=kwargs0["bin_edges"], **mar_kwargs)
axs[0].stairs(pdf_sep_0, edges=kwargs0["bin_edges"], **sep_kwargs)

## plot for second period
axs[1].stairs(pdf_mar_1, edges=kwargs1["bin_edges"], **mar_kwargs)
axs[1].stairs(pdf_sep_1, edges=kwargs1["bin_edges"], **sep_kwargs)

## set axis limits
axs[0].set_xlim([-0.75, 2.25])
axs[1].set_xlim([1.5, 4.5])
for ax in axs:
    ax.set_ylim([0, 3])

## label
axs[0].legend(prop=dict(size=8))
axs[0].set_title(f"{kwargs0['t0']}-{kwargs0['t1']}")
axs[1].set_title(f"{kwargs1['t0']}-{kwargs1['t1']}")
axs[0].set_xticks([-0.5, 0.5, 1.5])
axs[1].set_xticks([2, 3, 4])
axs[0].set_yticks([0, 1.5, 3])
axs[1].set_yticks([])
axs[0].set_ylabel("Prob. density")
for ax in axs:
    ax.set_xlabel(r"$\Delta T$ ($^{\circ}$C)")

## save to file
if SAVE_FIGS:
    fig.savefig("figs/histograms.svg")

plt.show()

## Look at change in standard deviation (spatial pattern)

In [None]:
## get internval variability signal
data_iv = data - data.mean("member")

## compute standard dev over first/last 30 years of simulation
std_init = data_iv.isel(time=slice(None, 360)).std(["time", "member"], skipna=False)
std_end = data_iv.isel(time=slice(-360, None)).std(["time", "member"], skipna=False)

## get percentage change
std_pct_change = 100 * (std_end - std_init) / std_init

### Plot initial standard deviation result 

In [None]:
## blank canvas
fig = plt.figure()

## plot background
fig, ax = plot_setup_woodshole(fig)


## plot the data
plot_data = ax.pcolormesh(
    std_init.TLONG,
    std_init.TLAT,
    std_init,
    cmap="cmo.amp",
    vmax=1,
    vmin=0.3,
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(
    plot_data,
    fraction=0.015,
    pad=0.05,
    ticks=[0.3, 1],
    label=r"Std. dev. ($^{\circ}$C)",
)

## plot outline of region
ax = plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44])

## Label
ax.set_title("Standard dev. of SST (1920-1950)")

plt.show()

### Plot % change

In [None]:
## blank canvas
fig = plt.figure(layout="constrained")

## plot background
fig, ax = plot_setup_woodshole(fig)


## plot the data
plot_data = ax.pcolormesh(
    std_init.TLONG,
    std_init.TLAT,
    std_pct_change,
    cmap="cmo.balance",
    vmax=40,
    vmin=-40,
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(
    plot_data, fraction=0.015, pad=0.05, ticks=[-40, 0, 40], label="% change"
)

## plot outline of region
ax = plot_box_outline(ax, lon_range=[287.5, 293.5], lat_range=[39, 44])

## label
ax.set_title("Change in standard dev. of SST (1935-2065)")

## save to  file
if SAVE_FIGS:
    fig.savefig("figs/sigma-change.svg")

plt.show()