# Look at climate change in CESM

## Imports

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import warnings
import tqdm
import pathlib
import cmocean
import os
import calendar

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Functions

In [None]:
def plot_hov(ax, data, amp, label=None):
    """Plot hovmoller of longitude vs. year"""

    # kwargs = dict(levels=src.utils.make_cb_range(3, 0.3), cmap="cmo.balance", extend="both")
    plot_data = ax.contourf(
        data.longitude,
        data.year,
        data.T,
        cmap="cmo.balance",
        extend="both",
        levels=src.utils.make_cb_range(amp, amp / 10),
    )
    cb = fig.colorbar(
        plot_data, orientation="horizontal", ticks=[-amp, 0, amp], label=label
    )

    ## label
    kwargs = dict(ls="--", c="w", lw=0.8)
    for ax in axs:
        ax.set_xlabel("Longitude")
        ax.set_xticks([190, 240])
        ax.set_yticks([])
        ax.axvline(190, **kwargs)
        ax.axvline(240, **kwargs)
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position("top")

    return


def plot_hov2(ax, data, amp, label=None):
    """Plot hovmoller of longitude vs. year"""

    # kwargs = dict(levels=src.utils.make_cb_range(3, 0.3), cmap="cmo.balance", extend="both")
    plot_data = ax.contourf(
        data.month,
        data.year,
        data.T,
        cmap="cmo.balance",
        extend="max",
        levels=src.utils.make_cb_range(amp, amp / 10),
    )
    # cb = fig.colorbar(
    #     # plot_data, orientation="horizontal", ticks=[-amp, 0, amp], label=label
    #     plot_data, orientation="horizontal", ticks=[], label=None
    # )

    ## label
    kwargs = dict(ls="--", c="w", lw=0.8)
    for ax in axs:
        # ax.set_xlabel("Month")
        # ax.set_xticks([1, 12])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position("top")

    return


def get_rolling_var(data, n=10):
    """
    Get variance, computing over time and ensemble member. To increase
    sample size for variance estimate, compute over time window of 2n+1
    years, centered at given year.
    """

    return src.utils.get_rolling_fn_bymonth(data, fn=np.var, n=n)

## Change in $T$, $h$

### Load data
And compute variance/skewness

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## get rolling variance, by month
Th_var = get_rolling_var(Th, n=15)
Th_var_bymonth = src.utils.unstack_month_and_year(Th_var)

## get rolling skew, by month
Th_skew = src.utils.get_rolling_fn_bymonth(Th, fn=scipy.stats.skew, n=15)
Th_skew_bymonth = src.utils.unstack_month_and_year(Th_skew)

## Get % increase in variance
baseline = Th_var_bymonth.isel(year=slice(None, 30)).mean("year")
Th_var_bymonth_pct = 100 * (Th_var_bymonth - baseline) / baseline

### Variance

Hövmöller

In [None]:
## setup plot
fig, axs = plt.subplots(
    1, len(list(Th)), figsize=(0.9 * len(list(Th)), 2), layout="constrained"
)

## plot T data
for ax, n in zip(axs, list(Th)):
    plot_hov2(ax, Th_var_bymonth_pct[n].T, amp=100)
    ax.set_title(n, fontsize=9)
    ax.axvline(8, ls="--", c="w", lw=1)


plt.show()

Compare August and December variance over time

In [None]:
fig, ax = plt.subplots(figsize=(2.5, 2))

## plot data
ax.plot(
    Th_var_bymonth.year, Th_var_bymonth["T_34"].sel(month=8), label="Aug", c="k", ls="-"
)
ax.plot(
    Th_var_bymonth.year,
    Th_var_bymonth["T_34"].sel(month=11),
    label="Dec",
    c="gray",
    ls="--",
)

## label and style
ax.set_title(r"$\sigma^2\left(\text{Niño 3.4}\right)$")
ax.set_ylim([0, None])
ax.legend(prop=dict(size=8))
ax.set_ylabel(r"$^{\circ}\text{C}^2$")

plt.show()

### Skewness

In [None]:
## setup plot
fig, axs = plt.subplots(1, 5, figsize=(3.6, 2), layout="constrained")

## plot T data
for i, n in enumerate(["T_4", "T_34", "T_3", "h", "h_w"]):
    plot_hov2(axs[i], Th_skew_bymonth[n].T, amp=1.5)
    axs[i].set_title(n, fontsize=9)
    axs[i].axvline(5, ls="--", c="w", lw=1, alpha=0.5)

plt.show()

### Seasonality

## Change in spatial patterns

In [None]:
## path to EOF data
eofs_fp = pathlib.Path(DATA_FP, "cesm")

## variables to load (and how to rename them)
names = ["tos", "zos", "tauu", "nhf"]
newnames = ["sst", "ssh", "taux", "nhf"]

## load the EOFs
load_var = lambda x: src.utils.load_eofs(pathlib.Path(eofs_fp, f"eofs_{x}.nc"))
eofs = {y: load_var(x) for (y, x) in zip(newnames, names)}

## for convenience, put spatial patterns / components in single dataset
components = xr.merge([eofs_.components().rename(y) for (y, eofs_) in eofs.items()])

# reset member dimension so they all match (NHF labeled differently...)
member_coord = dict(member=np.arange(100))
get_scores = lambda x, n: x.scores().assign_coords(member_coord).rename(n)
scores = xr.merge([get_scores(eofs_, n) for (n, eofs_) in eofs.items()])

## convert ssh from m to cm
scores["ssh"].values *= 100

## convert from stress on atm to stress on ocn
scores["taux"].values *= -1

## get forced/anomalous component
forced, anom = src.utils.separate_forced(scores)

### Reconstruct equatorial strip

In [None]:
## specify which month
month = 3

## get rolling avg
forced_rolling = src.utils.get_rolling_avg(forced, n=1)

## get NDJ season
month = forced_rolling.time.dt.month
forced_ndj = forced_rolling.isel(time=(month == month))

## reconstruct equatorial average
eq_avg = lambda x: x.sel(latitude=slice(-5, 5), longitude=slice(125, 279)).mean(
    "latitude"
)
kwargs = dict(scores=forced_ndj, components=components, fn=eq_avg)
forced_ndj_recon = src.utils.reconstruct_fn(**kwargs)

## replace time axis with year for easier plotting
forced_ndj_recon["time"] = forced_ndj_recon["time"].dt.year
forced_ndj_recon = forced_ndj_recon.rename({"time": "year"})

## subtract leading 30-year mean
clim = forced_ndj_recon.isel(year=slice(None, 30)).mean("year")
forced_anom = forced_ndj_recon - clim

### Hövmöller plot

In [None]:
## setup plot
fig, axs = plt.subplots(1, 4, figsize=(7.2, 4), layout="constrained")

for j, (varname, total_amp, diff_amp) in enumerate(zip(["sst", "ssh"], [3, 5], [1, 4])):

    ## get total plot and difference
    total = forced_anom[varname]
    diff = total - total.mean("longitude")

    ## plot data
    plot_hov(axs[j], total, amp=total_amp, label=f"{varname} (total)")
    plot_hov(axs[j + 2], diff, amp=diff_amp, label=f"{varname} (diff)")

axs[0].set_yticks(np.arange(1860, 2100, 55))
axs[0].set_ylabel("Year")
plt.show()

### Line plots

Function to do formatting. Not sure about units for SSH: data attribute says "m", but data itself suggests "cm"...

### Change in mean

In [None]:
def recon_clim(data, components):
    """reconstruct climatology for data"""

    ## get climatolgoy in PC space
    monthly_clim = data.groupby("time.month").mean()

    ## function to compute equatorial mean
    equatorial_mean = lambda x: x.sel(latitude=slice(-2, 2)).mean("latitude")

    ## reconstruct
    recon = src.utils.reconstruct_fn(
        components["sst"], monthly_clim["sst"], fn=equatorial_mean
    )

    ## fill zero values with NaN
    recon.values[recon.values == 0] = np.nan

    return recon


## compute clims
clim_early = recon_clim(forced.isel(time=slice(None, 360)), components)
clim_late = recon_clim(forced.isel(time=slice(-360, None)), components)

## get difference
clim_diff = clim_late - clim_early

In [None]:
## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.thermal", levels=np.arange(23, 32), extend="both")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], clim_early, **kwargs)

## plot late
kwargs["levels"] = kwargs["levels"] + 3
cp1 = src.utils.plot_cycle_hov(axs[1], clim_late, **kwargs)

## plot bias
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    clim_diff,
    cmap="cmo.amp",
    levels=np.arange(2.6, 5.2, 0.2),
    extend="both",
)

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

cb0 = fig.colorbar(cp0, ax=axs[0], ticks=[23, 27, 31], label=r"$^{\circ}C$")
cb1 = fig.colorbar(cp1, ax=axs[1], ticks=[26, 30, 34], label=r"$^{\circ}C$")
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[3, 4, 5], label=r"$^{\circ}C$")

plt.show()

### Change in variance

#### Equatorial region

In [None]:
def get_var(data, year_center, n=15, fn=None):
    """
    Reconstruct variance for given time period.
    Estimate variance based on window centered on `year_center'.
    Window encompasses all samples within 'n' years of year_center.
    """

    ## get indices of samples in window
    in_window = np.abs(data.time.dt.year - year_center) <= n

    ## get variance of samples in window
    kwargs = dict(components=components, scores=data.isel(time=in_window), fn=fn)
    data_var = src.utils.reconstruct_var(**kwargs)

    return data_var


def get_var_bymonth(data, year_center, n=15, fn=None):
    """Get variance by month"""

    ## function to apply to each month
    kwargs = dict(year_center=year_center, n=n, fn=fn)
    get_var_ = lambda x: get_var(x, **kwargs)
    return data.groupby("time.month").map(get_var_)


def get_var_for_periods(data, periods, n=15, fn=None, by_month=True):
    """get variance for specified periods"""

    ## get arguments for variance reconstruction
    kwargs = dict(data=data, n=n, fn=fn)

    ## get variance recon func
    get_var_fn = get_var_bymonth if by_month else get_var
    var_by_period = [get_var_fn(year_center=y, **kwargs) for y in tqdm.tqdm(periods)]

    ## get dimension to represent period_centers
    period_dim = pd.Index(periods, name="period")

    ## put in array
    return xr.concat(var_by_period, dim=period_dim)


## get variance by period
period_centers = np.array([1868, 1939, 2010, 2082])

## Get equatorial strip variance
kwargs = dict(data=anom, periods=period_centers, n=15, fn=eq_avg)
var_by_period_eq = get_var_for_periods(**kwargs)

## get full variance
kwargs = dict(kwargs, fn=None, periods=period_centers[[0, -1]], by_month=True)
var_by_period = get_var_for_periods(**kwargs)

Set plot style for plots

#### Spatial (Hovmoller)

In [None]:
## get plot data
baseline = var_by_period_eq["sst"].isel(period=0)
future = var_by_period_eq["sst"].isel(period=-1)
change = future - baseline

## shared args for plotting
plot_kwargs = dict(cmap="cmo.amp", extend="max")

## Set up plot
fig, axs = plt.subplots(3, 1, figsize=(4, 6), layout="constrained")

## make hövmöllers
cp0 = src.utils.plot_cycle_hov(
    axs[0], baseline, levels=np.arange(0, 3.3, 0.3), **plot_kwargs
)
cp1 = src.utils.plot_cycle_hov(
    axs[1], future, levels=np.arange(0, 3.3, 0.3), **plot_kwargs
)
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    change,
    levels=src.utils.make_cb_range(1.5, 0.15),
    cmap="cmo.balance",
    extend="both",
)

## label
axs[0].set_title("Baseline")
axs[1].set_title("Future")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

## add colorbars
kwargs = dict(ticks=[0, 1, 2, 3], label=r"$^{\circ}\text{C}^2$")
cb0 = fig.colorbar(cp0, ax=axs[0], **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **kwargs)
cb2 = fig.colorbar(cp2, ax=axs[2], **dict(kwargs, ticks=[-1.5, 0, 1.5]))

plt.show()

#### Spatial

Compute

In [None]:
## specify month
month = 5

## set up paneled subplot
fig = plt.figure(figsize=(6, 5), layout="constrained")
axs = src.utils.subplots_with_proj(
    fig, nrows=3, ncols=1, format_func=src.utils.plot_setup_pac
)

## plot data
kwargs = dict(
    var0=var_by_period["sst"].isel(period=0).sel(month=month),
    var1=var_by_period["sst"].isel(period=-1).sel(month=month),
    amp=2,
    amp_diff=1,
    show_colorbars=True,
    cbar_label=r"$^{\circ}$C$^2$",
)
fig, axs = src.utils.make_variance_subplots(fig, axs, **kwargs)
axs[0, 0].set_title("Early (1853 – 1883)")
axs[1, 0].set_title("Late (2067 – 2097)")
axs[2, 0].set_title("Difference")

plt.show()

#### ENSO composite

Function to compute composite

In [None]:
def get_hov_composite(data, peak_month, q, idx_fn, is_warm=True):
    """
    Get hovmoller composite based on specified:
    - data: used to compute index/make composite
    - peak_month: month to center composite on
    - q: quantile threshold for composite
    - idx_fn: function to compute index from spatial sst anoms
    """

    ## get data subset

    ## get index
    idx = src.utils.reconstruct_fn(
        components=components["sst"], scores=data["sst"], fn=idx_fn
    )

    ## handle warm/cold case
    if is_warm:
        kwargs = dict(q=q, check_cutoff=lambda x, cut: x > cut)
    else:
        kwargs = dict(q=1 - q, check_cutoff=lambda x, cut: x < cut)

    ## kwargs for composite
    kwargs = dict(kwargs, peak_month=peak_month, idx=idx, data=data)

    ## composite of projected data
    comp_proj = src.utils.make_composite(**kwargs)

    ## fn to get meridional mean
    lat = dict(latitude=slice(-5, 5))
    get_merimean = lambda x: x.sel(lat).mean("latitude")

    ## meridional mean for MPI
    comp_merimean = src.utils.reconstruct_fn(
        components=components, scores=comp_proj, fn=get_merimean
    ).transpose("lag", ...)

    return comp_merimean

Do the computation

In [None]:
## specify shared args
kwargs = dict(
    peak_month=12,
    q=0.95,
    idx_fn=src.utils.get_nino34,
    is_warm=True,
)

## get early/late composites
comp_early = get_hov_composite(anom.isel(time=slice(None, 360)), **kwargs)
comp_late = get_hov_composite(anom.isel(time=slice(-361, -1)), **kwargs)

Plot the result

In [None]:
## specify amplitudes for plots
scales = np.array([1.5, 1.5, 0.5])

## set up plot
fig, axs = plt.subplots(1, 3, figsize=(6, 3), layout="constrained")

for ax, merimean, scale in zip(
    axs, [comp_early, comp_late, comp_late - comp_early], scales
):
    cf, _ = src.utils.plot_hov(ax=ax, x=merimean, beta=scale)
    ax.set_xticks([])
    ax.set_yticks([])

    ## label x axis
    ax.set_xlabel("Longitude")
    ax.set_xticks([190, 240])
    ax.set_yticks([])

## label
axs[0].set_title("Early (1853-1883)")
axs[1].set_title("Late (2067-2097)")
axs[2].set_title("Difference")
src.utils.label_hov_yaxis(axs[0], peak_mon=kwargs["peak_month"])

plt.show()

In [None]:
## specify variable
varname = "T_3"

## specify month idx
m_idx = 1

if "h" in varname:
    edges = np.arange(-10e-2, 11e-2, 1e-2)
else:
    edges = np.arange(-4, 4.25, 0.25)


## get data
y0 = Th[varname].isel(time=slice(m_idx, 600, 12)).values.flatten()
y1 = Th[varname].isel(time=slice(-600 + m_idx, None, 12)).values.flatten()

## compute pdfs
pdf0, _ = src.utils.get_empirical_pdf(
    y0,
    edges=edges,
)
pdf1, _ = src.utils.get_empirical_pdf(
    y1,
    edges=edges,
)

## compute skewness
s0 = scipy.stats.skew(y0)
s1 = scipy.stats.skew(y1)

fig, ax = plt.subplots(figsize=(4, 3))
ax.stairs(pdf0, edges, label=f"skew = {s0:.2f}")
ax.stairs(pdf1, edges, label=f"skew = {s1:.2f}")
ax.axvline(0, c="k", lw=1)
ax.legend(prop=dict(size=8))

plt.show()

### Change in Bjerknes coupling

In [None]:
def merimean(x):
    return x.sel(longitude=slice(140, 285), latitude=slice(-5, 5)).mean("latitude")


def plot_cycle_hov(ax, data, amp, is_filled=True):
    """plot data on ax object"""

    ## specify shared kwargs
    shared_kwargs = dict(levels=src.utils.make_cb_range(amp, amp / 5), extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap="cmo.balance")

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## do the plotting
    cp = plot_fn(
        merimean(data).longitude,
        merimean(data).month,
        merimean(data),
        **kwargs,
        **shared_kwargs,
    )

    ## format ax object
    xticks = [160, 210]
    kwargs = dict(c="w", ls="--", lw=1)
    ax.set_xlabel("Lon")
    ax.set_xticks(xticks)
    for tick in xticks:
        ax.axvline(tick, **kwargs)

    return cp

Add EOF info to dataset

In [None]:
for v in list(components):
    if f"{v}_comp" not in list(anom):
        anom[f"{v}_comp"] = components[v]

## add nino3 to data
names = ["nino3", "nino34", "nino4"]
fns = [src.utils.get_nino3, src.utils.get_nino34, src.utils.get_nino4]
for n, fn in zip(names, fns):
    if n not in list(anom):
        anom[n] = src.utils.reconstruct_fn(
            scores=anom["sst"], components=anom["sst_comp"], fn=fn
        )

Partition into early/late

In [None]:
anom_early = anom.isel(time=slice(12, 372))
anom_late = anom.isel(time=slice(-372, -12))

#### $\tau_x$-SST

Compute slope

In [None]:
## shared kwargs
kwargs = dict(x_var="sst", y_var="taux", fn_x=src.utils.get_nino3)

## function to get slope
get_slope = lambda x: x.groupby("time.month").map(src.utils.regress_proj, **kwargs)

## then, reconstruct regression coefficient
m_early = get_slope(anom_early)
m_late = get_slope(anom_late)

Plot hovmoller

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_early, amp=0.015)
plot_cycle_hov(axs[0], data=m_late, amp=0.015, is_filled=False)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=m_late - m_early, amp=0.0075)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"$\tau_x$-SST coupling")
axs[1].set_title("Change")

plt.show()

##### Explained variance

In [None]:
def regress_helper(data):
    """helper function for regression"""

    ## stack data
    data_stack = data.stack(sample=["member", "time"])

    ## kwargs for regression
    kwargs = dict(X=data_stack["nino3"], Y=data_stack["taux"], dim="sample")

    return src.utils.regress_core(**kwargs)


def get_recon_skill(data):
    """get reconstruction error"""

    ## get coefficients
    m = regress_helper(data)

    ## get reconstruction
    recon = data["nino3"] * m

    ## reconstruct correlation
    cov = src.utils.reconstruct_cov_da(
        V_x=recon,
        V_y=data["taux"],
        U_x=data["taux_comp"],
        U_y=data["taux_comp"],
    )
    var = src.utils.reconstruct_var(
        scores=data["taux"],
        components=data["taux_comp"],
    )

    return cov / var


## function to get correlation
get_skill = lambda x: x.groupby("time.month").map(get_recon_skill)

## then, reconstruct regression coefficient
r_early = get_skill(anom_early)
r_late = get_skill(anom_late)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=r_early, amp=0.8)
plot_cycle_hov(axs[0], data=r_late, amp=0.8, is_filled=False)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=r_late - r_early, amp=0.4)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"Correlation")
axs[1].set_title("Change")

plt.show()

#### Heat flux

In [None]:
## shared kwargs
kwargs = dict(x_var="sst", y_var="nhf", fn_x=src.utils.get_nino34)

## function to get slope
get_slope = lambda x: x.groupby("time.month").map(src.utils.regress_proj, **kwargs)

## then, reconstruct regression coefficient
m_early = get_slope(anom_early)
m_late = get_slope(anom_late)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_early, amp=40)
plot_cycle_hov(axs[0], data=m_late, amp=40, is_filled=False)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=m_late - m_early, amp=20)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"NHF-SST coupling")
axs[1].set_title("Change")

plt.show()

#### SST-SSH

In [None]:
## shared kwargs
kwargs = dict(x_var="ssh", y_var="sst")

## function to get slope
get_slope = lambda x: x.groupby("time.month").map(src.utils.regress_proj, **kwargs)

## then, reconstruct regression coefficient
m_early = get_slope(anom_early)
m_late = get_slope(anom_late)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_early, amp=0.3)
plot_cycle_hov(axs[0], data=m_late, amp=0.3, is_filled=False)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=m_late - m_early, amp=0.15)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"SST-SSH coupling")
axs[1].set_title("Change")

plt.show()

#### $\tau_x$-SST asymmetry

In [None]:
def regress_relu(data, is_pos=True):
    """regress on function for positive values of nino3"""

    ## stack data
    data = data.stack(sample=["member", "time"])

    ## find which indices to keep
    if is_pos:
        idx = data["nino3"] > 0
    else:
        idx = data["nino3"] < 0

    ## select valid values of idx
    data_idx = data.where(idx, other=0)

    ## do regression
    return src.utils.regress_core(Y=data_idx["taux"], X=data_idx["nino3"], dim="sample")


def regress_relu_wrapper(data, eofs):
    """wrapper function"""

    ## get nino3
    nino3 = src.utils.reconstruct_fn(
        components=eofs["sst"].components(), scores=data["sst"], fn=src.utils.get_nino3
    )

    ## new array
    taux = xr.merge([data["taux"], nino3.rename("nino3")])

    ## compute regression coefs
    m_pos = eofs["taux"].inverse_transform(
        taux.groupby("time.month").map(regress_relu, is_pos=True)
    )

    m_neg = eofs["taux"].inverse_transform(
        taux.groupby("time.month").map(regress_relu, is_pos=False)
    )

    return m_pos - m_neg

Compute

In [None]:
asym_early = regress_relu_wrapper(anom_early, eofs)
asym_late = regress_relu_wrapper(anom_late, eofs)

Plot

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=asym_early, amp=0.0075)
# plot_cycle_hov(axs[0], data=asym_late, amp=0.0075, is_filled=False)
cp2 = plot_cycle_hov(axs[1], data=asym_late, amp=0.0075)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")

axs[0].set_title("Early asym.")
axs[1].set_title("Late asym.")

plt.show()

## Change in zonal gradient

### Compute zonal gradient

In [None]:
def get_zonal_grad(x):
    """
    Function to compute zonal gradient. Ref: Fig 7 in Maher et al, 2023.
    """

    ## outlines for boxes
    ebox_coords = dict(longitude=slice(210, 270), latitude=slice(-5, 5))
    wbox_coords = dict(longitude=slice(120, 180), latitude=slice(-5, 5))

    ## compute box averages
    ebox_avg = src.utils.spatial_avg(x.sel(ebox_coords))
    wbox_avg = src.utils.spatial_avg(x.sel(wbox_coords))

    return ebox_avg - wbox_avg

In [None]:
## compute zonal gradient
zonal_grad_forced = src.utils.reconstruct_fn(components, forced, fn=get_zonal_grad)

## sep. time into year and month
zonal_grad_forced_bymonth = src.utils.unstack_month_and_year(zonal_grad_forced)

## get change from initial climatology
clim = zonal_grad_forced_bymonth.isel(year=slice(None, 30)).mean("year")
zonal_grad_change = zonal_grad_forced_bymonth - clim

In [None]:
## setup plot
fig, ax = plt.subplots(figsize=(2, 4), layout="constrained")

## plot data
kwargs = dict(levels=src.utils.make_cb_range(2, 0.2), cmap="cmo.balance")
plot_data = ax.contourf(
    zonal_grad_change.month, zonal_grad_change.year, zonal_grad_change["sst"], **kwargs
)

## label
ax.set_ylabel("Year")
ax.set_xlabel("Month")
ax.set_xticks([1, 6, 12])
ax.set_title(r"$\Delta \left(\partial_x T\right)$")
ax.set_ylim([1950, None])

plt.show()

## Look at mean state-dependence

### Compare Bjerknes growth rate to zonal gradient

#### Load $T,h$ data

In [None]:
## MPI data
mpi_load_fp = pathlib.Path(DATA_FP, "mpi_Th", "Th.nc")
Th = xr.open_dataset(mpi_load_fp)

#### Fit RO to $T,h$ data

In [None]:
## get subset of data
Th_sub = Th.sel(time=slice("1979", "2024"))

## initialize model
model = XRO(ncycle=12, ac_order=3, is_forward=True)

## fit to individual ensemble members
kwargs = dict(model=model, T_var="T_3", h_var="h_w", verbose=True)
_, fits = src.utils.get_RO_ensemble(Th_sub, **kwargs)

## extract parameters
params = model.get_RO_parameters(fits)

#### Look at intra-ensemble spread
$\partial_x T$ vs. BJ index  
Compare to Maher et al (2023)

In [None]:
## get subset of data to look at
anom_ = anom.sel(time=slice("1979", "2024"))

## compute zonal gradient at every time step
zonal_grad = src.utils.reconstruct_fn(components, anom_, fn=get_zonal_grad)

## get monthly avg
zonal_grad_by_month = zonal_grad.groupby("time.month").mean()

fig, ax = plt.subplots(figsize=(2, 2))
ax.scatter(zonal_grad_by_month["sst"].mean("month"), params["R"].mean("cycle"), s=10)
ax.set_xlabel("Zonal SST gradient")
ax.set_ylabel("Bjerknes feedback")
plt.show()