# RO-CESM validation

Should we shut off seasonal variation of noise term in equation for $\frac{dh}{dt}$? Maybe not, because it would remove seasonal predictability barrier for $h$ (note even though $\varepsilon$ has no seasonal cycle, $F_2$ *does*).

## Imports

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import warnings
import tqdm
import pathlib
import cmocean
import os
import scipy.stats

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr
import src.XRO_utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## load mixed layer depth EOFs
MLD_EOFS = src.utils.load_eofs(pathlib.Path(DATA_FP, "cesm", "eofs_mlotst.nc"))

## Functions

In [None]:
def fit_RO_to_ensemble(
    model, Th_data, T_var, h_var, fit_to_members=False, **fit_kwargs
):
    """Fit RO parameters to ensemble"""

    ## fit model and generate synthetic ensemble
    if fit_to_members:

        ## fit to individual ensemble members
        kwargs = dict(model=model, T_var=T_var, h_var=h_var, verbose=True)
        _, params = src.utils.get_RO_ensemble(Th_data, **kwargs)

    else:

        ## fit to all ensemble members simultaneously
        params = model.fit_matrix(
            Th_data[[T_var, h_var]], **fit_kwargs
        )  # ac_mask_idx=ac_mask_idx, maskNT=[])

    return params


def generate_RO_ensemble(model, params, fit_to_members=False, **simulation_kwargs):
    """generate RO ensemble using given model and parameters"""

    if fit_to_members:

        RO_ensemble = src.utils.generate_ensemble(model, params, **simulation_kwargs)

    else:

        RO_ensemble = model.simulate(fit_ds=params, **simulation_kwargs)

    return RO_ensemble


def format_axs_cyc(axs):
    """add formatting to paneled subplot of seasonal cycle"""

    axs[0].set_yticks([0, 1, 2, 3])
    axs[0].set_ylabel(r"$^{\circ}$C")
    axs[0].set_title(r"1850-1889")
    axs[1].set_title(r"2051-2100")
    axs[2].set_title(r"Early vs. Late (model)")
    axs[2].set_yticks([0, 1, 2, 3])
    axs[2].yaxis.set_ticks_position("right")
    axs[0].legend(prop=dict(size=8))
    axs[1].legend(prop=dict(size=8))
    axs[2].legend(prop=dict(size=8))
    return


def format_axs_autocorr(axs):
    """make axes nice for plot"""

    for ax in axs[1, :]:
        ax.set_xticks([-12, 0, 12], labels=[-1, 0, 1])
        ax.set_xlabel("Lag (years)")
    for ax in axs[:, 0]:
        ax.set_yticks([0, 0.5, 1])
    axs[0, 0].set_title(r"1850-1889")
    axs[0, 1].set_title(r"2051-2100")
    axs[0, 2].set_title(r"Early vs. Late (Model)")
    axs[0, 0].set_ylabel(r"Corr$(T,T)$ ")
    axs[1, 0].set_ylabel(r"Corr$(T,h)$ ")
    return


def plot_xcorr2(axs, xcorr_eval, T_var="T_3", h_var="h_w", **plot_kwargs):
    """plot <T,T> and <h,T> correlation on set of axs"""

    ## compute cross correlation and get ensemble stats
    xcorr_stats = src.utils.get_ensemble_stats(xcorr_eval)

    ## plot autocorrelation/cross correlation on separate axes
    src.utils.plot_xcorr(axs[0], xcorr_stats[T_var], **plot_kwargs)
    src.utils.plot_xcorr(axs[1], xcorr_stats[h_var], **plot_kwargs)

    return


def remove_sst_dependence_core(h, T, dims=["time", "member"]):
    """function to remove linear dependence of h on T"""

    ## get covariance/variance
    cov = xr.cov(h, T, dim=["time", "member"], ddof=0)
    var = T.var(dim=["time", "member"])

    ## remove linear dependence
    return h - cov / var * T


def remove_sst_dependence(Th, T_var="T_3", h_var="h_w", dims=["time", "member"]):
    """remove sst dependence by month"""

    helper_fn = lambda x: remove_sst_dependence_core(h=x[h_var], T=x[T_var], dims=dims)

    ## apply to each month
    return Th.groupby("time.month").map(helper_fn)

## Load data

In [None]:
## load ELI data
eli = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/eli.nc"))

## get forced/anomalous component
_, eli_anom = src.utils.separate_forced(eli)

## open data
Th = src.utils.load_cesm_indices(load_z20=True)

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## add ELI
Th["eli"] = eli_anom["eli_05"]

## standardize (for convenience)
Th /= Th.std()

## remove linear dependence
Th["h_w_hat"] = remove_sst_dependence(Th)

In [None]:
## specify T_var
T_var = "T_3"

sel = lambda x: x.sel(time=x.time.dt.month == 3)

fig, axs = plt.subplots(1, 3, figsize=(7, 2), layout="constrained")

axs[0].scatter(
    sel(Th[T_var]),
    sel(Th["h"] - Th["h_w"]),
    s=1,
)

axs[1].scatter(
    sel(Th[T_var]),
    sel(Th["h_w"]),
    s=1,
)

axs[2].scatter(
    sel(Th[T_var]),
    sel(Th["h_w_hat"]),
    s=1,
)

axs[0].set_title(r"$h-h_w$")
axs[1].set_title(r"$h_w$")
axs[2].set_title(r"$\hat{h}_w$")

plt.show()

## Fit XRO

For ELI, use:  
```python
varnames = ["eli","h_w"]
maskNT = ["T2"]
```
For RO, use:
```python
varnames = ["T_3","h_w"]
maskNT = []
```

In [None]:
## specify which variables to use
# varnames = ["T_34", "h_w"]
# varnames = ["T_3", "h_w"]
varnames = ["T_34", "h_w_z20"]
# varnames = ["eli", "h_w"]

# specify order of annual cycle
ac_order = 4

# specify whether to fit parameters to each ens. member individually
# if False, fit single set of parameters to all ensemble members
fit_to_members = False

# specify fit kwargs
fit_kwargs = dict(ac_mask_idx=None, maskNT=[])

# get data for early/late period
Th_early = Th[varnames].sel(time=slice("1851", "1900"))
Th_late = Th[varnames].sel(time=slice("2051", "2100"))

# specify model to use
model = XRO(ncycle=12, ac_order=ac_order, is_forward=True)

# fit models to early/late periods
fit_early = model.fit_matrix(Th_early, **fit_kwargs)
fit_late = model.fit_matrix(Th_late, **fit_kwargs)

## sensitivity test: check center vs. forward differencing
model_c = XRO(ncycle=12, ac_order=ac_order, is_forward=False)
fit_early_c = model_c.fit_matrix(Th_early, **fit_kwargs)
fit_late_c = model_c.fit_matrix(Th_late, **fit_kwargs)

### check sensitivity to differencing scheme
(see ```tests/test_Q.ipynb``` for more)

In [None]:
## should we plot early or late period?
USE_EARLY = True

## select fits
if USE_EARLY:
    fits = [fit_early, fit_early_c]
else:
    fits = [fit_late, fit_late_c]

## colormap
colors = sns.color_palette("colorblind")

fig, axs = plt.subplots(1, 4, figsize=(9, 2), layout="constrained")

for m, fit, c in zip([model, model_c], fits, colors):

    p = m.get_RO_parameters(fit)

    axs[0].plot(fit.cycle * 12, p.R, c=c)
    axs[1].plot(fit.cycle * 12, p.epsilon, c=c)
    axs[2].plot(fit.cycle * 12, p.BJ_ac, c=c)
    axs[2].axhline(p.BJ_ac.mean("cycle"), c=c, ls="--")
    axs[3].plot(fit.cycle * 12, p.xi_T, c=c)
    axs[3].axhline(p.xi_T.mean("cycle"), c=c, ls="--")


for ax, t in zip(axs, ["R", "eps", "BJ", "noise"]):
    ax.axhline(0, ls="--", c="k", lw=0.6)
    ax.set_title(t)
    # ax.set_yticks([])

plt.show()

## Stochastic simulations

In [None]:
## generate ensemble
## specify arguments for simulation
simulation_kwargs = dict(
    nyear=49,
    ncopy=1000,
    seed=1000,
    X0_ds=Th_early.isel(member=0, time=0),
    is_xi_stdac=True,
    noise_type="white",
    use_noise_cov=False,
)
RO_ensemble_early = model.simulate(fit_ds=fit_early, **simulation_kwargs)
RO_ensemble_late = model.simulate(fit_ds=fit_late, **simulation_kwargs)

## Diagnostics

### Seasonality

In [None]:
def plot_seasonal_stats(fn, var_idx, ylim=[None, None], yticks=[]):
    """plot seasonal cycle of given variable"""

    print(varnames[var_idx])

    ## kwargs
    kwargs = dict(fn=fn, varname=varnames[var_idx])

    ### Set up plot
    fig, axs = plt.subplots(1, 3, figsize=(7, 2), layout="constrained")

    ## plot MPI vs. RO (early period)
    plot_data_early = src.utils.plot_seasonal_comp(
        axs[0],
        x0=Th_early,
        x1=RO_ensemble_early,
        plot_kwargs0=dict(label="CESM"),
        plot_kwargs1=dict(label="RO"),
        **kwargs,
    )

    ## Plot MPI vs RO (late period)
    plot_data_late = src.utils.plot_seasonal_comp(
        axs[1],
        x0=Th_late,
        x1=RO_ensemble_late,
        plot_kwargs0=dict(label="CESM"),
        plot_kwargs1=dict(label="RO"),
        **kwargs,
    )

    ## Plot early vs late (MPI)
    plot_data_delta = src.utils.plot_seasonal_comp(
        axs[2],
        x0=Th_early,
        x1=Th_late,
        plot_kwargs0=dict(label="1850-1889", c="k"),
        plot_kwargs1=dict(label="2051-2100", c="r"),
        **kwargs,
    )

    ## label plot
    format_axs_cyc(axs)

    for ax in axs:
        ax.set_ylim(ylim)
        ax.set_yticks(yticks)
        ax.axhline(0, c="k", lw=0.5, ls="-")
    axs[1].set_yticks([])

    return fig, axs

Std dev.

In [None]:
kwargs = dict(fn=np.std, ylim=[0.5, 2], yticks=[0.5, 2])

fig, axs = plot_seasonal_stats(var_idx=0, **kwargs)
plt.show()

fig, axs = plot_seasonal_stats(var_idx=1, **kwargs)
plt.show()

Skewness

In [None]:
kwargs = dict(fn=scipy.stats.skew, ylim=[-1.5, 1.5], yticks=[-1.5, 1.5])

fig, axs = plot_seasonal_stats(var_idx=0, **kwargs)
plt.show()

fig, axs = plot_seasonal_stats(var_idx=1, **kwargs)
plt.show()

Scaling factor (to look at effect of shallower MLD)

In [None]:
def get_H(t_bnds):
    """get Niño 3.4 MLD for given bounds"""

    ## get scores in given subset of data
    scores_ = MLD_EOFS.scores().isel(time=slice(*t_bnds)).mean("member")

    ## get clim
    scores_clim = scores_.groupby("time.month").mean()

    ## reconstruct nino MLD
    mld = src.utils.reconstruct_fn(
        scores=scores_clim, components=MLD_EOFS.components(), fn=src.utils.get_nino34
    )

    return mld.rename({"month": "cycle"})


## computation
H_early = get_H((None, 600))
H_late = get_H((-612, -12))

## get difference
delta_H = H_late - H_early

## scaling factor
gamma = 1 + delta_H / H_early

Get (and plot) parameters

In [None]:
## get parameters
p0 = model.get_RO_parameters(fit_early)
p1 = model.get_RO_parameters(fit_late)


## plot them
c = sns.color_palette()
months = np.arange(1, 13)
fig, ax = plt.subplots(figsize=(4, 3))

ax.plot(months, fit_early["xi_stdac"].isel(ranky=0), c=c[0], label="noise")
ax.plot(months, fit_late["xi_stdac"].isel(ranky=0), c=c[0], ls="--")

# ax.plot(months, fit_early["xi_std"].isel(ranky=0), c=c[0], label="noise")
# ax.plot(months, fit_late["xi_std"].isel(ranky=0), c=c[0], ls="--")

ax.plot(months, p0["R"], c=c[1], label=r"$R$")
ax.plot(months, p1["R"], c=c[1], ls="--")
ax.plot(months, p1["R"] * gamma, c=c[1], ls=":")

ax.plot(months, p0["epsilon"], c=c[2], label=r"$\varepsilon$")
ax.plot(months, p1["epsilon"], c=c[2], ls="--")

ax.plot(months, p0["BJ_ac"], c="k", label=r"$BJ$")
ax.plot(months, p1["BJ_ac"], c="k", ls="--")


ax.axhline(0, ls="--", c="k", lw=1)
ax.legend(prop=dict(size=6))

plt.show()

### $T$, $h$ cross-correlation

Plotting functions

Compute cross-correlation stats

In [None]:
## function to compute cross-corr
get_xcorr = lambda x: src.XRO.xcorr(x, x[varnames[0]], maxlags=18)

## compute
xcorr_Th_early = get_xcorr(Th_early)
xcorr_Th_late = get_xcorr(Th_late)
xcorr_RO_early = get_xcorr(RO_ensemble_early)
xcorr_RO_late = get_xcorr(RO_ensemble_late)

Make plot

In [None]:
## specify which variables to plot
kwargs = dict(T_var=varnames[0], h_var=varnames[1])

fig, axs = plt.subplots(2, 3, figsize=(7, 4.5), layout="constrained")

plot_xcorr2(axs[:, 0], xcorr_Th_early, label="model", **kwargs)
plot_xcorr2(axs[:, 0], xcorr_RO_early, label="RO", **kwargs)

plot_xcorr2(axs[:, 1], xcorr_Th_late, **kwargs)
plot_xcorr2(axs[:, 1], xcorr_RO_late, **kwargs)

plot_xcorr2(axs[:, 2], xcorr_Th_early, **kwargs, c="k", label="1850-1889")
plot_xcorr2(axs[:, 2], xcorr_Th_late, **kwargs, c="r", label="2051-2100")

## label axes
format_axs_autocorr(axs)
axs[0, 0].legend(prop=dict(size=8))
axs[0, 2].legend(prop=dict(size=8))

for ax in axs[1, :]:
    ax.set_ylim([-1, None])
    ax.set_yticks([-0.8, 0, 0.8])

plt.show()

### Power spectrum

Plotting functions

Compute PSD

In [None]:
## specify args for psd
psd_kwargs = dict(dim="time", dt=1 / 12, nw=5)

## compute PSD
compute_psd = lambda x: src.XRO_utils.pmtm(x[varnames[0]], **psd_kwargs)
psd_mpi_early = compute_psd(Th_early)
psd_RO_early = compute_psd(RO_ensemble_early)
psd_mpi_late = compute_psd(Th_late)
psd_RO_late = compute_psd(RO_ensemble_late)

Plot PSD

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(7, 2.5), layout="constrained")

## plot early
src.utils.plot_psd(axs[0], psd_mpi_early, label="CESM")
src.utils.plot_psd(axs[0], psd_RO_early, label="RO")

# ## plot late
src.utils.plot_psd(axs[1], psd_mpi_late, label="CESM")
src.utils.plot_psd(axs[1], psd_RO_late, label="RO")

# ## plot late
src.utils.plot_psd(axs[2], psd_mpi_early, label="1850-1889", color="k")
src.utils.plot_psd(axs[2], psd_mpi_late, label="2051-2100", color="r")


## label
axs[0].set_ylabel(r"PSD ($^{\circ}$C$^2$/cpm)")
axs[0].legend(prop=dict(size=6))
axs[2].legend(prop=dict(size=6))
for ax in axs[1:]:
    ax.set_yticks([])


plt.show()