# Bjerknes feedback changes over time

## imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def format_subsurf_axs(axs):
    """add labels/formatting to 3-panel axs"""

    ## loop thru axs
    for ax in axs:
        ax.set_ylim(ax.get_ylim()[::-1])
        ax.set_xlim([None, 281])
        ax.set_yticks([])
        ax.set_xlabel("Longitude")
    axs[0].set_yticks([300, 150, 0])
    axs[0].set_ylabel("Depth (m)")

    return


def format_hov_axs(axs):
    """put hovmoller axs in standardized format"""

    ## set fontsize
    font_kwargs = dict(size=8)
    axs[0].set_ylabel("Month", **font_kwargs)
    axs[0].set_title("Early", **font_kwargs)
    axs[1].set_title("Late", **font_kwargs)
    axs[2].set_title("Difference (x2)", **font_kwargs)

    axs[1].set_yticks([])
    axs[2].set_yticks([])
    axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])

    for ax in axs:
        # ax.set_xlim([190, None])
        ax.set_xticks([190, 240])
        ax.axvline(240, ls="--", c="w", lw=1)
        ax.axvline(190, ls="--", c="w", lw=1)

    return


def get_dT_sub(Tsub, mld, delta=25):
    """get temperature difference b/n mixed layer and entrainment zone"""

    ## subset for non-NaN coords
    valid_lon_idx = ~np.isnan(mld).all("month")
    mld = mld.isel(lon=valid_lon_idx)
    Tsub = Tsub.isel(lon=valid_lon_idx)

    ## find indices in ML and entrainment zone (ez)
    in_ml = Tsub.z_t <= mld_interp
    in_ez = (Tsub.z_t > mld_interp) & (Tsub.z_t < (delta + mld_interp))

    ## get Tbar and Tplus (following Frankignoul et al paper)
    Tbar = Tsub.where(in_ml).mean("z_t")
    Tplus = Tsub.where(in_ez).mean("z_t")

    ## get gradient
    dT = Tbar - Tplus

    return dT


def merimean(x, lat_bound=5):
    """get meridional mean"""

    ## get bounds for latitude averaging
    coords = dict(
        longitude=slice(140, 285),
        latitude=slice(-lat_bound, lat_bound),
    )

    return x.sel(coords).mean("latitude")


def plot_cycle_hov(ax, data, amp, is_filled=True, xticks=[190, 240], lat_bound=5):
    """plot data on ax object"""

    ## specify shared kwargs
    shared_kwargs = dict(levels=src.utils.make_cb_range(amp, amp / 5), extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap="cmo.balance")

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## average over latitudes (if necessary)
    if "latitude" in data.coords:
        plot_data = merimean(data, lat_bound=lat_bound)
    else:
        plot_data = data

    ## do the plotting
    cp = plot_fn(
        plot_data.longitude,
        plot_data.month,
        plot_data.transpose("month", "longitude"),
        **kwargs,
        **shared_kwargs,
    )

    ## format ax object
    kwargs = dict(c="w", ls="--", lw=1)
    ax.set_xlim([145, 280])
    ax.set_xlabel("Lon")
    ax.set_xticks(xticks)
    for tick in xticks:
        ax.axvline(tick, **kwargs)

    return cp


def prep(data):
    """remove sst dependence and compute tendencies"""

    ## remove from h indices
    for h_idx in ["h_w", "h"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var=h_idx, T_var="T_34"
        )

    return data


def load_budget_data(t_early, t_late, target_grid):
    """load ground truth heat budget data"""

    def load_var(varname):
        """load variable from prepped folder"""

        ## open data
        data = xr.open_mfdataset(
            sorted(list(pathlib.Path(DATA_FP, "cesm", f"{varname}_temp").glob("*.nc"))),
            concat_dim="member",
            combine="nested",
            parallel=True,
        )

        return data.assign_coords({"member": np.arange(100)})

    ## load data
    budget_data = xr.merge([load_var(v) for v in ["adv", "ddt_T"]])

    ## get difference
    budget_data["diff"] = budget_data["TEND_TEMP"] - budget_data["ADV_3D_TEMP"]

    ## convert z coord from cm to m, and unit from K/s to K/mo
    M_PER_CM = 1e-2
    SEC_PER_MO = 8.64e4 * 30

    ## convert from (i) cm to m and (ii) K/s to K/mo
    budget_data = budget_data.assign_coords({"z_t": budget_data.z_t * M_PER_CM})
    budget_data = budget_data * SEC_PER_MO

    ## fix longitude coordinate
    budget_data = budget_data.assign_coords({"nlon": budget_data.lon.values})
    budget_data = budget_data.drop_vars("lon").rename({"nlon": "longitude"})

    ## trim in time
    budget_data = xr.concat(
        [budget_data.sel(t_early), budget_data.sel(t_late)], dim="time"
    )

    ## interpolate
    budget_data = budget_data.interp_like(target_grid)

    # ## separate forced/anomalies
    forced_bud, anom_bud = src.utils.separate_forced(budget_data)

    return forced_bud, anom_bud

## Load data

### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")

## compute relative sst
for n in ["T_3", "T_34", "T_4"]:
    Th[f"{n}_rel"] = Th_total[n] - trop_sst["trop_sst_10"]

### Spatial data

#### Load

In [None]:
## load spatial data
CONS_DIR = pathlib.Path(DATA_FP, "cesm", "consolidated")
forced = xr.open_dataset(CONS_DIR / "forced.nc")
anom = xr.open_dataset(CONS_DIR / "anom.nc")

## add T,h information
for n in ["T_3", "T_34", "T_4", "h", "h_w"]:
    anom[n] = Th[n]

#### early/late split

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early)
anom_late = anom.sel(t_late)

## preprocess
anom_early = prep(anom_early)
anom_late = prep(anom_late)

### Compute BJ coefficients

In [None]:
def fit(y_var, data, x_vars=["T_34", "h_w_hat"]):
    """fit linear regression model to data"""

    ## infer T variable
    T_var = x_vars[0]

    ## get coeffs
    kwargs = dict(y_var=y_var, x_vars=x_vars)
    coefs = src.utils.multi_regress_bymonth(data, **kwargs)

    return coefs[T_var]

#### Compute regression coefficients

In [None]:
## empty datasets to hold results
m_early = xr.Dataset()
m_late = xr.Dataset()


## compute regression coefficients for given variables
for n in tqdm.tqdm(["sst", "nhf", "taux", "T", "w", "u"]):
    m_early[n] = fit(y_var=n, data=anom_early)
    m_late[n] = fit(y_var=n, data=anom_late)


## taux-Tsub
kwargs = dict(x_var="taux", y_var="T")
get_slope = lambda x, fn_x: x.groupby("time.month").map(
    src.utils.regress_proj, fn_x=fn_x, **kwargs
)
m_early["taux_T"] = get_slope(anom_early, fn_x=src.utils.get_nino4)
m_late["taux_T"] = get_slope(anom_late, fn_x=src.utils.get_nino4)

#### compute feedback terms

Compute regression coefficients

In [None]:
## get mean-state for each period
sub_vars = ["u", "w", "T", "u_comp", "T_comp", "w_comp"]
bar_early = src.utils.reconstruct_clim(forced[sub_vars].sel(t_early))
bar_late = src.utils.reconstruct_clim(forced[sub_vars].sel(t_late))

## get early/late feedbacks
feedbacks_early = src.utils.get_feedbacks(bar=bar_early, prime=m_early)
feedbacks_late = src.utils.get_feedbacks(bar=bar_late, prime=m_late)

## decompose changes
feedback_changes = src.utils.decompose_feedback_changes(
    bar_early=bar_early,
    bar_late=bar_late,
    prime_early=m_early,
    prime_late=m_late,
)

## merge with other data
m_early = xr.merge([m_early, feedbacks_early])
m_late = xr.merge([m_late, feedbacks_late, feedback_changes])

Compute things which *do* depend on MLD

In [None]:
## specify mixed layer kwargs
ml_early_kwargs = dict(H0=70, Hm=None)
ml_late_kwargs = dict(H0=70, Hm=None)

## Integrate over mixed layer
m_early = src.utils.get_ml_avg_ds(m_early, **ml_early_kwargs)
m_late = src.utils.get_ml_avg_ds(m_late, **ml_late_kwargs)

#### Plot feedback hovmollers 
E.g., thermocline ($\overline{w}~\frac{\partial T'}{\partial z}$) and Ekman feedback ($w'~\frac{\partial \overline{T}}{\partial z}$)

##### Plot mixed layer integral

In [None]:
## specify plot amplitude
amp = 1

for n in ["THF_ml", "EKM_ml", "ADV_ml"]:

    print(f"\n{n}")
    fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

    ## plot data
    cp0 = plot_cycle_hov(axs[0], data=m_early[n], amp=amp)
    cp1 = plot_cycle_hov(axs[1], data=m_late[n], amp=amp)
    cp2 = plot_cycle_hov(axs[2], data=(m_late - m_early)[n], amp=amp * 0.75)

    ## make it look nicer
    cb = fig.colorbar(
        cp0,
        ax=axs[2],
        ticks=[-amp, 0, amp],
        label=r"$K~\left(\text{month}\right)^{-1}$",
    )
    format_hov_axs(axs)
    for ax in axs:
        ax.axhline(7, ls="--", c="k", lw=1)

    plt.show()

##### Decompose changes in Ekman

In [None]:
## specify colorbar amp
amp = 0.5

## specify plot data and titles
plot_data = [
    (m_late - m_early)["EKM_ml"],
    m_late["delta_EKM_mean_ml"],
    m_late["delta_EKM_anom_ml"],
    m_late["delta_EKM_nl_ml"],
]
titles = [
    r"$\Delta\left(w~\frac{\partial \overline{T}}{\partial z}\right)$",
    r"$w_0'~\Delta\left(\frac{\partial \overline{T}}{\partial z}\right)$",
    r"$\Delta\left(w'\right)~\frac{\partial \overline{T}_0}{\partial z}$",
    r"$\Delta\left(w'\right)~\Delta\left(\frac{\partial \overline{T}}{\partial z}\right)$",
]

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

format_hov_axs(axs)

## plot data
for ax, p, t in zip(axs, plot_data, titles):
    cp = plot_cycle_hov(ax, data=p, amp=amp)
    ax.set_title(t)
    ax.axhline(7, ls="--", c="k", lw=1)

for ax in axs[3:]:
    ax.set_yticks([])

## make it look nicer
cb = fig.colorbar(
    cp,
    ax=axs[-1],
    ticks=[-amp, 0, amp],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)

plt.show()

##### Decompose changes in thermocline feedback

In [None]:
## specify colorbar amp
amp = 0.5

## specify plot data and titles
plot_data = [
    (m_late - m_early)["THF_ml"],
    m_late["delta_THF_mean_ml"],
    m_late["delta_THF_anom_ml"],
    m_late["delta_THF_nl_ml"],
]
titles = [
    r"$\Delta\left(\overline{w}~\frac{\partial T'}{\partial z}\right)$",
    r"$\Delta\left(\overline{w}\right)~\frac{\partial T_0}{\partial z}$",
    r"$\overline{w}_0~\Delta\left(\frac{\partial T'}{\partial z}\right)$",
    r"$\Delta\left(\overline{w}\right)~\Delta\left(\frac{\partial T'}{\partial z}\right)$",
]

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

format_hov_axs(axs)

## plot data
for ax, p, t in zip(axs, plot_data, titles):
    cp = plot_cycle_hov(ax, data=p, amp=amp)
    ax.set_title(t)
    ax.axhline(7, ls="--", c="k", lw=1)

for ax in axs[3:]:
    ax.set_yticks([])

## make it look nicer
cb = fig.colorbar(
    cp,
    ax=axs[-1],
    ticks=[-amp, 0, amp],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)

plt.show()

#### Ground truth budget

Load data

In [None]:
## load budget data
budg_forced, budg_anom = load_budget_data(
    t_early=t_early,
    t_late=t_late,
    target_grid=anom[["longitude", "z_t"]],
)

## split into early/late
budg_early = budg_anom.sel(t_early).compute()
budg_late = budg_anom.sel(t_late).compute()

## add T,h information
for n in ["T_34", "h_w_hat"]:
    budg_early[n] = anom_early[n]
    budg_late[n] = anom_late[n]

Compute regression coefficients

In [None]:
## get datasets to hold budget coefficients
m_early_bud = xr.Dataset()
m_late_bud = xr.Dataset()

## get ground truth tendencies
for n in ["TEND_TEMP", "ADV_3D_TEMP"]:

    ## regression coefficients
    kwargs = dict(x_vars=["T_34", "h_w_hat"], y_var=n)
    m_early_bud[n] = src.utils.regress_bymonth(budg_early, **kwargs)["T_34"]
    m_late_bud[n] = src.utils.regress_bymonth(budg_late, **kwargs)["T_34"]

    ## mixed layer avg
    m_early_bud[f"{n}_ml"] = src.utils.get_ml_avg(m_early_bud[n], **ml_early_kwargs)
    m_late_bud[f"{n}_ml"] = src.utils.get_ml_avg(m_late_bud[n], **ml_late_kwargs)


## get d/dt(SST)
m_early_bud["ddt_sst"] = m_early_bud["TEND_TEMP"].isel(z_t=0)
m_late_bud["ddt_sst"] = m_late_bud["TEND_TEMP"].isel(z_t=0)

Plot

In [None]:
amp = 1

for n in ["ADV_3D_TEMP_ml", "ddt_sst"]:

    print(f"\n{n}")
    fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

    ## plot data
    cp0 = plot_cycle_hov(axs[0], data=m_early_bud[n], amp=amp)
    cp1 = plot_cycle_hov(axs[1], data=m_late_bud[n], amp=amp)
    cp2 = plot_cycle_hov(axs[2], data=(m_late_bud - m_early_bud)[n], amp=amp * 0.75)

    ## make it look nicer
    cb = fig.colorbar(
        cp0,
        ax=axs[2],
        ticks=[-amp, 0, amp],
        label=r"$K~\left(\text{month}\right)^{-1}$",
    )
    format_hov_axs(axs)
    for ax in axs:
        ax.axhline(7, ls="--", c="k", lw=1)

    plt.show()

### Plot BJ couplings

#### Compute $T_{sub}$

In [None]:
## height of entrainment zone (from base of ML)
# ez_height = 20

## other subsurface vars
# update_lon = lambda x: x.interp({"longitude": m_early.longitude})

# ## T_ml - T_sub (regr on Niño 3.4 and taux)
# fit_Tsub = lambda **kwargs: update_lon(-get_dT_sub(**kwargs, delta=ez_height))
# m_early["dT_n34"] = fit_Tsub(Tsub=m_early["T"], mld=Hm_early)
# m_late["dT_n34"] = fit_Tsub(Tsub=m_late["T"], mld=Hm_late)
# m_early["dT_taux"] = fit_Tsub(Tsub=m_early["taux_T"], mld=Hm_early)
# m_late["dT_taux"] = fit_Tsub(Tsub=m_late["taux_T"], mld=Hm_late)

#### Niño 3.4 - SST

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_early["sst"], amp=2, lat_bound=1.5)
plot_cycle_hov(axs[0], data=m_late["sst"], amp=2, is_filled=False, lat_bound=1.5)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=m_late["sst"] - m_early["sst"], amp=1, lat_bound=1.5)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"Niño 3.4-SST coupling")
axs[1].set_title("Change")

for ax in axs:
    ax.axhline(6, c="k", ls="--", lw=1)

plt.show()

### SST - NHF

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_early["nhf"], amp=40, lat_bound=1.5)
plot_cycle_hov(axs[0], data=m_late["nhf"], amp=40, lat_bound=1.5, is_filled=False)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=m_late["nhf"] - m_early["nhf"], amp=20, lat_bound=1.5)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"$NHF$-SST coupling")
axs[1].set_title("Change")

for ax in axs:
    ax.axhline(7, c="k", ls="--", lw=1)

plt.show()

### SST-$\tau_x$

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_early["taux"], amp=0.015)
plot_cycle_hov(axs[0], data=m_late["taux"], amp=0.015, is_filled=False)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=m_late["taux"] - m_early["taux"], amp=0.0075)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"$\tau_x$-SST coupling")
axs[1].set_title("Change")

for ax in axs:
    ax.axhline(6, c="k", ls="--", lw=1)
    ax.axvline(160, ls="--", c="w")
    ax.axvline(210, ls="--", c="w")

plt.show()

In [None]:
sel = lambda x: x.sel(month=6, latitude=slice(-5, 5)).mean("latitude")

fig, ax = plt.subplots(figsize=(3, 2.5))
ax.plot(m_early["taux"].longitude, sel(m_late["taux"] - m_early["taux"]))
ax.axhline(0, ls="--", c="k", lw=0.8)
ax.set_xlim([140, 280])

### $T_{sub}-$Niño3.4

In [None]:
## helper func to update longitude coord
sel_data = lambda x: x["dT_n34"]

## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(
    cmap="cmo.balance", levels=src.utils.make_cb_range(1e0, 1e-1), extend="both"
)
cb_kwargs = dict(ticks=[-1, 0, 1], label=r"$K~/~K$")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], sel_data(m_early), **kwargs)
cb0 = fig.colorbar(cp0, ax=axs[0], **cb_kwargs)

# ## plot late
cp1 = src.utils.plot_cycle_hov(axs[1], sel_data(m_late), **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **cb_kwargs)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    sel_data(m_late - m_early),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.5e0, 0.5e-1),
    extend="both",
)

cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-0.5, 0, 0.5], label=r"$K~/~K$")

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])
axs[-1].axhline(7, c="k", lw=0.8, ls="--")

plt.show()

### $T_{sub}-\tau_x$

In [None]:
## specify which period/month to plot
# sel = lambda x: x.mean("month")
sel = lambda x: x.sel(month=7)

fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")

for ax, m in zip(axs[:2], [m_early, m_late]):

    ## temperature
    cp = ax.contourf(
        m.lon,
        m.z_t,
        sel(m["taux_T"]),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(200, 20),
        extend="both",
    )

## difference
axs[2].contourf(
    m.lon,
    m.z_t,
    sel(m_late - m_early)["taux_T"],
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(100, 10),
    extend="both",
)
## plot MLD
plot_mlds(axs=axs, sel=sel)

## label
cb = fig.colorbar(cp, ax=axs[2], ticks=[-10, 0, 10], label=r"$K~\text{Pa}^{-1}$")
format_subsurf_axs(axs)
for ax in axs:
    ax.set_ylim([100, 5])
    ax.axvline(190, ls="--", c="w", lw=0.8)
    ax.axvline(240, ls="--", c="w", lw=0.8)

plt.show()

In [None]:
## helper func to update longitude coord
sel_data = lambda x: x["dT_taux"]

## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(
    cmap="cmo.balance", levels=src.utils.make_cb_range(1e2, 1e1), extend="both"
)
cb_kwargs = dict(ticks=[-100, 0, 100], label=r"$K~/~Pa$")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], sel_data(m_early), **kwargs)
cb0 = fig.colorbar(cp0, ax=axs[0], **cb_kwargs)

# ## plot late
cp1 = src.utils.plot_cycle_hov(axs[1], sel_data(m_late), **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **cb_kwargs)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    sel_data(m_late - m_early),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(5e1, 5e0),
    extend="both",
)

cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-50, 0, 50], label=r"$K~/~Pa$")

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

plt.show()

##### Horizontal cross-section

In [None]:
for n in ["THF", "EKM", "ZAF", "ADV", "ADV_3D_TEMP"]:

    print(f"\n\n{n}")

    ## specify which period/month to plot
    sel = lambda x: x.sel(month=6)

    fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")

    for ax, m in zip(axs[:2], [m_early, m_late]):

        ## temperature
        cp = ax.contourf(
            m.lon,
            m.z_t,
            sel(m)[n],
            cmap="cmo.balance",
            levels=src.utils.make_cb_range(2, 0.2),
            extend="both",
        )

    ## difference
    axs[2].contourf(
        m_early.lon,
        m_early.z_t,
        sel(m_late - m_early)[n],
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(1, 0.1),
        extend="both",
    )

    ## plot MLD
    plot_mlds(axs, sel=sel)

    ## set ax limit and plot Niño 3.4 bounds
    cb = fig.colorbar(
        cp, ax=axs[2], ticks=[-2, 0, 2], label=r"$K~\left(\text{month}\right)^{-1}$"
    )
    format_subsurf_axs(axs)
    for ax in axs:
        ax.set_ylim([100, 5])
        ax.axvline(190, ls="--", c="w", lw=0.8)
        ax.axvline(240, ls="--", c="w", lw=0.8)

    plt.show()