# ENSO composites

## Imports

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Load data

### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load ELI and tropical sst data
eli = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/eli.nc"))
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))
Th = xr.merge([Th, eli, trop_sst])

### Spatial data

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()

## add T,h information
for n in ["T_3", "T_34", "T_4", "eli_05", "h", "h_w", "trop_sst_05"]:
    anom[n] = Th[n]

## get "total" sst and precip
for n in ["sst", "pr", "nhf"]:
    anom[f"{n}_total"] = forced[n] + anom[n]

#### Budget data

In [None]:
_, budg_anom = src.utils.load_budget_data()
anom = xr.merge([anom, budg_anom[["ADV_3D_TEMP", "TEND_TEMP"]]])

### Preprocess

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
# t_early = dict(time=slice("1990", "2019"))
# t_late = dict(time=slice("2030", "2059"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early)
anom_late = anom.sel(t_late)

## remove T-dependence from ssh field
for anom_ in [anom_early, anom_late]:
    anom_["ssh_hat"] = src.utils.remove_sst_dependence_v2(
        anom_,
        h_var="ssh",
        T_var="T_3",
    )
    anom_["ssh_hat_comp"] = anom_["ssh_comp"]

## Composite

### Funcs

In [None]:
def get_composite(idx, data, peak_month, q=0.95, is_warm=True, event_type=None):
    """
    Get hovmoller composite based on specified:
    - data: used to compute index/make composite
    - peak_month: month to center composite on
    - q: quantile threshold for composite
    """

    ## handle warm/cold case
    if is_warm:
        kwargs = dict(q=q, check_cutoff=lambda x, cut: x > cut)
    else:
        kwargs = dict(q=1 - q, check_cutoff=lambda x, cut: x < cut)

    ## kwargs for composite
    kwargs = dict(
        kwargs, peak_month=peak_month, idx=idx, data=data, event_type=event_type
    )

    ## composite of projected data
    comp_proj = src.utils.make_composite(**kwargs)

    return comp_proj


def get_spatial_composite(data, **composite_kwargs):
    """
    Get spatial composite
    """

    ## pull out components
    components, scores = src.utils.split_components(data)

    ## get projected composite
    comp_proj = get_composite(data=scores, **composite_kwargs)

    ## reconstruct spatial fields
    comp = reconstruct_helper(comp_proj, components, func=lambda x: x).drop_vars("mode")

    ## reconstruct relative SST
    comp["sst_rel"] = comp["sst_total"] - comp["trop_sst_05"]

    return comp


def reconstruct_helper(composite, components, func):
    """reconstruction helper function for composite"""

    ## copy to hold reconstructed results
    composite_recon = copy.deepcopy(composite)

    ## reconstruct anomalies
    for c in list(components):
        composite_recon[c] = src.utils.reconstruct_fn(
            components=components[c],
            scores=composite[c],
            fn=func,
        )

    ## check for "total" fields
    for c in list(composite):
        if "_total" in c:
            n = c[:-6]
            composite_recon[c] = src.utils.reconstruct_fn(
                components=components[n],
                scores=composite[c],
                fn=func,
            )

    return composite_recon


def get_spatial_clim(forced, lags, peak_month, fn=lambda x: x):
    """get climatologies of spatial variables"""

    ## reconstruct monthly climatology for period
    clim = src.utils.reconstruct_clim(data=forced, fn=fn)

    ## convert to lag coordinates
    months = 1 + np.mod(lags + peak_month - 1, 12)
    clim_comp = xr.concat(
        [clim.sel(month=m).drop_vars("month") for m in months],
        dim=lags,
    )

    return clim_comp


def get_T_ml_tendency(T, mld, delta=5, H0=None):
    """compute mixed-layer temperature tendency"""

    ## integrate over mixed layer
    T_ml = src.utils.get_ml_avg(T, Hm=mld, delta=delta, H0=H0)

    ## compute tendency
    return T_ml.differentiate("lag")


def get_spatial_composite_wrapper(
    data,
    forced_scores,
    peak_month,
    H0=None,
    **composite_kwargs,
):
    """
    Get spatial composite
    """

    ## get spatial composite of anomalies
    composite = get_spatial_composite(
        data=data,
        peak_month=peak_month,
        **composite_kwargs,
    )

    ## get background state
    comp_clim = get_spatial_clim(
        forced=forced_scores,
        lags=composite.lag,
        peak_month=peak_month,
    )

    ## get thermocline depth
    H_clim = src.utils.get_H_int(comp_clim["T"], thresh=0.04)
    H_full = src.utils.get_H_int(comp_clim["T"] + composite["T"], thresh=0.04)
    comp_clim["H"] = H_clim
    composite["H"] = H_full - H_clim

    ## compute ocean feedbacks
    feedbacks = src.utils.get_feedbacks(bar=comp_clim, prime=composite)
    composite = xr.merge([composite, feedbacks])

    ## get differences
    composite["diff"] = composite["TEND_TEMP"] - composite["ADV_3D_TEMP"]

    ## add SST tendency
    composite["ddt_sst"] = composite["sst"].differentiate("lag")
    composite["ddt_ssh"] = composite["ssh"].differentiate("lag")
    composite["ddt_ssh_hat"] = composite["ssh_hat"].differentiate("lag")

    ## get NHF in units of K/mo
    sec_per_mo = 8.64e4 * 30
    rho = 1.02e3
    Cp = 4.2e3

    ## get H
    if H0 is None:
        H = comp_clim["mld"]
    else:
        H = H0

    composite["Q"] = composite["nhf"] * sec_per_mo / (rho * Cp * H)
    composite["Q_total"] = composite["nhf_total"] * sec_per_mo / (rho * Cp * H)

    ## nonlinear (account for MLD changes too)
    H_nl = composite["mld"] + comp_clim["mld"]
    composite["Q_total_nl"] = composite["nhf_total"] * sec_per_mo / (rho * Cp * H_nl)

    return composite, comp_clim

### Compute

In [None]:
## specify what variable to use
VARNAME = "T_34"

## specify shared args
kwargs = dict(
    peak_month=12,
    q=0.75,
    is_warm=False,
    event_type=1,
)

## do the compute
comp_early, clim_early = get_spatial_composite_wrapper(
    idx=anom_early[VARNAME],
    data=anom_early,
    forced_scores=forced.sel(t_early),
    **kwargs,
)
comp_late, clim_late = get_spatial_composite_wrapper(
    idx=anom_late[VARNAME],
    data=anom_late,
    forced_scores=forced.sel(t_late),
    **kwargs,
)

## Decompose feedback changes
feedback_changes = src.utils.decompose_feedback_changes(
    bar_early=clim_early,
    prime_early=comp_early,
    bar_late=clim_late,
    prime_late=comp_late,
)
comp_late = xr.merge([comp_late, feedback_changes])

#### Integrate over mixed layer and latitudes

In [None]:
## Integrate over ML
ml_kwargs = dict(H0=70, Hm=None)
comp_early = src.utils.get_ml_avg_ds(comp_early, **ml_kwargs)
comp_late = src.utils.get_ml_avg_ds(comp_late, **ml_kwargs)

## hovmoller version
merimean = lambda x: x.sel(latitude=slice(-5, 5)).mean("latitude")
hov_comp_early = merimean(comp_early).transpose("lag", ...)
hov_comp_late = merimean(comp_late).transpose("lag", ...)

#### Line plots

In [None]:
get_nino4_comp = lambda x: x.sel(longitude=slice(160, 210)).mean("longitude")
nino4_early = get_nino4_comp(hov_comp_early)
nino4_late = get_nino4_comp(hov_comp_late)
nino4_diff = nino4_late - nino4_early

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))

ax.plot(nino4_early.lag, nino4_diff["ddt_ssh"], label="ddt_ssh")
ax.plot(nino4_early.lag, nino4_diff["ddt_sst"] * 5, label="ddt_sst")
# ax.plot(nino4_early.lag, nino4_diff["nhf"]/20, label="NHF")
ax.plot(nino4_early.lag, nino4_diff["nhf_total"] / 10, label="NHF")
ax.plot(nino4_early.lag, -nino4_diff["taux"] * 3e2, label=r"$\tau_x$")

## format
ax.set_xlim([-4, 8])
ax_kwargs = dict(ls="--", c="k", lw=0.8)
ax.axhline(**ax_kwargs)
ax.axvline(**ax_kwargs)
ax.legend(prop=dict(size=8))


plt.show()

### Plot Hovmollers

In [None]:
def plot_cycle_hov_lagged(
    ax,
    data,
    amp,
    is_filled=True,
    xticks=[190, 240],
    lat_bound=5,
    nlev=5,
    cmap="cmo.balance",
    levels=None,
):
    """plot data on ax object"""

    ## set levels if not specified
    if levels is None:
        levels = src.utils.make_cb_range(amp, amp / nlev)

    ## specify shared kwargs
    shared_kwargs = dict(levels=levels, extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap=cmap)

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## average over latitudes (if necessary)
    if "latitude" in data.coords:
        plot_data = merimean(data, lat_bound=lat_bound)
    else:
        plot_data = data

    ## do the plotting
    cp = plot_fn(
        plot_data.longitude,
        plot_data.lag,
        plot_data.transpose("lag", "longitude"),
        **kwargs,
        **shared_kwargs,
    )

    ## format ax object
    kwargs = dict(c="w", ls="--", lw=0.8, alpha=0.5)
    ax.set_xlim([145, 280])
    ax.set_xlabel("Lon")
    ax.set_xticks(xticks)
    for tick in xticks:
        ax.axvline(tick, **kwargs)
    ax.axhline(0, **kwargs)

    return cp


def format_hov_axs(axs, peak_mon):
    """format hov axs"""

    ## label
    font_kwargs = dict(size=10)
    axs[0].set_title("Early", **font_kwargs)
    axs[1].set_title("Late", **font_kwargs)
    axs[2].set_title("Difference (scaled)", **font_kwargs)
    for ax in axs:
        ax.set_yticks([])

    src.utils.label_hov_yaxis(axs[0], peak_mon=peak_mon)

    return axs


def plot_eli_on_axs(axs, x0, x1, name="eli_05"):
    """plot ELI boundary on hovmoller axs"""

    ## plot ELI
    axs[0].plot(x0[name], x0.lag, c="magenta")
    axs[1].plot(x1[name], x1.lag, c="magenta", ls="--")
    axs[2].plot(x0[name], x0.lag, c="magenta", ls="-")
    axs[2].plot(x1[name], x1.lag, c="magenta", ls="--")

    return


def plot_level(ax, comp, level, ls="-", c="magenta"):
    """plot single level on hovmoller"""
    ax.contour(
        comp.longitude,
        comp.lag,
        comp.transpose("lag", ...),
        levels=[level],
        colors=c,
        linestyles=ls,
    )
    return


def plot_sst_rel_on_axs(axs, x0, x1, lev=0):
    """plot ELI boundary on hovmoller axs"""

    ## plot ELI
    plot_level(axs[0], x0["sst_rel"], level=lev)
    plot_level(axs[1], x1["sst_rel"], level=lev, ls="--")
    plot_level(axs[2], x0["sst_rel"], level=lev)
    plot_level(axs[2], x1["sst_rel"], level=lev, ls="--")

    return


def plot_comp_on_axs(
    axs,
    x0,
    x1,
    name,
    amp,
    peak_month,
    amp_diff=None,
    nlev=5,
    **kwargs,
):
    """plot composite on axs objects"""

    ## handle amp_diff=None
    if amp_diff is None:
        amp_diff = amp / 2

    ## loop thru: x0, x1, (x1-x0)
    for ax, x in zip(axs, [x0, x1, (amp / amp_diff) * (x1 - x0)]):

        ## Plot data and suppress y-axis plot
        cf = plot_cycle_hov_lagged(ax=ax, data=x[name], amp=amp, nlev=nlev, **kwargs)

    ## format/label
    format_hov_axs(axs, peak_mon=peak_month)

    return cf

#### Shared args for hovmoller

In [None]:
hov_kwargs = dict(
    x0=hov_comp_early,
    x1=hov_comp_late,
    peak_month=kwargs["peak_month"],
)

#### SST

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="sst", amp=2.5, **hov_kwargs)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-3.5, 0, 3.5], label="K")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)

for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

#### SSH

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="ddt_ssh", amp=2, **hov_kwargs)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-4, 0, 4], label="cm/mo")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)
for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

#### Thermocline/Ekman feedbacks

In [None]:
fig, axs = plt.subplots(5, 3, figsize=(6, 11), layout="constrained")

for j, n in enumerate(["ddt_sst", "THF_ml", "EKM_ml", "ZAF_ml", "ADV_ml"]):

    ## plot data
    cf = plot_comp_on_axs(axs[j, :], name=n, amp=1, **hov_kwargs)

    ## label
    cb = fig.colorbar(cf, ax=axs[j, 2], ticks=[-1, 0, 1], label="K/mo")

    ## plot ELI
    plot_eli_on_axs(axs[j, :], x0=comp_early, x1=comp_late)

## format
for ax in axs[:-1].flatten():
    ax.set_xticks([])
    ax.set_xlabel(None)
for ax in axs[1:].flatten():
    ax.set_title(None)

for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

#### Budget terms

In [None]:
## make plot
fig, axs = plt.subplots(3, 3, figsize=(6, 7), layout="constrained")

for j, n in enumerate(["TEND_TEMP_ml", "ADV_3D_TEMP_ml", "diff_ml"]):

    ## plot data
    cf = plot_comp_on_axs(axs[j, :], name=n, amp=1, **hov_kwargs)

    ## label
    cb = fig.colorbar(cf, ax=axs[j, 2], ticks=[-1, 0, 1], label="K/mo")

    ## plot ELI
    plot_eli_on_axs(axs[j, :], x0=comp_early, x1=comp_late)

## format
for ax in axs[:-1].flatten():
    ax.set_xticks([])
    ax.set_xlabel(None)
for ax in axs[1:].flatten():
    ax.set_title(None)
for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

#### Precip, relative ELI

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="pr", amp=1e-4, cmap="cmo.balance_r", **hov_kwargs)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-1e-4, 0, 1e-4], label="$kg~m^{-1}s^{-1}$")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)
for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

#### Total precip, relative SST

In [None]:
## shared args
pr_kwargs = dict(levels=np.arange(0, 21, 3), cmap="cmo.rain", amp=1e-4)

## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf0 = plot_cycle_hov_lagged(
    ax=axs[0],
    data=8.6e4 * hov_comp_early["pr_total"],
    **pr_kwargs,
)

cf1 = plot_cycle_hov_lagged(
    ax=axs[1],
    data=8.6e4 * hov_comp_late["pr_total"],
    **pr_kwargs,
)

cf_diff = plot_cycle_hov_lagged(
    ax=axs[2],
    data=8.6e4 * (hov_comp_late - hov_comp_early)["pr_total"],
    cmap="cmo.balance_r",
    levels=src.utils.make_cb_range(9, 1.8),
    amp=1e-4,
)

## label
cb = fig.colorbar(cf1, ax=axs[1], ticks=[0, 9, 18])
cb = fig.colorbar(cf_diff, ax=axs[2], ticks=[-9, 0, 9], label="mm/day")
format_hov_axs(axs, peak_mon=hov_kwargs["peak_month"])

## plot ELI
plot_sst_rel_on_axs(axs, hov_comp_early, hov_comp_late)
for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

#### $\tau_x$

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="taux", amp=2e-2, **hov_kwargs)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-4e-2, 0, 4e-2], label="Pa")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)
for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

#### $u_{ml}$

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="u_ml", amp=1.5e6, **hov_kwargs)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-2e6, 0, 2e6], label="m / month")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)
for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

#### Heat flux

##### anomaly

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="nhf", amp=80, **hov_kwargs)
cb = fig.colorbar(cf, ax=axs[2], ticks=[-80, 0, 80], label=r"$W/m^{2}$")
# cf = plot_comp_on_axs(axs, name="Q", amp=2, **hov_kwargs)
# cb = fig.colorbar(cf, ax=axs[2], ticks=[-2, 0, 2], label=r"$K/K$")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)
for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

In [None]:
## specify month
MONTH = 1

sel = lambda x: src.utils.sel_month(
    x.sel(longitude=slice(160, 210)).mean(["longitude", "z_t"]),
    months=MONTH,
)
# diff_early = sel(budg_anom.sel(t_early)["diff"])
diff_late = sel(budg_anom.sel(t_late)["diff"])

nhf_nino4 = src.utils.reconstruct_fn(
    scores=src.utils.sel_month(anom_late["nhf_total"], months=MONTH),
    components=anom_late["nhf_comp"],
    fn=src.utils.get_nino4,
)

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))

ax.scatter(
    diff_late.stack(sample=["time", "member"]),
    nhf_nino4.stack(sample=["time", "member"]),
    s=3,
)

##### Total

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="nhf_total", amp=160, amp_diff=40, **hov_kwargs)
cb = fig.colorbar(cf, ax=axs[2], ticks=[-160, 0, 160], label=r"$W/m^{2}$")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)
axs[-1].set_title("Difference (x4)", size=10)
for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

##### total, scaled by MLD

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="Q_total", amp=5, amp_diff=1, **hov_kwargs)
cb = fig.colorbar(cf, ax=axs[2], ticks=[-2, 0, 2], label=r"$W/m^{2}$")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)
axs[-1].set_title("Difference (x4)", size=10)
for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

In [None]:
sel = lambda x: x.sel(lag=2).sel(latitude=slice(-5, 5)).mean("latitude")
fig, ax = plt.subplots(figsize=(3, 2.5))
for clim_, anom_ in zip([clim_early, clim_late], [comp_early, comp_late]):

    c0 = sel(clim_["mld"])
    c1 = sel(clim_["mld"] + anom_["mld"])

    p0 = ax.plot(clim_.longitude, c0, lw=1, ls="--")
    ax.plot(clim_.longitude, c1, c=p0[0].get_color(), lw=2)

ax.set_ylim(ax.get_ylim()[::-1])

plt.show()

#### MLD

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="mld", amp=30, amp_diff=15, **hov_kwargs)
cb = fig.colorbar(cf, ax=axs[2], ticks=[-160, 0, 160], label=r"$m$")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)
axs[-1].set_title("Difference (x2)", size=10)

for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

#### Thermocline depth

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="H", amp=30, amp_diff=10, **hov_kwargs)
cb = fig.colorbar(cf, ax=axs[2], ticks=[-160, 0, 160], label=r"$m$")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)
axs[-1].set_title("Difference (x2)", size=10)

for ax in axs.flatten():
    ax.axhline(2, ls="--", c="k", lw=0.8)

plt.show()

### Subsurface

In [None]:
def plot_mlds(axs, comp_early, comp_late, lag=None, n="mld"):
    """plot mixed layer depth on pair of axs objects"""

    ## helper function to get month
    if lag is None:
        sel = lambda x: x.mean("lag")
    else:
        sel = lambda x: x.sel(lag=lag)

    ## plot
    if "latitude" in comp_early[n].dims:
        sel_ = lambda x: sel(src.utils.merimean(x))
        lon = src.utils.merimean(comp_early[n]).longitude
    else:
        sel_ = lambda x: sel(x)
        lon = comp_early.longitude

    axs[0].plot(lon, sel_(comp_early[n]), c="k")
    axs[1].plot(lon, sel_(comp_late[n]), c="k", ls="--")
    axs[2].plot(lon, sel_(comp_early[n]), c="k")
    axs[2].plot(lon, sel_(comp_late[n]), c="k", ls="--")

    return

In [None]:
## specify LAG
LAG = 1

## specify vars to plot
plot_vars = ["TEND_TEMP", "diff", "ADV_3D_TEMP", "ADV", "THF", "EKM", "ZAF", "DD"]


for n in plot_vars:

    print(f"\n\n{n}")

    fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")

    for ax, m in zip(axs[:2], [comp_early, comp_late]):

        ## temperature
        cp = ax.contourf(
            m.longitude,
            m.z_t,
            m[n].sel(lag=LAG),
            cmap="cmo.balance",
            levels=src.utils.make_cb_range(2, 0.2),
            extend="both",
        )

    ## difference
    axs[2].contourf(
        comp_early.longitude,
        comp_early.z_t,
        (comp_late - comp_early)[n].sel(lag=LAG),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(2, 0.2),
        extend="both",
    )

    ## plot MLD
    plot_mlds(
        axs,
        comp_early=clim_early,
        comp_late=clim_late,
        lag=LAG,
        n="mld",
    )
    plot_mlds(
        axs,
        comp_early=clim_early + comp_early,
        comp_late=clim_late + comp_late,
        lag=LAG,
        n="H",
    )

    ## set ax limit and plot Niño 3.4 bounds
    cb = fig.colorbar(
        cp, ax=axs[2], ticks=[-2, 0, 2], label=r"$K~\left(\text{month}\right)^{-1}$"
    )
    src.utils.format_subsurf_axs(axs)
    for ax in axs:
        ax.set_xlim([140, 280])
        ax.set_ylim([300, 5])
        ax.axvline(190, ls="--", c="w", lw=0.8)
        ax.axvline(240, ls="--", c="w", lw=0.8)
    axs[0].set_yticks([200, 100, 0])

    plt.show()

In [None]:
lon = comp_late.longitude
d = (comp_late - comp_early).sel(lag=1)
sel_h = lambda x: x.sel(latitude=slice(-1.5, 1.5)).mean("latitude")
sel_v = lambda x: x.mean("z_t")
sc = 5.5

fig, ax = plt.subplots(figsize=(4, 3))
ax.axhline(0, ls="--", c="k", lw=0.8)
# ax.plot(lon, sel_h(d["ddt_ssh"]))
ax.plot(lon, sel_v(sc * d["TEND_TEMP"]))
plt.plot(lon, sel_v(sc * d["ADV_3D_TEMP"]))
plt.plot(lon, sel_v(sc * (d["diff"])))
plt.plot(lon, sel_v(sc * (d["ADV_3D_TEMP"])) + sel_h(d["nhf_total"]) * 2e-2)
# plt.plot(lon, sel_h(d["nhf_total"] * 2e-2), ls="--")
# plt.plot(lon, sel_h(d["nhf"] * 2e-2), ls="--")
ax.set_ylim([-2, 2])
# plt.plot(comp_late.longitude, (comp_late-comp_early).sel(lag=1)["ADV_3D_TEMP"].sel(z_t=slice(0,150)).mean("z_t")*4)

### Plot spatial

updated

In [None]:
def plot_lev_2d(ax, data, lev=0, lw=2, c="magenta", ls="-"):
    """plot relative sst on ax object"""

    cp = ax.contour(
        data.longitude,
        data.latitude,
        data,
        transform=ccrs.PlateCarree(),
        levels=[lev],
        linewidths=lw,
        colors=c,
        linestyles=ls,
    )

    return cp


def plot_sst_rel_on_axs_2d(axs, x0, x1, **kwargs):
    """plot ELI boundary on hovmoller axs"""

    ## plot ELI
    plot_lev_2d(axs[0], x0["sst_rel"], **kwargs)
    plot_lev_2d(axs[1], x1["sst_rel"], ls="--", **kwargs)
    plot_lev_2d(axs[2], x0["sst_rel"], **kwargs)
    plot_lev_2d(axs[2], x1["sst_rel"], ls="--", **kwargs)

    return


def plot_2d(
    ax,
    data,
    amp=None,
    is_filled=True,
    xticks=[190, 240],
    lat_bound=5,
    nlev=10,
    cmap="cmo.balance",
    levels=None,
):
    """plot data on ax object"""

    ## set levels if not specified
    if levels is None:
        levels = src.utils.make_cb_range(amp, amp / nlev)

    ## specify shared kwargs
    shared_kwargs = dict(levels=levels, extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap=cmap)

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## do the plotting
    cp = plot_fn(
        data.longitude,
        data.latitude,
        data.transpose("latitude", "longitude"),
        transform=ccrs.PlateCarree(),
        **kwargs,
        **shared_kwargs,
    )
    return cp

#### Relative SST, precip

In [None]:
## specify which month to look at
sel = lambda x: x.sel(lag=2)

fig = plt.figure(figsize=(15, 4.375), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## plot relative SST
cp_sst = plot_2d(axs[0, 0], sel(comp_early)["sst_rel"], amp=10)
plot_2d(axs[1, 0], sel(comp_late)["sst_rel"], amp=10)
cp_sst_ = plot_2d(axs[2, 0], sel(comp_late - comp_early)["sst_rel"], amp=2)

## plot precip (total)
pr_kwargs = dict(levels=np.arange(0, 28, 4), cmap="cmo.rain")
cp_pr = plot_2d(axs[0, 1], 8.6e4 * sel(comp_early)["pr_total"], **pr_kwargs)
plot_2d(axs[1, 1], 8.6e4 * sel(comp_late)["pr_total"], **pr_kwargs)

## plot precip (difference)
pr_diff = 8.6e4 * sel(comp_late - comp_early)["pr_total"]
cp_pr_ = plot_2d(axs[2, 1], pr_diff, amp=12, nlev=5, cmap="cmo.balance_r")

## plot zero line
for j in [0, 1]:
    plot_sst_rel_on_axs_2d(axs[:, j], sel(comp_early), sel(comp_late), lw=1.5)

## colorbars
fig.colorbar(cp_sst, ax=axs[:2, 0], ticks=[-10, 0, 10])
fig.colorbar(cp_pr, ax=axs[:2, 1], ticks=[0, 12, 24])
fig.colorbar(cp_sst_, ax=axs[2, 0], ticks=[-2, 0, 2])
fig.colorbar(cp_pr_, ax=axs[2, 1], ticks=[-12, 0, 12])

plt.show()

#### Heat flux, precip anomaly

In [None]:
## specify which month to look at
sel = lambda x: x.sel(lag=1)

fig = plt.figure(figsize=(15, 4.375), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## plot relative NHF
cp_nhf = plot_2d(axs[0, 0], sel(comp_early)["nhf"], amp=80)
plot_2d(axs[1, 0], sel(comp_late)["nhf"], amp=80)
cp_nhf_ = plot_2d(axs[2, 0], sel(comp_late - comp_early)["nhf"], amp=80)

## plot precip (anom)
pr_kwargs = dict(amp=10, cmap="cmo.balance_r")
cp_pr = plot_2d(axs[0, 1], 8.6e4 * sel(comp_early)["pr"], **pr_kwargs)
plot_2d(axs[1, 1], 8.6e4 * sel(comp_late)["pr"], **pr_kwargs)

## plot precip (difference)
pr_diff = 8.6e4 * sel(comp_late - comp_early)["pr"]
cp_pr_ = plot_2d(axs[2, 1], pr_diff, **pr_kwargs)

## plot zero line
for j in [0, 1]:
    plot_sst_rel_on_axs_2d(axs[:, j], sel(comp_early), sel(comp_late), lw=1.5)

## colorbars
fig.colorbar(cp_nhf, ax=axs[:2, 0], ticks=[-80, 0, 80])
fig.colorbar(cp_pr, ax=axs[:2, 1], ticks=[-10, 0, 10])
fig.colorbar(cp_nhf_, ax=axs[2, 0], ticks=[-40, 0, 40])
fig.colorbar(cp_pr_, ax=axs[2, 1], ticks=[-10, 0, 10])

plt.show()

#### Heat flux, wind stress

In [None]:
## specify which month to look at
sel = lambda x: x.sel(lag=1)

fig = plt.figure(figsize=(15, 4.375), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## plot relative NHF
cp_nhf = plot_2d(axs[0, 0], sel(comp_early)["nhf"], amp=80)
plot_2d(axs[1, 0], sel(comp_late)["nhf"], amp=80)
cp_nhf_ = plot_2d(axs[2, 0], sel(comp_late - comp_early)["nhf"], amp=80)

## plot taux
taux_kwargs = dict(amp=3e-2, cmap="cmo.balance")
cp_pr = plot_2d(axs[0, 1], sel(comp_early)["taux"], **taux_kwargs)
plot_2d(axs[1, 1], sel(comp_late)["taux"], **taux_kwargs)

## plot precip (difference)
cp_pr_ = plot_2d(axs[2, 1], sel(comp_late - comp_early)["taux"], **taux_kwargs)

## plot zero line
for j in [0, 1]:
    plot_sst_rel_on_axs_2d(axs[:, j], sel(comp_early), sel(comp_late), lw=1.5)

for ax in axs[-1]:
    src.utils.plot_nino4_box(ax, c="k")

## colorbars
# fig.colorbar(cp_nhf, ax=axs[:2, 0], ticks=[-80, 0, 80])
# fig.colorbar(cp_pr, ax=axs[:2, 1], ticks=[-10, 0, 10])
# fig.colorbar(cp_nhf_, ax=axs[2, 0], ticks=[-40, 0, 40])
# fig.colorbar(cp_pr_, ax=axs[2, 1], ticks=[-10, 0, 10])

plt.show()

#### SST, heat flux

In [None]:
## specify which month to look at
sel = lambda x: x.sel(lag=1)

fig = plt.figure(figsize=(15, 4.375), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## plot sst
cp_nhf = plot_2d(axs[0, 0], sel(comp_early)["ddt_sst"], amp=0.8)
plot_2d(axs[1, 0], sel(comp_late)["ddt_sst"], amp=0.8)
cp_nhf_ = plot_2d(axs[2, 0], sel(comp_late - comp_early)["ddt_sst"], amp=0.8)

## plot nhf
## plot relative NHF
cp_nhf = plot_2d(axs[0, 1], sel(comp_early)["nhf"], amp=100)
plot_2d(axs[1, 1], sel(comp_late)["nhf"], amp=100)
cp_nhf_ = plot_2d(axs[2, 1], sel(comp_late - comp_early)["nhf"], amp=50)

## plot precip (difference)
# cp_pr_ = plot_2d(axs[2, 1], sel(comp_late - comp_early)["tauy"], **taux_kwargs)

## plot zero line
for j in [0, 1]:
    plot_sst_rel_on_axs_2d(axs[:, j], sel(comp_early), sel(comp_late), lw=1.5)

src.utils.plot_nino4_box(axs[-1, 0], c="k")

## colorbars
# fig.colorbar(cp_nhf, ax=axs[:2, 0], ticks=[-80, 0, 80])
# fig.colorbar(cp_pr, ax=axs[:2, 1], ticks=[-10, 0, 10])
# fig.colorbar(cp_nhf_, ax=axs[2, 0], ticks=[-40, 0, 40])
# fig.colorbar(cp_pr_, ax=axs[2, 1], ticks=[-10, 0, 10])

plt.show()

#### ddt(SSH), heat flux

In [None]:
## specify which month to look at
sel = lambda x: x.sel(lag=1)

fig = plt.figure(figsize=(15, 4.375), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## plot sst
cp_nhf = plot_2d(axs[0, 0], sel(comp_early)["ddt_ssh"], amp=5)
plot_2d(axs[1, 0], sel(comp_late)["ddt_ssh"], amp=5)
cp_nhf_ = plot_2d(axs[2, 0], sel(comp_late - comp_early)["ddt_ssh"], amp=2.5)

## plot nhf
## plot relative NHF
cp_nhf = plot_2d(axs[0, 1], sel(comp_early)["nhf"], amp=100)
plot_2d(axs[1, 1], sel(comp_late)["nhf"], amp=100)
cp_nhf_ = plot_2d(axs[2, 1], sel(comp_late - comp_early)["nhf"], amp=50)

## plot precip (difference)
# cp_pr_ = plot_2d(axs[2, 1], sel(comp_late - comp_early)["tauy"], **taux_kwargs)

## plot zero line
for j in [0, 1]:
    plot_sst_rel_on_axs_2d(axs[:, j], sel(comp_early), sel(comp_late), lw=1.5)

src.utils.plot_nino4_box(axs[-1, 0], c="k")

## colorbars
# fig.colorbar(cp_nhf, ax=axs[:2, 0], ticks=[-80, 0, 80])
# fig.colorbar(cp_pr, ax=axs[:2, 1], ticks=[-10, 0, 10])
# fig.colorbar(cp_nhf_, ax=axs[2, 0], ticks=[-40, 0, 40])
# fig.colorbar(cp_pr_, ax=axs[2, 1], ticks=[-10, 0, 10])

plt.show()

#### SSH, heat flux

In [None]:
## specify which month to look at
sel = lambda x: x.sel(lag=1)

fig = plt.figure(figsize=(15, 4.375), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## plot sst
cp_nhf = plot_2d(axs[0, 0], sel(comp_early)["ssh"], amp=15)
plot_2d(axs[1, 0], sel(comp_late)["ssh"], amp=15)
cp_nhf_ = plot_2d(axs[2, 0], sel(comp_late - comp_early)["ssh"], amp=7.5)

## plot nhf
## plot relative NHF
cp_nhf = plot_2d(axs[0, 1], sel(comp_early)["ssh"].differentiate("longitude"), amp=0.5)
plot_2d(axs[1, 1], sel(comp_late)["ssh"].differentiate("longitude"), amp=0.5)
cp_nhf_ = plot_2d(
    axs[2, 1], sel(comp_late - comp_early)["ssh"].differentiate("longitude"), amp=0.25
)

## plot precip (difference)
# cp_pr_ = plot_2d(axs[2, 1], sel(comp_late - comp_early)["tauy"], **taux_kwargs)

## plot zero line
for j in [0, 1]:
    plot_sst_rel_on_axs_2d(axs[:, j], sel(comp_early), sel(comp_late), lw=1.5)

src.utils.plot_nino4_box(axs[-1, 0], c="k")
src.utils.plot_nino4_box(axs[-1, -1], c="k")

## colorbars
# fig.colorbar(cp_nhf, ax=axs[:2, 0], ticks=[-80, 0, 80])
# fig.colorbar(cp_pr, ax=axs[:2, 1], ticks=[-10, 0, 10])
# fig.colorbar(cp_nhf_, ax=axs[2, 0], ticks=[-40, 0, 40])
# fig.colorbar(cp_pr_, ax=axs[2, 1], ticks=[-10, 0, 10])

plt.show()

#### NHF PDF analysis

##### Scatter vs. Niño indices

In [None]:
## func to get nhf in nino4
def get_nhf(x):

    return src.utils.reconstruct_fn(
        scores=x["nhf_total"],
        components=x["nhf_comp"],
        fn=src.utils.get_nino4,
    )


def get_pr(x):

    return src.utils.reconstruct_fn(
        scores=x["pr"],
        components=x["pr_comp"],
        fn=src.utils.get_nino4,
    )


get_idxs = lambda x: xr.merge([get_nhf(x).rename("nhf"), get_pr(x).rename("pr")])

## get nhf
idxs_early = get_idxs(anom_early)
idxs_late = get_idxs(anom_late)

In [None]:
xvar = "T_4"
yvar = "nhf"

## specify sel func
sel = lambda x: src.utils.sel_month(x, months=2).transpose("time", "member")

fig, axs = plt.subplots(1, 2, figsize=(6, 2.5), layout="constrained")

axs[0].scatter(sel(anom_early[xvar]), sel(idxs_early[yvar]), s=1)
axs[1].scatter(sel(anom_late[xvar]), sel(idxs_late[yvar]), s=1)

## format
src.utils.set_lims(axs)
axs[1].set_yticks([])
for ax in axs:
    ax_kwargs = dict(ls="--", c="k", lw=0.8)
    ax.axhline(0, **ax_kwargs)
    ax.axvline(0, **ax_kwargs)
    ax.axhline(90, **ax_kwargs)

plt.show()

##### Look at pdfs

In [None]:
## get windowed
nhf_scores = anom["nhf"]  # + forced["nhf"]
nhf_scores = xr.merge([anom[["nhf"]], (anom["nhf"] + forced["nhf"]).rename("total")])
nhf_rolling = src.utils.get_windowed(nhf_scores, stride=120)

## add back components
nhf_rolling = xr.merge(
    [nhf_rolling, forced["nhf_comp"], forced["nhf_comp"].rename("total_comp")]
)

## subset for month
nhf_rolling = src.utils.sel_month(nhf_rolling, months=2).isel(year=slice(None, -1))

## reconstruct nino4
nhf_rolling = src.utils.reconstruct_wrapper(nhf_rolling, fn=src.utils.get_nino4)

## stack member/time
nhf_rolling = nhf_rolling.stack(sample=["member", "time"])

In [None]:
## should we plot total or anomaly?
PLOT_TOTAL = True

if PLOT_TOTAL:
    x = nhf_rolling["total"]
else:
    x = nhf_rolling["nhf"]

pdf_kwargs = dict(edges=np.arange(-140, 140, 10))
pdf0, edges = src.utils.get_empirical_pdf(x.sel(year=1870), **pdf_kwargs)
pdf2, _ = src.utils.get_empirical_pdf(x.sel(year=2040), **pdf_kwargs)
pdf3, _ = src.utils.get_empirical_pdf(x.sel(year=2080), **pdf_kwargs)

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
ax.stairs(pdf0, edges)
ax.stairs(pdf2, edges)
ax.stairs(pdf3, edges)
ax.set_xlim([-60, None])

plt.show()

In [None]:
Th_late

In [None]:
list(Th)

In [None]:
EDGES = np.arange(-4.2, 4.6, 0.4)

sel = lambda x, t_idx: src.utils.sel_month(x["T_4"].sel(t_idx), months=3)

pdf0, _ = src.utils.get_empirical_pdf(sel(Th, t_early), edges=EDGES)
pdf1, _ = src.utils.get_empirical_pdf(sel(Th, t_late), edges=EDGES)

In [None]:
plt.stairs(pdf0, EDGES)
plt.stairs(pdf1, EDGES)