# Bjerknes feedback changes over time

## imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_dT_sub(Tsub, Hm, delta=25):
    """
    Get temperature difference b/n entrainment zone and mixed layer.
    (positive if entrainment zone is warmer than ML)
    """

    ## find indices in ML and entrainment zone (ez)
    in_ml = Tsub.z_t <= Hm
    in_ez = (Tsub.z_t > Hm) & (Tsub.z_t < (delta + Hm))

    ## get Tbar and Tplus (following Frankignoul et al paper)
    Tbar = Tsub.where(in_ml).mean("z_t")
    Tplus = Tsub.where(in_ez).mean("z_t")

    ## get gradient
    dT = Tplus - Tbar

    return dT


def prep(data):
    """remove sst dependence and compute tendencies"""

    ## remove from h indices
    for h_idx in ["h_w", "h", "h_w_z20"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var=h_idx, T_var="T_34"
        )

    return data


def load_budget_data(t_early, t_late, target_grid):
    """load ground truth heat budget data"""

    def load_var(varname):
        """load variable from prepped folder"""

        ## open data
        data = xr.open_mfdataset(
            sorted(list(pathlib.Path(DATA_FP, "cesm", f"{varname}_temp").glob("*.nc"))),
            concat_dim="member",
            combine="nested",
            parallel=True,
        )

        return data.assign_coords({"member": np.arange(100)})

    ## load data
    budget_data = xr.merge([load_var(v) for v in ["adv", "ddt_T"]])

    ## get difference
    budget_data["diff"] = budget_data["TEND_TEMP"] - budget_data["ADV_3D_TEMP"]

    ## convert z coord from cm to m, and unit from K/s to K/mo
    M_PER_CM = 1e-2
    SEC_PER_MO = 8.64e4 * 30

    ## convert from (i) cm to m and (ii) K/s to K/mo
    budget_data = budget_data.assign_coords({"z_t": budget_data.z_t * M_PER_CM})
    budget_data = budget_data * SEC_PER_MO

    ## fix longitude coordinate
    budget_data = budget_data.assign_coords({"nlon": budget_data.lon.values})
    budget_data = budget_data.drop_vars("lon").rename({"nlon": "longitude"})

    ## trim in time
    budget_data = xr.concat(
        [budget_data.sel(t_early), budget_data.sel(t_late)], dim="time"
    )

    ## interpolate
    budget_data = budget_data.interp_like(target_grid)

    # ## separate forced/anomalies
    forced_bud, anom_bud = src.utils.separate_forced(budget_data)

    return forced_bud, anom_bud


def plot_mlds(axs, bar_early, bar_late, month=None):
    """plot mixed layer depth on pair of axs objects"""

    ## get longitude
    lon = src.utils.merimean(bar_early).longitude

    ## helper function to get month
    if month is None:
        sel = lambda x: x.mean("month")
    else:
        sel = lambda x: x.sel(month=month)

    ## plot
    axs[0].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[1].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")
    axs[2].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[2].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")

    return


def plot_level(ax, data, level, ls="-", c="w"):
    """plot single level on hovmoller"""
    ax.contour(
        data.longitude,
        data.z_t,
        data,
        levels=[level],
        colors=c,
        linestyles=ls,
        linewidths=1,
    )
    return

## Load data

#### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices(load_z20=True)

### Spatial data

#### Load

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()

## add T,h information
for n in ["T_3", "T_34", "T_4", "h", "h_w", "h_w_z20", "h_z20"]:
    # anom[n] = Th[n] / Th[n].std()
    anom[n] = Th[n]

#### compute some indices

In [None]:
## get sst tendency (and convert from 1/yr to 1/mo)
anom["ddt_sst"] = 1 / 12 * src.utils.get_ddt(anom[["sst"]], is_forward=True)["ddt_sst"]

#### early/late split

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
# t_late = dict(time=slice("1981", "2010"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early)
anom_late = anom.sel(t_late)

## preprocess
anom_early = prep(anom_early).compute()
anom_late = prep(anom_late).compute()

## Compute/plot

In [None]:
def get_slope_bymonth(x, **kwargs):
    """get slope for each month separately"""
    return x.groupby("time.month").map(src.utils.regress_proj, **kwargs)

### SST - SST

In [None]:
## shared kwargs
kwargs = dict(x_var="sst", y_var="sst", fn_x=src.utils.get_nino3)

## function to get slope
# get_slope = lambda x: x.groupby("time.month").map(src.utils.regress_proj, **kwargs)

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, **kwargs)
m_late = get_slope_bymonth(anom_late, **kwargs)

Hovmoller

In [None]:
## shared args
kwargs = dict(amp=1.5, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early, **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late, **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * (m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$W~m^{-2}~K^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

Spatial

In [None]:
## select month
sel = lambda x: x.sel(month=6)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=2, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early, **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late, **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], (m_late - m_early), **dict(contour_kwargs, amp=1)
)

## colorbar
lab = r"$Pa~K^{-1}$"
cb0 = fig.colorbar(cp0, ax=axs[:2], ticks=[-2, 0, 2], label=lab)
cb2 = fig.colorbar(cp2, ax=axs[-1], ticks=[-1, 0, 1], label=lab)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_nino34_box(ax, **box_kwargs)

plt.show()

Scatter

In [None]:
## specify func
grad = lambda x: src.utils.get_nino3(x) - src.utils.get_nino4(x)

## kwargs for plotting
kwargs = dict(months=6, scale=1, x_var="T_3", y_var="sst", fn_y=grad)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $W~m^{-2}~K^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $W~m^{-2}~K^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### SST - NHF

In [None]:
## shared kwargs
kwargs = dict(x_var="sst", y_var="nhf", fn_x=src.utils.get_nino34)

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, **kwargs)
m_late = get_slope_bymonth(anom_late, **kwargs)
alpha = xr.merge([m_early.rename("early"), m_late.rename("late")])

In [None]:
## shared args
kwargs = dict(amp=40, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=alpha["early"], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=alpha["late"], **kwargs)
cp2 = src.utils.make_cycle_hov(
    axs[2], data=2 * (alpha["late"] - alpha["early"]), **kwargs
)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$W~m^{-2}~K^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

In [None]:
## kwargs for plotting
kwargs = dict(months=4, scale=1, x_var="T_34", y_var="nhf", fn_y=src.utils.get_nino34)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $W~m^{-2}~K^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $W~m^{-2}~K^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### SST-$\tau_x$

In [None]:
## shared kwargs
kwargs = dict(x_var="sst", y_var="taux", fn_x=src.utils.get_nino34)

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, **kwargs)
m_late = get_slope_bymonth(anom_late, **kwargs)
mu_a = xr.merge([m_early.rename("early"), m_late.rename("late")])

In [None]:
## shared args
kwargs = dict(amp=0.02, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=mu_a["early"], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=mu_a["late"], **kwargs)
cp2 = src.utils.make_cycle_hov(
    axs[2], data=2 * (mu_a["late"] - mu_a["early"]), **kwargs
)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    for t in [5, 7]:
        ax.axhline(t, ls="--", c="k", lw=1)

plt.show()

##### Spatial plot

In [None]:
## select month
sel = lambda x: x.sel(month=6)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=3e-2, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], mu_a["early"], **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], mu_a["late"], **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], (mu_a["late"] - mu_a["early"]), **dict(contour_kwargs, amp=1.5e-2)
)

## colorbar
lab = r"$Pa~K^{-1}$"
cb0 = fig.colorbar(cp0, ax=axs[:2], ticks=[-3e-2, 0, 3e-2], label=lab)
cb2 = fig.colorbar(cp2, ax=axs[-1], ticks=[-1.5e-2, 0, 1.5e-2], label=lab)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
src.utils.plot_nino4_box(axs[0, 0], **box_kwargs)
for ax in axs[1:].flatten():
    src.utils.plot_nino34_box(ax, **box_kwargs)

plt.show()

In [None]:
sel = lambda x: x.sel(month=6, latitude=slice(-5, 5)).mean("latitude")

fig, ax = plt.subplots(figsize=(3, 2.5))
ax.plot(mu_a.longitude, sel(mu_a["late"] - mu_a["early"]))
ax.axhline(0, ls="--", c="k", lw=0.8)
ax.set_xlim([140, 280])
plt.show()

In [None]:
## kwargs for plotting
kwargs = dict(months=6, scale=1e3, x_var="T_34", y_var="taux")

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, fn_y=src.utils.get_nino4, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, fn_y=src.utils.get_nino34, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $mPa~K^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $mPa~K^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### $\tau_x$ - $Z_{20}$

In [None]:
## shared kwargs
kwargs = dict(x_var="taux", y_var="z20")

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, fn_x=src.utils.get_nino4, **kwargs)
m_late = get_slope_bymonth(anom_late, fn_x=src.utils.get_nino34, **kwargs)
beta_h_int = xr.merge([m_early.rename("early"), m_late.rename("late")])
beta_h_grad = beta_h_int.differentiate("longitude")

##### Hovmoller

Note: could plot $\frac{dh}{dx}$ here...lines up better with shift in wind pattern...

In [None]:
## shared args
kwargs = dict(amp=2e3, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=beta_h_int["early"], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=beta_h_int["late"], **kwargs)
cp2 = src.utils.make_cycle_hov(
    axs[2], data=2 * (beta_h_int["late"] - beta_h_int["early"]), **kwargs
)

axs[2].contour(
    mu_a.longitude,
    mu_a.month,
    (mu_a["late"] - mu_a["early"]).sel(latitude=slice(-5, 5)).mean("latitude"),
    colors="k",
    linewidths=1,
)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$cm~\left(\text{Pa}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    for t in [5, 7]:
        ax.axhline(t, ls="--", c="k", lw=1)

plt.show()

Scatter

In [None]:
## select regions for beta h
sel_eq = lambda x: x.sel(latitude=slice(-2, 2)).mean("latitude")
# sel_e_beta_h = lambda x: sel_eq(x).sel(longitude=slice(210, 270)).mean("longitude")
# sel_w_beta_h = lambda x: sel_eq(x).sel(longitude=slice(140, 210)).mean("longitude")
sel_w_beta_h = lambda x: sel_eq(x).sel(longitude=slice(120, 210)).mean("longitude")
sel_e_beta_h = lambda x: sel_eq(x).sel(longitude=slice(210, 270)).mean("longitude")
get_beta_h_grad = lambda x: sel_e_beta_h(x) - sel_w_beta_h(x)

## kwargs for plotting
kwargs = dict(months=6, scale=1e-3, x_var="taux", y_var="z20", fn_y=get_beta_h_grad)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, fn_x=src.utils.get_nino4, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, fn_x=src.utils.get_nino34, **kwargs)

## label
axs[0].set_title(f"{m0:.2f}" + r" $cm~mPa^{-1}$")
axs[1].set_title(f"{m1:.2f}" + r" $cm~mPa^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

Spatial plot

In [None]:
import cartopy.crs as ccrs

## specify plot var and amplitude
# PLOT_VAR = beta_h_grad
# AMP = 5e1
PLOT_VAR = beta_h_int
AMP = 1e3

## select month
sel = lambda x: x.sel(month=6)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=AMP, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], PLOT_VAR["early"], **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], PLOT_VAR["late"], **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0],
    (PLOT_VAR["late"] - PLOT_VAR["early"]),
    **dict(contour_kwargs, amp=AMP / 2)
)

## colorbar
lab = r"$cm~Pa^{-1}$"
cb0 = fig.colorbar(cp0, ax=axs[:2], ticks=[-AMP, 0, AMP], label=lab)
cb2 = fig.colorbar(cp2, ax=axs[-1], ticks=[-AMP / 2, 0, AMP / 2], label=lab)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
src.utils.plot_nino4_box(axs[0, 0], **box_kwargs)
for ax in axs[1:].flatten():
    src.utils.plot_nino34_box(ax, **box_kwargs)
    src.utils.plot_hw_box(ax, **box_kwargs)
    # src.utils.plot_nino3_box(ax, **box_kwargs)

plt.show()

In [None]:
sel = lambda x: x.sel(month=6, latitude=slice(-5, 5)).mean("latitude")

fig, axs = plt.subplots(1, 2, figsize=(5.5, 2.5))

## plot wind
axs[0].plot(mu_a.longitude, sel(mu_a["early"]))
axs[0].plot(mu_a.longitude, sel(mu_a["late"]))

## plot thermocline
axs[1].plot(mu_a.longitude, sel(beta_h_grad["early"]))
axs[1].plot(mu_a.longitude, sel(beta_h_grad["late"]))
axs[1].set_ylim([-40, 40])

## plot thermocline
# ax2 = ax.twinx()
# ax2.plot(mu_a.longitude, sel(beta_h_int["late"]))


## format
for ax in axs:
    ax.axhline(0, ls="--", c="k", lw=0.8)
    ax.set_xlim([120, 280])
plt.show()

Subsurface plot

In [None]:
## shared kwargs
kwargs = dict(x_var="taux", y_var="T")

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, fn_x=src.utils.get_nino4, **kwargs)
m_late = get_slope_bymonth(anom_late, fn_x=src.utils.get_nino34, **kwargs)

In [None]:
## get climatologies
T_forced = forced[["T", "T_comp", "z20", "z20_comp"]]
T_forced = src.utils.sel_month(T_forced, 6).sel(latitude=slice(-2, 2)).mean("latitude")
T_clim_early = src.utils.reconstruct_wrapper(T_forced.sel(t_early).mean("time"))
T_clim_late = src.utils.reconstruct_wrapper(T_forced.sel(t_late).mean("time"))

## get scale for taux
sigma_taux = src.utils.reconstruct_std(
    scores=anom_early["taux"],
    components=anom_early["taux_comp"],
    fn=src.utils.get_nino4,
)

# ## Get typical thermocline deviation
SCALE = sigma_taux * 3
sel_ = lambda x: x.sel(latitude=slice(-2, 2), month=6).mean("latitude")
z20_early = T_clim_early["z20"] + sel_(beta_h_int["early"]) * SCALE
z20_late = T_clim_late["z20"] + sel_(beta_h_int["late"]) * SCALE

## scale T anomalies by same amount
m_early_ = m_early.sel(month=6) * SCALE
m_late_ = m_late.sel(month=6) * SCALE

## put everything in xr.datasets
T_early = xr.merge([T_clim_early["T"].rename("clim"), m_early_.rename("anom")])
T_late = xr.merge([T_clim_late["T"].rename("clim"), m_late_.rename("anom")])

In [None]:
## specify amplitude
amp = 4

fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")

for ax, T_, lev in zip(axs[:2], [T_early, T_late], [20, 20]):

    ## temperature
    cp = ax.contourf(
        T_.longitude,
        T_.z_t,
        T_["clim"],
        cmap="cmo.thermal",
        levels=np.arange(10, 34, 2),
        extend="both",
    )

    ## warming pattern
    ax.contour(
        T_.longitude,
        T_.z_t,
        T_["anom"],
        colors="k",
        levels=src.utils.make_cb_range(amp, amp / 5),
        extend="both",
        linewidths=1,
    )

    ## highlight z20
    plot_level(ax, T_["clim"], level=lev)
    plot_level(ax, T_["clim"] + T_["anom"], level=lev, ls="--")

## difference
axs[2].contourf(
    m_early.longitude,
    m_early.z_t,
    (T_late - T_early)["anom"],
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(amp, amp / 10),
    extend="both",
)

## set ax limit and plot Niño 3.4 bounds
cb = fig.colorbar(
    cp, ax=axs[2], ticks=[-amp, 0, amp], label=r"$K~\left(\text{Niño 3}\right)^{-1}$"
)
src.utils.format_subsurf_axs(axs)
for ax in axs:
    ax.set_xlim([125, 280])
    ax.set_ylim([200, 5])

plt.show()

### $Z_{20} - T_{sub}$

In [None]:
## shared kwargs
kwargs = dict(x_var="z20", y_var="T")

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, fn_x=src.utils.get_nino34, **kwargs)
m_late = get_slope_bymonth(anom_late, fn_x=src.utils.get_nino34, **kwargs)
a_h = xr.merge([m_early.rename("early"), m_late.rename("late")])

In [None]:
## specify Hm
Hm = 70
sel = lambda x: x.sel(z_t=Hm, method="nearest")

## shared args
kwargs = dict(amp=0.3, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=sel(a_h["early"]), **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=sel(a_h["late"]), **kwargs)
cp2 = src.utils.make_cycle_hov(
    axs[2], data=2 * sel(a_h["late"] - a_h["early"]), **kwargs
)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    for t in [5, 7]:
        ax.axhline(t, ls="--", c="k", lw=1)

plt.show()

In [None]:
sel_lon = lambda x: x.sel(longitude=slice(210, 280)).mean("longitude")
sel_z20 = lambda x: sel_lon(x).sel(latitude=slice(-2, 2)).mean("latitude")
sel_T = lambda x: sel_lon(x).sel(z_t=Hm, method="nearest")

## kwargs for plotting
kwargs = dict(months=6, scale=1e1, x_var="z20", y_var="T", fn_x=sel_z20, fn_y=sel_T)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $K~cm^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $K~cm^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

#### $T_{sub}$ vs $T_e$ and $h_w$

In [None]:
# get_beta_h_grad = lambda x: sel_e_beta_h(x) - sel_w_beta_h(x)
## specify early/late funcs
tau_fn_early = src.utils.get_nino4
tau_fn_late = src.utils.get_nino34

## get area-averaged coefficients
MONTH = 6
mu_a_ = xr.merge([tau_fn_early(mu_a["early"]), tau_fn_late(mu_a["late"])]).sel(
    month=MONTH
)
beta_h_ = get_beta_h_grad(beta_h_int).sel(month=MONTH)
a_h_ = sel_T(a_h).sel(month=MONTH)

## compute variables
recon = lambda x: src.utils.reconstruct_wrapper(x[["T", "T_comp"]], fn=sel_T)["T"]
T_sub_early = src.utils.sel_month(recon(anom_early), MONTH)
T_sub_late = src.utils.sel_month(recon(anom_late), MONTH)

## thermocline slope
get_dh = lambda x: src.utils.reconstruct_wrapper(
    x[["z20", "z20_comp"]], fn=get_beta_h_grad
)["z20"]
dh_early = src.utils.sel_month(get_dh(anom_early), MONTH)
dh_late = src.utils.sel_month(get_dh(anom_late), MONTH)

## thermocline in east
get_he = lambda x: src.utils.reconstruct_wrapper(
    x[["z20", "z20_comp"]], fn=sel_e_beta_h
)["z20"]
get_hw = lambda x: src.utils.reconstruct_wrapper(
    x[["z20", "z20_comp"]], fn=sel_w_beta_h
)["z20"]
he_early = src.utils.sel_month(get_he(anom_early), MONTH)
he_late = src.utils.sel_month(get_he(anom_late), MONTH)
hw_early = src.utils.sel_month(get_hw(anom_early), MONTH)
hw_late = src.utils.sel_month(get_hw(anom_late), MONTH)

## Niño 3.4
T_early = src.utils.sel_month(anom_early["T_34"], MONTH)
T_late = src.utils.sel_month(anom_late["T_34"], MONTH)

## get predicted thermocline slope
dh_early_hat = (beta_h_ * mu_a_)["early"] * T_early
dh_early_hat_est = 1.6 * dh_early_hat  ## this gives better estimate
dh_late_hat = (beta_h_ * mu_a_)["late"] * T_late

## get predicted subsurface temperature
he_hat_early = hw_early + dh_early_hat
he_hat_late = hw_late + dh_late_hat
T_sub_early_hat = he_hat_early * a_h_["early"]
T_sub_late_hat = he_hat_late * a_h_["late"]

# ## merge for easier plotting
# dh_early = xr.merge([dh_early.rename("actual"), dh_early_hat.rename("hat")])
# he_early = xr.merge([he_early.rename("actual"), he_early_hat.rename("hat")])
# hw_early = xr.merge([hw_early.rename("actual"), hw_early_hat.rename("hat")])
# T_sub_early = xr.merge([T_sub_early.rename("actual"), T_sub_early_hat.rename("hat")])

Plot slope prediction

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(5.5, 2.5))
# for dh, dh_hat in zip([[dh_early, dh_late], [
axs[0].scatter(
    dh_early,
    dh_early_hat,
    s=3,
    alpha=0.5,
)
zz = np.linspace(-20, 20)
axs[0].plot(zz, zz, c="k")
plt.show()

Plot $h_e$ prediction

In [None]:
fig, ax = plt.subplots(figsize=(2.5, 2.5))
ax.scatter(
    he_early,
    he_hat_early,
    s=3,
    alpha=0.5,
)
zz = np.linspace(-20, 20)
ax.plot(zz, zz, c="k")
plt.show()

In [None]:
h_e_est = dh_early_hat_est + anom_early["h_w_z20"]
fig, ax = plt.subplots(figsize=(4, 4))
ax.scatter(
    src.utils.sel_month(T_sub_early, 6),
    T_sub_early_hat,
    s=3,
    alpha=0.5,
)
zz = np.linspace(-4, 4)
ax.plot(zz, zz, c="k")

Compute differently

In [None]:
# for x in [anom_early, anom_late]:
#     x["T_e"] =

In [None]:
kwargs = dict(y_vars=["T"], x_vars=["T_34"])
coefs = src.utils.regress_xr_bymonth(data, **kwargs)

## Scratch

### $\tau_x-T_{sub}$
Note: there's no quasi-balance condition for this; $\tau_x$ tells us tilt of the thermocline, not depth

In [None]:
## shared kwargs
kwargs = dict(x_var="taux", y_var="T")

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, fn_x=src.utils.get_nino4, **kwargs)
m_late = get_slope_bymonth(anom_late, fn_x=src.utils.get_nino34, **kwargs)

In [None]:
## specify Hm
Hm = 70
sel = lambda x: x.sel(z_t=Hm, method="nearest")

## shared args
kwargs = dict(amp=3e2, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=sel(m_early), **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=sel(m_late), **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * sel(m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    for t in [5, 7]:
        ax.axhline(t, ls="--", c="k", lw=1)

plt.show()

In [None]:
sel_T = (
    lambda x: x.sel(longitude=slice(210, 280))
    .mean("longitude")
    .sel(z_t=Hm, method="nearest")
)

## kwargs for plotting
kwargs = dict(months=6, scale=1e-2, x_var="taux", y_var="T", fn_y=sel_T)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, fn_x=src.utils.get_nino4, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, fn_x=src.utils.get_nino34, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $K~hPa^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $K~hPa^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)