# RO change over time

## Imports

In [None]:
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import src.XRO
import copy
import scipy.stats
import warnings
import calendar

# import gsw

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Functions

In [None]:
def get_ensemble_fits_over_time(
    data,
    model,
    ac_mask_idx,
    mask_NT,
    window_size=480,
    step_size=60,
    by_ensemble_member=True,
    remove_T_dependence=False,
):
    """Get RO fits for each ensemble member as a function of time.
    Args:
        window_size: size of sliding window (units: months)
        step_size: how many months to slide the window between each calculation
    """

    ## Get number of timesteps in data
    n = len(data.time)

    ## empty list to hold results and dates
    fits_by_year = []
    start_dates = []

    ## loop through rolling windows
    for j, i in enumerate(tqdm.tqdm(np.arange(0, n - step_size, step_size))):

        ## make sure there's enough samples for robust estimate
        if (n - i) > (0.7 * window_size):

            ## get subset of data for fitting model
            data_subset = data.isel(time=slice(i, i + window_size))

            ## Get start date for subset
            start_dates.append(data_subset.time.isel(time=0))

            ## remove h's linear dependence on T, if desired
            if remove_T_dependence:
                T_var, h_var = list(data_subset)
                data_subset[h_var] = src.utils.remove_sst_dependence_v2(
                    data_subset, h_var=h_var, T_var=T_var
                )

            ## get parameter fit
            if by_ensemble_member:
                _, fits = src.utils.get_RO_ensemble(
                    data_subset,
                    model=model,
                    ac_mask_idx=ac_mask_idx,
                )

            else:
                with warnings.catch_warnings(action="ignore"):
                    fits = model.fit_matrix(
                        data_subset,
                        ac_mask_idx=ac_mask_idx,
                        maskNT=mask_NT,
                    )

            ## drop X,Y,time variables
            fits = fits.drop_vars(["X", "Y", "Yfit", "time"])
            fits_by_year.append(fits)

    ## convert from list to xarray
    start_dates = xr.concat(start_dates, dim="time")
    fits_by_year = xr.concat(fits_by_year, dim=start_dates)

    return fits_by_year


def update_time_coord(data, window_size=480):
    """change time coordinate to center of period used for computing stats"""

    ## get year corresponding to start of period
    year = copy.deepcopy(data.time.dt.year.values)

    ## update to midpoint of period
    year_shift = int(window_size / 12 / 2)
    year += year_shift

    ## update coordinate
    data = data.assign_coords({"time": year}).rename({"time": "year"})

    ## trim to end in 2080
    data = data.sel(year=slice(None, 2082))

    return data


def get_params(fits, model):
    """Get parameters from fits dataarray"""

    ## get parameters from fits
    params = model.get_RO_parameters(fits)

    ## get normalized noise stats
    fix_coords = lambda x: x.assign_coords({"cycle": params.cycle})
    params["xi_T_norm"] = fix_coords(fits["normxi_stdac"].isel(ranky=0))
    params["xi_h_norm"] = fix_coords(fits["normxi_stdac"].isel(ranky=1))

    ## get wyrtki index
    sign = np.sign(params["F1"] * params["F2"])
    params["wyrtki"] = sign * np.sqrt(np.abs(params["F1"] * params["F2"]))

    return params.squeeze()


def get_rolling_std(data, n=20):
    """
    Get standard deviation, computing over time and ensemble member. To increase
    sample size for variance estimate, compute over time window of 2n+1
    years, centered at given year.
    """

    ## do the computation
    kwargs = dict(fn=np.std, n=n, reduce_ensemble_dim=False)
    data_std = src.utils.get_rolling_fn_bymonth(data, **kwargs)

    ## unstack year and month
    data_std = src.utils.unstack_month_and_year(data_std)

    return data_std

## Load data

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## omit first year (bc of NaN in h,hw vars)
Th = Th.isel(time=slice(12, None))

## standardize (for convenience)
Th /= Th.std()

## Compute time-varying RO parameters

In [None]:
## specify variables names
varnames = ["T_3", "h_w"]

## specify: should we remove h's linear dependence on T?
REMOVE_T_DEP = False

## specify save filepath
if REMOVE_T_DEP:
    SAVE_FP = pathlib.Path(os.environ["SAVE_FP"], "fits_h_hat.nc")
else:
    # SAVE_FP = pathlib.Path(os.environ["SAVE_FP"], "fits.nc")
    SAVE_FP = pathlib.Path(os.environ["SAVE_FP"], "fits_test.nc")

# ## specify args for model fit
kwargs = dict(
    model=src.XRO.XRO(ncycle=12, ac_order=3, is_forward=True),
    ac_mask_idx=None,
    window_size=480,
    step_size=120,
    mask_NT=[],
    remove_T_dependence=REMOVE_T_DEP,
)

## try to load pre-computed
if SAVE_FP.is_file():
    fits = xr.open_dataset(SAVE_FP)

else:

    ## Get fits
    fits = get_ensemble_fits_over_time(Th[varnames], by_ensemble_member=False, **kwargs)

    ## expand "ensemble dim" to match rest of script
    fits = update_time_coord(
        fits.expand_dims("member"), window_size=kwargs["window_size"]
    )

    # fits.to_netcdf(SAVE_FP)

## extract parameters
params = get_params(fits=fits, model=kwargs["model"])

## get change from initial period
delta_params = params - params.isel(year=0)

## Validate changes in variance over time

### Estimate variance of RO model over time

In [None]:
def get_RO_sigma(model, params, **simulation_kwargs):
    """Compute stats (e.g., standard deviation) for RO parameters over time"""

    output = model.simulate(fit_ds=params.mean("member"), **simulation_kwargs)

    return output.groupby("time.month").std()


def get_RO_sigma_over_time(model, params, **simulation_kwargs):
    """Compute stats over time"""

    ## empty list to hold result
    sigmas = []

    ## loop through years
    for y in tqdm.tqdm(params.year):

        ## specs for simulation
        kwargs = dict(
            simulation_kwargs,
            model=model,
            params=params.sel(year=y),
        )

        ## do the simulation
        sigmas.append(get_RO_sigma(**kwargs))

    ## put back in xarray
    sigmas = xr.concat(sigmas, dim=params.year)

    return sigmas

In [None]:
## simulation specs
simulation_kwargs = dict(
    nyear=40,
    ncopy=50,
    seed=1000,
    X0_ds=Th[varnames].isel(member=0, time=0),
    noise_type="white",
)

## compute with parameters estimated from all ensemble members
RO_sigma_over_time_v2 = get_RO_sigma_over_time(
    model=kwargs["model"], params=fits, **simulation_kwargs
)

### estimate variance of CESM over time

In [None]:
## compute rolling std
Th_std = get_rolling_std(Th, n=20)

## compute percentage change in std
baseline = Th_std.isel(year=0).mean("member")
delta_Th_std = 100 * (Th_std - baseline) / baseline

### Compare model and RO

Function to plot results

In [None]:
def plot_stats_comp(ax, list_of_stats, labels, colors=None, n=varnames[0]):
    """plot comparison of variance over time"""

    if colors is None:
        colors = sns.color_palette()[: len(list_of_stats)]

    for stats, label, c in zip(list_of_stats, labels, colors):

        ## plot median
        mplot = ax.plot(stats.year, stats[n].sel(q=0.5), lw=2.5, label=label, c=c)

        ## plot lower/upper quantiles
        kwargs = dict(c=mplot[0].get_color(), lw=0.8)
        for q in stats.q:
            if q != 0.5:
                ax.plot(stats.year, stats[n].sel(q=q), **kwargs)

    ## label and set plotting specs
    ax.set_xlabel("Year")
    ax.set_ylabel(r"$\sigma_T$ ($^{\circ}$C)")
    ax.set_ylim([0.3, 1.7])
    ax.set_xticks([1870, 1975, 2080])
    ax.set_yticks([0.6, 1.2])

    return

Make the plot

In [None]:
## specify function to reduce over months
sel_fn = lambda x, m: x.sel(month=m)

## specify function to compute bounds
get_stats = (
    lambda x, m: sel_fn(x, m)
    .quantile(q=[0.1, 0.5, 0.9], dim="member")
    .rename({"quantile": "q"})
)

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

for i, (ax, m) in enumerate(zip(axs, [2, 5, 8, 11])):

    ## compute stats
    stats_mpi = get_stats(Th_std, m=m)
    stats_ro_v2 = get_stats(RO_sigma_over_time_v2, m=m)

    ## first, MPI vs RO
    plot_stats_comp(
        ax,
        [stats_mpi, stats_ro_v2],
        labels=["CESM", "RO"],
        colors=["k", sns.color_palette()[1]],
        n=varnames[0],
    )

    ax.set_title(calendar.month_name[m])
    if i > 0:
        ax.set_yticks([])
        ax.set_ylabel(None)


axs[1].legend(prop=dict(size=8))

plt.show()

## Plot diagnostics

### Snapshots of parameters over time

In [None]:
def format_params_line(axs):
    """format line plots of parameters"""
    for ax in axs:

        ax.axhline(0, ls="--", c="k", lw=0.8)
        ax.set_xticks([2, 11])

    return

In [None]:
## specify colormap and norm
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

## specify years to plot
YEARS = [1871, 1981, 2021, 2081]

## Plot variance, noise, bjerknes index, and period
fig, axs = plt.subplots(1, 4, figsize=(8, 2), layout="constrained")

## plot data
for i, y in enumerate(YEARS):

    ## variance
    axs[0].plot(
        Th_std.month,
        Th_std[varnames[0]].sel(year=y, method="nearest").mean("member"),
        c=CMAP(CMAP_NORM(i)),
    )

    ## noise/BJ
    for ax, p in zip(axs[1:], ["xi_T_norm", "BJ_ac", "wyrtki"]):
        ax.plot(params.cycle, params[p].sel(year=y), c=CMAP(CMAP_NORM(i)))
        ax.set_title(p)

## formatting/label
format_params_line(axs)
axs[0].set_title(r"$\sigma(T)$")

plt.show()

## same, but for individual RO parameters
fig, axs = plt.subplots(1, 4, figsize=(8, 2), layout="constrained")

## loop thru parameters and years
for ax, p in zip(axs, ["R", "epsilon", "F1", "F2"]):
    for i, y in enumerate(YEARS):

        ## plot data
        ax.plot(params.cycle, params[p].sel(year=y), c=CMAP(CMAP_NORM(i)))

        ## formatting
        ax.set_title(p)

## formatting
format_params_line(axs)
plt.show()

### Hovmoller plots for variance, growth rate, and noise

In [None]:
def format_params_hov(axs):
    """format hovmoller axes"""

    for ax in axs:
        ax.set_xticks([1, 7, 12], labels=["Jan", "Jul", "Dec"])
        ax.axvline(7, c="w", ls="--", lw=1, alpha=0.8)
        ax.set_xlabel("Month")
        ax.axhline(2025, c="w", ls="--", lw=1)

    axs[0].set_ylabel("Year")
    axs[0].set_yticks([1870, 1975, 2080])

    for ax in axs[1:]:
        ax.set_yticks([])
        ax.set_ylim(axs[0].get_ylim())

    return

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(8, 3.5), layout="constrained")

#### plot change in std dev
plot_kwargs = dict(
    cmap="cmo.balance", levels=src.utils.make_cb_range(30, 6), extend="both"
)

## plot data
cp0 = axs[0].contourf(
    delta_Th_std.month,
    delta_Th_std.year,
    delta_Th_std[varnames[0]].mean("member"),
    **plot_kwargs
)

##### plot change in model params

## specify plotting specs
plot_kwargs = dict(
    cmap="cmo.balance", levels=src.utils.make_cb_range(2, 0.2), extend="both"
)

## plot data
cp1 = axs[1].contourf(
    params.cycle, params.year, delta_params["xi_T_norm"], **plot_kwargs
)
cp2 = axs[2].contourf(params.cycle, params.year, delta_params["BJ_ac"], **plot_kwargs)
cp3 = axs[3].contourf(params.cycle, params.year, delta_params["wyrtki"], **plot_kwargs)

## add colorbar
cb0 = fig.colorbar(cp0, label=r"% change", ticks=[-30, 0, 30])
cb1 = fig.colorbar(cp3, label=r"yr$^{-1}$", ticks=[-2, 0, 2])

## label])
axs[0].set_title(r"$\frac{\Delta \sigma(T)}{\sigma(T)_{1870}}$", size=10)
axs[1].set_title(r"$\Delta\left(\frac{\text{Noise}}{\sigma(T)}\right)$", size=10)
axs[2].set_title(r"$\Delta$ BJ", size=10)
axs[3].set_title(r"$\Delta$ Wyrtki", size=10)

## formatting
format_params_hov(axs)

plt.show()

### Same, but for RO parameters

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(7, 3.5), layout="constrained")


##### plot change in model params

## specify plotting specs
plot_kwargs = dict(
    cmap="cmo.balance", levels=src.utils.make_cb_range(2, 0.2), extend="both"
)

## plot data
cp0 = axs[0].contourf(params.cycle, params.year, delta_params["R"], **plot_kwargs)
cp1 = axs[1].contourf(params.cycle, params.year, delta_params["epsilon"], **plot_kwargs)
cp2 = axs[2].contourf(params.cycle, params.year, delta_params["F1"], **plot_kwargs)
cp3 = axs[3].contourf(params.cycle, params.year, delta_params["F2"], **plot_kwargs)

## add colorbar
cb1 = fig.colorbar(cp3, ticks=[-2, 0, 2])

## label
axs[0].set_ylabel("Year")
axs[0].set_yticks([1870, 1975, 2080])
axs[0].set_title(r"$\Delta R$", size=10)
axs[1].set_title(r"$\Delta \epsilon$", size=10)
axs[2].set_title(r"$\Delta F_1$", size=10)
axs[3].set_title(r"$\Delta F_2$", size=10)

## formatting
format_params_hov(axs)

plt.show()

### Change in annual-mean parameters over time

Plotting funcs

In [None]:
def plot_curve(ax, x, **plot_kwargs):
    """plot change in parameter over time on given ax"""
    plot_data = ax.plot(x.year, x, lw=2, **plot_kwargs)

    return plot_data


def format_pot_ax(ax):
    """add formatting to parameter-over-time plot"""
    ax.set_xlim([None, None])
    ax.legend(prop=dict(size=6), loc="upper left")
    ax.set_xticks([1870, 1975, 2080])
    ax.set_yticks([-0.5, 0, 0.5])
    ax.set_ylabel(r"year$^{-1}$")
    ax.set_xlabel("Year")
    ax.set_title(r"$\Delta \left(\text{RO parameters}\right)$")
    ax.axhline(0, c="k", ls="-", lw=0.5)
    return


def plot_pot0(ax, dp):
    """Plot change in parameters over time on ax object. 'dp' is change in params"""

    ## Plot Bjerknes growth rate
    plot_curve(ax, dp["BJ_ac"].mean("cycle"), c="k", label=r"BJ")
    plot_curve(ax, dp["wyrtki"].mean("cycle"), c="k", ls="--", label=r"wyrtki")

    ## plot noise
    plot_curve(
        ax,
        dp["xi_T_norm"].mean("cycle"),
        label=r"$\xi_T/\sigma_T$",
        c="darkgray",
    )
    plot_curve(
        ax, dp["xi_h_norm"].mean("cycle"), label=r"$\xi_h/\sigma_h$", c="lightgray"
    )

    ## set axis specs
    format_pot_ax(ax)

    return


def plot_pot1(ax, dp):
    """Plot change in parameters over time on ax object. 'dp' is change in params"""

    # Plot R, epsilon
    plot_curve(ax, dp["R"].mean("cycle"), label=r"$\overline{R}$")
    plot_curve(ax, dp["epsilon"].mean("cycle"), label=r"$\varepsilon$")

    # Plot F1, F2
    plot_curve(ax, dp["F1"].mean("cycle"), label=r"$F_1$", ls="--")
    plot_curve(ax, dp["F2"].mean("cycle"), label=r"$F_2$", ls="--")

    ## set axis specs
    format_pot_ax(ax)

    return

Make plot

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6.5, 3))

## plot change over time of parameters
plot_pot0(axs[0], dp=delta_params)
plot_pot1(axs[1], dp=delta_params)
axs[1].set_yticks([])
axs[1].set_ylabel(None)

plt.show()

## Sensitivity tests

### Get perturbed params

In [None]:
def get_perturbed_multi(params, idxs, fix_others=False, fix_noise=False):
    """
    Fix values of R parameter set.
    if 'fix_others' is True, then other parameters are fixed to their
    initial value. Otherwise, given parameter is fixed to its initial value
    """

    ## unfold indices

    ## initialize empty array to hold parameters
    params_new = copy.deepcopy(params)
    params_new["Lac"] = params_new["Lac"].transpose("year", "ranky", "rankx", ...)

    ## get numpy version of linear operator
    Lac = params_new["Lac"].values

    if fix_others:

        ## get copies of 'perturbed' parameters
        pparams = [copy.deepcopy(Lac[:, y_i, x_i]) for (y_i, x_i) in idxs]

        ## fix all parameters to initial value
        Lac = Lac[:1] * np.ones_like(Lac)

        ## add back perturbed parameters
        for i, (y_i, x_i) in enumerate(idxs):
            Lac[:, y_i, x_i] = pparams[i]

    else:

        ## update Lac
        for y_i, x_i in idxs:
            Lac[:, y_i, x_i] = Lac[:1, y_i, x_i]

    ## add back to parameters
    params_new["Lac"] = xr.ones_like(params_new["Lac"]) * Lac

    ## fix noise if necessary
    if fix_noise:
        params_new = get_perturbed_noise(params_new, fix_others=False)

    return params_new


def get_perturbed_R_cyc(params):
    """fix values of R parameter set"""

    ## initialize empty array to hold parameters
    pparams = copy.deepcopy(params)
    pparams["Lac"] = pparams["Lac"].transpose("year", "ranky", "rankx", ..., "cycle")

    ## get copy of linear operator
    Lac = pparams["Lac"].values

    ## update
    R = Lac[:, 0, 0]
    R_mean = R.mean(-1, keepdims=True)
    Lac[:, 0, 0] = R * R_mean[1] / R_mean

    ## update Lac
    pparams["Lac"] = xr.ones_like(pparams["Lac"]) * Lac

    return pparams


def get_perturbed_noise_helper(params, name):
    """get version of parameters where specified parameter
    is fixed to its starting value"""

    ## copy of params to hold perturbed values
    pparams = copy.deepcopy(params)

    ## get initial value of parameter and broadcast it to correct shape
    x0 = params[name].isel(year=1) * xr.ones_like(params[name])

    ## transpose dims to make sure they match
    x0 = x0.transpose(*params[name].dims)

    ## update parameters
    pparams[name].values = x0

    return pparams


def get_perturbed_noise(params, fix_others=False):
    """fix values of noise in parameter set"""

    if fix_others:

        ## copy parameters
        pparams1 = copy.deepcopy(params)

        ## get Lac
        pparams1["Lac"] = pparams1["Lac"].transpose("year", "ranky", "rankx", ...)

        ## get numpy version of linear operator
        Lac = pparams1["Lac"].values

        ## fix all parameters to initial value
        Lac = Lac[:1] * np.ones_like(Lac)

        ## add back to parameters
        pparams1["Lac"] = xr.ones_like(pparams1["Lac"]) * Lac

    else:

        ## fix noise for ac version
        pparams0 = get_perturbed_noise_helper(params, "xi_stdac")

        ## fix annual mean value
        pparams1 = get_perturbed_noise_helper(pparams0, "xi_std")

    return pparams1


def get_perturbed_xi(params, ranky):
    """fix values of R and epsilon in parameter set"""

    ## initialize empty array to hold parameters
    pparams = copy.deepcopy(params)

    ## get copy of linear operator
    for n in ["xi_std", "xi_stdac"]:
        xi = pparams[n].transpose("year", "ranky", ...)
        xi0 = xi.values[1:2, ranky]

        ## update matrix
        xi.values[:, ranky] = xi0
        pparams[n] = xi

    return pparams

In [None]:
## should we fix the given parameter? or all others?
FIX_OTHERS = True
FIX_NOISE = True

## fixed BJ/Wyrtki
pparam_kwargs = dict(params=fits, fix_others=FIX_OTHERS, fix_noise=FIX_NOISE)
param_set_dict = {
    "control": fits,
    "noise": get_perturbed_noise(fits, fix_others=FIX_OTHERS),
    "BJ": get_perturbed_multi(idxs=[(0, 0), (1, 1)], **pparam_kwargs),
    "Wyrtki": get_perturbed_multi(idxs=[(1, 0), (0, 1)], **pparam_kwargs),
    "test0": get_perturbed_multi(
        idxs=[(0, 0), (1, 1)], **dict(pparam_kwargs, fix_noise=False)
    ),
    "test1": get_perturbed_multi(
        idxs=[(0, 0), (1, 1), (1, 0)], **dict(pparam_kwargs, fix_noise=False)
    ),
}

# ## get list of param sets and labels
param_sets = list(param_set_dict.values())
labels = list(param_set_dict.keys())

### plot perturbed params

In [None]:
def plot_param_set(ax, params, model):
    """plot parameter set over time for given experiment"""

    ## get named named params (nnual mean)
    params_ = model.get_RO_parameters(params).mean(["member", "cycle"])

    ## plot core params
    for p in ["R", "epsilon", "F1", "F2"]:
        plot = ax.plot(params_.year, params_[p], label=p)
        ax.axhline(params_[p].isel(year=0), ls="--", c=plot[0].get_color(), lw=0.5)

    ## plot noise
    for p, c in zip(["xi_T", "xi_h"], ["gray", "lightgray"]):
        ax.plot(params_.year, params_[p], label=p, c=c)
        ax.axhline(params_[p].isel(year=0), ls="--", c=c, lw=0.5)

    ## format ax
    ax.set_xticks([1870, 2070])

    return

In [None]:
fig, axs = plt.subplots(1, len(labels), figsize=(1.6 * len(labels), 1.5))
for ax, param_set, label in zip(axs, param_sets, labels):
    plot_param_set(ax, param_set, model=kwargs["model"])
    ax.set_title(label)

## formatting
for ax in axs[1:]:
    ax.set_yticks([])

## legend
axs[-1].legend(loc=(1.3, 0.1), prop=dict(size=8))

plt.show()

### Run simulations

In [None]:
## compute RO sigma over time for each experiment
exp_kwargs = dict(**simulation_kwargs, model=kwargs["model"])
RO_sigma_exp = [get_RO_sigma_over_time(params=p, **exp_kwargs) for p in param_sets[1:]]

### Compute stats

In [None]:
def get_stats(x):
    """helper function to compute plotting bounds for experiment"""
    stats = x.quantile(q=[0.1, 0.5, 0.9], dim="member")
    return stats.rename({"quantile": "q"})

In [None]:
## specify sum idxs (or set to None)
SUM_IDXS = None

## Get stats
stats_control = get_stats(RO_sigma_over_time_v2)
stats_exp = [get_stats(x) for x in RO_sigma_exp]

## Get linearized sum of subset
if SUM_IDXS is not None:
    n = len(stats_exp)
    x0 = stats_control[varnames[0]].isel(year=0)
    stats_exp.append(x0 + sum([stats_exp[i] - x0 for i in SUM_IDXS]))

    if len(labels) == (len(stats_exp)):
        labels.append("sum")

### Plot results

In [None]:
def format_row(axs, y0):
    """format row of plot in comparison"""

    ## label
    axs[0].set_yticks([np.round(sigma0, 1), np.round(sigma0 + 0.5, 1)])
    for ax in axs[1:]:
        ax.set_ylabel(None)
    axs[-1].set_ylabel(calendar.month_abbr[m])
    axs[-1].yaxis.set_label_position("right")
    for ax in axs:
        ax.axhline(y0, c="gray", lw=1, ls="--")
        ax.set_ylim([y0 - 0.3, y0 + 0.7])

    return


def format_subplots(axs, labels):
    """format all subplots"""
    for ax in axs[:-1].flatten():
        ax.set_xticks([])
        ax.set_xlabel(None)

    for ax in axs[:, 1:].flatten():
        ax.set_yticks([])

    for j in range(axs.shape[1]):
        axs[0, j].set_title(labels[j])

    return

In [None]:
## plot results
n = len(stats_exp)
fig, axs = plt.subplots(4, n, figsize=(n * 2, 5), layout="constrained")

for i, m in enumerate([2, 5, 8, 11]):

    ## loop thru experiments
    for j in range(n):
        stats_exp_ = RO_sigma_exp
        plot_stats_comp(
            axs[i, j],
            [stats_control.sel(month=m), stats_exp[j].sel(month=m)],
            labels=[labels[0], labels[j + 1]],
            colors=["k", sns.color_palette()[j]],
        )

    ## get baseline value
    sigma0 = stats_control.sel(month=m)[varnames[0]].isel(q=1, year=0).values.item()

    ## format
    format_row(axs=axs[i], y0=sigma0)

## format all subplots
format_subplots(axs, labels=labels[1:])

plt.show()

## Look at MLD scaling

In [None]:
## load mixed layer depth EOFs
MLD_EOFS = src.utils.load_eofs(pathlib.Path(DATA_FP, "cesm", "eofs_mlotst.nc"))


def get_H(t_bnds):
    """get Niño 3.4 MLD for given bounds"""

    ## get scores in given subset of data
    scores_ = MLD_EOFS.scores().isel(time=slice(*t_bnds)).mean("member")

    ## get clim
    scores_clim = scores_.groupby("time.month").mean()

    ## reconstruct nino MLD
    mld = src.utils.reconstruct_fn(
        scores=scores_clim, components=MLD_EOFS.components(), fn=src.utils.get_nino34
    )

    return mld.rename({"month": "cycle"})


## get forced component
scores_mean = MLD_EOFS.scores().mean("member")

## compute MLD
H = src.utils.reconstruct_fn(
    scores=scores_mean.isel(time=slice(None, -1)),
    components=MLD_EOFS.components(),
    fn=src.utils.get_nino34,
)

## sep month/year
H = src.utils.unstack_month_and_year(H)

## smooth
H = H.rolling({"year": 31}, center=True).mean()
H = H.isel(year=slice(15, -15))

## get inverse MLD
beta = 1 / H

In [None]:
## specify months
month_range = slice(4, 5)

## get F2 and eps
F2 = -fits.Lac.isel(rankx=0, ranky=1).mean("member")
eps = -fits.Lac.isel(rankx=1, ranky=1).mean("member")
R = fits.Lac.isel(rankx=0, ranky=0).mean("member")

## subset for months
beta_ = beta.isel(month=month_range).mean("month")
F2_ = F2.isel(cycle=month_range).mean("cycle")
eps_ = eps.isel(cycle=month_range).mean("cycle")
R_ = R.isel(cycle=slice(4, 5)).mean("cycle")

## get fractional change
get_delta = lambda x: (x - x.isel(year=1)) / x.isel(year=1)
delta_beta = get_delta(beta_)
delta_F2 = get_delta(F2_)
delta_eps = get_delta(eps_)
delta_R = get_delta(R_)

In [None]:
fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")

## plot data
ax.plot(
    delta_beta.year, delta_beta, label=r"$\frac{\Delta H^{-1}}{H^{-1}}$", c="k", lw=2
)
ax.plot(delta_F2.year, delta_F2, label=r"$\frac{\Delta F_2}{F_2}$")
ax.plot(delta_eps.year, delta_eps, label=r"$\frac{\Delta \varepsilon}{\varepsilon}$")
# ax.plot(delta_R.year, delta_R*.25, label=r"$\frac{\Delta R}{R}$")

## label
ax.set_ylabel("Fractional change")


ax.legend()
plt.show()

In [None]:
# ## get Niño 3.4 mean
# Th_forced = xr.open_dataset(DATA_FP / "cesm" / "Th_emean.nc")
# T34 = Th_forced["T_34"].groupby("time.year").mean().isel(year=slice(None, -1))
# T34 = T34.rolling({"year": 15}, center=True).mean().isel(year=slice(7, -7))

# n_year = len(T34.year)
# sa = 10 * np.ones(n_year)
# p = 0 * np.ones(n_year)
# alpha = gsw.alpha(SA=sa, CT=T34.values, p=p)
# dalpha = (alpha - alpha[0]) / alpha[0]

## Floquet analysis

In [None]:
def get_L(fit, cycle_shift=np.arange(0, 1, 1 / 12)):
    """reconstruct linear operator from fit parameters and cycle (in range [0,1])"""

    ncycle = len(cycle_shift)
    ac_order = len(fit.ac_rank) - 1
    ncol = ac_order * 2 + 1
    rank_y = len(fit.ranky)
    rank_x = len(fit.rankx)
    omega = 2 * np.pi

    L_ac = np.zeros((rank_y, rank_x, ncycle))
    for i in range(rank_y):
        for j in range(rank_x):
            for m in range(ncol):

                ## dict for slicing
                sel_dict = dict(rankx=j, ranky=i, cossin=m)

                if m <= ac_order:
                    L_ac[i, j, :] = L_ac[i, j, :] + fit.Lcoef.isel(
                        sel_dict
                    ).values * np.cos(m * omega * cycle_shift)
                else:
                    mm = m - ac_order
                    L_ac[i, j, :] = L_ac[i, j, :] + fit.Lcoef.isel(
                        sel_dict
                    ).values * np.sin(mm * omega * cycle_shift)

    ## put in xarray
    cycle_coord = cycle_shift + cycle_shift[1] / 2
    L_ac = xr.DataArray(
        L_ac,
        coords=dict(ranky=fit.ranky, rankx=fit.rankx, cycle=cycle_coord),
        dims=["ranky", "rankx", "cycle"],
    )

    return L_ac


def integrate(L, x0, t0, tf, dt=1 / 2 * 1 / 365, save_hist=False, verbose=False):
    """integrate ode in time. L is a time-dependent function and
    x0 is the initial condition"""

    ## initialize and get time to integrate over
    x = x0
    time = np.arange(t0, tf, dt)

    ## define identity for convenience
    I = np.eye(2)

    ## empty array to hold results if desired
    if save_hist:
        x_hist = []

    ## integrate
    for t in tqdm.tqdm(time, disable=not (verbose)):
        x = (I + L(t) * dt) @ x

        ## save results if specified
        if save_hist:
            x_hist.append(x)

    if save_hist:
        return np.concatenate(x_hist, axis=1), time

    else:
        return x


def get_monodromy(L, dt=1 / 2 * 1 / 365, verbose=False):
    """Given time-dependent linear operator, construct monodromy matrix"""

    ## initial condition is identity matrix
    M0 = np.eye(2)

    ## integrate for one cycle
    M = integrate(L=L, x0=M0, t0=0, tf=1, dt=dt, verbose=verbose)

    return M


def get_timescales(fit, dt=1 / 365 * 1 / 2, verbose=False):
    """Get efolding timescale and period for given RO fit"""

    ## get cyclostationary operator
    L_cyc = get_L(fit, cycle_shift=np.arange(0, 1, dt))

    ## helper function to get numpy version based on time
    L_ = (
        lambda t: L_cyc.sel(cycle=t, method="nearest")
        .transpose("ranky", "rankx")
        .values
    )

    ## Get monodromy mat
    M = get_monodromy(L=L_, dt=dt, verbose=verbose)

    ## eigen-decomp
    w, _ = scipy.linalg.eig(M, left=False, right=True)

    ## get eigenvalues of linear operator
    gamma = np.log(w)
    sigma = gamma.real[0]
    omega = np.abs(gamma.imag)[0]

    ## compute timescales
    efold = -1 / sigma
    period = 2 * np.pi / omega

    return efold, period

In [None]:
## check L reconstruction works
np.allclose(get_L(fits.isel(year=1)), fits.isel(year=1).Lac)

In [None]:
## years to loop thru
years = fits.year.values[1:]

## empty list to hold results
efolds = []
periods = []

for y in tqdm.tqdm(years):
    efold, period = get_timescales(fits.sel(year=y).squeeze())
    efolds.append(efold)
    periods.append(period)

efolds = np.array(efolds)
periods = np.array(periods)

In [None]:
def format_xaxis(ax):
    """add formatting to x axis"""
    ax.set_xlabel("Time")
    ax.set_xticks([1880, 2010, 2080])
    ax.axvline(2010, ls="--", c="gray", lw=0.8)

    return


## specify colors
colors = sns.color_palette("colorblind")[1:]
# colors = [colors[i] for i in [0,2]]


### Plot 1: same scale
fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")

ax.plot(years, efolds, label=r"$e$-fold", c=colors[0])
ax.plot(years, periods, label="period", c=colors[1])

## formatting
ax.set_ylim([0, None])
ax.set_yticks([0, 2, 4])
ax.set_ylabel("year")
ax.legend(prop=dict(size=10))
format_xaxis(ax)

plt.show()


### Plot 2: different scales
fig, ax = plt.subplots(figsize=(3.5, 2.5), layout="constrained")

ax.plot(years, efolds, label=r"$e$-fold", c=colors[0])
ax.set_yticks([1.4, 1.6, 1.8])
ax.set_ylabel(r"$e$-fold (yr)", color=colors[0])

ax2 = ax.twinx()
ax2.plot(years, periods, c=colors[1])
ax2.set_yticks([3.3, 3.9, 4.5])
ax2.set_ylabel("period (yr)", color=colors[1])

## scale axes
ax.set_ylim([1.3, 1.9])
ax2.set_ylim([3, 4.8])

format_xaxis(ax)


plt.show()

## Snapshots of seasonal changes, with variance

In [None]:
def plot_mean_and_bounds(ax, x, show_bounds=True, **plot_kwargs):
    """print seasonal cycle of data in x on specified ax object"""

    ## plot ensemble mean
    plot_data = ax.plot(x.cycle, x.sel(posn="center"), lw=2, **plot_kwargs)

    ## plot bounds
    if show_bounds:
        kwargs = dict(c=plot_data[0].get_color(), ls="--", lw=0.5)
        for bound in ["upper", "lower"]:
            ax.plot(x.cycle, x.sel(posn=bound), **kwargs)

    return plot_data


def format_ax_and_twin(ax, ax_twin):
    """format ax as desired"""

    ax.set_title(y)
    ax.set_xticks([1, 7, 12], labels=["Jan", "Jul", "Dec"])
    ax.set_ylim([-3, 4.7])
    ax_twin.set_ylim([0, 2.2])
    ax.set_yticks([])
    ax_twin.set_yticks([])

    return

3 panels: BJ, noise, variance

In [None]:
params.year

In [None]:
stats_baseline = src.utils.get_ensemble_stats(params.sel(year=1871, method="nearest"))
var_baseline = src.utils.get_ensemble_stats(
    Th_std.sel(year=1871, method="nearest")
).rename({"month": "cycle"})

## change in params between years
fig, axs = plt.subplots(1, 3, figsize=(7, 7 / 3))

## colors for different years
colors = sns.color_palette("mako")[::2]
years = [1871, 1971, 2071]
alphas = [1 / 4, 1, 1]

axs_twin = []
for j, (y, c, a) in enumerate(zip(years, colors, alphas)):

    ## twin axis for plotting variance
    axs_twin.append(axs[j].twinx())

    ## Get stats for year
    stats = src.utils.get_ensemble_stats(params.sel(year=y, method="nearest"))
    var_stats = src.utils.get_ensemble_stats(
        Th_std.sel(year=y, method="nearest")
    ).rename({"month": "cycle"})

    ## plot baselines
    kwargs = dict(show_bounds=False, alpha=1 / 3)
    plot_mean_and_bounds(axs[j], x=stats_baseline["BJ_ac"], c="k", **kwargs)
    plot_mean_and_bounds(
        axs[j], x=stats_baseline["xi_T_norm"], c="k", ls="--", **kwargs
    )
    plot_mean_and_bounds(axs_twin[j], x=var_baseline["T_34"], c="r", **kwargs)

    ## get stats for new period
    plot_mean_and_bounds(axs[j], x=stats["BJ_ac"], c="k", alpha=a, label=r"$BJ$")
    plot_mean_and_bounds(
        axs[j], x=stats["xi_T_norm"], ls="--", c="k", alpha=a, label=r"$\xi_T/\sigma_T$"
    )
    plot_mean_and_bounds(
        axs_twin[j], x=var_stats["T_34"], c="r", alpha=a, label=r"$\sigma(T)$"
    )

    ## format axes
    format_ax_and_twin(axs[j], axs_twin[j])

axs[-1].legend(prop=dict(size=6), loc="upper right")
axs[0].set_yticks([-2, 0, 2, 4])
axs[0].set_ylabel(r"year$^{-1}$")
axs_twin[-1].set_yticks([0, 1], labels=[0, 1], color="r")
axs_twin[-1].set_ylabel(r"$\sigma_T$ ($^{\circ}$C)", color="r")

plt.show()

## Growth rate and noise vs variance
(need ensemble of parameters to do this...)

Function to format plot

In [None]:
def format_ax(ax):
    """add guidelines and labels"""
    ax.axvline(0, c="k", lw=0.5)
    ax.axhline(0, c="k", lw=0.5)
    ax.set_xticks([-0.4, 0, 0.4, 0.8])
    ax.set_yticks([0, 35, 70])
    ax.set_xlabel(r"$\Delta ~\overline{BJ}$ (year$^{-1}$)")
    ax.set_ylabel(r"$\Delta \sigma_T$ (%)")
    return

compute relative change in variance (normalize each model separately)

In [None]:
baseline = Th_std.isel(year=0)
delta_Th_std_relative = 100 * (Th_std - baseline) / baseline

Make plot

In [None]:
## get data for plot
x0 = delta_params["BJ_ac"].sel(year=2071).mean("cycle")
x1 = delta_params["xi_T"].sel(year=2071).mean("cycle")
y = delta_Th_std_relative["T_34"].sel(year=2071).mean("month")

## compute correlation
corr0, pval0 = scipy.stats.pearsonr(x0.values, y.values)
corr1, pval0 = scipy.stats.pearsonr(x1.values, y.values)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2))

## scatter plot preferences
scatter_kwargs = dict(zorder=10, s=10)

## plot data (BJ) and label
axs[0].scatter(x0, y, **scatter_kwargs)
axs[0].set_title(f"$r=${corr0:.2f}")
axs[0] = format_ax(axs[0])

## plot data (noise) and label
axs[1].scatter(x1, y, **scatter_kwargs)
axs[1].set_title(f"$r=${corr1:.2f}")

## label 2nd axis
format_ax(axs[1])
axs[1].set_yticks([])
axs[1].set_ylabel(None)
axs[1].set_xlabel(r"$\Delta\xi_T$ (K year$^{-1}$)")

plt.show()

## Supplementary

#### Parameter estimates from individual ensemble members

In [None]:
## param to plot
p = "epsilon"

fig, axs = plt.subplots(2, 2, figsize=(6, 6), layout="constrained")

param_names = ["R", "xi_T_norm", "epsilon", "xi_h_norm"]
titles = [r"$R$", r"$\xi_T$", r"$\varepsilon$", r"$\xi_h$"]

for p, title, ax in zip(param_names, titles, axs.flatten()):

    ax.set_title(title)

    ## plot ensemble members
    for i, m in enumerate(params.member.values):

        label = "Ensemble members" if (i == 50) else None

        ax.plot(
            params.year,
            delta_params[p].sel(member=m).mean("cycle"),
            c="gray",
            alpha=0.5,
            lw=0.5,
        )

    ## plot ensemble mean
    ax.plot(
        params.year,
        delta_params[p].mean(["member", "cycle"]),
        c="k",
        lw=2,
        label="Ensemble mean",
    )

    ## set axis specs
    ax.set_ylim([-2, 2])
    ax.axhline(0, c="k", ls="--", lw=1)
    ax.set_xticks([])

## format
for ax in axs[1, :]:
    ax.set_xlabel("Time")
    ax.set_xticks([1870, 1975, 2080])
for ax in axs[:, 0]:
    ax.set_ylabel(r"Growth rate (yr$^{-1}$)")

plt.show()