In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import warnings
import tqdm
import pathlib
import cmocean
import os
import scipy.stats
import copy

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr
import src.XRO_utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## funcs

In [None]:
def get_cov(X, Y=None):
    """compute covariance of data"""

    ## convert to dataarray
    X_ = X.to_dataarray(dim="ranky")

    if Y is None:
        Y_ = X_.rename({"ranky": "rankx"})
    else:
        Y_ = Y.to_dataarray(dim="rankx")

    ## get XXt
    XXt = xr.cov(X_, Y_, dim=["member", "time"])

    return XXt


def get_Q(L, X):
    """estimate covariance"""

    ## get covariance, and swap coordinates from month to cycle (to match fits)
    cov = X.groupby("time.month").map(get_cov)
    cov = cov.rename({"month": "cycle"}).assign_coords({"cycle": L.cycle})

    ## differentiate WRT time
    dCdt = cov.differentiate("cycle")

    ## get operator
    Lac = copy.deepcopy(L)
    Lac = Lac.assign_coords({"ranky": dCdt.ranky, "rankx": dCdt.rankx})

    ## get components of operator
    L_C = xr.dot(
        Lac.rename({"rankx": "rankz"}), cov.rename({"ranky": "rankz"}), dim="rankz"
    )
    L_C_T = L_C.rename({"rankx": "ranky", "ranky": "rankx"})

    ## get Q
    Q = dCdt - (L_C + L_C_T)

    return Q * 12

## Load data

In [None]:
## open data
Th = src.utils.load_cesm_indices()[["T_3", "h_w"]]

## standardize (for convenience)
Th /= Th.std()

## fit XRO

In [None]:
## specify which variables to use
varnames = ["T_3", "h_w"]

# specify order of annual cycle
ac_order = 4
# specify fit kwargs
fit_kwargs = dict(ac_mask_idx=None, maskNT=[])

# get data for early/late period
Th_early = Th[varnames].sel(time=slice("1851", "1900"))
Th_late = Th[varnames].sel(time=slice("2051", "2100"))

# specify model to use
model = XRO(ncycle=12, ac_order=ac_order, is_forward=True)

# fit models to early/late periods
fit_early = model.fit_matrix(Th_early, **fit_kwargs)
fit_late = model.fit_matrix(Th_late, **fit_kwargs)

## sensitivity test: check center vs. forward differencing
model_c = XRO(ncycle=12, ac_order=ac_order, is_forward=False)
fit_early_c = model_c.fit_matrix(Th_early, **fit_kwargs)
fit_late_c = model_c.fit_matrix(Th_late, **fit_kwargs)

## diagnostics

Seems to be correlated, but off by a scaling factor...

In [None]:
xi_T0 = model.get_RO_parameters(fit_late).xi_T
xi_T1 = model_c.get_RO_parameters(fit_late_c).xi_T

# fits_early =

fig, axs = plt.subplots(1, 2, figsize=(5, 2))

for ax, fits in zip(axs, [[fit_early, fit_early_c], [fit_late, fit_late_c]]):

    xi_T0 = model.get_RO_parameters(fits[0]).xi_T
    xi_T1 = model_c.get_RO_parameters(fits[1]).xi_T

    ax.plot(xi_T0.cycle, xi_T0 * 0.8)
    ax.plot(xi_T0.cycle, xi_T1)

axs[0].set_title("early")
axs[1].set_title("late")
plt.show()

Plot all params

In [None]:
## should we plot early or late period?
USE_EARLY = False

## select fits
if USE_EARLY:
    fits = [fit_early, fit_early_c]
else:
    fits = [fit_late, fit_late_c]

## colormap
colors = sns.color_palette("colorblind")

fig, axs = plt.subplots(1, 6, figsize=(13.5, 2), layout="constrained")

for fits, ls in zip([[fit_early, fit_early_c], [fit_late, fit_late_c]], ["-", "--"]):

    for m, fit, c in zip([model, model_c], fits, colors):

        p = m.get_RO_parameters(fit)

        kwargs = dict(ls=ls, c=c)

        axs[0].plot(fit.cycle * 12, p.R, **kwargs)
        axs[1].plot(fit.cycle * 12, p.epsilon, **kwargs)
        axs[2].plot(fit.cycle * 12, p.BJ_ac, **kwargs)
        axs[3].plot(fit.cycle * 12, p.xi_T, **kwargs)
        axs[4].plot(fit.cycle * 12, p.F1, **kwargs)
        axs[5].plot(fit.cycle * 12, p.F2, **kwargs)


for ax in axs:
    ax.axhline(0, ls="--", c="k", lw=0.6)
    # ax.set_yticks([])

plt.show()

## Stochastic sims

In [None]:
## generate ensemble
## specify arguments for simulation
simulation_kwargs = dict(
    nyear=49,
    ncopy=1000,
    seed=1000,
    X0_ds=Th_early.isel(member=0, time=0),
    is_xi_stdac=True,
    noise_type="white",
    use_noise_cov=False,
)
RO_ensemble_early = model.simulate(fit_ds=fit_early, **simulation_kwargs)
RO_ensemble_late = model.simulate(fit_ds=fit_late, **simulation_kwargs)

## Estimate Q

In [None]:
Q_synth = get_Q(fit_early.Lac, RO_ensemble_early)
Q_early = get_Q(fit_early.Lac, X=Th_early)

## Plot

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(fit_early.xi_covac.isel(ranky=0, ranky_=0))
ax.plot(Q_early.isel(rankx=0, ranky=0))
ax.plot(Q_synth.isel(rankx=0, ranky=0), ls="--")

ax.axhline(Q_synth.mean("cycle").isel(rankx=0, ranky=0), ls="--", c="gray")
ax.axhline(fit_early.xi_covac.isel(ranky=0, ranky_=0).mean("cycle"), ls="--", c="r")
ax.axhline(0, ls="-", c="k", lw=0.8)
plt.show()

## Try estimating $L$ a different way

Get lagged data

In [None]:
X = copy.deepcopy(Th_early).to_dataarray(dim="ranky")
X_plus = X.isel(time=slice(1, None))
X_plus = X_plus.assign_coords({"time": X.time.isel(time=slice(None, -1))})
X_ = X.isel(time=slice(None, -1))

## put in dataset
ds = xr.merge(
    [
        X_.to_dataset("ranky"),
        X_plus.to_dataset("ranky").rename({"T_3": "T_plus", "h_w": "h_plus"}),
    ]
)

Fit

In [None]:
fit_kwargs = dict(x_vars=["T_3", "h_w"], max_order=3)

G_top = src.utils.regress_harm_wrapper(
    ds,
    y_vars=["T_plus"],
    x_vars=["T_3", "h_w"],
    max_order=4,
)
G_bot = src.utils.regress_harm_wrapper(
    ds,
    y_vars=["h_plus"],
    x_vars=["T_3", "h_w"],
    max_order=4,
)
G = xr.merge([G_top, G_bot]).to_dataarray("ranky").rename({"ell": "rankx"})

## put in numpy
G = G.transpose("month", "ranky", "rankx")
G_np = copy.deepcopy(G.values)

## get operator
Lhat = np.nan * xr.zeros_like(G)
for i in range(12):

    w, v = np.linalg.eig(G_np[i])
    Lhat.values[i] = 12 * np.real(v @ np.diag(np.log(w)) @ np.linalg.inv(v))

Get "ground truth" Lac

In [None]:
fit_kwargs = dict(x_vars=["T_3", "h_w"], max_order=4)
Th_early_ = src.utils.get_ddt(Th_early)

R_F1 = src.utils.regress_harm_wrapper(Th_early_, y_vars=["ddt_T_3"], **fit_kwargs)
F2_eps = src.utils.regress_harm_wrapper(Th_early_, y_vars=["ddt_h_w"], **fit_kwargs)
Lac_recon = xr.merge([R_F1, F2_eps]).to_dataarray("ranky").rename({"ell": "rankx"})

Plot comparison

In [None]:
idx = dict(rankx=0, ranky=0)

fig, ax = plt.subplots(figsize=(3, 2.5))
ax.plot(Lac_recon.isel(idx), label="orig")
ax.plot(Lhat.isel(idx), label="green's mat")
ax.legend()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(3, 2.5))
ax.plot(get_Q(Lhat.rename({"month": "cycle"}), RO_ensemble_early).isel(idx) / 12)

ax.plot(Q_synth.isel(idx))
ax.plot(fit_early.xi_covac.isel(ranky=0, ranky_=0))

plt.show()

## how to compute lagged covariance...

In [None]:
def get_cov2(X, Y=None):
    """compute covariance of data"""

    if Y is None:
        Y_ = X.rename({"ranky": "rankx"})
    else:
        Y_ = Y.rename({"ranky": "rankx"})

    ## get XXt
    XXt = xr.cov(X, Y_, dim=["member", "time"])

    return XXt


def get_cov_helper(z):
    return get_cov2(z["X_"], z["X_plus"])


## get covariance by month
X_Xplus = xr.merge([X_.rename("X_"), X_plus.rename("X_plus")])
cov_bymonth = X_Xplus.groupby("time.month").map(get_cov_helper).mean("month")