# RO reference
Fit RO on observations to get "ground truth" behavior

## Imports

In [None]:
import warnings
import copy
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import pandas as pd
import os
import scipy.stats

# Import custom modules
import src.XRO
import src.XRO_utils
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## get RNG
rng = np.random.default_rng()

#### Funcs

In [None]:
def get_dCdt(C0_cs):
    """get change in covariance matrix as a func of time"""

    dCdt = copy.deepcopy(C0_cs)
    for t_idx in range(12):

        ## get indices
        t_plus_idx = np.mod(t_idx + 1, 12)
        t_minus_idx = np.mod(t_idx - 1, 12)

        ## get finite differnece
        C_plus = C0_cs.values[..., t_plus_idx]
        C_minus = C0_cs.values[..., t_minus_idx]

        ## update matrix
        dt = 2 * 1 / 12
        dCdt.values[..., t_idx] = 1 / dt * (C_plus - C_minus)

    return dCdt


def get_Qcs(Lcs, C0_cs, dCdt):
    """get change in covariance matrix as a func of time"""

    Qcs = copy.deepcopy(Lcs)
    for t_idx in range(12):

        dCdt_ = dCdt.values[..., t_idx]
        L_ = Lcs.values[..., t_idx]
        C0_ = C0_cs.values[..., t_idx]

        Qcs.values[..., t_idx] = dCdt_ - (L_ @ C0_ + C0_ @ L_.T)

    return Qcs


def SEMIPRO(
    L,
    Cnn,
    dt,
    tf,
    seed,
    IC=rng.normal(size=2),
):
    """
    A 'simplified' version of the PRO model. This model incorporates multiplicative noise
    with decay timescale.
    Args:
        L: time-dependent linear operator.
        Cnn: noise matrix
        dt: timestep (unit: years)
        tf: final time (unit: years)
        IC: initial condition (2x1 array, representing T and h at initial time)
        rng: instance of random number generator
        n_xi: amplitude of noise for fast variable
        r: damping rate for noise
        B: state-dependent component of noise amplitude  (K^1 YEAR^{-1/2})
    """

    ## initialize rng
    rng = np.random.default_rng(seed=seed)

    ## get noise matrix
    M = get_M_np(Cnn)

    ## Specify freq of annual cycle (rad/day)
    ## Get timesteps and dimensions for output
    t = np.arange(0, tf, dt)
    nt = len(t)

    ## Create empty arrays arrays to hold simulation output
    Y = np.zeros([2, nt])  # state vector
    xi = np.zeros([2, nt])  # noise

    ## define functions to compute terms
    dx_dt = lambda x, t: L(t) @ x

    ## initialize for first timestep
    Y[:, 0] = IC
    dW = np.sqrt(dt) * rng.normal(size=2)
    xi[:, 0] = (M @ dW).squeeze()

    ## Integrate
    for i, t_ in enumerate(tqdm.tqdm(t[:-1]), start=0):

        ## compute steps
        dx = dt * L(t_) @ Y[:, i]
        Y[:, i + 1] = Y[:, i] + dx.squeeze() + xi[:, i]

        ## generate next noise term
        dW = np.sqrt(dt) * rng.normal(size=2)
        xi[:, i + 1] = (M @ dW).squeeze()

    ## downsample to monthly
    Y, t = get_monthly(Y, t, dt_years=dt)

    ## remove 30 years of spinup

    return Y[:, 360:], t[360:]


def L(t, F=2):
    """time-varying covariance matrix"""

    ## specify R
    R_cos = np.cos(t * 2 * np.pi)

    ## add to stationary part
    L = np.array([[-0.5 + R_cos, F], [-F, -0.3]])

    return L


## Get monthly averages
def expand_month_dim(X, dt_years):
    """Separate time dimension into month and day of month."""

    ## Get timestep in months
    months_per_year = 12
    dt_months = dt_years * months_per_year

    ## Get number of timesteps in month
    timesteps_per_month = int(1 / dt_months)
    X_reshape = X.reshape(X.shape[0], timesteps_per_month, -1, order="F")

    return X_reshape


def get_monthly(Y, t, dt_years):
    """downsample PRO output to monthly"""

    ## Reshape data
    Y_monthly = expand_month_dim(Y, dt_years=dt_years)
    t_monthly = expand_month_dim(t[None, :], dt_years=dt_years).squeeze()

    ## Get monthly mean output
    Y_monthly = Y_monthly.mean(1)

    ## label with midpoint of month
    dt_years_new = 1 / 12  # new timestep after downsampling
    t_monthly = t_monthly[0] + dt_years_new * 1 / 2

    return Y_monthly, t_monthly

## Load data

In [None]:
## use XRO indices
idx = xr.open_dataset(DATA_FP / "XRO_indices_oras5.nc")

## Fit RO models

In [None]:
## other variables
other_vars = ["Nino34", "WWV", "NPMM", "SPMM"]

## specify order of annual cycle, mask parameters
ac_order = 3
ac_mask_idx = None

## initialize model
model = src.XRO.XRO(ncycle=12, ac_order=ac_order, is_forward=True)

## get fit for reanalysis
fit = model.fit_matrix(idx[["Nino34", "WWV"]], ac_mask_idx=ac_mask_idx)
fit_all = model.fit_matrix(idx[other_vars], ac_mask_idx=ac_mask_idx)

## extract params
p = model.get_RO_parameters(fit)
p_all = model.get_RO_parameters(fit_all)

In [None]:
XY = idx[["Nino34", "WWV"]].to_dataarray().rename({"variable": "ranky"})
C0 = src.XRO.cov_xr(XY).values
L = fit.Lac.mean("cycle").values
Q = -(L @ C0 + C0 @ L.T)

## try a different way
X = XY.values[:, :-1]
Y = XY.values[:, 1:]
G = Y @ np.linalg.pinv(X)
lam, V = np.linalg.eig(G)

## reconstruct L
gamma = 12 * np.log(lam)
L2 = np.real(V @ np.diag(gamma) @ np.linalg.inv(V))
Q2 = -(L2 @ C0 + C0 @ L2.T)

## cyclostationary covariance
C0_cs = XY.groupby("time.month").map(src.XRO.cov_xr)
C0_cs = C0_cs.rename({"month": "cycle"}).assign_coords({"cycle": fit.cycle})
C0_cs = C0_cs.transpose(..., "cycle")

## get change in covariance
dCdt = get_dCdt(C0_cs)

## get cyclostationary noise
Qcs = get_Qcs(fit["Lac"], C0_cs, dCdt)

## get copy of fit with updated noise cov
fit_aug = copy.deepcopy(fit)
fit_aug["xi_covac"] = Qcs.transpose("ranky", "cycle", ...).rename({"rankx": "ranky_"})

#### interlude: noise covariance:

In [None]:
# ## specify noise matrix
# Cnn = np.diag([1, 0.8])

# ## add noise
# Cnn_corr = copy.deepcopy(Cnn)
# rho = -0.5
# cov = rho * np.prod(np.diag(Cnn))
# Cnn_corr[0, 1] = cov
# Cnn_corr[1, 0] = cov

# ## integrate
# kwargs = dict(L=L, dt=1 / (12 * 90), tf=1200, seed=1000)
# Y, t = SEMIPRO(Cnn=Cnn, **kwargs)
# Y_corr, _ = SEMIPRO(Cnn=Cnn_corr, **kwargs)

# ## reshape to (variable, year, month)
# Y = Y.reshape(2, -1, 12)
# Y_corr = Y_corr.reshape(2, -1, 12)

# ## variance
# Y_var = Y.var(1)
# Y_corr_var = Y_corr.var(1)

# fig, ax = plt.subplots(figsize=(3, 2.5))
# ax.plot(Y_var[1])
# ax.plot(Y_corr_var[1])
# ax.set_ylim([0, None])
# plt.show()

## check stats 

Generate simulations

In [None]:
## specify random IC
x0 = idx.isel(time=rng.choice(np.arange(len(idx.time))))

## simulation specs
sim_kwargs = dict(
    nyear=63,
    ncopy=1000,
    is_xi_stdac=True,
    # use_noise_cov=True,
    X0_ds=x0[["Nino34", "WWV"]],
)

## do simulations
X = model.simulate(fit_ds=fit, noise_type="white", **sim_kwargs)
# X1 = model.simulate(fit_ds=fit, noise_type="white", use_noise_cov=True, **sim_kwargs)
# X1 = model.simulate(fit_ds=fit_aug, noise_type="white", use_noise_cov=False, **sim_kwargs)
X_all = model.simulate(fit_ds=fit, noise_type="red", **sim_kwargs)
# X1 = model.simulate(fit_ds=fit_aug, **sim_kwargs)

#### Seasonal synchronization

In [None]:
### Set up plot
fig, axs = plt.subplots(1, 2, figsize=(4.5, 2), layout="constrained")

## plot RO with h (early period)
plot_data_early = src.utils.plot_seasonal_comp(
    axs[0],
    x0=idx.expand_dims("member"),
    x1=X,
    plot_kwargs0=dict(label="ORAS"),
    plot_kwargs1=dict(label="RO"),
    varname="Nino34",
    use_quantile=True,
)

## plot RO with h (early period)
plot_data_early = src.utils.plot_seasonal_comp(
    axs[1],
    x0=idx.expand_dims("member"),
    x1=X_all,
    plot_kwargs0=dict(label="ORAS"),
    plot_kwargs1=dict(label="RO"),
    varname="Nino34",
    use_quantile=True,
)

## label
axs[1].legend(prop=dict(size=8))
axs[0].set_yticks([0, 1, 2])
axs[0].set_ylabel(r"$\sigma(T)$")
axs[0].set_title(r"Niño 3.4, $h=\overline{h}$")
axs[1].set_title(r"Niño 3, $h=h_w$")
for ax in axs:
    ax.set_ylim([0, 2])
    ax.set_xticks([1, 4, 12], labels=["Jan", "Apr", "Dec"])

plt.show()

#### Power spectrum

Compute

In [None]:
## specify args for psd
psd_kwargs = dict(dim="time", dt=1 / 12, nw=5)

## compute PSD
compute_psd = lambda x: src.XRO_utils.pmtm(x, **psd_kwargs)
psd_oras = compute_psd(idx["Nino34"])
psd_RO_h = compute_psd(X["Nino34"])
psd_RO_hw = compute_psd(X_all["Nino34"])

Plot

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6.5, 3), layout="constrained")

## plot data (T34)
src.utils.plot_psd(axs[0], psd_oras, label=r"ORAS", color="k")
src.utils.plot_psd(axs[0], psd_RO_h, label=r"RO", color=sns.color_palette()[1])

## plot data(T3)
src.utils.plot_psd(axs[1], psd_oras, label=r"ORAS", color="k")
src.utils.plot_psd(axs[1], psd_RO_hw, label=r"RO", color=sns.color_palette()[1])

## label
axs[0].set_ylabel(r"PSD ($^{\circ}$C$^2$/cpm)")
axs[0].set_title(r"Niño 3.4 spectrum (using $h=\overline{h}$)")
axs[1].set_title(r"Niño 3 spectrum (using $h=h_w$)")
axs[0].legend(prop=dict(size=6))
axs[1].legend(prop=dict(size=6))
# for ax in axs:
#     ax.set_ylim([None,None])

plt.show()

## Cross-correlation

In [None]:
def format_xcorr_ax(ax):
    """make xcorr plot look nice"""

    axis_kwargs = dict(c="k", lw=0.5, alpha=0.5)
    ax.axhline(0, **axis_kwargs)
    ax.axvline(0, **axis_kwargs)
    ax.set_ylim([-0.9, 1.1])
    ax.set_xlabel("Lag (years)")
    ax.set_xticks([-24, -12, 0, 12, 24], labels=[-2, -1, 0, 1, 2])
    ax.set_yticks([-0.5, 0, 0.5, 1])
    ax.set_ylabel("Correlation")
    return

In [None]:
## specify T variable to use
T_var = "Nino34"

## function to compute cross-corr
get_xcorr = lambda x: src.XRO.xcorr(x, x[T_var], maxlags=18)

## compute
xcorr = get_xcorr(idx)
xcorr0 = get_xcorr(X)
xcorr1 = get_xcorr(X_all)

In [None]:
## compute cross-corr
# xcorr = src.XRO.xcorr(idx, idx["Nino34"], maxlags=24)
# xcorr0 = src.XRO.xcorr(idx, idx["Nino34"], maxlags=24)

## plot result
fig, ax = plt.subplots(figsize=(3.5, 3))

## plot data
ax.plot(xcorr.lag, xcorr["Nino34"], label=r"actual", c="k")
ax.plot(xcorr.lag, xcorr0["Nino34"].mean("member"), label=r"first")
ax.plot(xcorr.lag, xcorr1["Nino34"].mean("member"), label=r"second")

ax.plot(xcorr.lag, xcorr["WWV"], label=r"actual", c="k")
ax.plot(xcorr.lag, xcorr0["WWV"].mean("member"), label=r"first")
ax.plot(xcorr.lag, xcorr1["WWV"].mean("member"), label=r"second")
# ax.plot(xcorr.lag, xcorr["T_34"], label=r"$T_{3.4}$", c="k", ls="--")
# ax.plot(xcorr.lag, xcorr["WWV"], label=r"$h$")
# ax.plot(xcorr.lag, xcorr["h_w"], label=r"$h_w$")

## format plot
ax.set_title("Corr. with Niño 3.4")
ax.legend(prop=dict(size=8))
format_xcorr_ax(ax)

plt.show()