# RO change over time

## Imports

In [None]:
import warnings
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
from climpred import HindcastEnsemble
from dateutil.relativedelta import *
from matplotlib.ticker import AutoMinorLocator
import warnings
import tqdm
import pathlib
import cmocean

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## Load data

In [None]:
## MPI data
mpi_load_fp = pathlib.Path("/Users/theo/research/enso2025_xro/data/mpi_Th/Th.nc")
Th = xr.open_dataset(mpi_load_fp)

## Create ensembles of RO to match MPI

Function to fit RO to ensemble

In [None]:
# def get_ensemble_params(data, T_var="T_3", h_var="h_w", verbose=False):
#     """get RO params for each ensemble member"""

#     ## empty list to hold params
#     params = []

#     ## Loop thru ensemble members
#     for m in tqdm.tqdm(data.member, disable=not (verbose)):

#         ## initialize model
#         model = XRO(ncycle=12, ac_order=1, is_forward=True)

#         ## select ensemble member and variables
#         data_subset = data[[T_var, h_var]].sel(member=m)

#         ## fit model
#         with warnings.catch_warnings(action="ignore"):
#             fit = model.fit_matrix(data_subset, maskNT=[], maskNH=[])

#         ## append to list of parameters
#         params.append(model.get_RO_parameters(fit))

#     return xr.concat(params, dim=data.member)

### Get ensemble of RO models for early/late period
Here we fit a different RO model to each MPI ensemble member (so we have an RO ensemble with equal number of members to MPI). To increase RO ensemble size could (i) randomly draw RO member from ensemble or (ii) estimate covariance of parameters, then randomly draw set of parameters.

In [None]:
## specify variables to use in RO fit order of annual cycle
T_var = "T_3"
h_var = "h_w"
ac_order = 0

## get data for early/late period
Th_early = Th.isel(time=slice(None, 588))
Th_late = Th.isel(time=slice(-588, None))

## specify model to use
model = XRO(ncycle=12, ac_order=ac_order, is_forward=True)

## fit to data
kwargs = dict(model=model, T_var=T_var, h_var=h_var, verbose=True)
model, RO_params_early = src.utils.get_RO_ensemble(Th_early, **kwargs)
_, RO_params_late = src.utils.get_RO_ensemble(Th_late, **kwargs)

In [None]:
def get_member_fit(i):
    XY = np.stack(
        [
            Th_early.isel(member=i)["T_3"].values,
            Th_early.isel(member=i)["h_w"].values,
        ], 
        axis=0
    )
    
    ## feature matrix
    X = XY[:,:-2]
    
    ## compute derivative (dXdt)
    Y = 6 * (XY[:,2:] - XY[:,:-2])
    
    Cxx = src.utils.get_cov(X, X)
    Cyx = src.utils.get_cov(Y, X)
    
    return Cyx @ np.linalg.pinv(Cxx)

fits = np.stack([get_member_fit(i) for i in np.arange(50)])

Get fit all at once

In [None]:
XY = np.stack(
    [
        Th_early["T_3"].values,
        Th_early["h_w"].values,
    ], 
    axis=0
)

## feature matrix
X = XY[..., :-2]

## compute derivative (dXdt)
Y = 6 * (XY[...,2:] - XY[...,:-2])

## reshape
X = X.reshape(2,-1)
Y = Y.reshape(2,-1)

Cxx = src.utils.get_cov(X, X)
Cyx = src.utils.get_cov(Y, X)

fit_all = Cyx @ np.linalg.pinv(Cxx)

fits_bj = fits[:,0,0] - fits[:,1,1]
fit_all_bj = fit_all[0,0] - fit_all[1,1]

In [None]:
pdf, edges = src.utils.get_empirical_pdf(fits_bj, edges = np.arange(.3,1.6,.1))

## plot period
fig,ax = plt.subplots(figsize=(4,3))
ax.stairs(pdf, edges, color="k", label="MPI (PDF)")
ax.axvline(fit_all_bj, c="r")
ax.axvline(fits_bj.mean(), c="r", ls="--")

In [None]:
arr = src.XRO._convert_to_numpy(Th_early.transpose(...,"member")[["T_3","h_w"]])
print(arr[0,:10,0])
print(arr.reshape(2,-1, order="F")[0,:10])

In [None]:
XY = np.stack(
    [
        Th_early.isel(member=0)["T_3"].values,
        Th_early.isel(member=0)["h_w"].values,
    ], 
    axis=0
)

A = get_member_fit(0)
A2 = RO_params_early.isel(member=0,cycle=0).Lac

Yhat2 = A2.values @ X

Yhat = A@X

fig,ax = plt.subplots(figsize=(4,3))
ax.plot(Y[0,:36], label="truth")
ax.plot(Yhat[0,:36], label="prediction")
ax.plot(Yhat2[0,:36], label="prediction2", ls="--")
ax.legend()

### Get ensemble of RO simulations

Function to generate ensemble

In [None]:
def generate_ensemble(params, sampling_type="ensemble_mean", **simulation_kwargs):
    """Generate ensemble of RO simulations given parameters estimated
    from each MPI ensemble member."""

    if sampling_type == "ensemble_mean":
        RO_ensemble = model.simulate(fit_ds=params.mean("member"), **simulation_kwargs)

    else:
        print("Not a valid sampling type")

    return RO_ensemble

Run the simulation

In [None]:
## specify arguments for simulation
simulation_kwargs = dict(
    nyear=49, ncopy=50, seed=1000, X0_ds=Th[[T_var, h_var]].isel(member=0, time=0)
)

## run simulation
RO_ensemble = generate_ensemble(RO_params_early, **simulation_kwargs)

## Check if RO can reproduce MPI stats

### Seasonality

In [None]:
## func to compute std dev as a function of month
get_std = lambda x: x.groupby("time.month").std("time")

## wrapper function to get mean and ± 1 std
get_stats = lambda x: src.utils.get_ensemble_stats(get_std(x[T_var]))

## compute std for each dataset
mpi_std_plot = get_stats(Th_early)
ro_std_plot = get_stats(RO_ensemble)

## months (x-coordinate for plotting
months = np.arange(1, 13)

### Set up plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot MPI ensemble mean
mpi_plot = ax.plot(months, mpi_std_plot.sel(posn="center"), label="MPI")

## plot MPI bounds
kwargs = dict(c=mpi_plot[0].get_color(), ls="--", lw=1)
for bound in ["upper", "lower"]:
    ax.plot(months, mpi_std_plot.sel(posn=bound), **kwargs)

## plot RO ensemble mean
ro_plot = ax.plot(months, ro_std_plot.sel(posn="center"), label="RO")

## plot RO bounds
kwargs = dict(c=ro_plot[0].get_color(), ls="--", lw=1)
for bound in ["upper", "lower"]:
    ax.plot(months, ro_std_plot.sel(posn=bound), **kwargs)

## adjust limits and label
ax.set_ylim([0.4, None])
ax.set_yticks([0.5, 1])
ax.set_xticks([1, 5, 12], labels=["Jan", "May", "Dec"])
ax.set_xlabel("Month")
ax.set_ylabel(f"$\\sigma(T)$")
ax.set_title("Seasonal synchronization (Niño 3)")
ax.legend()
plt.show()

### Power spectrum

### $T$, $h$ cross-correlation

Compute cross-correlation

In [None]:
## compute cross-correlation
xcorr_mpi = xcorr(Th_early, Th_early[T_var], maxlags=18)
xcorr_ro = xcorr(RO_ensemble, RO_ensemble[T_var], maxlags=18)

## compute stats
xcorr_mpi_stats = src.utils.get_ensemble_stats(xcorr_mpi)
xcorr_ro_stats = src.utils.get_ensemble_stats(xcorr_ro)

Function to help with plotting

In [None]:
def plot_xcorr(ax, data, color, label=None):
    """plot mean and bounds for data"""

    ## center
    ax.plot(
        data.lag,
        data.sel(posn="center"),
        c=color,
        label=label,
    )

    ## bounds
    ax.fill_between(
        data.lag, data.sel(posn="upper"), data.sel(posn="lower"), color=color, alpha=0.2
    )

    return

Make plots

In [None]:
## specify plot properties for legend
legend_prop = dict(size=7)

fig, axs = plt.subplots(1, 2, figsize=(6, 2.5), layout="constrained")

## plot <T,T>
axs[0].set_title(r"$<T, T>$")
plot_xcorr(axs[0], xcorr_mpi_stats[T_var], color="k", label="MPI")
plot_xcorr(axs[0], xcorr_ro_stats[T_var], color="r", label="RO")
axs[0].legend(prop=dict(size=7))


## plot <T,h>
axs[1].set_title(r"$<T, h>$")
plot_xcorr(axs[1], xcorr_mpi_stats[h_var], color="k", label="MPI")
plot_xcorr(axs[1], xcorr_ro_stats[h_var], color="r", label="RO")

## clean up axes and label
for ax in axs:
    ax.set_ylim([-0.5, 1.05])
    ax.set_xticks([-12, 0, 12], labels=[-1, 0, 1])
    ax.set_xlabel("Lag (years)")
    axis_kwargs = dict(c="k", lw=0.5, alpha=0.5)
    ax.axhline(0, **axis_kwargs)
    ax.axvline(0, **axis_kwargs)

plt.show()

## Look at parameter change over time

Next, define a function to compute parameter changes over time

In [None]:
def get_ensemble_params_over_time(
    data, T_var="T_3", h_var="h_w", window_size=360, step_size=60
):
    """Get RO params for each ensemble member as a function of time.
    Args:
        window_size: size of sliding window (units: months)
        step_size: how many months to slide the window between each calculation
    """

    ## Get number of timesteps in data
    n = len(data.time)

    ## empty list to hold results and dates
    params_by_year = []
    start_dates = []

    ## loop through rolling windows
    for i in tqdm.tqdm(np.arange(0, n - step_size, step_size)):

        ## make sure there's enough samples for robust estimate
        if (n - i) > (0.7 * window_size):

            ## get subset of data for fitting model
            data_subset = data.isel(time=slice(i, i + window_size))

            ## Get start date for subset
            start_dates.append(data_subset.time.isel(time=0))

            ## get parameter fit
            params_by_year.append(
                get_ensemble_params(
                    data_subset, T_var=T_var, h_var=h_var, verbose=False
                )
            )

    ## convert from list to xarray
    start_dates = xr.concat(start_dates, dim="time")
    params_by_year = xr.concat(params_by_year, dim=start_dates)

    return params_by_year

### Do the computation and save parameters to file

In [None]:
## specify save filepath
save_fp = pathlib.Path("/Users/theo/research/enso2025_xro/results/params.nc")

## Load parameters if already computed
if save_fp.is_file():
    params = xr.open_dataset(save_fp)

else:
    ## compute params
    params = get_ensemble_params_over_time(Th, window_size=360, step_size=60)

    ## save to file
    params.to_netcdf(save_fp)

### Plot diagnostics

#### Ensemble-mean growth rate as a function of seasonal cycle and time

In [None]:
fig, ax = plt.subplots(figsize=(3.5, 3.5))

## plot data
cp = ax.contourf(
    params.cycle, params.time, params["BJ_ac"].mean("member"), cmap="cmo.amp", levels=10
)

## set ticks and add guideline
ax.set_xticks([1, 7, 12], labels=["Jan", "Jul", "Dec"])
ax.axvline(7, c="w", ls="--", lw=1, alpha=0.8)

## add colorbar
cb = fig.colorbar(cp, label=r"Growth rate (yr$^{-1}$)")

## label
ax.set_title("Ensemble-mean growth rate")
ax.set_xlabel("Month")
ax.set_ylabel("Year")

plt.show()

#### Plot annual max growth rate

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))

## plot ensemble members
for i, m in enumerate(params.member.values):

    label = "Ensemble members" if (i == 50) else None

    ax.plot(
        params.time,
        params["BJ_ac"].sel(member=m).max("cycle"),
        c="gray",
        alpha=0.5,
        lw=0.5,
    )

## plot ensemble mean
ax.plot(
    params.time,
    params["BJ_ac"].mean("member").max("cycle"),
    c="k",
    lw=2,
    label="Ensemble mean",
)

## add labels and set plot style
ax.axhline(0, c="k", ls="--", lw=1)
ax.set_ylim([None, 0.5])
ax.set_xlabel("Time")
ax.set_ylabel(r"Growth rate (yr$^{-1}$)")
ax.set_title("Max annual growth rate")
ax.legend()

plt.show()

### Stochastic integration

In [None]:
## Fit model on first 30 years of dataset and last 30 years of dataset


def get_RO_ensemble(data, T_var="T_3", h_var="h_w", verbose=False):
    """get RO params for each ensemble member"""

    ## initialize model
    model = XRO(ncycle=12, ac_order=1, is_forward=True)

    ## empty list to hold model fits
    fits = []

    ## Loop thru ensemble members
    for m in tqdm.tqdm(data.member, disable=not (verbose)):

        ## select ensemble member and variables
        data_subset = data[[T_var, h_var]].sel(member=m)

        ## fit model
        with warnings.catch_warnings(action="ignore"):
            fits.append(model.fit_matrix(data_subset, maskNT=[], maskNH=[]))

    return model, xr.concat(fits, dim=data.member)

### Get ensemble of RO models for early/late period

In [None]:
## get data for early/late period
Th_early = Th.isel(time=slice(None, 600))
Th_late = Th.isel(time=slice(-600, None))

model, RO_params_early = get_RO_ensemble(Th_early, verbose=True)
_, RO_params_late = get_RO_ensemble(Th_late, verbose=True)

### Get ensemble of RO simulations

In [None]:
seed = 1000
RO_ensemble_early = model.simulate(
    fit_ds=RO_params_early.mean("member"),
    X0_ds=Th[["T_3", "h_w"]].isel(time=0, member=0),
    nyear=50,
    ncopy=100,
    noise_type="red",
    seed=seed,
    is_xi_stdac=False,
    xi_B=0.0,
    is_heaviside=True,
)

#### Seasonal synchronization

In [None]:
## specify which variable to plot
var_name = "T_3"

## compute std for each dataset
RO_std = RO_ensemble_early.groupby("time.month").std()
mpi_std = Th_early.groupby("time.month").std()

fig, ax = plt.subplots(figsize=(4, 3))

ax.plot(mpi_std.month, mpi_std[var_name].mean("member"))
ax.plot(
    mpi_std.month,
    mpi_std[var_name].mean("member") + mpi_std[var_name].std("member"),
    c="k",
)
ax.plot(mpi_std.month, RO_std[var_name].mean("member"))
ax.plot(mpi_std.month, RO_std[var_name].mean("member") + RO_std[var_name].std("member"))