# MPI–ORAS validation
Comparison between MPI climate model and ORAS5

## Imports

In [None]:
import warnings
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import pandas as pd
import os
import scipy.stats
import calendar
import copy

# import cartopy.util

# Import custom modules
from src.XRO import XRO, xcorr
import src.XRO_utils
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## $T, h$ metrics

### Load data

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## standardize (for convenience)
Th /= Th.std()

In [None]:
## ORAS5 reanalysis (use as benchmark)
oras_load_fp = pathlib.Path(DATA_FP, "XRO_indices_oras5.nc")
Th_oras = xr.open_dataset(oras_load_fp)

## rename variables for consistency and trim so time periods covered are the same
Th = Th.sel(time=slice("1979", "2024"))
Th_oras = Th_oras[["Nino34", "WWV"]].rename({"Nino34": "T_34", "WWV": "h"})

### Diagnostics

#### Seasonal synchronization

In [None]:
## specify which "T" variable to use for MPI
T_var = "T_34"

## func to compute std dev as a function of month
get_std = lambda x: x.groupby("time.month").std("time")

## compute std for each dataset
oras_std = get_std(Th_oras[T_var])
model_std = get_std(Th[T_var])
model_std_plot = src.utils.get_ensemble_stats(model_std)

## months (x-coordinate for plotting
months = np.arange(1, 13)

### Set up plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot for ORAS5
oras_plot = ax.plot(months, oras_std, label="ORAS5")

## plot model ensemble mean
model_plot = ax.plot(months, model_std_plot.sel(posn="center"), label="Model")

## plot Model bounds
kwargs = dict(c=model_plot[0].get_color(), ls="--", lw=1)
for bound in ["upper", "lower"]:
    ax.plot(months, model_std_plot.sel(posn=bound), **kwargs)

## adjust limits and label
ax.set_ylim([0.4, None])
ax.set_yticks([0.5, 1])
ax.set_xticks([1, 5, 12], labels=["Jan", "May", "Dec"])
ax.set_xlabel("Month")
ax.set_ylabel(f"$\\sigma(T)$")
ax.set_title("Seasonal synchronization (Niño 3.4)")
ax.legend()
plt.show()

#### $T,h$ cross-correlation

Plotting funcs

In [None]:
def format_axs(axs):
    ## add axis lines to plots
    axis_kwargs = dict(c="k", lw=0.5, alpha=0.5)
    for ax in axs.flatten():
        ax.axhline(0, **axis_kwargs)
        ax.axvline(0, **axis_kwargs)
        ax.set_ylim([-0.7, 1.1])

    #### label plots

    ## bottom row
    for ax in axs[1]:
        ax.set_xlabel("Lag (years)")
        ax.set_xticks([-12, 0, 12], labels=[-1, 0, 1])

    ## top row
    for ax in axs[0]:
        ax.set_xticks([])

    ## left col
    for ax in axs[:, 0]:
        ax.set_ylabel("Correlation")

    ## right col
    for ax in axs[:, 1]:
        ax.set_yticks([])

    return axs


def plot_oras(ax, data, color, label=None):
    """Plot curve for oras"""

    ax.plot(data.lag, data, c=color, label=label, ls="--")

    return


def plot_model(ax, data, color, label=None):
    """plot mean and bounds for model"""

    ## center
    ax.plot(
        data.lag,
        data.sel(posn="center"),
        c=color,
        label=label,
    )

    ## bounds
    ax.fill_between(
        data.lag, data.sel(posn="upper"), data.sel(posn="lower"), color=color, alpha=0.2
    )

    return

Compute stats

In [None]:
## specify which "T" variable to use for MPI
T_var = "T_34"
h_var = "h"

## compute cross-correlation
xcorr_oras = xcorr(Th_oras, Th_oras["T_34"], maxlags=18)
xcorr_model = xcorr(Th, Th[T_var], maxlags=18)

## compute model stats
xcorr_model_stats = src.utils.get_ensemble_stats(xcorr_model)

Make plot

In [None]:
## specify plot properties for legend
legend_prop = dict(size=7)

fig, axs = plt.subplots(2, 2, figsize=(6, 5), layout="constrained")

## plot <T,T>
axs[0, 0].set_title(r"$<T, T>$")
plot_oras(axs[0, 0], xcorr_oras["T_34"], color="r", label="ORAS5")
plot_model(axs[0, 0], xcorr_model_stats[T_var], color="r", label="Model")

## plot <T,h>
axs[0, 1].set_title(r"$<T, h>$")
plot_oras(axs[0, 1], xcorr_oras["h"], color="k", label="ORAS5")
plot_model(axs[0, 1], xcorr_model_stats[h_var], color="k", label="Model")

## plot for ORAS5
axs[1, 0].set_title("ORAS5")
plot_oras(axs[1, 0], xcorr_oras["T_34"], color="r", label="$<T, T>$")
plot_oras(axs[1, 0], xcorr_oras["h"], color="k", label="$<T, h>$")

## plot for model
axs[1, 1].set_title("Model")
plot_model(axs[1, 1], xcorr_model_stats[T_var], color="r", label="$<T, T>$")
plot_model(axs[1, 1], xcorr_model_stats[h_var], color="k", label="$<T, h>$")

## clean up axes
axs = format_axs(axs)

for ax in axs.flatten():
    ax.legend(prop=dict(size=7))

plt.show()

#### Power spectrum

Compute

In [None]:
## specify which variable to use
varname = "T_34"

## specify args for psd
psd_kwargs = dict(dim="time", dt=1 / 12, nw=5)

## compute PSD
compute_psd = lambda x: src.XRO_utils.pmtm(x[varname], **psd_kwargs)
psd_oras = compute_psd(Th_oras)
psd_model = compute_psd(Th)

Make plot

In [None]:
fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")

## plot data
src.utils.plot_psd(ax, psd_oras, label="ORAS5")
src.utils.plot_psd(ax, psd_model, label="Model")

## label
ax.set_ylabel(r"PSD ($^{\circ}$C$^2$/cpm)")
ax.legend(prop=dict(size=6))


plt.show()

#### PDF / Skewness

In [None]:
## select variables to compare
time_idx = dict(time=slice(None, None, 12))
model_var = "T_34"
oras_var = "T_34"

## edges for PDF
edges = np.arange(-3.4, 3.8, 0.4)

## extract relevant data
x_model = Th[model_var].isel(time_idx).values.flatten()
x_oras = Th_oras[oras_var].isel(time_idx).values.flatten()

## compute and plot pdf
pdf_model, _ = src.utils.get_empirical_pdf(x_model, edges=edges)
pdf_oras, _ = src.utils.get_empirical_pdf(x_oras, edges=edges)

## compute skewness
skew_model = scipy.stats.skew(x_model)
skew_oras = scipy.stats.skew(x_oras)

## get gaussian best fit to obs.
pdf_gauss, pdf_gauss_pts = src.utils.get_gaussian_best_fit(x_oras)

## set up plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot pdf
ax.stairs(
    pdf_model, edges, fill=True, alpha=0.3, label=f"Model (skew = {skew_model:.2f})"
)
ax.stairs(pdf_oras, edges, color="r", lw=1.5, label=f"ORAS (skew = {skew_oras:.2f})")

## plot gaussian best fit
ax.plot(pdf_gauss_pts, pdf_gauss, c="k", lw=1)

## label
ax.legend(prop=dict(size=8))
ax.set_xlabel(r"$^{\circ}C$")
ax.set_ylabel("Prob. density")

plt.show()

### RO-based diagnostics
Note: need to non-dimensionalize $T$ and $h$ data to compare $F_1$ and $F_2$ parameters to each other and to $R$ and $\epsilon$

In [None]:
## get non-dimensional scale for each variable (specific to each dataset)
scale_oras = Th_oras.std()
scale_model = Th.std()

## non-dimensionalize
Th_oras_nondim = Th_oras / scale_oras
Th_model_nondim = Th / scale_model

In [None]:
def plot_param_ensemble(ax, param, label=None):
    """plot ensemble of parameters on given ax object"""

    ## specify plot style for ensemble mean v. members
    mean_kwargs = dict(c="k", lw=3, label=label)
    mem_kwargs = dict(c="k", lw=0.5, alpha=0.2)

    ## plot ensemble mean
    ax.plot(param.cycle, param.mean("member"), **mean_kwargs)

    ## plot individual members
    for m in param.member:
        ax.plot(param.cycle, param.sel(member=m), **mem_kwargs)

    return


def plot_param(ax, param, **plot_kwargs):
    """plot parameter on given ax object"""

    ax.plot(param.cycle, param, **plot_kwargs)


def label_ac_ax(ax):
    """add labels to growth rate annual cycle plot (ax object)"""

    ax.set_ylabel(r"Growth rate (yr$^{-1}$)")
    ax.axhline(0, c="k", ls="-", lw=0.5)
    ax.set_xticks([1, 8, 12], labels=["Jan", "Aug", "Dec"])

    return

#### Growth rate / periodicity

Do the computation

In [None]:
## specify T and h variables to use for model
T_var_model = "T_34"
h_var_model = "h"

## specify order of annual cycle
ac_order = 3

## specify which parameters to mask annual cycle out for [(y_idx0, x_idx0), ...]
ac_mask_idx = [(1, 1)]
# ac_mask_idx = None

## initialize model
model = XRO(ncycle=12, ac_order=ac_order, is_forward=True)

## get fit for reanalysis
fit_oras = model.fit_matrix(Th_oras_nondim[["T_34", "h"]], ac_mask_idx=ac_mask_idx)

## get fit for model (ensemble-of-RO method)
_, fit_model = src.utils.get_RO_ensemble(
    Th_model_nondim,
    model=model,
    T_var=T_var_model,
    h_var=h_var_model,
    ac_mask_idx=ac_mask_idx,
)

## get fit for MPI (ensemble fit)
fit_model_all = model.fit_matrix(
    Th_model_nondim[[T_var_model, h_var_model]],
    ac_mask_idx=ac_mask_idx,  # , maskNT=['T2', 'TH'],
)

## compute timescales for each
bj_oras, period_oras = src.utils.get_timescales(model, fit_oras)
bj_model, period_model = src.utils.get_timescales_ensemble(model, fit_model)
bj_model_all, period_model_all = src.utils.get_timescales(model, fit_model_all)

## get PDF for period
pdf, edges = src.utils.get_empirical_pdf(period_model, edges=np.arange(2, 8, 1 / 3))

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(7, 3), layout="constrained")

### plot seasonal cycle of growth rate
# model: ensemble and ensemble mean
plot_param_ensemble(axs[0], bj_model, label="model (ensemble mean)")

# model: fit to all ensemble members
plot_kwargs = dict(c="r", ls="-", lw=2, label="model (ensemble fit)")
plot_param(axs[0], bj_model_all, **plot_kwargs)

# ORAS5
plot_kwargs = dict(c="k", ls="--", lw=2, label="ORAS5")
plot_param(axs[0], bj_oras, **plot_kwargs)
axs[0].legend(prop=dict(size=8))

## label axes
label_ac_ax(axs[0])

## plot period
axs[1].stairs(pdf, edges, color="k", label="model (PDF)")
axs[1].axvline(period_oras, c="k", ls="--", lw=2, label="ORAS5")
axs[1].axvline(period_model_all, c="r", ls="-", lw=2, label="model (ensemble fit)")
axs[1].axvline(period_model.mean(), c="k", lw=2, label="model (PDF mean)")
axs[1].set_xlim([2, None])
axs[1].set_yticks([])
axs[1].set_xlabel("Period (yrs)")
axs[1].legend(prop=dict(size=8))

plt.show()

#### Parameters

In [None]:
## Get parameters for ORAS and model
params_oras = model.get_RO_parameters(fit_oras)
params_model = model.get_RO_parameters(fit_model)
params_model_all = model.get_RO_parameters(fit_model_all)

## get labels of parameters for plotting
labels = [r"$R$", r"$F_1$", r"$F_2$", r"$\epsilon$"]
var_names = ["R", "F1", "F2", "epsilon"]

## set up plot
fig, axs = plt.subplots(2, 2, figsize=(7, 6), layout="constrained")

## separate plo for each set of parameters
for i, (ax, label, n) in enumerate(zip(axs.flatten(), labels, var_names)):

    ### plot seasonal cycle parameter for each
    plot_param_ensemble(ax, params_model[n], label="model (ensemble mean)")
    plot_param(ax, params_model_all[n], c="r", lw=2, label="model (ensemble fit)")
    plot_param(ax, params_oras[n], c="k", ls="--", lw=2, label="ORAS")
    ax.set_title(label)
    ax.set_ylim([-3, 5])
    label_ac_ax(ax)

## make plot less clustered
for ax in axs[0, :]:
    ax.set_xlabel(None)
    ax.set_xticks([])
for ax in axs[:, 1]:
    ax.set_ylabel(None)
    ax.set_yticks([])
for ax in axs[:, 0]:
    ax.set_yticks([-2, 0, 2, 4])

axs[0, 0].legend(prop=dict(size=8))
plt.show()

## Spatial data

#### Load ORAS5 data and compute anomalies

In [None]:
## load data
data_oras = src.utils.load_oras_spatial_extended(
    DATA_FP / "oras", varnames=["tos", "ssh", "taux", "d20", "nhf"]
)

## convert m to cm
data_oras["ssh"].values *= 100

## estimate forced signal by removing 2nd-order polynomial from each calendar month
detrend_fn = lambda x: src.utils.detrend_dim(x, dim="time", deg=2)
oras_anom = data_oras.groupby("time.month").map(detrend_fn)
oras_forced = data_oras - oras_anom

### Load MPI data and compute anomalies

In [None]:
## path to EOF data
eofs_fp = pathlib.Path(DATA_FP, "cesm")

## variables to load (and how to rename them)
names = ["tos", "zos", "tauu", "nhf"]
newnames = ["sst", "ssh", "taux", "nhf"]

## load the EOFs
load_var = lambda x: src.utils.load_eofs(pathlib.Path(eofs_fp, f"eofs_{x}.nc"))
eofs = {y: load_var(x) for (y, x) in zip(newnames, names)}

## for convenience, put spatial patterns / components in single dataset
components = xr.merge([eofs_.components().rename(y) for (y, eofs_) in eofs.items()])

# reset member dimension so they all match (NHF labeled differently...)
member_coord = dict(member=np.arange(100))
get_scores = lambda x, n: x.scores().assign_coords(member_coord).rename(n)
pc_data = xr.merge([get_scores(eofs_, n) for (n, eofs_) in eofs.items()])

## convert ssh from m to cm
pc_data["ssh"].values *= 100

## convert from stress on atm to stress on ocn
pc_data["taux"].values *= -1

## get forced signal and anomalies
forced_all = pc_data.mean("member")
anom_all = pc_data - forced_all

## trim to same period as ORAS
yr_range = [str(y) for y in data_oras.time.dt.year.values[[0, -1]]]

## estimate forced signal
forced = forced_all.sel(time=slice(*yr_range))
anom = anom_all.sel(time=slice(*yr_range))

### Mean

#### Compute

In [None]:
## compute time-mean for ORAS
mean_oras = oras_forced.mean("time")

## compute time-mean for MPI
mean_mpi = forced.mean("time")
mean_mpi_spatial = xr.merge(
    [
        eofs["sst"].inverse_transform(mean_mpi["sst"]).rename("sst"),
        eofs["ssh"].inverse_transform(mean_mpi["ssh"]).rename("ssh"),
    ]
)

## interpolate to match ORAS grid
mean_mpi_spatial = mean_mpi_spatial.interp_like(mean_oras)

## compute absolute bias
bias = mean_mpi_spatial - mean_oras

#### Plot

In [None]:
## get lon and lat for convenience
lon = mean_oras.longitude
lat = mean_oras.latitude

## specify plotting styles for bias and ground truth
bias_kwargs = dict(
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(3, 0.3),
    transform=ccrs.PlateCarree(),
    extend="both",
    alpha=1,
)

truth_kwargs = dict(
    levels=10,
    transform=ccrs.PlateCarree(),
    extend="both",
    linewidths=0.7,
    colors="k",
    alpha=0.5,
)

## set up the plot
fig = plt.figure(figsize=(6, 1.2), layout="constrained")
ax = src.utils.plot_setup(fig, lon_range=[100, 300], lat_range=[-20, 20])
ax.set_title("SST bias")

## plot bias
plot_data = ax.contourf(lon, lat, bias["sst"], **bias_kwargs)
cb = fig.colorbar(plot_data, ticks=[-3, 0, 3], label=r"$^{\circ}$C")

## plot "true" mean state
plot_data = ax.contour(lon, lat, mean_oras["sst"], **truth_kwargs)

plt.show()

### Spatial variance

#### Compute

In [None]:
## subset for month if desired
months = None

## subset data
oras_anom_ = src.utils.sel_month(oras_anom, months)
anom_ = src.utils.sel_month(anom, months)

## compute for ORAS
var_oras = oras_anom_.std("time")

## compute for MPI
var_mpi = src.utils.reconstruct_std(scores=anom_, components=components)

## interpolate to same grid
var_mpi = var_mpi.interp_like(var_oras)

## compute bias
bias = var_mpi - var_oras

#### Plot (large scale)

In [None]:
def plot_var(x0, x1, amp, amp_diff, label, amp_min=0):
    """plot variance comparison"""

    ## set up paneled subplot
    fig = plt.figure(figsize=(6, 3.6), layout="constrained")
    format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
    axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

    ## plot data
    kwargs = dict(
        var0=x0,
        var1=x1,
        amp=amp,
        amp_diff=amp_diff,
        amp_min=amp_min,
        show_colorbars=True,
        cbar_label=label,
    )
    fig, axs = src.utils.make_variance_subplots(fig, axs, **kwargs)

    plt.show()

SST

In [None]:
plot_var(
    x0=var_oras["sst"],
    x1=var_mpi["sst"],
    amp=1.8,
    amp_diff=1,
    amp_min=0.5,
    label=r"$^{\circ}$C",
)

SSH

In [None]:
plot_var(
    x0=var_oras["ssh"], x1=var_mpi["ssh"], amp=10, amp_diff=5, amp_min=3, label=r"$cm$"
)

#### Plot (zoom in on Niño regions)

In [None]:
def plot_setup_zoom(ax):
    """Plot Pacific region"""

    ## trim and add coastlines
    ax.coastlines(linewidth=0.3)
    ax.set_extent([160, 285, -10, 10], crs=ccrs.PlateCarree())

    return ax


## set up paneled subplot
fig = plt.figure(figsize=(5, 3.5), layout="constrained")
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=plot_setup_zoom)

## plot data
kwargs = dict(
    var0=var_oras["sst"],
    var1=var_mpi["sst"],
    amp=1.8,
    amp_diff=1,
    amp_min=0.5,
    show_colorbars=True,
)
fig, axs = src.utils.make_variance_subplots(fig, axs, **kwargs)

## plot boxes
for ax in axs.flatten():
    kwargs = dict(linewidth=0.9, alpha=0.8)
    src.utils.plot_nino34_box(ax, c="w", **kwargs)
    src.utils.plot_nino3_box(ax, c="k", ls="--", **kwargs)

axs[1, 0].axvline(275, c="yellow", lw=5)

plt.show()

### Seasonal cycle

#### SST mean

Compute

In [None]:
## function to compute equatorial mean
equatorial_mean = lambda x: x.sel(latitude=slice(-2, 2)).mean("latitude")

## Get clim for reanalysis
clim_oras = equatorial_mean(oras_forced["sst"].groupby("time.month").mean())

## get reconstruction for MPI
monthly_clim = forced.groupby("time.month").mean()
recon = src.utils.reconstruct_fn(
    components["sst"], monthly_clim["sst"], fn=equatorial_mean
)
recon.values[recon.values == 0] = np.nan

## compute bias
bias = recon.interp_like(clim_oras) - clim_oras

Plot

In [None]:
## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.thermal", levels=np.arange(22, 32), extend="both")

## plot renalysis
cp0 = src.utils.plot_cycle_hov(axs[0], clim_oras, **kwargs)

## plot model data
cp1 = src.utils.plot_cycle_hov(axs[1], recon, **kwargs)

## plot bias
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    bias,
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(2, 0.4),
    extend="both",
)

## label
axs[0].set_title("ORAS5")
axs[1].set_title("MPI")
axs[2].set_title("Bias")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

cb0 = fig.colorbar(cp0, ax=axs[0], ticks=[22, 26, 30], label=r"$^{\circ}C$")
cb1 = fig.colorbar(cp1, ax=axs[1], ticks=[22, 26, 30], label=r"$^{\circ}C$")
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-2, 0, 2], label=r"$^{\circ}C$")

plt.show()

#### SST variance

Compute

In [None]:
## compute for ORAS
var_oras = equatorial_mean(oras_anom["sst"].groupby("time.month").var("time"))

## compute for MPI
get_var = lambda x: equatorial_mean(
    src.utils.reconstruct_var(scores=x, components=components["sst"])
)
var_mpi = anom["sst"].groupby("time.month").map(get_var)

## get difference
diff = var_mpi.interp_like(var_oras) - var_oras

In [None]:
## shared args for plotting
plot_kwargs = dict(cmap="cmo.amp", extend="max")

## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## plot renalysis
cp0 = src.utils.plot_cycle_hov(
    axs[0], var_oras, levels=np.arange(0, 4.4, 0.4), **plot_kwargs
)

## plot model data
cp1 = src.utils.plot_cycle_hov(
    axs[1], var_mpi, levels=np.arange(0, 4.4, 0.4), **plot_kwargs
)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    diff,
    levels=src.utils.make_cb_range(2, 0.2),
    cmap="cmo.balance",
    extend="both",
)

## label
axs[0].set_title("ORAS5")
axs[1].set_title("MPI")
axs[2].set_title("Bias")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

## add colorbars
kwargs = dict(ticks=[0, 2, 4], label=r"$^{\circ}\text{C}^2$")
cb0 = fig.colorbar(cp0, ax=axs[0], **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **kwargs)
cb2 = fig.colorbar(cp2, ax=axs[2], **dict(kwargs, ticks=[-2, 0, 2]))

plt.show()

## ENSO composites

Get indices for composite

In [None]:
## comp on el niño or la niña?
is_warm = True

## which month to composite on?
peak_month = 12

## specify function for composite
idx_fn_oras = src.utils.get_nino34
idx_fn_mpi = src.utils.get_nino34

## threshold for composite
q = 0.85

## quantile for composite
if is_warm:
    kwargs = dict(q=q, check_cutoff=lambda x, cut: x > cut)
else:
    kwargs = dict(q=1 - q, check_cutoff=lambda x, cut: x < cut)

## get ORAS index
idx_oras = idx_fn_oras(oras_anom["sst"])

## get MPI index
idx_mpi = src.utils.reconstruct_fn(
    components=components["sst"], scores=anom["sst"], fn=idx_fn_mpi
)

Compute composites

In [None]:
## get composites
comps_oras = src.utils.get_composites(idx_oras, oras_anom, **kwargs)
comps_mpi = src.utils.get_composites(idx_mpi, anom, components=components, **kwargs)

Plot

In [None]:
## specify beta values
betas = np.array([1, 1, 0.5])

## set up plot
fig, axs = plt.subplots(2, 3, figsize=(6, 5), layout="constrained")

for row_i, comps in enumerate([comps_oras, comps_mpi]):
    for col_i, (comp, beta) in enumerate(zip(comps, betas)):

        ## get ax object
        ax = axs[row_i, col_i]

        cf, _ = src.utils.plot_hov(ax=ax, x=comp, beta=beta)
        ax.set_xticks([])
        ax.set_yticks([])

## label
for i, l in enumerate(["Total", "'In-phase'", "'Out-of-phase'"]):
    axs[0, i].set_title(l)
for j, l in zip([0, 1], ["ORAS", "MPI"]):
    axs[j, -1].set_ylabel(l)
    axs[j, -1].yaxis.set_label_position("right")

for ax in axs[:, 0]:
    src.utils.label_hov_yaxis(ax, peak_mon=peak_month)

## label x axis
for ax in axs[1, :]:
    ax.set_xlabel("Longitude")
    ax.set_xticks([190, 240])

plt.show()

## Bjerknes couplings

In [None]:
def make_scatter(ax, data, x_var, y_var, fn_x, fn_y, scale=1):
    """scatter plot data on axis"""

    ## evaluate functions
    if "mode" in data.dims:
        fn_x_eval = src.utils.reconstruct_fn(
            scores=data[x_var], components=data[f"{x_var}_comp"], fn=fn_x
        )
        fn_y_eval = src.utils.reconstruct_fn(
            scores=data[y_var], components=data[f"{y_var}_comp"], fn=fn_y
        )

        ## stack member/time dim
        stack = lambda x: x.stack(sample=["member", "time"])
        fn_x_eval = stack(fn_x_eval)
        fn_y_eval = stack(fn_y_eval)
        dim = "sample"

    else:
        fn_x_eval = fn_x(data[x_var])
        fn_y_eval = fn_y(data[y_var])
        dim = "time"

    ## compute slope for best fit line
    slope = src.utils.regress_core(X=fn_x_eval, Y=scale * fn_y_eval, dim=dim)

    ## convert to numpy
    slope = slope.values.item()

    ## plot data
    ax.scatter(fn_x_eval, scale * fn_y_eval, s=0.5)

    ## plot best fit
    xtest = np.linspace(fn_x_eval.values.min(), fn_x_eval.values.max())
    ax.plot(xtest, slope * xtest, c="k", lw=1, label=label)

    ## plot some guidelines
    ax.axhline(0, ls="--", lw=0.8, c="k")
    ax.axvline(0, ls="--", lw=0.8, c="k")

    return slope


def pcolor_plot(ax, x, amp, sel=lambda x: x.mean("month")):
    """plot data on ax object"""
    cp = ax.pcolormesh(
        x.longitude,
        x.latitude,
        sel(x),
        vmax=amp,
        vmin=-amp,
        cmap="cmo.balance",
        transform=ccrs.PlateCarree(),
    )

    return cp0


def contour_plot(ax, x, amp, sel=lambda x: x.mean("month")):
    """plot data on ax object"""
    cp = ax.contourf(
        x.longitude,
        x.latitude,
        sel(x),
        levels=src.utils.make_cb_range(amp, amp / 10),
        cmap="cmo.balance",
        transform=ccrs.PlateCarree(),
        extend="both",
    )

    return cp


def merimean(x):
    return x.sel(longitude=slice(140, 285), latitude=slice(-5, 5)).mean("latitude")


def plot_cycle_hov(ax, data, amp, is_filled=True):
    """plot data on ax object"""

    ## specify shared kwargs
    shared_kwargs = dict(levels=src.utils.make_cb_range(amp, amp / 5), extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap="cmo.balance")

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## do the plotting
    cp = plot_fn(
        merimean(data).longitude,
        merimean(data).month,
        merimean(data),
        **kwargs,
        **shared_kwargs,
    )

    ## format ax object
    xticks = [160, 210]
    kwargs = dict(c="w", ls="--", lw=1)
    ax.set_xlabel("Lon")
    ax.set_xticks(xticks)
    for tick in xticks:
        ax.axvline(tick, **kwargs)

    return cp


def get_bias(cesm_data, oras_data):
    """get bias, accounting for grid differences
    1. reverse direction of latitude
    2. get longitudes to match
    """

    ## get target grid
    grid = dict(longitude=oras_data.longitude, latitude=oras_data.latitude)
    cesm_data_ = cesm_data.interp(grid)

    return cesm_data_ - oras_data

#### Add components to data (useful for reconstructing)

In [None]:
for v in list(components):
    if f"{v}_comp" not in list(anom):
        anom[f"{v}_comp"] = components[v]

### SST-SSH

In [None]:
## shared kwargs
kwargs = dict(x_var="ssh", y_var="sst")

## compute local regression by month
m_oras = oras_anom.groupby("time.month").map(src.utils.regress, **kwargs)

## then, reconstruct regression coefficient
m_cesm = anom.groupby("time.month").map(src.utils.regress_proj, **kwargs)

In [None]:
## set up plot
fig = plt.figure(figsize=(7, 2.8), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=2, ncols=1, format_func=format_func)

## plot data
contour_kwargs = dict(amp=0.3, sel=lambda x: x.mean("month"))
cp0 = contour_plot(axs[0, 0], m_oras, **contour_kwargs)
cp1 = contour_plot(axs[1, 0], m_cesm, **contour_kwargs)

## colorbar
cb_kwargs = dict(ticks=[-0.3, 0, 0.3], label=r"$K ~cm^{-1}$")
cb0 = fig.colorbar(cp0, ax=axs, **cb_kwargs)

plt.show()

In [None]:
## specify kwargs
scatter_kwargs = dict(kwargs, fn_x=src.utils.get_nino3, fn_y=src.utils.get_nino3)

## helper func for label
get_label = lambda slope: f"slope$=${slope:.2f} " + r"$K ~cm^{-1}$"

## plot
fig, axs = plt.subplots(2, 1, figsize=(2, 4), layout="constrained")

## ORAS
slope = make_scatter(axs[0], data=oras_anom, **scatter_kwargs)
axs[0].set_title(get_label(slope))

## MPI
slope = make_scatter(axs[1], data=anom, **scatter_kwargs)
axs[1].set_title(get_label(slope))

## set limits
for ax in axs:
    ax.set_xlim([-25, 28])
    ax.set_ylim([-6, 6])
    ax.set_ylabel(r"Niño 3 SST ($K$)")

## format
axs[0].set_xticks([])
axs[1].set_xticks([-25, 0, 25])
axs[1].set_xlabel(r"Niño 3 SSH ($cm$)")


plt.show()

##### Hövmöller

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_oras, amp=0.2)
plot_cycle_hov(axs[0], data=m_cesm, amp=0.2, is_filled=False)


cp2 = plot_cycle_hov(axs[1], data=get_bias(m_cesm, m_oras), amp=0.15)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")

axs[0].set_title("SSH-SST coupling")
axs[1].set_title("Bias")

plt.show()

### $\tau_x$-SST

In [None]:
## shared kwargs
kwargs = dict(x_var="sst", y_var="taux", fn_x=src.utils.get_nino3)

## compute local regression by month
m_oras = oras_anom.groupby("time.month").map(src.utils.regress, **kwargs)

## then, reconstruct regression coefficient
m_cesm = anom.groupby("time.month").map(src.utils.regress_proj, **kwargs)

In [None]:
## set up plot
fig = plt.figure(figsize=(7, 2.8), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=2, ncols=1, format_func=format_func)

## plot data
contour_kwargs = dict(amp=2e-2, sel=lambda x: x.mean("month"))
cp0 = contour_plot(axs[0, 0], m_oras, **contour_kwargs)
cp1 = contour_plot(axs[1, 0], m_cesm, **contour_kwargs)

## colorbar
cb_kwargs = dict(ticks=[-0.02, 0, 0.02], label=r"$N~m^{-2}~K^{-1}$")
cb0 = fig.colorbar(cp0, ax=axs, **cb_kwargs)

plt.show()

In [None]:
## scaling factor (to make units nicer)
scale = 1000

## specify kwargs
scatter_kwargs = dict(kwargs, fn_y=src.utils.get_nino4, scale=scale)

## helper func for label
get_label = lambda slope: f"slope$=${slope:.1f} " + r"$mPa ~K^{-1}$"

## plot
fig, axs = plt.subplots(2, 1, figsize=(2, 4), layout="constrained")

## ORAS
slope = make_scatter(axs[0], data=oras_anom, **scatter_kwargs)
axs[0].set_title(get_label(slope))

## MPI
slope = make_scatter(axs[1], data=anom, **scatter_kwargs)
axs[1].set_title(get_label(slope))

## set limits
for ax in axs:
    ax.set_xlim([-5, 6])
    ax.set_ylim([-70, 70])
    ax.axhline(0, ls="--", lw=0.8, c="k")
    ax.axvline(0, ls="--", lw=0.8, c="k")
    ax.set_ylabel(r"Niño 4 $\tau_x$ ($mPa$)")

# ## format
axs[0].set_xticks([])
axs[1].set_xticks([-3, 0, 3])
axs[1].set_xlabel(r"Niño 3 SST ($K$)")


plt.show()

Hövmöller

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_oras, amp=0.02)
plot_cycle_hov(axs[0], data=m_cesm, amp=0.02, is_filled=False)


cp2 = plot_cycle_hov(axs[1], data=get_bias(m_cesm, m_oras), amp=0.01)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")

axs[0].set_title(r"$\tau_x$-SST coupling")
axs[1].set_title("Bias")

plt.show()

### Look at asymmetry of wind stress

In [None]:
def regress_relu(data, is_pos=True):
    """regress on function for positive values of nino3"""

    ## stack data
    data = data.stack(sample=["member", "time"])

    ## find which indices to keep
    if is_pos:
        idx = data["nino3"] > 0
    else:
        idx = data["nino3"] < 0

    ## select valid values of idx
    data_idx = data.where(idx, other=0)

    ## do regression
    return src.utils.regress_core(Y=data_idx["taux"], X=data_idx["nino3"], dim="sample")

In [None]:
## get nino3
nino3 = src.utils.reconstruct_fn(
    components=components["sst"], scores=anom["sst"], fn=src.utils.get_nino3
)

## new array
taux = xr.merge([anom["taux"], nino3.rename("nino3")])

## compute regression coefs
m_pos = eofs["taux"].inverse_transform(
    taux.groupby("time.month").map(regress_relu, is_pos=True)
)

m_neg = eofs["taux"].inverse_transform(
    taux.groupby("time.month").map(regress_relu, is_pos=False)
)

In [None]:
## set up plot
fig = plt.figure(figsize=(7, 4.2), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

## plot data
contour_kwargs = dict(amp=2e-2, sel=lambda x: x.sel(month=5))
cp0 = contour_plot(axs[0, 0], m_pos, **contour_kwargs)
cp1 = contour_plot(axs[1, 0], m_neg, **contour_kwargs)
cp2 = contour_plot(axs[2, 0], m_pos - m_neg, **dict(contour_kwargs, amp=1e-2))

## colorbar
cb_kwargs = dict(ticks=[-0.2, 0, 0.2], label=r"$K ~cm^{-1}$")
cb0 = fig.colorbar(cp0, ax=axs[:-1], **cb_kwargs)
cb2 = fig.colorbar(cp2, ax=axs[-1], **(dict(cb_kwargs, ticks=[-0.1, 0, 0.1])))

for ax in axs.flatten():
    src.utils.plot_nino4_box(ax, c="k", lw=1, ls="--")

plt.show()

Look at meridional means

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_pos, amp=0.015)
plot_cycle_hov(axs[0], data=m_neg, amp=0.015, is_filled=False)
cp2 = plot_cycle_hov(axs[1], data=m_pos - m_neg, amp=0.0075)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")

axs[0].set_title("Wind response")
axs[1].set_title("Asymmetry")

plt.show()

#### Net heat flux

In [None]:
## shared kwargs
kwargs = dict(x_var="sst", y_var="nhf", fn_x=src.utils.get_nino3)

## compute local regression by month
m_oras = oras_anom.groupby("time.month").map(src.utils.regress, **kwargs)

## then, reconstruct regression coefficient
m_cesm = anom.groupby("time.month").map(src.utils.regress_proj, **kwargs)

In [None]:
## set up plot
fig = plt.figure(figsize=(7, 2.8), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=2, ncols=1, format_func=format_func)

## plot data
contour_kwargs = dict(amp=30, sel=lambda x: x.mean("month"))
cp0 = contour_plot(axs[0, 0], m_oras, **contour_kwargs)
cp1 = contour_plot(axs[1, 0], m_cesm, **contour_kwargs)

## colorbar
cb_kwargs = dict(ticks=[-30, 0, 30], label=r"$W~m^{-2}~K^{-1}$")
cb0 = fig.colorbar(cp0, ax=axs, **cb_kwargs)

plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_oras, amp=40)
plot_cycle_hov(axs[0], data=m_cesm, amp=40, is_filled=False)


cp2 = plot_cycle_hov(axs[1], data=get_bias(m_cesm, m_oras), amp=40)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")

axs[0].set_title(r"Net heat flux")
axs[1].set_title("Bias")

plt.show()

In [None]:
## scaling factor (to make units nicer)
scale = 1

## specify kwargs
scatter_kwargs = dict(kwargs, fn_y=src.utils.get_nino3, scale=scale)

## helper func for label
get_label = lambda slope: f"slope$=${slope:.1f} " + r"$W~m^{-2} ~K^{-1}$"

## plot
fig, axs = plt.subplots(2, 1, figsize=(2, 4), layout="constrained")

## ORAS
slope = make_scatter(axs[0], data=oras_anom, **scatter_kwargs)
axs[0].set_title(get_label(slope))

## MPI
slope = make_scatter(axs[1], data=anom, **scatter_kwargs)
axs[1].set_title(get_label(slope))

## set limits
for ax in axs:
    # ax.set_xlim([-5, 6])
    ax.set_ylim([-120, 60])
    ax.axhline(0, ls="--", lw=0.8, c="k")
    ax.axvline(0, ls="--", lw=0.8, c="k")
    ax.set_ylabel(r"Niño 3 NHF ($W~m^{-2}$)")

# ## format
axs[0].set_xticks([])
axs[1].set_xticks([-3, 0, 3])
axs[1].set_xlabel(r"Niño 3 SST ($K$)")


plt.show()

## LIM

In [None]:
## how many EOF modes to use
nmodes = 20

## which time period to look at
t_idx = dict(time=slice(-480, None))

## get X and Y data
prep = lambda x: x.isel(mode=slice(None, nmodes)).stack(sample=["member", "time"])
X = prep(anom.isel(t_idx).isel(time=slice(None, -1)))
Y = prep(anom.isel(t_idx).isel(time=slice(1, None)))

## put in 2d array format
prep2 = lambda x: x.to_dataarray().stack(m=["variable", "mode"]).transpose("m", ...)

X = prep2(X)
Y = prep2(Y)

## get month labels
month_idx = X.time.dt.month.values - 1

X = X.values
Y = Y.values

In [None]:
import src.lim

## fit LIM
lim = src.lim.LIM_CS(X=X, Y=Y, month_labels=month_idx)

## get spectrum
sigma = 12 * lim.sigma
omega = 12 / (2 * np.pi) * lim.omega

In [None]:
## plot spectrum
fig, ax = plt.subplots(figsize=(3, 2.5))

## plot data
ax.scatter(sigma, omega)

## plot RO as reference
# ax.scatter(sigma_RO, omega_RO)

## label / guidelines
kwargs = dict(c="k", zorder=0.5, lw=1, ls="--")
for yt in [-0.25, 0, 0.25]:
    ax.axhline(yt, **kwargs)
ax.set_yticks([-1 / 4, 0, 1 / 4], labels=[r"1/4", "0", "1/4"])
ax.set_xticks([-1, -0.5, 0])
ax.set_xlim([-1, 0])
ax.set_xlabel(r"$\sigma$ (year$^{-1}$)")
ax.set_ylabel(r"$\omega$ (year$^{-1}$)")

plt.show()

In [None]:
## specify eigenfunction index for ENSO
enso_idx = 0
enso_idxs = np.array([enso_idx, enso_idx + 1])

## get reduced eigendecomp
Uk = copy.deepcopy(lim.U[..., enso_idxs])
gammak = copy.deepcopy(lim.gamma[enso_idxs])
Vk = copy.deepcopy(lim.V[..., enso_idxs])

## evaluate ENSO eigenfunction
varphi = np.einsum("nmk,mn->nk", lim.V[month_idx], X)

## plot relationship between leadin eigenfuncs
mag = np.abs(varphi[:, 0])
is_large = mag > np.percentile(mag, 95)

fig, axs = plt.subplots(1, 2, figsize=(4, 2))
axs[0].scatter(
    src.utils.get_angle(varphi[is_large, 0]),
    src.utils.get_angle(varphi[is_large, 3]),
    s=3,
)
axs[1].scatter(
    src.utils.get_angle(varphi[is_large, 0]),
    src.utils.get_angle(varphi[is_large, 6]),
    s=3,
)
plt.show()

Reconstruction functions

In [None]:
def reconstruct(scores, other_coord=None):
    """reconstruct projected data"""

    ## get reconstructions
    kwargs = dict(other_coord=other_coord)
    T_recon = reconstruct_helper(scores[:nmodes], model=eofs_sst, **kwargs)
    h_recon = reconstruct_helper(scores[nmodes:], model=eofs_ssh, **kwargs)

    return xr.merge([T_recon.rename("T"), h_recon.rename("h")]).real


def reconstruct_helper(scores, model, other_coord=None):
    """reconstruct scores for given model"""

    ## put raw scores into xarray
    n = scores.shape[1]
    scores_xr = xr.zeros_like(
        model.scores().isel(member=0, time=slice(None, n))
    ).transpose("mode", ...)
    scores_xr.values[:nmodes] = scores

    ## do inverse transform
    data = model.inverse_transform(scores_xr)

    ## rename coord
    if other_coord is None:
        other_coord = pd.Index(np.arange(n), name="n")
    data = data.rename({"time": other_coord.name})
    data[other_coord.name] = other_coord.values

    return data

Find peak angle for El Niño

In [None]:
## specify month to find peak index
peak_month_idx = 11

## get test pts
theta_test = np.arange(0, 2 * np.pi, np.pi / 32)
varphi_test = varphi.std() * np.exp(1j * theta_test)
VtX = np.stack([varphi_test, np.conj(varphi_test)], axis=0)

## Get recon (EOF space)
recon = Uk[peak_month_idx] @ VtX

## Get recon (real space)
recon_xr = reconstruct(recon.real, other_coord=pd.Index(theta_test, name="theta"))
nino_recon = src.utils.get_nino3(recon_xr["T"])

## get theta for maximimum Niño in peak month
theta_max = recon_xr.theta.isel(theta=nino_recon.argmax("theta")).values.item()

Get "real" reconstruction

In [None]:
## get difference from max
lags_months = np.arange(-12, 13)
lags_years = lags_months / 12
delta_theta = 2 * np.pi * omega[enso_idx] * lags_years

## get data
month_idx_test = np.mod(peak_month_idx + lags_months, 12)
theta_test = np.mod(theta_max + delta_theta, 2 * np.pi)
varphi_test = varphi.std() * np.exp(1j * theta_test)

## get complex conjugate
varphi_test = np.stack([varphi_test, np.conj(varphi_test)], axis=0)

## Get recon (EOF space)
recon = np.einsum(
    "nmk,kn->mn",
    Uk[month_idx_test],
    varphi_test,
).real

## Get recon (real space)
recon_xr = reconstruct(recon, other_coord=pd.Index(lags_months, name="lag"))

In [None]:
## get meridional means
lat = dict(latitude=slice(-5, 5))
recon_merimean = recon_xr.sel(lat).mean("latitude")
# comp_merimean = comp.sel(lat).mean("latitude")

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(6, 3), layout="constrained")

for ax, merimean in zip(axs[:1], [recon_merimean]):

    ## SST
    ax.contourf(
        merimean.longitude,
        merimean.lag,
        merimean["T"],
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(3, 0.3),
        extend="both",
    )

    ## thermocline
    ax.contour(
        merimean.longitude,
        merimean.lag,
        merimean["h"],
        colors="k",
        levels=src.utils.make_cb_range(9, 1.5),
        extend="both",
        linewidths=1,
    )

    ## label x axis
    ax.set_xticks([190, 240])
    ax.set_xlim([120, 280])

## label
axs[0].set_yticks([-12, 0, 12], labels=["Jan(-1)", "Jan(0)", "Jan(+1)"])
axs[1].set_yticks([])
# ax.set_xticks([190, 240])

plt.show()