# RO change over time

## Imports

In [None]:
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import src.XRO
import copy
import scipy.stats
import warnings

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Functions

In [None]:
def get_ensemble_fits_over_time(
    data,
    model,
    ac_mask_idx,
    mask_NT,
    window_size=480,
    step_size=60,
    by_ensemble_member=True,
):
    """Get RO fits for each ensemble member as a function of time.
    Args:
        window_size: size of sliding window (units: months)
        step_size: how many months to slide the window between each calculation
    """

    ## Get number of timesteps in data
    n = len(data.time)

    ## empty list to hold results and dates
    fits_by_year = []
    start_dates = []

    ## loop through rolling windows
    for j, i in enumerate(tqdm.tqdm(np.arange(0, n - step_size, step_size))):

        ## make sure there's enough samples for robust estimate
        if (n - i) > (0.7 * window_size):

            ## get subset of data for fitting model
            data_subset = data.isel(time=slice(i, i + window_size))

            ## Get start date for subset
            start_dates.append(data_subset.time.isel(time=0))

            ## get parameter fit
            if by_ensemble_member:
                _, fits = src.utils.get_RO_ensemble(
                    data_subset,
                    # T_var=T_var,
                    # h_var=h_var,
                    model=model,
                    ac_mask_idx=ac_mask_idx,
                )

            else:
                with warnings.catch_warnings(action="ignore"):
                    fits = model.fit_matrix(
                        data_subset,
                        ac_mask_idx=ac_mask_idx,
                        maskNT=mask_NT,
                    )

            ## drop X,Y,time variables
            fits = fits.drop_vars(["X", "Y", "Yfit", "time"])
            fits_by_year.append(fits)

    ## convert from list to xarray
    start_dates = xr.concat(start_dates, dim="time")
    fits_by_year = xr.concat(fits_by_year, dim=start_dates)

    return fits_by_year


def update_time_coord(data, window_size=480):
    """change time coordinate to center of period used for computing stats"""

    ## get year corresponding to start of period
    year = copy.deepcopy(data.time.dt.year.values)

    ## update to midpoint of period
    year_shift = int(window_size / 12 / 2)
    year += year_shift

    ## update coordinate
    data = data.assign_coords({"time": year}).rename({"time": "year"})

    ## trim to end in 2080
    data = data.sel(year=slice(None, 2080))

    return data


def get_params(fits, model):
    """Get parameters from fits dataarray"""

    ## get parameters from fits
    params = model.get_RO_parameters(fits)

    ## get normalized noise stats
    fix_coords = lambda x: x.assign_coords({"cycle": params.cycle})
    params["xi_T_norm"] = fix_coords(fits["normxi_stdac"].isel(ranky=0))
    params["xi_h_norm"] = fix_coords(fits["normxi_stdac"].isel(ranky=1))

    return params


def get_rolling_std(data, n=20):
    """
    Get standard deviation, computing over time and ensemble member. To increase
    sample size for variance estimate, compute over time window of 2n+1
    years, centered at given year.
    """

    ## do the computation
    kwargs = dict(fn=np.std, n=n, reduce_ensemble_dim=False)
    data_std = src.utils.get_rolling_fn_bymonth(data, **kwargs)

    ## unstack year and month
    data_std = src.utils.unstack_month_and_year(data_std)

    return data_std

## Load data

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## standardize (for convenience)
Th /= Th.std()

## Compute RO parameters' change over time

### Version 2: single RO fit at each timestep (using all MPI ensemble members)

In [None]:
## specify variables names
varnames = ["T_34", "h_w"]

# ## specify args for model fit
kwargs = dict(
    model=src.XRO.XRO(ncycle=12, ac_order=3, is_forward=True),
    ac_mask_idx=None,
    window_size=480,
    step_size=120,
    mask_NT=[],
)

## Get fits
fits_v2 = get_ensemble_fits_over_time(Th[varnames], by_ensemble_member=False, **kwargs)

## expand "ensemble dim" to match rest of script
fits_v2 = update_time_coord(
    fits_v2.expand_dims("member"), window_size=kwargs["window_size"]
)

## extract parameters
params_v2 = get_params(fits=fits_v2, model=kwargs["model"])

## get change from initial period
delta_params_v2 = params_v2 - params_v2.isel(year=1)

## Generate ensemble of RO parameters

In [None]:
def get_RO_sigma(model, params, **simulation_kwargs):
    """Compute stats (e.g., standard deviation) for RO parameters over time"""

    output = model.simulate(fit_ds=params.mean("member"), **simulation_kwargs)

    return output.groupby("time.month").std()


def get_RO_sigma_over_time(model, params, **simulation_kwargs):
    """Compute stats over time"""

    ## empty list to hold result
    sigmas = []

    ## loop through years
    for y in tqdm.tqdm(params.year):

        ## specs for simulation
        kwargs = dict(
            simulation_kwargs,
            model=model,
            params=params.sel(year=y),
        )

        ## do the simulation
        sigmas.append(get_RO_sigma(**kwargs))

    ## put back in xarray
    sigmas = xr.concat(sigmas, dim=params.year)

    return sigmas

In [None]:
## simulation specs
simulation_kwargs = dict(
    nyear=40,
    ncopy=50,
    seed=1000,
    X0_ds=Th[varnames].isel(member=0, time=0),
    noise_type="white",
)

## compute with parameters estimated from all ensemble members
RO_sigma_over_time_v2 = get_RO_sigma_over_time(
    model=kwargs["model"], params=fits_v2, **simulation_kwargs
)

## compute change in $\sigma(\text{Niño 3.4})$

In [None]:
## compute rolling std
Th_std = get_rolling_std(Th, n=20)

## compute percentage change in std
baseline = Th_std.isel(year=0).mean("member")
delta_Th_std = 100 * (Th_std - baseline) / baseline

### Compare MPI and RO

Function to plot results

In [None]:
def plot_stats_comp(ax, list_of_stats, labels, colors=None, n=varnames[0]):
    """plot comparison of variance over time"""

    if colors is None:
        colors = sns.color_palette()[: len(list_of_stats)]

    for stats, label, c in zip(list_of_stats, labels, colors):

        ## plot median
        mplot = ax.plot(stats.year, stats[n].sel(q=0.5), lw=2.5, label=label, c=c)

        ## plot lower/upper quantiles
        kwargs = dict(c=mplot[0].get_color(), lw=0.8)
        for q in stats.q:
            if q != 0.5:
                ax.plot(stats.year, stats[n].sel(q=q), **kwargs)

    ## label and set plotting specs
    ax.set_xlabel("Year")
    ax.set_ylabel(r"$\sigma_T$ ($^{\circ}$C)")
    ax.legend(prop=dict(size=8))
    ax.set_ylim([0, None])
    ax.set_xticks([1870, 1975, 2080])
    ax.set_yticks([0, 0.6, 1.2])

    return

Make the plot

In [None]:
## specify function to reduce over months
sel_fn = lambda x: x.sel(month=5)

## specify function to compute bounds
get_stats = (
    lambda x: sel_fn(x)
    .quantile(q=[0.1, 0.5, 0.9], dim="member")
    .rename({"quantile": "q"})
)

## compute stats
stats_mpi = get_stats(Th_std)
stats_ro_v2 = get_stats(RO_sigma_over_time_v2)

## plot results
fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")

## first, MPI vs RO
plot_stats_comp(
    ax,
    [stats_mpi, stats_ro_v2],
    labels=["MPI", "RO"],
    colors=["k", sns.color_palette()[1]],
    n=varnames[0],
)

plt.show()

## Plot diagnostics

### Hovmoller plots for variance, growth rate, and noise

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(8, 3.5), layout="constrained")

#### plot change in std dev
plot_kwargs = dict(
    cmap="cmo.balance", levels=src.utils.make_cb_range(30, 6), extend="both"
)

## plot data
cp0 = axs[0].contourf(
    delta_Th_std.month,
    delta_Th_std.year,
    delta_Th_std[varnames[0]].mean("member"),
    **plot_kwargs
)

##### plot change in model params

## specify plotting specs
plot_kwargs = dict(
    cmap="cmo.balance", levels=src.utils.make_cb_range(2, 0.2), extend="both"
)

## plot data
cp1 = axs[1].contourf(
    params_v2.cycle,
    params_v2.year,
    delta_params_v2["BJ_ac"].mean("member"),
    **plot_kwargs
)
cp2 = axs[2].contourf(
    params_v2.cycle,
    params_v2.year,
    0.5 * delta_params_v2["R"].mean("member"),
    **plot_kwargs
)
cp3 = axs[3].contourf(
    params_v2.cycle,
    params_v2.year,
    delta_params_v2["xi_T_norm"].mean("member"),
    **plot_kwargs
)

## set ticks and add guideline
for ax in axs:
    ax.set_xticks([1, 7, 12], labels=["Jan", "Jul", "Dec"])
    ax.axvline(7, c="w", ls="--", lw=1, alpha=0.8)
    ax.set_xlabel("Month")
    ax.axhline(2025, c="w", ls="--", lw=1)

## add colorbar
cb0 = fig.colorbar(cp0, label=r"% change", ticks=[-30, 0, 30])
cb1 = fig.colorbar(cp3, label=r"$\Delta$ Growth rate (yr$^{-1}$)", ticks=[-1, 0, 1])

## label
axs[0].set_ylabel("Year")
axs[0].set_yticks([1870, 1975, 2080])
axs[0].set_title(r"$\frac{\Delta \sigma(T)}{\sigma(T)_{1870}}$", size=10)
axs[1].set_title(r"$\Delta BJ$", size=10)
axs[2].set_title(r"$\frac{1}{2}\Delta R$", size=10)
axs[3].set_title(r"$\Delta\left(\frac{\text{Noise}}{\sigma(T)}\right)$", size=10)
for ax in axs[1:]:
    ax.set_yticks([])
    ax.set_ylim(axs[0].get_ylim())

plt.show()

### Change in ensemble-mean parameters over time

Plotting funcs

In [None]:
def plot_curve(ax, x, **plot_kwargs):
    """plot change in parameter over time on given ax"""
    plot_data = ax.plot(x.year, x, lw=2, **plot_kwargs)

    return plot_data


def plot_param_change_over_time(ax, dp):
    """Plot change in parameters over time on ax object. 'dp' is change in params"""

    ## annual mean over time
    ax.axhline(0, c="k", ls="-", lw=0.5)

    ## Plot Bjerknes growth rate
    plot_curve(ax, dp["BJ_ac"].mean("cycle"), c="k", label=r"$\overline{BJ}$")
    plot_curve(ax, dp["BJ_ac"].sel(cycle=7), c="k", ls="--", label="$BJ_{Jul}$")

    ## Plot R
    R_plot = plot_curve(ax, dp["R"].mean("cycle"), label=r"$\overline{R}$")
    kwargs = dict(c=R_plot[0].get_color(), ls="--", label=r"$R_{Jul}$")
    plot_curve(ax, dp["R"].sel(cycle=6), **kwargs)

    ## plot epsilon
    plot_curve(ax, -dp["epsilon"].mean("cycle"), label=r"$-\varepsilon$")

    ## plot noise
    plot_curve(
        ax,
        dp["xi_T_norm"].mean("cycle"),
        label=r"$\xi_T/\sigma_T$",
        c="darkgray",
        # ax,
        # dp["xi_T_norm"].sel(cycle=7),
        # label=r"$\xi_T/\sigma_T$",
        # c="darkgray",
    )
    plot_curve(
        ax, dp["xi_h_norm"].mean("cycle"), label=r"$\xi_h/\sigma_h$", c="lightgray"
    )

    ## set axis specs
    ax.set_xlim([None, None])
    ax.legend(prop=dict(size=6), loc="upper left")
    ax.set_xticks([1870, 1975, 2080])
    ax.set_yticks([-0.5, 0, 0.5, 1])
    ax.set_ylabel(r"year$^{-1}$")
    ax.set_xlabel("Year")
    ax.set_title(r"$\Delta \left(\text{RO parameters}\right)$")

    return

Make plot

In [None]:
fig, ax = plt.subplots(figsize=(3.5, 3))

## plot change over time of parameters
# plot_param_change_over_time(ax, dp=delta_params.mean("member"))
plot_param_change_over_time(ax, dp=delta_params_v2.mean("member"))

plt.show()

#### Sensitivity tests
Is noise or BJ index driving changes?

In [None]:
def get_perturbed_params(params, name):
    """get version of parameters where specified parameter
    is fixed to its starting value"""

    ## copy of params to hold perturbed values
    pparams = copy.deepcopy(params)

    ## get initial value of parameter and broadcast it to correct shape
    x0 = params[name].isel(year=1) * xr.ones_like(params[name])

    ## transpose dims to make sure they match
    x0 = x0.transpose(*params[name].dims)

    ## update parameters
    pparams[name].values = x0

    return pparams


def get_perturbed_BJ(params):
    """fix values of R and epsilon in parameter set"""

    ## initialize empty array to hold parameters
    pparams = copy.deepcopy(params)

    ## get copy of linear operator
    Lac = pparams["Lac"].transpose("year", "ranky", "rankx", ...)
    R0 = Lac.values[1:2, 0, 0]
    eps0 = Lac.values[1:2, 1, 1]

    ## update Lac
    Lac.values[:, 0, 0] = R0
    Lac.values[:, 1, 1] = eps0
    pparams["Lac"] = Lac

    return pparams


def get_perturbed_Lac(params):
    """fix values of R and epsilon in parameter set"""

    ## initialize empty array to hold parameters
    pparams = copy.deepcopy(params)

    ## get copy of linear operator
    Lac = pparams["Lac"].transpose("year", "ranky", "rankx", ...)

    ## perturb core RO
    L0 = Lac.values[1:2, :2, :2]
    Lac.values[:, :2, :2] = L0

    # ## perturb PMM effect on RO
    # L0 = Lac.values[1:2, :2, 2:]
    # Lac.values[:,:2, 2:] = L0

    pparams["Lac"] = Lac

    return pparams


def get_perturbed_xi(params, ranky):
    """fix values of R and epsilon in parameter set"""

    ## initialize empty array to hold parameters
    pparams = copy.deepcopy(params)

    ## get copy of linear operator
    for n in ["xi_std", "xi_stdac"]:
        xi = pparams[n].transpose("year", "ranky", ...)
        xi0 = xi.values[1:2, ranky]

        ## update matrix
        xi.values[:, ranky] = xi0
        pparams[n] = xi

    return pparams

Get perturbed parameters and plot

In [None]:
## get parameters with fixed noise
params_fixed_noise = get_perturbed_params(fits_v2, "xi_stdac")
params_fixed_noise = get_perturbed_params(params_fixed_noise, "xi_std")
# params_fixed_noise = get_perturbed_xi(fits_v2, ranky=0)
# params_fixed_noise = get_perturbed_xi(fits_v2, ranky=0)

# ## next, fix R and epsilon
params_fixed_BJ = get_perturbed_BJ(fits_v2)
# params_fixed_BJ = get_perturbed_Lac(fits_v2)
# params_fixed_BJ = get_perturbed_Lac(params_fixed_noise)

## functions to extract data for plotting
sel_fn = lambda x: x.mean(["member", "cycle"])
sel_noise = lambda x: sel_fn(x)["xi_stdac"].isel(ranky=0)
sel_BJ = lambda x: kwargs["model"].get_RO_parameters(sel_fn(x))["BJ_ac"]

## set up plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot noise
c = sns.color_palette()[0]
ax.plot(params_fixed_noise.year, sel_noise(fits_v2), c=c, label=r"$\xi_T$")
ax.plot(params_fixed_noise.year, sel_noise(params_fixed_noise), c=c, ls="--")

## plot R
c = sns.color_palette()[1]
ax.plot(params_fixed_BJ.year, sel_BJ(fits_v2), c=c, label=r"$BJ$")
ax.plot(params_fixed_BJ.year, sel_BJ(params_fixed_BJ), c=c, ls="--")

## label
ax.legend()
ax.set_title(r"Parameters over time")
ax.axhline(0, lw=0.5, c="k")

plt.show()

Run simulations with perturbed parameters

In [None]:
## compute with parameters estimated from all ensemble members
RO_sigma_over_time_fixed_noise = get_RO_sigma_over_time(
    model=kwargs["model"], params=params_fixed_noise, **simulation_kwargs
)

RO_sigma_over_time_fixed_BJ = get_RO_sigma_over_time(
    model=kwargs["model"], params=params_fixed_BJ, **simulation_kwargs
)

Plot results

In [None]:
plt.plot(fits_v2.Lac.isel(rankx=1, ranky=0).mean(["member", "cycle"]))
plt.plot(-fits_v2.Lac.isel(rankx=0, ranky=1).mean(["member", "cycle"]))

In [None]:
## specify function to reduce over months
sel_fn = lambda x: x.sel(month=4)
# sel_fn = lambda x : x.mean("month")

## specify function to compute bounds
get_stats = (
    lambda x: sel_fn(x)
    .quantile(q=[0.1, 0.5, 0.9], dim="member")
    .rename({"quantile": "q"})
)

## compute stats
stats_ro_v2 = get_stats(RO_sigma_over_time_v2)
stats_fixed_noise = get_stats(RO_sigma_over_time_fixed_noise)
stats_fixed_BJ = get_stats(RO_sigma_over_time_fixed_BJ)

## plot results
fig, axs = plt.subplots(1, 2, figsize=(5.5, 2.5), layout="constrained")

## first, fix noise
plot_stats_comp(
    axs[0],
    [stats_ro_v2, stats_fixed_noise, stats_fixed_BJ],
    labels=["Control", "fixed noise"],
    colors=["k", sns.color_palette()[1]],
)

## next, fix BJ
plot_stats_comp(
    axs[1],
    [stats_ro_v2, stats_fixed_BJ],
    labels=["Control", r"fixed $BJ$"],
    colors=["k", sns.color_palette()[0]],
)

axs[1].set_yticks([])
axs[1].set_ylabel(None)
for ax in axs:
    # ax.set_ylim([0, 1.])
    ax.axhline(
        stats_ro_v2[varnames[0]].isel(q=1, year=1).values.item(), c="magenta", lw=1
    )

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(delta_Th_std["iod"].year, delta_Th_std["iod"].mean("member").mean("month"))

xi_iod = fits_v2["xi_stdac"].isel(member=0, ranky=3).mean("cycle")

ax.plot(fits_v2.year, 100 * (xi_iod - xi_iod.isel(year=1)) / xi_iod.isel(year=1))
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
for r in range(8):
    p_ = fits_v2["normxi_stdac"].squeeze().mean("cycle").isel(ranky=r)
    ax.plot(p_.year, (p_ - p_[1]) / p_[1], label=varnames[r])

ax.legend(prop=dict(size=8), loc="upper left")
ax.axhline(0, ls="--", c="k")
plt.show()

### Snapshots of seasonal changes, with variance

In [None]:
def plot_mean_and_bounds(ax, x, show_bounds=True, **plot_kwargs):
    """print seasonal cycle of data in x on specified ax object"""

    ## plot ensemble mean
    plot_data = ax.plot(x.cycle, x.sel(posn="center"), lw=2, **plot_kwargs)

    ## plot bounds
    if show_bounds:
        kwargs = dict(c=plot_data[0].get_color(), ls="--", lw=0.5)
        for bound in ["upper", "lower"]:
            ax.plot(x.cycle, x.sel(posn=bound), **kwargs)

    return plot_data


def format_ax_and_twin(ax, ax_twin):
    """format ax as desired"""

    ax.set_title(y)
    ax.set_xticks([1, 7, 12], labels=["Jan", "Jul", "Dec"])
    ax.set_ylim([-2, 4.7])
    ax_twin.set_ylim([0, 2.2])
    ax.set_yticks([])
    ax_twin.set_yticks([])

    return

3 panels: BJ, noise, variance

In [None]:
stats_baseline = src.utils.get_ensemble_stats(params.sel(year=1870))
var_baseline = src.utils.get_ensemble_stats(Th_std.sel(year=1870)).rename(
    {"month": "cycle"}
)

## change in params between years
fig, axs = plt.subplots(1, 3, figsize=(7, 7 / 3))

## colors for different years
colors = sns.color_palette("mako")[::2]
years = [1870, 1975, 2080]
alphas = [1 / 4, 1, 1]

axs_twin = []
for j, (y, c, a) in enumerate(zip(years, colors, alphas)):

    ## twin axis for plotting variance
    axs_twin.append(axs[j].twinx())

    ## Get stats for year
    stats = src.utils.get_ensemble_stats(params.sel(year=y))
    var_stats = src.utils.get_ensemble_stats(Th_std.sel(year=y)).rename(
        {"month": "cycle"}
    )

    ## plot baselines
    kwargs = dict(show_bounds=False, alpha=1 / 3)
    plot_mean_and_bounds(axs[j], x=stats_baseline["BJ_ac"], c="k", **kwargs)
    plot_mean_and_bounds(
        axs[j], x=stats_baseline["xi_T_norm"], c="k", ls="--", **kwargs
    )
    plot_mean_and_bounds(axs_twin[j], x=var_baseline["T_34"], c="r", **kwargs)

    ## get stats for new period
    plot_mean_and_bounds(axs[j], x=stats["BJ_ac"], c="k", alpha=a, label=r"$BJ$")
    plot_mean_and_bounds(
        axs[j], x=stats["xi_T_norm"], ls="--", c="k", alpha=a, label=r"$\xi_T/\sigma_T$"
    )
    plot_mean_and_bounds(
        axs_twin[j], x=var_stats["T_34"], c="r", alpha=a, label=r"$\sigma(T)$"
    )

    ## format axes
    format_ax_and_twin(axs[j], axs_twin[j])

axs[-1].legend(prop=dict(size=6), loc="upper right")
axs[0].set_yticks([-2, 0, 2, 4])
axs[0].set_ylabel(r"year$^{-1}$")
axs_twin[-1].set_yticks([0, 1], labels=[0, 1], color="r")
axs_twin[-1].set_ylabel(r"$\sigma_T$ ($^{\circ}$C)", color="r")

plt.show()

### Growth rate and noise vs variance

Function to format plot

In [None]:
def format_ax(ax):
    """add guidelines and labels"""
    ax.axvline(0, c="k", lw=0.5)
    ax.axhline(0, c="k", lw=0.5)
    ax.set_xticks([-0.4, 0, 0.4, 0.8])
    ax.set_yticks([0, 35, 70])
    ax.set_xlabel(r"$\Delta ~\overline{BJ}$ (year$^{-1}$)")
    ax.set_ylabel(r"$\Delta \sigma_T$ (%)")
    return

compute relative change in variance (normalize each model separately)

In [None]:
baseline = Th_std.isel(year=0)
delta_Th_std_relative = 100 * (Th_std - baseline) / baseline

Make plot

In [None]:
## get data for plot
x0 = delta_params["BJ_ac"].sel(year=2080).mean("cycle")
x1 = delta_params["xi_T"].sel(year=2080).mean("cycle")
y = delta_Th_std_relative["T_34"].sel(year=2080).mean("month")

## compute correlation
corr0, pval0 = scipy.stats.pearsonr(x0.values, y.values)
corr1, pval0 = scipy.stats.pearsonr(x1.values, y.values)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2))

## scatter plot preferences
scatter_kwargs = dict(zorder=10, s=10)

## plot data (BJ) and label
axs[0].scatter(x0, y, **scatter_kwargs)
axs[0].set_title(f"$r=${corr0:.2f}")
axs[0] = format_ax(axs[0])

## plot data (noise) and label
axs[1].scatter(x1, y, **scatter_kwargs)
axs[1].set_title(f"$r=${corr1:.2f}")

## label 2nd axis
format_ax(axs[1])
axs[1].set_yticks([])
axs[1].set_ylabel(None)
axs[1].set_xlabel(r"$\Delta\xi_T$ (K year$^{-1}$)")

plt.show()

## Supplementary

#### Parameter estimates from individual ensemble members

In [None]:
## param to plot
p = "epsilon"

fig, axs = plt.subplots(2, 2, figsize=(6, 6), layout="constrained")

param_names = ["R", "xi_T_norm", "epsilon", "xi_h_norm"]
titles = [r"$R$", r"$\xi_T$", r"$\varepsilon$", r"$\xi_h$"]

for p, title, ax in zip(param_names, titles, axs.flatten()):

    ax.set_title(title)

    ## plot ensemble members
    for i, m in enumerate(params.member.values):

        label = "Ensemble members" if (i == 50) else None

        ax.plot(
            params.year,
            delta_params[p].sel(member=m).mean("cycle"),
            c="gray",
            alpha=0.5,
            lw=0.5,
        )

    ## plot ensemble mean
    ax.plot(
        params.year,
        delta_params[p].mean(["member", "cycle"]),
        c="k",
        lw=2,
        label="Ensemble mean",
    )

    ## set axis specs
    ax.set_ylim([-2, 2])
    ax.axhline(0, c="k", ls="--", lw=1)
    ax.set_xticks([])

## format
for ax in axs[1, :]:
    ax.set_xlabel("Time")
    ax.set_xticks([1870, 1975, 2080])
for ax in axs[:, 0]:
    ax.set_ylabel(r"Growth rate (yr$^{-1}$)")

plt.show()