# RO reference
Fit RO on observations to get "ground truth" behavior

## Imports

In [None]:
import warnings
import copy
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import pandas as pd
import os
import scipy.stats

# Import custom modules
import src.XRO
import src.XRO_utils
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## get RNG
rng = np.random.default_rng()

## Load data

In [None]:
## use XRO indices
idx = xr.open_dataset(DATA_FP / "XRO_indices_oras5.nc")

## Fit RO models

In [None]:
## other variables
other_vars = ["Nino34", "WWV", "NPMM", "SPMM"]

## specify order of annual cycle, mask parameters
ac_order = 3
ac_mask_idx = None

## initialize model
model = src.XRO.XRO(ncycle=12, ac_order=ac_order, is_forward=True)

## get fit for reanalysis
fit = model.fit_matrix(idx[["Nino34", "WWV"]], ac_mask_idx=ac_mask_idx)
fit_all = model.fit_matrix(idx[other_vars], ac_mask_idx=ac_mask_idx)

## extract params
p = model.get_RO_parameters(fit)
p_all = model.get_RO_parameters(fit_all)

#### interlude: noise covariance:

In [None]:
def cov_xr(data, outer_dim = "ranky", cov_dim="time"):
    """compute covariance of 2-D data along specified dimension"""

    ## get number of samples
    n = len(data[cov_dim])

    ## remove mean
    X = data - data.mean(cov_dim)
    Xt = X.rename({outer_dim:f"{outer_dim}_"})

    ## outer product
    XXt = (X * Xt).sum(cov_dim)

    return 1/n * XXt

def get_M_helper(Cxx, dim = "ranky"):
    """Find matrix M satsifying M @ M.T = Cxx"""

    ## get empty array to hold M
    M = np.nan * xr.zeros_like(Cxx)

    ## compute svd
    U, s, _ = np.linalg.svd(Cxx.values)

    ## get M
    M.values = U @ np.diag(np.sqrt(s))

    return M

def get_M(data, outer_dim = "ranky", cov_dim="time"):
    """Find matrix M satsifying M @ M.T = Cxx,
    where Cxx is covariance of data"""

    ## compute covariance
    Cxx = cov_xr(data, outer_dim=outer_dim, cov_dim=cov_dim)

    ## compute M
    M = get_M_helper(Cxx, dim=outer_dim)

    return M

test it works

In [None]:
## get data
X_ = copy.deepcopy(fit.X)
X_ = X_.assign_coords(dict(cycle = ("time", np.mod(np.arange(X_.shape[1]), 12))))

## kwargs for cov
kwargs = dict(outer_dim="rankx")

## covariance by month
Cxx = X_.groupby("cycle").map(lambda x : cov_xr(x, **kwargs))
M = X_.groupby("cycle").map(lambda x : get_M(x, **kwargs))

## check it works
mi = 7
M_ = M.isel(cycle=mi).values
Cxx_ = Cxx.isel(cycle=mi).values
print(np.allclose( M_@M_.T, Cxx_))

In [None]:
Cxx[[

In [None]:
src.XRO.gen_noise(Cxx.values[...,0])

## check stats 

Generate simulations

In [None]:
## specify random IC
x0 = idx.isel(time=rng.choice(np.arange(len(idx.time))))

## simulation specs
simulation_kwargs = dict(
    nyear=63,
    ncopy=1000,
    is_xi_stdac=True,
)

## do simulations
kwargs_h = dict(simulation_kwargs, fit_ds=fit, X0_ds=x0[["Nino34", "WWV"]])
X = model.simulate(**kwargs_h)

kwargs_hw = dict(simulation_kwargs, fit_ds=fit_all, X0_ds=x0[other_vars])
X_all = model.simulate(**kwargs_hw)

#### Seasonal synchronization

In [None]:
### Set up plot
fig, axs = plt.subplots(1, 2, figsize=(4.5, 2), layout="constrained")

## plot RO with h (early period)
plot_data_early = src.utils.plot_seasonal_comp(
    axs[0],
    x0=idx.expand_dims("member"),
    x1=X,
    plot_kwargs0=dict(label="ORAS"),
    plot_kwargs1=dict(label="RO"),
    varname="Nino34",
    use_quantile=True,
)

## plot RO with h (early period)
plot_data_early = src.utils.plot_seasonal_comp(
    axs[1],
    x0=idx.expand_dims("member"),
    x1=X_all,
    plot_kwargs0=dict(label="ORAS"),
    plot_kwargs1=dict(label="RO"),
    varname="Nino34",
    use_quantile=True,
)

## label
axs[1].legend(prop=dict(size=8))
axs[0].set_yticks([0, 1, 2])
axs[0].set_ylabel(r"$\sigma(T)$")
axs[0].set_title(r"Niño 3.4, $h=\overline{h}$")
axs[1].set_title(r"Niño 3, $h=h_w$")
for ax in axs:
    ax.set_ylim([0, 2])
    ax.set_xticks([1, 4, 12], labels=["Jan", "Apr", "Dec"])

plt.show()

#### Power spectrum

Compute

In [None]:
## specify args for psd
psd_kwargs = dict(dim="time", dt=1 / 12, nw=5)

## compute PSD
compute_psd = lambda x: src.XRO_utils.pmtm(x, **psd_kwargs)
psd_oras = compute_psd(idx["Nino34"])
psd_RO_h = compute_psd(X["Nino34"])
psd_RO_hw = compute_psd(X_all["Nino34"])

Plot

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6.5, 3), layout="constrained")

## plot data (T34)
src.utils.plot_psd(axs[0], psd_oras, label=r"ORAS", color="k")
src.utils.plot_psd(axs[0], psd_RO_h, label=r"RO", color=sns.color_palette()[1])

## plot data(T3)
src.utils.plot_psd(axs[1], psd_oras, label=r"ORAS", color="k")
src.utils.plot_psd(axs[1], psd_RO_hw, label=r"RO", color=sns.color_palette()[1])

## label
axs[0].set_ylabel(r"PSD ($^{\circ}$C$^2$/cpm)")
axs[0].set_title(r"Niño 3.4 spectrum (using $h=\overline{h}$)")
axs[1].set_title(r"Niño 3 spectrum (using $h=h_w$)")
axs[0].legend(prop=dict(size=6))
axs[1].legend(prop=dict(size=6))
# for ax in axs:
#     ax.set_ylim([None,None])

plt.show()

In [None]:
plt.plot(fit["normxi_stdac"].isel(ranky=0))
plt.plot(fit_all["normxi_stdac"].isel(ranky=0))

## Cross-correlation

In [None]:
def format_xcorr_ax(ax):
    """make xcorr plot look nice"""

    axis_kwargs = dict(c="k", lw=0.5, alpha=0.5)
    ax.axhline(0, **axis_kwargs)
    ax.axvline(0, **axis_kwargs)
    ax.set_ylim([-0.9, 1.1])
    ax.set_xlabel("Lag (years)")
    ax.set_xticks([-24, -12, 0, 12, 24], labels=[-2, -1, 0, 1, 2])
    ax.set_yticks([-0.5, 0, 0.5, 1])
    ax.set_ylabel("Correlation")
    return

In [None]:
## specify T variable to use
T_var = "Nino34"

## function to compute cross-corr
get_xcorr = lambda x: src.XRO.xcorr(x, x[T_var], maxlags=18)

## compute
xcorr = get_xcorr(idx)
xcorr0 = get_xcorr(X)
xcorr1 = get_xcorr(X_all)

In [None]:
## compute cross-corr
# xcorr = src.XRO.xcorr(idx, idx["Nino34"], maxlags=24)
# xcorr0 = src.XRO.xcorr(idx, idx["Nino34"], maxlags=24)

## plot result
fig, ax = plt.subplots(figsize=(3.5, 3))

## plot data
ax.plot(xcorr.lag, xcorr["Nino34"], label=r"actual", c="k")
ax.plot(xcorr.lag, xcorr0["Nino34"].mean("member"), label=r"first")
ax.plot(xcorr.lag, xcorr1["Nino34"].mean("member"), label=r"second")

ax.plot(xcorr.lag, xcorr["WWV"], label=r"actual", c="k")
ax.plot(xcorr.lag, xcorr0["WWV"].mean("member"), label=r"first")
ax.plot(xcorr.lag, xcorr1["WWV"].mean("member"), label=r"second")
# ax.plot(xcorr.lag, xcorr["T_34"], label=r"$T_{3.4}$", c="k", ls="--")
# ax.plot(xcorr.lag, xcorr["WWV"], label=r"$h$")
# ax.plot(xcorr.lag, xcorr["h_w"], label=r"$h_w$")

## format plot
ax.set_title("Corr. with Niño 3.4")
ax.legend(prop=dict(size=8))
format_xcorr_ax(ax)

plt.show()