# RO reference
Fit RO on observations to get "ground truth" behavior

## Imports

In [None]:
import warnings
import copy
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import pandas as pd
import os
import scipy.stats
import xeofs as xe
import src.lim

# Import custom modules
import src.XRO
import src.XRO_utils
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## get RNG
rng = np.random.default_rng()

## Load data

In [None]:
## specify filepaths
oras_fp = DATA_FP / "oras5"

## load data
data = src.utils.load_oras_spatial_extended(oras_fp, varnames=["tos", "d20"])

## compute indices
idx_total = src.utils.get_RO_indices(data, h_var="d20")

## compute "wide" hbar
idx_total["h_"] = src.utils.spatial_avg(
    data["d20"].sel(longitude=slice(120, 280), latitude=slice(-15, 15))
)
idx_total["h__"] = src.utils.spatial_avg(
    data["d20"].sel(longitude=slice(120, 210), latitude=slice(-15, 15))
)

## Detrend / pre-process

In [None]:
## estimate forced signal by removing 2nd-order polynomial from each calendar month
detrend_fn = lambda x: src.utils.detrend_dim(x, dim="time", deg=3)
idx = idx_total.groupby("time.month").map(detrend_fn)
idx_forced = idx_total - idx

## standardize for convenience
idx /= idx.std()

## Cross correlation stats

Plotting function

In [None]:
def format_xcorr_ax(ax):
    """make xcorr plot look nice"""

    axis_kwargs = dict(c="k", lw=0.5, alpha=0.5)
    ax.axhline(0, **axis_kwargs)
    ax.axvline(0, **axis_kwargs)
    ax.set_ylim([-0.9, 1.1])
    ax.set_xlabel("Lag (years)")
    ax.set_xticks([-24, -12, 0, 12, 24], labels=[-2, -1, 0, 1, 2])
    ax.set_yticks([-0.5, 0, 0.5, 1])
    ax.set_ylabel("Correlation")
    return

Make the plot

In [None]:
## compute cross-corr
xcorr = src.XRO.xcorr(idx, idx["T_3"], maxlags=36)

## plot result
fig, ax = plt.subplots(figsize=(5, 3.5))

## plot data
ax.plot(xcorr.lag, xcorr["T_3"], label=r"$T_3$", c="k")
ax.plot(xcorr.lag, xcorr["T_34"], label=r"$T_{3.4}$", c="k", ls="--")
ax.plot(xcorr.lag, xcorr["h"], label=r"$h$")
ax.plot(xcorr.lag, xcorr["h_w"], label=r"$h_w$")
ax.plot(xcorr.lag, xcorr["h_"], label=r"$h_{wide}$")

## format plot
ax.set_title("Corr. with Niño 3")
ax.legend()
format_xcorr_ax(ax)

plt.show()

## Fit RO models

In [None]:
## specify order of annual cycle, mask parameters
ac_order = 3
ac_mask_idx = [(1, 1)]  # epsilon
# ac_mask_idx = [(1,0), (1, 1)] # epsilon and F2
# ac_mask_idx = [(0,1),(1,0),(1, 1)] # all except R
# ac_mask_idx = None

## initialize model
model = src.XRO.XRO(ncycle=12, ac_order=ac_order, is_forward=True)

## get fit for reanalysis
fit_h = model.fit_matrix(idx[["T_34", "h"]], ac_mask_idx=ac_mask_idx)
fit_hw = model.fit_matrix(idx[["T_34", "h_"]], ac_mask_idx=ac_mask_idx)

## extract params
p_h = model.get_RO_parameters(fit_h)
p_hw = model.get_RO_parameters(fit_hw)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4.5, 2), layout="constrained")

for i, (ax, n) in enumerate(zip(axs, ["R", "BJ_ac"])):
    ax.plot(p_h.cycle, p_h[n], label=r"$h$")
    ax.plot(p_hw.cycle, p_hw[n], label=r"$h_w$")

    ## format
    ax.axhline(0, ls="-", c="k", lw=0.5)
    ax.set_xticks([1, 7, 12], labels=["Jan", "Jul", "Dec"])

## plot epsilon values as well
kwargs = dict(lw=1, ls="--")
axs[0].plot(p_h.cycle, p_h["epsilon"], c=sns.color_palette()[0], **kwargs)
axs[0].plot(p_h.cycle, p_hw["epsilon"], c=sns.color_palette()[1], **kwargs)

## label
axs[1].set_yticks([])
axs[0].set_yticks([-4, -2, 0, 2])
axs[0].set_title(r"$R$")
axs[1].set_title(r"$2\left(R-\varepsilon\right)$")
axs[0].legend(prop=dict(size=8))
axs[0].set_ylim([-6, 3])
axs[1].set_ylim(0.5 * np.array(axs[0].get_ylim()))

plt.show()

## check stats 

Generate simulations

In [None]:
## specify random IC
x0 = idx.isel(time=rng.choice(np.arange(len(idx.time))))

## simulation specs
simulation_kwargs = dict(
    nyear=63,
    ncopy=1000,
    is_xi_stdac=False,
)

## do simulations
kwargs_h = dict(simulation_kwargs, fit_ds=fit_h, X0_ds=x0[["T_34", "h"]])
X_h = model.simulate(**kwargs_h)

kwargs_hw = dict(simulation_kwargs, fit_ds=fit_hw, X0_ds=x0[["T_34", "h_"]])
X_hw = model.simulate(**kwargs_hw)

In [None]:
### Set up plot
fig, axs = plt.subplots(1, 2, figsize=(4.5, 2), layout="constrained")

## plot RO with h (early period)
plot_data_early = src.utils.plot_seasonal_comp(
    axs[0],
    x0=idx.expand_dims("member"),
    x1=X_h,
    plot_kwargs0=dict(label="ORAS"),
    plot_kwargs1=dict(label="RO"),
    varname="T_34",
    use_quantile=True,
)

## plot RO with h (early period)
plot_data_early = src.utils.plot_seasonal_comp(
    axs[1],
    x0=idx.expand_dims("member"),
    x1=X_hw,
    plot_kwargs0=dict(label="ORAS"),
    plot_kwargs1=dict(label="RO"),
    varname="T_34",
    use_quantile=True,
)

axs[1].legend(prop=dict(size=8))

for ax in axs:
    ax.set_ylim([0, 2])

plt.show()

## Compute $R$ explicitly

In [None]:
## compute gradients
grads = src.XRO.gradient(src.XRO._convert_to_numpy(idx))

## add to array
for i, n in enumerate(list(idx)):
    idx.update({f"{n}_grad": xr.DataArray(grads[i], coords=dict(time=idx.time))})


def get_stats(T_var, h_var, month):
    """Get statistics needed to compute R"""

    ## indexer
    t_idx = dict(time=idx.time.dt.month == month)

    ## get subset of array
    sel_vars = [T_var, h_var, f"{T_var}_grad", f"{h_var}_grad"]
    names = {
        T_var: "T",
        h_var: "h",
        f"{T_var}_grad": "T_grad",
        f"{h_var}_grad": "h_grad",
    }
    X = idx.isel(t_idx)[sel_vars].rename(names)

    ## compute stats
    sigma_T = np.std(X["T"].values)
    sigma_h = np.std(X["h"].values)
    r = scipy.stats.pearsonr(X["T"].values, X["h"].values)[0]

    ## compute covariance
    TtT = np.mean(X["T_grad"] * X["T"]).values.item() * np.pow(sigma_T, -2)
    Tth = (
        np.mean(X["T_grad"] * X["h"]).values.item()
        * np.pow(sigma_T, -1)
        * np.pow(sigma_h, -1)
    )

    return dict(r=r, TtT=TtT, Tth=Tth)


def get_R(T_var, h_var, month):
    """function to compute estimate of R for given variables and month"""

    ## compute stats
    stats = get_stats(T_var=T_var, h_var=h_var, month=month)
    r = stats["r"]
    TtT = stats["TtT"]
    Tth = stats["Tth"]

    ## compute estimate for R
    Rhat = 1 / (1 - r**2) * (TtT - r * Tth)

    return Rhat

\begin{align}
    R &= \frac{\sigma_T^{-1}}{1-r^2}\left[\left<T_t,\tilde{T}\right> - r\left<T_t,\tilde{h}\right>\right]
\end{align}

In [None]:
#### compute
## kwargs
kwargs_h = dict(T_var="T_34", h_var="h")
kwargs_hw = dict(T_var="T_34", h_var="h_")

## compute estimates
Rhats_h = np.array([get_R(month=m, **kwargs_h) for m in np.arange(1, 13)])
Rhats_hw = np.array([get_R(month=m, **kwargs_hw) for m in np.arange(1, 13)])


#### plot
colors = sns.color_palette()
fig, ax = plt.subplots(figsize=(4, 3))

## plot first version
ax.plot(p_h.cycle, Rhats_h, c=colors[0], ls="--")
ax.plot(p_h.cycle, p_h["R"], c=colors[0], label=r"$h$")

## plot second version
ax.plot(p_h.cycle, Rhats_hw, c=colors[1], ls="--")
ax.plot(p_h.cycle, p_hw["R"], c=colors[1], label=r"$h_w$")

## label
ax.legend(prop=dict(size=8))
ax.axhline(0, lw=1, c="k", zorder=0.5)

plt.show()

## Scatter plots

Look at relationship between $T_e$ and $\overline{h}-h_w$

In [None]:
m = 10
t_idx = dict(time=idx.time.dt.month == m)

fig, ax = plt.subplots(figsize=(3, 3))
ax.scatter(
    # idx["h"].isel(t_idx) - idx["h_w"].isel(t_idx),
    # idx["T_3"].isel(t_idx),
    idx["T_34"].isel(t_idx),
    # idx["h_"].isel(t_idx),# - idx["h"].isel(t_idx),
    idx["h"].isel(t_idx),
    # idx["h_"].isel(t_idx), idx["h"].isel(t_idx)
)


## label
ax.set_ylabel(r"$\overline{h}_{wide} - \overline{h}$")
ax.set_xlabel(r"$T_e$")
kwargs = dict(c="k", lw=0.8, zorder=0.5)
ax.axvline(0, **kwargs)
ax.axhline(0, **kwargs)
plt.show()

In [None]:
m = 11
t_idx = dict(time=idx.time.dt.month == m)

fig, axs = plt.subplots(1, 2, figsize=(6, 3))
axs[0].scatter(
    idx["T_3"].isel(t_idx),
    idx["h"].isel(t_idx) - idx["h_w"].isel(t_idx),
)

axs[1].scatter(
    idx["T_3"].isel(t_idx),
    idx["h_"].isel(t_idx) - idx["h_w"].isel(t_idx),
)


## label
axs[0].set_ylabel(r"$\Delta h$")
axs[1].set_yticks([])
axs[0].set_title(r"$\overline{h}-h_w$")
axs[1].set_title(r"$\overline{h}_{wide}-h_w$")
kwargs = dict(c="k", lw=0.8, zorder=0.5)
axs[1].set_ylim(axs[0].get_ylim())

for ax in axs:
    ax.axvline(0, **kwargs)
    ax.axhline(0, **kwargs)
    ax.set_xlabel(r"$T_e$")
plt.show()

## Compute EOFs

In [None]:
def remove_sst_dependence(ds, sst_idx, remove_from_sst=False):
    """function to remove linear dependence of variables on SST"""

    ## get h-variables and T-variables
    if remove_from_sst:
        h_vars = list(ds)
        T_vars = []
    else:
        h_vars = [n for n in list(ds) if ("h" in n) or ("d20" in n)]
        T_vars = list(set(list(ds)) - set(h_vars))

    ## create array to hold results
    ds_hat = copy.deepcopy(ds[h_vars])

    ## add SST index to array
    ds_hat = ds_hat.assign_coords(dict(sst_idx=("time", sst_idx.data)))

    ## function to remove linear dependence
    fn = lambda x: src.utils.detrend_dim(x, deg=1, dim="sst_idx")

    ## remove linear dependence for each month separately
    ds_hat = ds_hat.groupby("time.month").map(fn)

    ## drop sst index as dim and add as variable
    ds_hat = ds_hat.drop_vars("sst_idx")

    ## merge with T-data
    ds_hat = xr.merge([ds_hat, ds[T_vars]])

    return ds_hat

get detrended data

In [None]:
def get_ddt(ds):
    """compute time derivative for each variable in dataset"""

    ## transpose so time is last dimension
    ds = ds.transpose(..., "time")

    ## loop through variables
    for n in list(ds):

        ## create empty variable and fill with gradient
        ds[f"ddt_{n}"] = xr.zeros_like(ds[n])
        ds[f"ddt_{n}"].values = src.XRO.gradient(ds[n].values)

    return ds

In [None]:
## get subset of data
data_ = data[["sst", "d20"]].groupby("time.month").map(detrend_fn)
data_ = data_.sel(longitude=slice(120, 280))

## get version with no Niño 3.4 dependence
data_hat = remove_sst_dependence(data_, idx["T_34"])
idx_hat = remove_sst_dependence(idx, idx["T_34"])

## compute gradients
data_ = get_ddt(data_)
data_hat = get_ddt(data_hat)

In [None]:
## specify time index for EOFs
t_idx = dict(time=slice(None, None, None))

## specify latitude range for EOFs
lat_idx = dict(latitude=slice(-5, 5))

## define subsetting operator
subset = lambda x: x.isel(t_idx).sel(lat_idx)

## fit EOFs
model_d20 = xe.single.EOF(use_coslat=True, n_modes=10)
model_d20.fit(subset(data_["d20"]), dim="time")

model_sst = xe.single.EOF(use_coslat=True, n_modes=10)
model_sst.fit(subset(data_["sst"]), dim="time")

## get weights for each mode
s = np.sqrt(model_d20.explained_variance())

## get patterns
patterns = model_d20.components() * s

## plot leading modes
fig = plt.figure(figsize=(8, 2.5), layout="constrained")
axs = src.utils.subplots_with_proj(
    fig, nrows=2, ncols=1, format_func=src.utils.plot_setup_pac
)

## shared arguments for plotting
kwargs = dict(
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
    levels=src.utils.make_cb_range(16, 2),
    extend="both",
)

cp = axs[0, 0].contourf(
    patterns.longitude,
    patterns.latitude,
    patterns.isel(mode=0),
    **kwargs,
)
cb = fig.colorbar(cp)

cp = axs[1, 0].contourf(
    patterns.longitude,
    patterns.latitude,
    patterns.isel(mode=1),
    **kwargs,
)
kwargs = dict(ls="--", c="w", lw=0.8)
for ax in axs.flatten():
    ax.set_extent([120, 280, -20, 20], crs=ccrs.PlateCarree())
    ax.axhline(-5, **kwargs)
    ax.axhline(5, **kwargs)

cb = fig.colorbar(cp)

plt.show()

## Look at regression with Niño 3.4 index

Create hovmollers of this by season?

In [None]:
#### specify regression to plot
use_hat = True
# idx_var = "T_34"
# spatial_var = "ddt_sst"
idx_var = "T_34"
spatial_var = "ddt_d20"
month = 4
t_idx = dict(time=slice(month - 1, None, 12))

## get data to use
if use_hat:
    spatial_data = data_hat[spatial_var]
    idx_data = idx_hat[idx_var]

else:
    spatial_data = data_[spatial_var]
    idx_data = idx[idx_var]

## compute correlation coefs
coefs = src.utils.rho_xr(Y_data=spatial_data.isel(t_idx), idx=idx_data.isel(t_idx))

## plot leading modes
fig = plt.figure(figsize=(6, 1.25), layout="constrained")
axs = src.utils.subplots_with_proj(
    fig, nrows=1, ncols=1, format_func=src.utils.plot_setup_pac
)

# ## shared arguments for plotting
kwargs = dict(
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
    # levels=src.utils.make_cb_range(16, 2),
    levels=src.utils.make_cb_range(0.75, 0.15),
    extend="both",
)

cp = axs[0, 0].contourf(
    coefs.longitude,
    coefs.latitude,
    coefs,
    **kwargs,
)
cb = fig.colorbar(cp)

kwargs = dict(ls="--", c="w", lw=0.8)
for ax in axs.flatten():
    ax.set_extent([120, 280, -20, 20], crs=ccrs.PlateCarree())
    ax.axhline(-5, **kwargs)
    ax.axhline(5, **kwargs)

plt.show()

## LIM

#### Put data in correct format

In [None]:
def reconstruct(scores, other_coord=None):
    """reconstruct projected data"""

    ## get reconstructions
    kwargs = dict(other_coord=other_coord)
    T_recon = src.utils.reconstruct_helper(scores[:10], model=model_sst, **kwargs)
    h_recon = src.utils.reconstruct_helper(scores[10:], model=model_d20, **kwargs)

    return xr.merge([T_recon.rename("T"), h_recon.rename("h")]).real


## apply to projected data
get_weighted_scores = lambda model: model.scores() * src.utils.get_weight(model).values

## reshape
XY_np = np.concatenate(
    [
        get_weighted_scores(model_sst).rename("sst").values,
        get_weighted_scores(model_d20).rename("d20").values,
    ],
    axis=0,
)

## get month labels
month_idx = model_sst.scores().time.dt.month.values - 1

## fit LIM
lim = src.lim.LIM_CS(X=XY_np[:, :-1], Y=XY_np[:, 1:], month_labels=month_idx[:-1])

## get spectrum
sigma = 12 * lim.sigma
omega = 12 / (2 * np.pi) * lim.omega

## check reconstruction works
print(
    np.allclose(
        model_d20.inverse_transform(model_d20.scores()).values,
        reconstruct(XY_np)["h"].values,
        equal_nan=True,
    )
)

## fit LIM reference
XY_idx = np.stack([idx["T_34"].values, idx["h"].values], axis=0)
RO = src.lim.LIM_CS(X=XY_idx[:, :-1], Y=XY_idx[:, 1:], month_labels=month_idx[:-1])
sigma_RO = 12 * RO.sigma
omega_RO = 12 / (2 * np.pi) * RO.omega

Plot spectrum

In [None]:
## plot spectrum
fig, ax = plt.subplots(figsize=(3, 2.5))

## plot data
ax.scatter(sigma, omega)

## plot RO as reference
ax.scatter(sigma_RO, omega_RO)

## label / guidelines
kwargs = dict(c="k", zorder=0.5, lw=1, ls="--")
for yt in [-0.25, 0, 0.25]:
    ax.axhline(yt, **kwargs)
ax.set_yticks([-1 / 4, 0, 1 / 4], labels=[r"1/4", "0", "1/4"])
ax.set_xticks([-1, -0.5, 0])
ax.set_xlim([-1, 0])
ax.set_xlabel(r"$\sigma$ (year$^{-1}$)")
ax.set_ylabel(r"$\omega$ (year$^{-1}$)")

plt.show()

#### evaluate eigenfuncs

In [None]:
## specify eigenfunction index for ENSO
enso_idx = 0
enso_idxs = np.array([enso_idx, enso_idx + 1])

## get reduced eigendecomp
Uk = copy.deepcopy(lim.U[..., enso_idxs])
gammak = copy.deepcopy(lim.gamma[enso_idxs])
Vk = copy.deepcopy(lim.V[..., enso_idxs])

## evaluate ENSO eigenfunction
varphi = np.einsum("nmk,mn->nk", Vk[month_idx], XY_np)

## same, but in RO
varphi_RO = np.einsum("nmk,mn->nk", RO.V[month_idx], XY_idx)

Find phase of "peak" ENSO for Dec

In [None]:
## specify month to find peak index
peak_month_idx = 0

## get test pts
theta_test = np.arange(0, 2 * np.pi, np.pi / 32)
varphi_test = varphi.std() * np.exp(1j * theta_test)
VtX = np.stack([varphi_test, np.conj(varphi_test)], axis=0)

## Get recon (EOF space)
recon = Uk[peak_month_idx] @ VtX

## Get recon (real space)
recon_xr = reconstruct(recon, other_coord=pd.Index(theta_test, name="theta"))
nino_recon = src.utils.get_nino34(recon_xr["T"])

## get theta for maximimum Niño in peak month
theta_max = recon_xr.theta.isel(theta=nino_recon.argmax("theta")).values.item()

print(theta_max)

#### Look at evolution of "peak" event

In [None]:
## get difference from max
lags_months = np.arange(-12, 13)
lags_years = lags_months / 12
delta_theta = 2 * np.pi * omega[enso_idx] * lags_years

## get data
month_idx_test = np.mod(peak_month_idx + lags_months, 12)
theta_test = np.mod(theta_max + delta_theta, 2 * np.pi)
varphi_test = 1.25 * varphi.std() * np.exp(1j * theta_test)

## get complex conjugate
varphi_test = np.stack([varphi_test, np.conj(varphi_test)], axis=0)

## Get recon (EOF space)
recon = np.einsum(
    "nmk,kn->mn",
    Uk[month_idx_test],
    varphi_test,
)

## Get recon (real space)
recon_xr = reconstruct(recon, other_coord=pd.Index(lags_months, name="lag"))

Same, but with composite

In [None]:
## remove dependence on SST as well (optional)
data_hat = remove_sst_dependence(data_, idx["T_34"], remove_from_sst=False)

## make composite
comp = src.utils.make_composite(
    idx=idx["T_34"],
    data=data_,
    peak_month=12,
    q=0.85,
    check_cutoff=lambda x, cut: x > cut,
).rename({"sst": "T", "d20": "h"})

#### Hövmöller

In [None]:
## get meridional means
lat = dict(latitude=slice(-5, 5))
recon_merimean = recon_xr.sel(lat).mean("latitude")
comp_merimean = comp.sel(lat).mean("latitude")

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(6, 3), layout="constrained")

for ax, merimean in zip(axs, [recon_merimean, comp_merimean]):

    ## SST
    ax.contourf(
        merimean.longitude,
        merimean.lag,
        merimean["T"],
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(3, 0.3),
        extend="both",
    )

    ## thermocline
    ax.contour(
        merimean.longitude,
        merimean.lag,
        merimean["h"],
        colors="k",
        levels=src.utils.make_cb_range(40, 8),
        extend="both",
        linewidths=1,
    )

    ## label x axis
    ax.set_xticks([190, 240])

## label
axs[0].set_yticks([-12, 0, 12], labels=["Jan(-1)", "Jan(0)", "Jan(+1)"])
axs[1].set_yticks([])
ax.set_xticks([190, 240])

plt.show()

In [None]:
## evaluate eigenfuncs
varphi_all = np.einsum("nmk,mn->nk", lim.V[month_idx], XY_np)

## get magnitude
mag = np.abs(varphi_all[:, enso_idx])
is_large = mag > np.percentile(mag, 70)


fig, ax = plt.subplots(figsize=(2, 2))
ax.scatter(
    src.utils.get_angle(varphi_all[is_large, enso_idx]),
    src.utils.get_angle(varphi_all[is_large, 4]),
    s=10,
)

plt.show()

#### Plot result

In [None]:
## specify plot data (LIM recon or comp)
plot_data = recon_xr
# plot_data = comp

for lag in plot_data.lag[::2]:

    ## plot leading modes
    fig = plt.figure(figsize=(8, 2.5), layout="constrained")
    axs = src.utils.subplots_with_proj(
        fig, nrows=2, ncols=1, format_func=src.utils.plot_setup_pac
    )

    ## shared arguments for plotting
    kwargs = dict(
        cmap="cmo.balance",
        transform=ccrs.PlateCarree(),
        extend="both",
    )

    cp = axs[0, 0].contourf(
        plot_data.longitude,
        plot_data.latitude,
        plot_data["T"].sel(lag=lag),
        levels=src.utils.make_cb_range(3, 0.3),
        **kwargs,
    )

    cp = axs[1, 0].contourf(
        plot_data.longitude,
        plot_data.latitude,
        plot_data["h"].sel(lag=lag),
        levels=src.utils.make_cb_range(40, 4),
        **kwargs,
    )
    kwargs = dict(ls="--", c="w", lw=0.8)
    for ax in axs.flatten():
        ax.set_extent([120, 280, -20, 20], crs=ccrs.PlateCarree())
        ax.axhline(-5, **kwargs)
        ax.axhline(5, **kwargs)

    axs[0, 0].set_title(f"Lag = {lag.values.item()} months")

    plt.show()

#### Compare RO eigenmode to LIM

In [None]:
def cross_corr(x0, x1):
    """function to compute cross correlation"""

    ## compute cross correlation
    cov = scipy.signal.correlate(x0, x1, mode="same")
    var0 = scipy.signal.correlate(x0, x0, mode="valid")
    var1 = scipy.signal.correlate(x1, x1, mode="valid")

    ## get lags
    lags = scipy.signal.correlation_lags(len(x0), len(x0), mode="same")

    return cov / np.sqrt(var0 * var1), lags

In [None]:
## this works for T_34/h combination
theta = 1 * np.pi / 8

## get phase shift
phi = np.exp(1j * theta)

xcorr, lags = cross_corr(phi * varphi_RO[:, 0], varphi[:, 0])
xcorr0, _ = cross_corr(varphi_RO[:, 0], varphi_RO[:, 0])
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(lags, xcorr.real)
ax.plot(lags, xcorr.imag)
ax.plot(lags, xcorr0.real, c=sns.color_palette()[0], ls="--")
ax.axvline(0, c="k")
ax.set_xlim([-36, 36])
plt.show()