# RO-MPI validation

Should we shut off seasonal variation of noise term in equation for $\frac{dh}{dt}$? Maybe not, because it would remove seasonal predictability barrier for $h$ (note even though $\varepsilon$ has no seasonal cycle, $F_2$ *does*).

## Imports

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import warnings
import tqdm
import pathlib
import cmocean
import os

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr
import src.XRO_utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Functions

In [None]:
def fit_RO_to_ensemble(
    model, Th_data, T_var, h_var, fit_to_members=False, **fit_kwargs
):
    """Fit RO parameters to ensemble"""

    ## fit model and generate synthetic ensemble
    if fit_to_members:

        ## fit to individual ensemble members
        kwargs = dict(model=model, T_var=T_var, h_var=h_var, verbose=True)
        _, params = src.utils.get_RO_ensemble(Th_data, **kwargs)

    else:

        ## fit to all ensemble members simultaneously
        params = model.fit_matrix(
            Th_data[[T_var, h_var]], **fit_kwargs
        )  # ac_mask_idx=ac_mask_idx, maskNT=[])

    return params


def generate_RO_ensemble(model, params, fit_to_members=False, **simulation_kwargs):
    """generate RO ensemble using given model and parameters"""

    if fit_to_members:

        RO_ensemble = src.utils.generate_ensemble(model, params, **simulation_kwargs)

    else:

        RO_ensemble = model.simulate(fit_ds=params, **simulation_kwargs)

    return RO_ensemble


def format_axs_cyc(axs):
    """add formatting to paneled subplot of seasonal cycle"""

    axs[0].set_yticks([0, 1, 2, 3])
    axs[0].set_ylabel(r"$^{\circ}$C")
    axs[0].set_title(r"1850-1889")
    axs[1].set_title(r"2051-2100")
    axs[2].set_title(r"Early vs. Late (model)")
    axs[2].set_yticks([0, 1, 2, 3])
    axs[2].yaxis.set_ticks_position("right")
    axs[0].legend(prop=dict(size=8))
    axs[1].legend(prop=dict(size=8))
    axs[2].legend(prop=dict(size=8))
    return


def format_axs_autocorr(axs):
    """make axes nice for plot"""

    for ax in axs[1, :]:
        ax.set_xticks([-12, 0, 12], labels=[-1, 0, 1])
        ax.set_xlabel("Lag (years)")
    for ax in axs[:, 0]:
        ax.set_yticks([0, 0.5, 1])
    axs[0, 0].set_title(r"1850-1889")
    axs[0, 1].set_title(r"2051-2100")
    axs[0, 2].set_title(r"Early vs. Late (MPI)")
    axs[0, 0].set_ylabel(r"Corr$(T,T)$ ")
    axs[1, 0].set_ylabel(r"Corr$(T,h)$ ")
    return


def plot_xcorr2(axs, xcorr_eval, T_var="T_3", h_var="h_w", **plot_kwargs):
    """plot <T,T> and <h,T> correlation on set of axs"""

    ## compute cross correlation and get ensemble stats
    xcorr_stats = src.utils.get_ensemble_stats(xcorr_eval)

    ## plot autocorrelation/cross correlation on separate axes
    src.utils.plot_xcorr(axs[0], xcorr_stats[T_var], **plot_kwargs)
    src.utils.plot_xcorr(axs[1], xcorr_stats[h_var], **plot_kwargs)

    return

## Load data

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience

Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## Diagnose variance issues

### Model fitting

In [None]:
# specify variables to use in RO and order of annual cycle
# varnames = ["T_34", "h", "iobm","iod","npmm","spmm","natl","nino_atl"]
varnames = ["T_34", "h"]
ac_order = 3

# specify whether to fit parameters to each ens. member individually
# if False, fit single set of parameters to all ensemble members
fit_to_members = False

# specify fit kwargs
fit_kwargs = dict(ac_mask_idx=None, maskNT=[])

# get first ~50ish years of data
# X = Th[varnames].isel(time=slice(12, 612))
X = Th[varnames].isel(time=slice(-612, -12))

# initialize model
model = XRO(ncycle=12, ac_order=ac_order, is_forward=True)

## fit model
fit = model.fit_matrix(X[varnames], **fit_kwargs)

## get parameters
p = model.get_RO_parameters(fit)

## generate ensemble
## specify arguments for simulation
simulation_kwargs = dict(
    nyear=49,
    ncopy=1000,
    seed=1000,
    X0_ds=X.isel(member=0, time=0),
    is_xi_stdac=True,
)
X_sim = model.simulate(fit_ds=fit, **simulation_kwargs)

### Diagnostics

#### Seasonal synch

In [None]:
### Set up plot
fig, ax = plt.subplots(figsize=(2.5, 2), layout="constrained")

## plot MPI vs. RO (early period)
plot_data_early = src.utils.plot_seasonal_comp(
    ax,
    x0=X,
    x1=X_sim,
    plot_kwargs0=dict(label="CESM"),
    plot_kwargs1=dict(label="RO"),
    varname="T_34",
)
ax.set_ylim([None, 2])

plt.show()

#### PSD

Compute

In [None]:
## specify which variable to use
varname = "T_3"

## specify args for psd
psd_kwargs = dict(dim="time", dt=1 / 12, nw=5)

## compute PSD
compute_psd = lambda x: src.XRO_utils.pmtm(x[varname], **psd_kwargs)
psd = compute_psd(X)
psd_RO = compute_psd(X_sim)

Plot

In [None]:
fig, ax = plt.subplots(figsize=(3.5, 2.5), layout="constrained")

## plot early
src.utils.plot_psd(ax, psd, label="CESM")
src.utils.plot_psd(ax, psd_RO, label="RO")

## label
ax.set_ylabel(r"PSD ($^{\circ}$C$^2$/cpm)")
ax.legend(prop=dict(size=6))
ax.legend(prop=dict(size=6))

plt.show()

#### Autocorrelation

Compute

In [None]:
## specify T variable to use
T_var = "T_3"

## function to compute cross-corr
get_xcorr = lambda x: src.XRO.xcorr(x, x[T_var], maxlags=18)

## compute
xcorr = get_xcorr(X)
xcorr_sim = get_xcorr(X_sim)

Plot

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(2.33, 4.5), layout="constrained")

plot_xcorr2(axs, xcorr, label="CESM")
plot_xcorr2(axs, xcorr_sim, label="RO")

## label
axs[0].legend(prop=dict(size=8))
axs[1].set_ylim([-0.9, 0.9])
axs[0].set_title(r"$T$")
axs[1].set_title(r"$h$")

plt.show()

## (old) Fit RO to ensemble data

In [None]:
# specify variables to use in RO fit order of annual cycle
T_var = "T_3"
h_var = "h_w"
ac_order = 5

# specify whether to fit parameters to each ens. member individually
# if False, fit single set of parameters to all ensemble members
fit_to_members = False

# specify fit kwargs
# ac_mask_idx: which indices to mask annual cycle out for.
# maskNT: which nonlinear terms to include for dT/dt
# fit_kwargs = dict(ac_mask_idx=[(1, 1)], maskNT=["T2","T3"])
linear_fit_kwargs = dict(ac_mask_idx=[(1, 1)], maskNT=[])
# fit_kwargs = dict(ac_mask_idx=[(1,1)], maskNT=[])
fit_kwargs = dict(ac_mask_idx=None, maskNT=[])
# fit_kwargs = dict(ac_mask_idx=None, maskNT=["T2","TH","T3"])

# get data for early/late period
Th_early = Th.isel(time=slice(12, 612))
Th_late = Th.isel(time=slice(-588, None))

# specify model to use
model = XRO(ncycle=12, ac_order=ac_order, is_forward=True)

# fit models to early/late periods
RO_params_early = fit_RO_to_ensemble(
    model,
    Th_data=Th_early,
    T_var=T_var,
    h_var=h_var,
    fit_to_members=False,
    **fit_kwargs
)

# fit models to early/late periods
RO_params_linear = fit_RO_to_ensemble(
    model,
    Th_data=Th_early,
    T_var=T_var,
    h_var=h_var,
    fit_to_members=False,
    **linear_fit_kwargs
)

RO_params_late = fit_RO_to_ensemble(
    model, Th_data=Th_late, T_var=T_var, h_var=h_var, fit_to_members=False, **fit_kwargs
)

## Get ensemble of RO simulations

In [None]:
## specify arguments for simulation
simulation_kwargs = dict(
    nyear=49,
    ncopy=1000,
    seed=1000,
    X0_ds=Th[[T_var, h_var]].isel(member=0, time=0),
    is_xi_stdac=True,
)

## perform simulations
RO_ensemble_early = generate_RO_ensemble(
    model, RO_params_early, fit_to_members=fit_to_members, **simulation_kwargs
)
RO_ensemble_linear = generate_RO_ensemble(
    model, RO_params_linear, fit_to_members=fit_to_members, **simulation_kwargs
)
RO_ensemble_late = generate_RO_ensemble(
    model, RO_params_late, fit_to_members=fit_to_members, **simulation_kwargs
)


def filter_sim(output, T_var):

    ## find simulations which blow up
    bad_sim = (np.abs(output[T_var]) > 5).any("time")

    return output.isel(member=~bad_sim)


## filter outsimulations which blow up
RO_ensemble_early = filter_sim(RO_ensemble_early, T_var)
RO_ensemble_late = filter_sim(RO_ensemble_late, T_var)

## Diagnostics

### Seasonality

In [None]:
### Set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 2), layout="constrained")

## plot MPI vs. RO (early period)
plot_data_early = src.utils.plot_seasonal_comp(
    axs[0],
    x0=Th_early,
    x1=RO_ensemble_early,
    plot_kwargs0=dict(label="MPI"),
    plot_kwargs1=dict(label="RO"),
    varname=T_var,
)

## Plot MPI vs RO (late period)
plot_data_late = src.utils.plot_seasonal_comp(
    axs[1],
    x0=Th_late,
    x1=RO_ensemble_late,
    plot_kwargs0=dict(label="MPI"),
    plot_kwargs1=dict(label="RO"),
    varname=T_var,
)

## Plot early vs late (MPI)
plot_data_delta = src.utils.plot_seasonal_comp(
    axs[2],
    x0=RO_ensemble_early,
    x1=RO_ensemble_late,
    # x0=Th_early,
    # x1=Th_late,
    plot_kwargs0=dict(label="1850-1889", c="k"),
    plot_kwargs1=dict(label="2051-2100", c="r"),
    varname=T_var,
)

## label plot
format_axs_cyc(axs)

for ax in axs:
    ax.set_ylim([None, 2])

plt.show()

In [None]:
Th_ = Th_early

m = 12
t_idx = dict(time=Th_.time.dt.month == m)

fig, ax = plt.subplots(figsize=(3, 3))
ax.scatter(
    Th_["h"].isel(t_idx) - Th_["h_w"].isel(t_idx),
    Th_["T_3"].isel(t_idx),
    s=1,
    # -idx["h_w"].isel(t_idx), idx["T_3"].isel(t_idx)
)


## label
ax.set_xlabel(r"$\overline{h}-h_w$")
ax.set_ylabel(r"$T_e$")
kwargs = dict(c="k", lw=0.8, zorder=0.5)
ax.axvline(0, **kwargs)
ax.axhline(0, **kwargs)
plt.show()

In [None]:
p0 = model.get_RO_parameters(RO_params_early)
p1 = model.get_RO_parameters(RO_params_late)

In [None]:
c = sns.color_palette()

months = np.arange(1, 13)
fig, ax = plt.subplots(figsize=(4, 3))


ax.plot(months, RO_params_early["normxi_stdac"].isel(ranky=0), c=c[0])
ax.plot(months, RO_params_late["normxi_stdac"].isel(ranky=0), c=c[0], ls="--")

ax.plot(months, p0["R"], c=c[1])
ax.plot(months, p1["R"], c=c[1], ls="--")

ax.plot(months, -p0["epsilon"], c=c[2])
ax.plot(months, -p1["epsilon"], c=c[2], ls="--")

ax.plot(months, p0["BJ_ac"], c="k")
ax.plot(months, p1["BJ_ac"], c="k", ls="--")

ax.axhline(0, ls="--", c="k", lw=1)

### $T$, $h$ cross-correlation

Plotting functions

Compute cross-correlation stats

In [None]:
## specify T variable to use
T_var = "T_3"

## function to compute cross-corr
get_xcorr = lambda x: src.XRO.xcorr(x, x[T_var], maxlags=18)

## compute
xcorr_Th_early = get_xcorr(Th_early)
xcorr_Th_late = get_xcorr(Th_late)
xcorr_RO_early = get_xcorr(RO_ensemble_early)
xcorr_RO_late = get_xcorr(RO_ensemble_late)

Make plot

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(7, 4.5), layout="constrained")

plot_xcorr2(axs[:, 0], xcorr_Th_early, label="model")
plot_xcorr2(axs[:, 0], xcorr_RO_early, label="RO")

plot_xcorr2(axs[:, 1], xcorr_Th_late)
plot_xcorr2(axs[:, 1], xcorr_RO_late)

plot_xcorr2(axs[:, 2], xcorr_Th_early, c="k", label="1850-1889")
plot_xcorr2(axs[:, 2], xcorr_Th_late, c="r", label="2051-2100")

## label axes
format_axs_autocorr(axs)
axs[0, 0].legend(prop=dict(size=8))
axs[0, 2].legend(prop=dict(size=8))

for ax in axs[1, :]:
    ax.set_ylim([-1, None])
    ax.set_yticks([-0.8, 0, 0.8])
    ax.axhline(-0.8)

plt.show()

### Power spectrum

Plotting functions

Compute PSD

In [None]:
## specify which variable to use
varname = "T_3"

## specify args for psd
psd_kwargs = dict(dim="time", dt=1 / 12, nw=5)

## compute PSD
compute_psd = lambda x: src.XRO_utils.pmtm(x[varname], **psd_kwargs)
psd_mpi_early = compute_psd(Th_early)
psd_RO_early = compute_psd(RO_ensemble_early)
psd_mpi_late = compute_psd(Th_late)
psd_RO_late = compute_psd(RO_ensemble_late)

Plot PSD

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(7, 2.5), layout="constrained")

## plot early
src.utils.plot_psd(axs[0], psd_mpi_early, label="MPI")
src.utils.plot_psd(axs[0], psd_RO_early, label="RO")

# ## plot late
src.utils.plot_psd(axs[1], psd_mpi_late, label="MPI")
src.utils.plot_psd(axs[1], psd_RO_late, label="RO")

# ## plot late
src.utils.plot_psd(axs[2], psd_mpi_early, label="1850-1889", color="k")
src.utils.plot_psd(axs[2], psd_mpi_late, label="2051-2100", color="r")


## label
axs[0].set_ylabel(r"PSD ($^{\circ}$C$^2$/cpm)")
axs[0].legend(prop=dict(size=6))
axs[2].legend(prop=dict(size=6))
for ax in axs[1:]:
    ax.set_yticks([])


plt.show()