# ENSO example

## Imports

In [None]:
import pathlib
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import datetime
import seaborn as sns
import cmocean
import matplotlib.patches as mpatches
import matplotlib.ticker as mticker
import scipy.signal
import copy
import dask.distributed
import pandas as pd

## (optional) remove gridlines from plots
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## Functions

In [None]:
def plot_setup(fig, projection, lon_range, lat_range, xticks=None, yticks=None):
    """Add a subplot to the figure with the given map projection
    and lon/lat range. Returns an Axes object."""

    ## increase resolution for projection
    ## (otherwise lines plotted on surface won't follow curved trajectories)
    projection.threshold /= 1000

    ## Create subplot with given projection
    ax = fig.add_subplot(projection=projection)

    ## Subset to given region
    extent = [*lon_range, *lat_range]
    ax.set_extent(extent, crs=ccrs.PlateCarree())

    ## draw coastlines
    ax.coastlines(linewidths=0.5)

    ## add tick labels
    if xticks is not None:

        ## add lon/lat labels
        gl = ax.gridlines(
            draw_labels=True,
            linestyle="-",
            alpha=0.1,
            linewidth=0.5,
            color="k",
            zorder=1.05,
        )

        ## specify which axes to label
        gl.top_labels = False
        gl.right_labels = False

        ## specify ticks
        gl.ylocator = mticker.FixedLocator(yticks)
        gl.xlocator = mticker.FixedLocator(xticks)

    return ax


def plot_box_outline(ax, lon_range, lat_range, c="k"):
    """
    Plot box outlining the specifed lon/lat range on given
    ax object.
    """

    ## get width and height
    height = lat_range[1] - lat_range[0]
    width = lon_range[1] - lon_range[0]

    ## add rectangle to plot
    ax.add_patch(
        mpatches.Rectangle(
            xy=[lon_range[0], lat_range[0]],
            height=height,
            width=width,
            transform=ccrs.PlateCarree(),
            facecolor="none",
            edgecolor=c,
            linewidth=1,
        )
    )

    return ax


def plot_correlation(plot_setup_fn, corr, x, y):
    """
    Make spatial plot of correlation, using the specified
    plot setup function and pre-computed correlation.
    Args:
        - plot_setup_fn: function that returns a fig, ax object
        - corr: xarray with spatial correlation
        - x, y: lon/lat points for plotting
    """

    ## blank canvas to plot on
    fig = plt.figure()

    ## draw background map of Atlantic
    fig, ax = plot_setup_fn(fig)

    ## plot the data
    plot_data = ax.contourf(
        x,
        y,
        corr,
        transform=ccrs.PlateCarree(),
        levels=make_cb_range(1, 0.1),
        extend="both",
        cmap="cmo.balance",
    )

    ## create colorbath
    colorbar = fig.colorbar(plot_data, label="Corr.", ticks=[-1, -0.5, 0, 0.5, 1])

    return fig, ax


def plot_setup_pacific(fig):
    """Plot Atlantic region"""

    ## adjust figure size
    fig.set_size_inches(5, 3)

    ## specify map projection
    proj = ccrs.PlateCarree(central_longitude=-160)

    ## get ax object
    ax = plot_setup(
        fig,
        proj,
        lon_range=[100, 300],
        lat_range=[-30, 30],
        xticks=[150, -160, -110],
        yticks=[-20, 0, 20],
    )

    return fig, ax


def make_cb_range(amp, delta):
    """Make colorbar_range for cmo.balance
    Args:
        - 'amp': amplitude of maximum value for colorbar
        - 'delta': increment for colorbar
    """
    return np.concatenate(
        [np.arange(-amp, 0, delta), np.arange(delta, amp + delta, delta)]
    )


def plot_setup_timeseries():
    """
    Create fig, ax objects and label time axis
    """

    ## set up plot
    fig, ax = plt.subplots(figsize=(4, 3))

    ## restrict to last 50 years and label axes
    ax.set_xlim([datetime.date(1970, 1, 1), None])

    ax.set_xticks(
        [
            datetime.date(1979, 1, 1),
            datetime.date(2000, 6, 30),
            datetime.date(2021, 12, 31),
        ]
    )
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))

    return fig, ax


def plot_seasonal_cycle(mean, std):
    """
    Plot the seasonal cycle (monthly mean ± 1 standard dev.)
    """

    ## plot
    fig, ax = plt.subplots(figsize=(4, 3))

    ## mean
    ax.plot(np.arange(1, 13), mean, c="k", label=r"$\mu$")

    ## mean ± std
    ax.plot(np.arange(1, 13), mean + std, c="k", lw=0.5, label=r"$\mu \pm \sigma$")
    ax.plot(np.arange(1, 13), mean - std, c="k", lw=0.5)

    ## label

    ax.legend()

    return fig, ax


def spatial_avg(data):
    """function to compute spatial average of data on grid with constant
    longitude/latitude spacing."""

    ## first, compute cosine of latitude (after converting degrees to radians)
    latitude_radians = np.deg2rad(data.latitude)
    cos_lat = np.cos(latitude_radians)

    ## get weighted average using xarray
    avg = data.weighted(weights=cos_lat).mean(["longitude", "latitude"])

    return avg


def get_trend_coefs(data, dim="time", deg=1):
    """get coefficients for trend"""
    return data.polyfit(dim=dim, deg=deg)["polyfit_coefficients"]


def get_trend(data, dim="time", deg=1):
    """
    Get trend for an xr.dataarray along specified dimension,
    by fitting polynomial of degree 'deg'.
    """

    ## Get coefficients for best fit
    polyfit_coefs = get_trend_coefs(data=data, dim=dim, deg=deg)

    ## Get best fit line (linear trend in this case)
    trend = xr.polyval(data[dim], polyfit_coefs)

    return trend


def detrend(data, dim="time", deg=1):
    """
    Remove trend of degree 'deg' from data, along dimension 'dim'.
    """

    return data - get_trend(data, dim=dim, deg=deg)


def get_empirical_pdf(x, bin_edges=None):
    """
    Estimate the "empirical" probability distribution function for the data x.
    In this case the result is a normalized histogram,
    Normalized means that integrating over the histogram yields 1.
    Returns the PDF (normalized histogram) and edges of the histogram bins
    """

    ## compute histogram
    if bin_edges is None:
        hist, bin_edges = np.histogram(x)

    else:
        hist, _ = np.histogram(x, bins=bin_edges)

    ## normalize to a probability distribution (PDF)
    bin_width = bin_edges[1:] - bin_edges[:-1]
    pdf = hist / (hist * bin_width).sum()

    return pdf, bin_edges


def get_gaussian_best_fit(x):
    """Get gaussian best fit to data, and evaluate
    probabilities over the range of the data."""

    ## get normal distribution best fit
    gaussian = scipy.stats.norm(loc=x.mean(), scale=x.std())

    ## evaluate over range of data
    amp = np.max(np.abs(x.values))
    x_eval = np.linspace(-amp, amp)
    pdf_eval = gaussian.pdf(x_eval)

    return pdf_eval, x_eval


def swap_longitude_range(data):
    """swap longitude range of xr.DataArray from [0,360) to (-180, 180]"""

    ## copy of longitude coordinate to be modified
    new_longitude = copy.deepcopy(data.longitude.values)

    ## find index where longitude first exceeds 180.
    ## (note: np.argmax returns first instance of "True" in boolean array)
    swap_idx = np.argmax(new_longitude > 180)

    ## relabel values >180
    new_longitude[swap_idx:] = -360 + new_longitude[swap_idx:]

    ## add this coordinate back to the array
    data["longitude"] = new_longitude

    ## "roll" the data to be centered at zero
    data = data.roll({"longitude": -swap_idx}, roll_coords=True)

    return data


def get_autocorr_helper(x, lag, month=None):
    """Get autocorrelation of data for single lag"""

    ## return 1 for a lag of 0
    if lag == 0:
        return 1.0

    ## get lagged version of x
    elif lag > 0:
        x_lagged = x.isel(time=slice(lag, None))
        x_ = x.isel(time=slice(None, -lag))

    else:
        x_lagged = x.isel(time=slice(None, lag))
        x_ = x.isel(time=slice(-lag, None))

    ## re-label time axis so arrays match
    x_lagged["time"] = x_.time

    ## subset for data from given month
    if month is not None:
        is_month = x_.time.dt.month == month
        x_ = x_.isel(time=is_month)
        x_lagged = x_lagged.isel(time=is_month)

    return get_corr_coef(x_, x_lagged).item()


def get_autocorr(x, lags, month=None):
    """Get autocorrelation for data for multiple lags"""

    ## put autocorrelation for each lag in array
    autocorr = [get_autocorr_helper(x, lag, month) for lag in lags]

    ## convert to xr.DataArray
    return xr.DataArray(autocorr, coords={"lag": lags})


def get_autocorr_by_month(x, lags):
    """Get autocorrelation for each month, and stack in array"""

    ## compute autocorrelation for each month
    autocorr = [get_autocorr(x, lags, month=m) for m in np.arange(1, 13)]

    ## convert to xarray
    return xr.concat(autocorr, dim=pd.Index(np.arange(1, 13), name="month"))


def load_simulation(varname, member_id, simulation_type, preprocess_func=None):
    """
    Load dataset for single simulation, for single variable.
    Arguments:
        - varname: name of variable to load, one of {"SST","PSL"}
        - member_id: ID of ensemble member to load, an integer in the range [1,10]
        - simulation_type: one of {"hist", "rcp85"}
        - preprocess func: optional preprocessing function to apply to the simulation
    Returns:
        - xarray dataarray with given data
    """

    ## Filepath to the CESM LENS dataset
    lens_fp = pathlib.Path("cmip6/data/cmip6/CMIP/NCAR/LENS")

    #### 1. get filepath to data
    data_fp = SERVER_FP / lens_fp / pathlib.Path(varname)

    #### 2. get naming pattern for files to open
    if simulation_type == "hist":
        file_pattern = f"*20TRC*.{member_id:03d}.*.nc"

    elif simulation_type == "rcp85":
        file_pattern = f"*RCP85*.{member_id:03d}.*.nc"

    else:
        print("Not a valid simulation type")

    #### 3. open the relevant datasets, applying preprocessing function
    data = xr.open_mfdataset(
        paths=data_fp.glob(file_pattern),
        preprocess=preprocess_func,
        chunks={"time": 60},
    )

    return data[varname].squeeze(drop=True)


def trim(data, lon_range=[100, 300], lat_range=[-30, 30]):
    """select part of data in given longitude/latitude range"""

    ## check if data is on the "T"-grid
    on_Tgrid = "TLONG" in data.coords

    ## handle trimming for T-grid
    if on_Tgrid:

        ## helper function to check if 'x' is in 'x_range'
        isin_range = lambda x, x_range: (x_range[0] <= x) & (x <= x_range[1])

        ## get mask for data in given lon/lat range
        in_lon_range = isin_range(data["TLONG"], lon_range)
        in_lat_range = isin_range(data["TLAT"], lat_range)
        in_lonlat_range = in_lon_range & in_lat_range

        ## load to memory
        in_lonlat_range.load()

        ## Retain all points with at least one valid grid cell
        x_idx = in_lonlat_range.any("nlat")
        y_idx = in_lonlat_range.any("nlon")

        ## select given points
        return data.isel(nlon=x_idx, nlat=y_idx)

    else:
        return data.sel(lon=slice(*lon_range), lat=slice(*lat_range))


def load_ensemble_helper(varname, simulation_type, preprocess_func=None):
    """
    Load all ensemble members for given simulation type and variable.
    Arguments:
        - varname: name of variable to load, one of {"SST","PSL"}
        - simulation_type: one of {"hist", "rcp85"}
        - preprocess func: optional preprocessing function to apply to the simulation
    Returns:
        - xarray dataarray with given data and 'ensemble' dimension
    """

    ## put arguments in dictionary
    kwargs = dict(
        varname=varname,
        simulation_type=simulation_type,
        preprocess_func=preprocess_func,
    )

    ## put results in list
    data = [load_simulation(member_id=i, **kwargs) for i in np.arange(1, 11)]

    ## concatenate data along the "ensemble" dimension
    ensemble_dim = pd.Index(np.arange(1, 11), name="member")
    data = xr.concat(data, dim=ensemble_dim)

    return data


def load_ensemble(varname, simulation_type, preprocess_func=None, save_fp=None):
    """
    Load all ensemble members for given simulation type and variable.
    (Checks if data exists locally first).
    Arguments:
        - varname: name of variable to load, one of {"SST","PSL"}
        - simulation_type: one of {"hist", "rcp85"}
        - preprocess func: optional preprocessing function to apply to the simulation
        - save_fp: pathlib.Path object (save the result here if specified)
    Returns:
        - xarray dataarray with given data and 'ensemble' dimension
    """

    ## put arguments in dictionary
    kwargs = dict(
        varname=varname,
        simulation_type=simulation_type,
        preprocess_func=preprocess_func,
    )

    ## load pre-computed data if it exists
    if save_fp is not None:

        ## path to file
        save_fp = save_fp / f"{varname}_{simulation_type}.nc"

        ## check if file exists:
        if save_fp.is_file():
            data = xr.open_dataarray(save_fp)

        else:

            ## load the data and save to file for next time
            data = load_ensemble_helper(**kwargs)

            print("saving to file")
            data.to_netcdf(save_fp)

    else:

        ## don't load/save the data
        data = load_ensemble_helper(**kwargs)

    return data


def preprocess(data):
    """
    Preprocessing steps:
        1. remove data before Feb 1920
        2. trim in lon/lat space
        3. convert time dimension from cftime to datetime
    """

    ## trim in time
    data_ = data.sel(time=slice("1920-02", None))

    ## trim in space
    data_ = trim(data_)

    ## update time dimension
    start_year = data_.time.isel(time=0).dt.year.item()
    start_month = data_.time.isel(time=0).dt.month.item()
    start_date = f"{start_year}-{start_month}-01"
    data_["time"] = pd.date_range(start=start_date, periods=len(data_.time), freq="MS")

    return data_

## Set up Dask dashboard

I commented this out, as some people were getting "HDF: error"-like messages possibly related to this.

In [None]:
# ## set up local cluster on dask
# client = dask.distributed.Client()

# ## display information about cluster (including address of dashboard)
# client

## Load data

### Set filepaths

````{admonition} To-do
Update the filepaths ```SERVER_FP``` and ```save_fp``` in the code cell below.
````

In [None]:
## Path to file server
SERVER_FP = pathlib.Path("/Volumes")

## Specify folder location for saving trimmed data ("./" means current directory)
save_fp = pathlib.Path("./")

### Do the loading

In [None]:
## Load data
data_hist = load_ensemble("SST", "hist", preprocess_func=preprocess, save_fp=save_fp)
data_rcp = load_ensemble("SST", "rcp85", preprocess_func=preprocess, save_fp=save_fp)

## Optional: load into memory (warning: may be slow!)
data_hist.load()
data_rcp.load();

## Plot a sample

In [None]:
## blank canvas
fig = plt.figure(figsize=(10, 3))

## make background of trop. Pacific
fig, ax = plot_setup_pacific(fig)

## plot the data
plot_data = ax.pcolormesh(
    data_hist.TLONG,
    data_hist.TLAT,
    data_hist.isel(member=0, time=0),
    cmap="cmo.thermal",
    vmax=30,
    vmin=15,
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(plot_data, ticks=[15, 20, 25, 30], fraction=0.015, pad=0.05)

## plot outline of Niño 3.4 region
ax = plot_box_outline(ax, lon_range=[190, 240], lat_range=[-5, 5])

plt.show()

## Statistics for "historical" scenario

### Spatial mean and standard deviation

#### Compute spatial mean/variance

In [None]:
## compute climatology statistics
spatial_mean = data_hist.groupby("time.month").mean(["time", "member"], skipna=False)
spatial_std = data_hist.groupby("time.month").std(["time", "member"], skipna=False)

#### Plot mean for january 

In [None]:
## blank canvas
fig = plt.figure(figsize=(10, 3))

## make background of trop. Pacific
fig, ax = plot_setup_pacific(fig)

## plot the data
plot_data = ax.pcolormesh(
    spatial_mean.TLONG,
    spatial_mean.TLAT,
    spatial_mean.sel(month=12),
    cmap="cmo.thermal",
    vmax=30,
    vmin=15,
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(
    plot_data, ticks=[15, 20, 25, 30], label=r"$^{\circ}C$", fraction=0.015, pad=0.05
)

## plot outline of Niño 3.4 region
ax = plot_box_outline(ax, lon_range=[190, 240], lat_range=[-5, 5])

plt.show()

#### Plot standard dev. for January

In [None]:
## blank canvas
fig = plt.figure(figsize=(10, 3))

## make background of trop. Pacific
fig, ax = plot_setup_pacific(fig)

## plot the data
plot_data = ax.pcolormesh(
    spatial_std.TLONG,
    spatial_std.TLAT,
    spatial_std.sel(month=12),
    cmap="cmo.amp",
    vmax=2,
    vmin=0,
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(
    plot_data, ticks=[0, 1, 2], label=r"$^{\circ}C$", fraction=0.015, pad=0.05
)

## plot outline of Niño 3.4 region
ax = plot_box_outline(ax, lon_range=[190, 240], lat_range=[-5, 5])

plt.show()

### Niño 3.4 index

Function to compute Niño 3.4 index 

In [None]:
def get_nino34_idx(data):
    """compute Niño34 index"""
    return trim(data, lon_range=[190, 240], lat_range=[-5, 5]).mean(["nlon", "nlat"])

Compute the index and do some pre-processing

In [None]:
## compute nino 3.4 climatology
nino34_raw = get_nino34_idx(data_hist)
nino34_mean = nino34_raw.groupby("time.month").mean(["time", "member"])
nino34_std = nino34_raw.groupby("time.month").std(["time", "member"])

## Compute anomalies (spatial data and nino 3.4)
data_anom = data_hist.groupby("time.month") - spatial_mean
nino34_anom = get_nino34_idx(data_anom)

## compute detrended anomalies (spatial data and Niño 3.4)
data_trend = get_trend(data_anom.mean("member"))
data_detrend = data_anom - data_trend
nino34_detrend = get_nino34_idx(data_detrend)

#### plot seasonal cycle

In [None]:
## make the plot
fig, ax = plot_seasonal_cycle(nino34_mean, nino34_std)

## add some labels
ax.set_xticks([1, 7, 12], labels=["Jan", "Jul", "Dec"])
ax.set_title(r"Niño 3.4 climatology")
ax.set_ylabel(r"$K$")

plt.show()

#### plot Niño 3.4 over time

In [None]:
fig, ax = plt.subplots(figsize=(7, 3))

## plot individual ensemble members
for m in nino34_anom.member:
    ax.plot(data_hist.time, nino34_anom.sel(member=m), alpha=0.5, c="gray", lw=0.5)

## plot ensemble mean
ax.plot(
    data_hist.time,
    nino34_anom.mean("member"),
    c="r",
    lw=2,
    zorder=2,
    label="Ensemble mean",
)

## label
ax.set_ylabel(r"Niño 3.4 ($^{\circ}C$)")
ax.legend()
ax.axhline(0, ls="--", c="k", lw=1)
ax.set_xlim([datetime.datetime(1920, 1, 1), datetime.datetime(2006, 12, 31)])

plt.show()

#### Estimate trend for historical simulation

Function to estimate trend in units of [1/century]

In [None]:
def get_trend_per_100yrs(x):
    """get trend of data in units of 100/yrs"""

    ## Get timeseries of trend
    x_trend_timeseries = get_trend(x)

    ## get total trend change over time period
    dx = x_trend_timeseries[-1] - x_trend_timeseries[0]

    ## convert to units of 1/month by dividing by number of months
    dt = len(x.time)
    dx_dt = dx / dt

    ## convert from 1/month to 1/(100 yrs)
    months_per_100_yrs = 100 * 12
    dx_dt *= months_per_100_yrs

    return dx_dt

Compute trend in Niño 3.4 index

In [None]:
## compute nino34_trend
nino34_trend = get_trend_per_100yrs(nino34_anom)

Plot result

In [None]:
fig, ax = plt.subplots(figsize=(1, 3))
ax.scatter(np.ones_like(nino34_trend), nino34_trend, c="k", s=10)
ax.axhline(nino34_trend.mean(), c="k", ls="--", label="Mean")
ax.axhline(0, c="gray", alpha=0.5, lw=1)
ax.set_yticks(
    [
        0,
        0.3,
        0.6,
    ]
)
ax.set_xticks([])
ax.set_ylim([-0.1, 0.8])
ax.legend(prop={"size": 8})
ax.set_title("Niño 3.4 trend by ensemble member")
ax.set_ylabel(r"$^{\circ}C$ / 100 yrs")
plt.show()

Compute spatial trend (trend at each gridcell)

In [None]:
## compute trend
spatial_trend = get_trend_per_100yrs(data_anom).mean("member")

Plot spatial trend

In [None]:
## plot
fig = plt.figure(figsize=(10, 3))

## make background of trop. Pacific
fig, ax = plot_setup_pacific(fig)

## plot the data
plot_data = ax.pcolormesh(
    data_hist.TLONG,
    data_hist.TLAT,
    spatial_trend,
    cmap="cmo.balance",
    vmax=1,
    vmin=-1,
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(
    plot_data,
    ticks=[-1, 0, 1],
    label=r"$^{\circ}C$ / 100 yrs",
    fraction=0.015,
    pad=0.05,
)

## plot outline of Niño 3.4 region
ax = plot_box_outline(ax, lon_range=[190, 240], lat_range=[-5, 5])

## label
ax.set_title("Trend over historical simulation")

plt.show()

### Spatial pattern

#### Regression

Functions to compute regression and correlation coefficients

In [None]:
def get_regression_coef(Y, X):
    """
    Solves for 'M' in the regression equation Y = MX.
    Compute covariance matrices over 'member' and 'time' dimensions.
    Assumes data is already centered
        Y.mean(["time","member"]) == 0, and
        X.mean(["time","member"]) == 0
    """

    ## compute covariance matrices
    cov_xy = (X * Y).mean(["member", "time"])
    cov_xx = (X * X).mean(["member", "time"])

    ## least squares fit for 'M'
    M = cov_xy / cov_xx

    return M


def get_corr_coef(Y, X):
    """
    Finds correlation between X and Y.
    Compute covariance matrices over 'member' and 'time' dimensions.
    Assumes data is already centered
        Y.mean(["time","member"]) == 0, and
        X.mean(["time","member"]) == 0
    """

    ## compute covariance matrices
    cov_xy = (X * Y).mean(["member", "time"])
    cov_xx = (X * X).mean(["member", "time"])
    cov_yy = (Y * Y).mean(["member", "time"])

    ## least squares fit for 'M'
    r = cov_xy / np.sqrt(cov_xx * cov_yy)

    return r

Compute the coefficients

In [None]:
## compute linear regression coefficient and correlation coefficient
regression_coef = get_regression_coef(data_detrend, nino34_detrend)
corr = get_corr_coef(data_detrend, nino34_detrend)

Plot the result

In [None]:
## plot regression coefficient
fig = plt.figure(figsize=(10, 3))

## make background of trop. Pacific
fig, ax = plot_setup_pacific(fig)

## plot the data
plot_data = ax.pcolormesh(
    data_hist.TLONG,
    data_hist.TLAT,
    regression_coef,
    cmap="cmo.balance",
    vmax=1.5,
    vmin=-1.5,
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(
    plot_data,
    ticks=[-1.5, 0, 1.5],
    label=r"$^{\circ}C$ / Niño$_{3.4}$",
    fraction=0.015,
    pad=0.05,
)

## plot outline of Niño 3.4 region
ax = plot_box_outline(ax, lon_range=[190, 240], lat_range=[-5, 5], c="w")

## label
ax.set_title(r"ENSO spatial pattern (historical)")

plt.show()

#### Composite
First, a function to compute composites.

In [None]:
def composite(data, mask):
    """
    Create composite (average) based on specified mask.
    Args:
        - data: dataarray to use for the composite
        - mask: dataarray with dimensions ["member","time"];
            used to filter 'data' to create the composite
    Returns:
        - composite
        - n_samples: number of samples in the composite
    """

    ## average over masked entries
    composite = data.where(mask).mean(["member", "time"], skipna=True)

    ## get number of samples
    n_samples = mask.sum()

    return composite, n_samples


## compute composites

Next, compute the composites

In [None]:
## get composite for warm and cold events
comp_warm, n_warm = composite(data_detrend, mask=nino34_detrend > 1.5)
comp_cold, n_cold = composite(data_detrend, mask=nino34_detrend < -1.5)

Finally, plot the composites

In [None]:
## spatial pattern of composites
for comp, count, label in zip(
    [comp_warm, comp_cold], [n_warm, n_cold], ["warm", "cold"]
):

    ## plot regression coefficient
    fig = plt.figure(figsize=(10, 3))

    ## make background of trop. Pacific
    fig, ax = plot_setup_pacific(fig)

    ## plot the data
    plot_data = ax.pcolormesh(
        data_hist.TLONG,
        data_hist.TLAT,
        comp,
        cmap="cmo.balance",
        vmax=3,
        vmin=-3,
        transform=ccrs.PlateCarree(),
    )

    ## make a colorbar
    cb = fig.colorbar(
        plot_data,
        ticks=[-3, 0, 3],
        label=r"$^{\circ}C$",
        fraction=0.015,
        pad=0.05,
    )

    ## plot outline of Niño 3.4 region
    ax = plot_box_outline(ax, lon_range=[190, 240], lat_range=[-5, 5], c="w")

    ## label
    ax.set_title(f"ENSO composite ({label}, n = {count.values.item()})")

    plt.show()


## Plot the asymmetry
fig = plt.figure(figsize=(10, 3))

## make background of trop. Pacific
fig, ax = plot_setup_pacific(fig)

## plot the data
plot_data = ax.pcolormesh(
    data_hist.TLONG,
    data_hist.TLAT,
    comp_warm + comp_cold,
    cmap="cmo.balance",
    vmax=1.5,
    vmin=-1.5,
    transform=ccrs.PlateCarree(),
)

## make a colorbar
cb = fig.colorbar(
    plot_data,
    ticks=[-1.5, 0, 1.5],
    label=r"$^{\circ}C$",
    fraction=0.015,
    pad=0.05,
)

## plot outline of Niño 3.4 region
ax = plot_box_outline(ax, lon_range=[190, 240], lat_range=[-5, 5], c="w")

## label
ax.set_title(f"Composite asymmetry (warm plus cold)")

plt.show()

## Autocorrelation

Compute autocorrelation by month

In [None]:
## compute autocorrelation by month
autocorr_by_month = get_autocorr_by_month(nino34_detrend, lags=np.arange(-24, 25))

Plot the result

In [None]:
## plot
fig, ax = plt.subplots(figsize=(6, 3))
ax.set_aspect("equal")

## plot data
lags = np.arange(0, 19)
months = np.arange(1, 13)
plot_data = ax.pcolormesh(
    lags, months, autocorr_by_month.sel(lag=lags), cmap="cmo.balance", vmin=-1, vmax=1
)

## colorbar
cb = fig.colorbar(plot_data, label="corr. coef.", ticks=[-1, 0, 1])

## label
ax.set_ylabel("Start month")
ax.set_xlabel("Lag (months)")
ax.set_xticks([0, 6, 12, 18])
ax.set_yticks([2, 7, 12], labels=["Feb", "Jul", "Dec"])
ax.set_title("Niño 3.4 autocorrelation")

## swap direction of y-axis
ax.set_ylim(ax.get_ylim()[::-1])

plt.show()

## Future Projections

### Niño 3.4 Timeseries

Compute Niño 3.4 in RCP 8.5 simulation and concatenate to Niño 3.4 from historical simulation.

In [None]:
## compute Niño 3.4 for RCP 8.5
nino34_rcp_raw = get_nino34_idx(data_rcp)

## concatenate with historical timeseries
nino34_raw_long = xr.concat([nino34_raw, nino34_rcp_raw], dim="time")

## remove seasonal cycle
deseason = lambda x: x.groupby("time.month") - x.groupby("time.month").mean(
    ["time", "member"]
)
nino34_anom_long = deseason(nino34_raw_long)

Plot concatenated timeseries

In [None]:
## plot result
fig, ax = plt.subplots(figsize=(7, 3))

## plot individual ensemble members
for m in nino34_anom_long.member:
    ax.plot(
        nino34_anom_long.time,
        nino34_anom_long.sel(member=m),
        alpha=0.5,
        c="gray",
        lw=0.5,
    )

## plot ensemble mean
ax.plot(
    nino34_anom_long.time,
    nino34_anom_long.mean("member"),
    c="r",
    lw=2,
    zorder=2,
    label="Ensemble mean",
)

## label
ax.set_ylabel(r"Niño 3.4 ($^{\circ}C$)")
ax.axhline(0, ls="--", c="k", lw=1)
ax.axvline(datetime.datetime(2006, 1, 1), lw=0.5, c="k")
ax.legend()

plt.show()

### Compare variance between 1920-1960 and 2160-2100

Compute PDFs over early and late period

In [None]:
## Get Niño 3.4 over early and late periods
nino34_anom_early = nino34_anom_long.sel(time=slice("1920", "1960"))
nino34_anom_late = nino34_anom_long.sel(time=slice("2060", "2100"))

## compute PDFs (normalized histograms)
pdf_early, edges_early = get_empirical_pdf(nino34_anom_early)
pdf_late, edges_late = get_empirical_pdf(nino34_anom_late)

## compute PDFs on detrended data
# helper function to detrend
detrend_fn = lambda x: x - get_trend(x.mean("member"))

# specify bin edges for the histograms
bin_edges = np.arange(-4 - 0.875, 4 + 1.25, 0.75)

# compute the PDFs
pdf_early_, edges_early_ = get_empirical_pdf(
    detrend_fn(nino34_anom_early), bin_edges=bin_edges
)
pdf_late_, edges_late_ = get_empirical_pdf(
    detrend_fn(nino34_anom_late), bin_edges=bin_edges
)

Plot result

In [None]:
#### Plot result
fig, axs = plt.subplots(1, 2, figsize=(8, 3))

## plot histogram
axs[0].stairs(values=pdf_early, edges=edges_early, label="1920-1960")
axs[0].stairs(values=pdf_late, edges=edges_late, label="2060-2100")

## label
axs[0].set_xlabel(r"$^{\circ}C$ anomaly")
axs[0].set_ylabel("Probability")

axs[0].legend()
axs[0].set_title("Niño 3.4 PDFs")


#### next, plot centered histograms
axs[1].stairs(values=pdf_early_, edges=edges_early_, label="1920-1960")
axs[1].stairs(values=pdf_late_, edges=edges_late_, label="2060-2100")

## make sure y-axis is the same across plots and remove y-ticks from the RHS panel
axs[0].set_ylim(axs[1].get_ylim())
axs[1].set_yticks([])

## label
axs[1].set_xlabel(r"$^{\circ}C$ anomaly")
axs[1].set_title("Niño 3.4 PDFs (detrended data)")

plt.show()

### Change in autocorrelation

Estimate autocorrelation in early & late periods, and compute the difference

In [None]:
## specify lags for autocorrelation
lags = np.arange(1, 19)

## compute autocorrelations and difference
autocorr_early = get_autocorr_by_month(detrend_fn(nino34_anom_early), lags=lags)
autocorr_late = get_autocorr_by_month(detrend_fn(nino34_anom_late), lags=lags)
autocorr_diff = autocorr_late - autocorr_early

Plot the result

In [None]:
## plot autocorelation for early time series
fig, ax = plt.subplots(figsize=(6, 3))
ax.set_aspect("equal")

## plot data
plot_data = ax.pcolormesh(
    autocorr_early.lag,
    autocorr_early.month,
    autocorr_early,
    cmap="cmo.balance",
    vmin=-1,
    vmax=1,
)

## colorbar
cb = fig.colorbar(plot_data, label="corr. coef.", ticks=[-1, 0, 1])

## label
ax.set_ylabel("Start month")
ax.set_xlabel("Lag (months)")
ax.set_xticks([0, 6, 12, 18])
ax.set_yticks([2, 7, 12], labels=["Feb", "Jul", "Dec"])
ax.set_title("Niño 3.4 autocorrelation (1920-1940)")

## swap direction of y-axis
ax.set_ylim(ax.get_ylim()[::-1])

plt.show()

## plot difference (late minus early)
fig, ax = plt.subplots(figsize=(6, 3))
ax.set_aspect("equal")

## plot data
plot_data = ax.pcolormesh(
    autocorr_diff.lag,
    autocorr_diff.month,
    autocorr_diff,
    cmap="cmo.balance",
    vmin=-0.3,
    vmax=0.3,
)

## colorbar
cb = fig.colorbar(plot_data, label="corr. coef.", ticks=[-1, 0, 1])

## label
ax.set_ylabel("Start month")
ax.set_xlabel("Lag (months)")
ax.set_xticks([0, 6, 12, 18])
ax.set_yticks([2, 7, 12], labels=["Feb", "Jul", "Dec"])
ax.set_title("Change in autocorr. (2060-2100 minus 1920-1960)")

## swap direction of y-axis
ax.set_ylim(ax.get_ylim()[::-1])

plt.show()