# ENSO composites

## Imports

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Load data

### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load ELI and tropical sst data
eli = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/eli.nc"))
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))
Th = xr.merge([Th, eli, trop_sst])

### Spatial data

In [None]:
## load spatial data
CONS_DIR = pathlib.Path(DATA_FP, "cesm", "consolidated")
forced = xr.open_dataset(CONS_DIR / "forced.nc")
anom = xr.open_dataset(CONS_DIR / "anom.nc")

## add T,h information
for n in ["T_3", "T_34", "T_4", "eli_05", "h", "h_w", "trop_sst_05"]:
    anom[n] = Th[n]

## get "total" sst and precip
for n in ["sst", "pr"]:
    anom[f"{n}_total"] = forced[n] + anom[n]

### Preprocess

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
# t_late = dict(time=slice("1990", "2020"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early)
anom_late = anom.sel(t_late)

## Composite

### Funcs

In [None]:
def get_composite(idx, data, peak_month, q=0.95, is_warm=True):
    """
    Get hovmoller composite based on specified:
    - data: used to compute index/make composite
    - peak_month: month to center composite on
    - q: quantile threshold for composite
    """

    ## handle warm/cold case
    if is_warm:
        kwargs = dict(q=q, check_cutoff=lambda x, cut: x > cut)
    else:
        kwargs = dict(q=1 - q, check_cutoff=lambda x, cut: x < cut)

    ## kwargs for composite
    kwargs = dict(kwargs, peak_month=peak_month, idx=idx, data=data)

    ## composite of projected data
    comp_proj = src.utils.make_composite(**kwargs)

    return comp_proj


def get_spatial_composite(data, **composite_kwargs):
    """
    Get spatial composite
    """

    ## pull out components
    components, scores = src.utils.split_components(data)

    ## get projected composite
    comp_proj = get_composite(data=scores, **composite_kwargs)

    ## reconstruct spatial fields
    comp = reconstruct_helper(comp_proj, components, func=lambda x: x).drop_vars("mode")

    ## reconstruct relative SST
    comp["sst_rel"] = comp["sst_total"] - comp["trop_sst_05"]

    return comp


def reconstruct_helper(composite, components, func):
    """reconstruction helper function for composite"""

    ## copy to hold reconstructed results
    composite_recon = copy.deepcopy(composite)

    ## reconstruct anomalies
    for c in list(components):
        composite_recon[c] = src.utils.reconstruct_fn(
            components=components[c],
            scores=composite[c],
            fn=func,
        )

    ## check for "total" fields
    for c in list(composite):
        if "_total" in c:
            n = c[:-6]
            composite_recon[c] = src.utils.reconstruct_fn(
                components=components[n],
                scores=composite[c],
                fn=func,
            )

    return composite_recon


def get_spatial_clim(forced, lags, peak_month):
    """get climatologies of spatial variables"""

    ## reconstruct monthly climatology for period
    clim = src.utils.reconstruct_clim(data=forced)

    ## convert to lag coordinates
    months = 1 + np.mod(lags + peak_month - 1, 12)
    clim_comp = xr.concat(
        [clim.sel(month=m).drop_vars("month") for m in months],
        dim=lags,
    )

    return clim_comp


def get_T_ml_tendency(T, mld, delta=5, H0=None):
    """compute mixed-layer temperature tendency"""

    ## integrate over mixed layer
    T_ml = src.utils.get_ml_avg(T, Hm=mld, delta=delta, H0=H0)

    ## compute tendency
    return T_ml.differentiate("lag")


def get_spatial_composite_wrapper(
    data,
    forced_scores,
    peak_month,
    delta=5,
    H0=None,
    **composite_kwargs,
):
    """
    Get spatial composite
    """

    ## get spatial composite of anomalies
    composite = get_spatial_composite(
        data=data,
        peak_month=peak_month,
        **composite_kwargs,
    )

    ## get background state
    comp_clim = get_spatial_clim(
        forced=forced_scores,
        lags=composite.lag,
        peak_month=peak_month,
    )

    ## compute ocean feedbacks
    feedbacks = src.utils.get_feedbacks(bar=comp_clim, prime=composite)
    composite = xr.merge([composite, feedbacks])

    ## add SST tendency
    composite["ddt_sst"] = composite["sst"].differentiate("lag")

    ## get NHF in units of K/mo
    sec_per_mo = 8.64e4 * 30
    rho = 1.02e3
    Cp = 4.2e3
    H = comp_clim["mld"]
    # H = 50
    composite["Q"] = composite["nhf"] * sec_per_mo / (rho * Cp * H)

    return composite, comp_clim

### Compute

In [None]:
## specify what variable to use
VARNAME = "T_34"

## specify shared args
kwargs = dict(
    peak_month=12,
    q=0.95,
    is_warm=True,
    delta=10,
    H0=60,
)

## do the compute
comp_early, clim_early = get_spatial_composite_wrapper(
    idx=anom_early[VARNAME],
    data=anom_early,
    forced_scores=forced.sel(t_early),
    **kwargs,
)
comp_late, clim_late = get_spatial_composite_wrapper(
    idx=anom_late[VARNAME],
    data=anom_late,
    forced_scores=forced.sel(t_late),
    **kwargs,
)

## Decompose feedback changes
feedback_changes = src.utils.decompose_feedback_changes(
    bar_early=clim_early,
    prime_early=comp_early,
    bar_late=clim_late,
    prime_late=comp_late,
)
comp_late = xr.merge([comp_late, feedback_changes])

#### Integrate over mixed layer and latitudes

In [None]:
## Integrate over ML
ml_kwargs = dict(H0=70, Hm=None)
comp_early = src.utils.get_ml_avg_ds(comp_early, **ml_kwargs)
comp_late = src.utils.get_ml_avg_ds(comp_late, **ml_kwargs)

## hovmoller version
merimean = lambda x: x.sel(latitude=slice(-5, 5)).mean("latitude")
hov_comp_early = merimean(comp_early).transpose("lag", ...)
hov_comp_late = merimean(comp_late).transpose("lag", ...)

### Plot Hovmollers

In [None]:
def plot_cycle_hov_lagged(
    ax,
    data,
    amp,
    is_filled=True,
    xticks=[190, 240],
    lat_bound=5,
    nlev=5,
    cmap="cmo.balance",
    levels=None,
):
    """plot data on ax object"""

    ## set levels if not specified
    if levels is None:
        levels = src.utils.make_cb_range(amp, amp / nlev)

    ## specify shared kwargs
    shared_kwargs = dict(levels=levels, extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap=cmap)

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## average over latitudes (if necessary)
    if "latitude" in data.coords:
        plot_data = merimean(data, lat_bound=lat_bound)
    else:
        plot_data = data

    ## do the plotting
    cp = plot_fn(
        plot_data.longitude,
        plot_data.lag,
        plot_data.transpose("lag", "longitude"),
        **kwargs,
        **shared_kwargs,
    )

    ## format ax object
    kwargs = dict(c="w", ls="--", lw=0.8, alpha=0.5)
    ax.set_xlim([145, 280])
    ax.set_xlabel("Lon")
    ax.set_xticks(xticks)
    for tick in xticks:
        ax.axvline(tick, **kwargs)
    ax.axhline(0, **kwargs)

    return cp


def format_hov_axs(axs, peak_mon):
    """format hov axs"""

    ## label
    font_kwargs = dict(size=10)
    axs[0].set_title("Early", **font_kwargs)
    axs[1].set_title("Late", **font_kwargs)
    axs[2].set_title("Difference (scaled)", **font_kwargs)
    for ax in axs:
        ax.set_yticks([])

    src.utils.label_hov_yaxis(axs[0], peak_mon=peak_mon)

    return axs


def plot_eli_on_axs(axs, x0, x1, name="eli_05"):
    """plot ELI boundary on hovmoller axs"""

    ## plot ELI
    axs[0].plot(x0[name], x0.lag, c="magenta")
    axs[1].plot(x1[name], x1.lag, c="magenta", ls="--")
    axs[2].plot(x0[name], x0.lag, c="magenta", ls="-")
    axs[2].plot(x1[name], x1.lag, c="magenta", ls="--")

    return


def plot_level(ax, comp, level, ls="-", c="magenta"):
    """plot single level on hovmoller"""
    ax.contour(
        comp.longitude,
        comp.lag,
        comp.transpose("lag", ...),
        levels=[level],
        colors=c,
        linestyles=ls,
    )
    return


def plot_sst_rel_on_axs(axs, x0, x1, lev=0):
    """plot ELI boundary on hovmoller axs"""

    ## plot ELI
    plot_level(axs[0], x0["sst_rel"], level=lev)
    plot_level(axs[1], x1["sst_rel"], level=lev, ls="--")
    plot_level(axs[2], x0["sst_rel"], level=lev)
    plot_level(axs[2], x1["sst_rel"], level=lev, ls="--")

    return


def plot_comp_on_axs(
    axs,
    x0,
    x1,
    name,
    amp,
    peak_month,
    amp_diff=None,
    nlev=5,
    **kwargs,
):
    """plot composite on axs objects"""

    ## handle amp_diff=None
    if amp_diff is None:
        amp_diff = amp / 2

    ## loop thru: x0, x1, (x1-x0)
    for ax, x in zip(axs, [x0, x1, (amp / amp_diff) * (x1 - x0)]):

        ## Plot data and suppress y-axis plot
        cf = plot_cycle_hov_lagged(ax=ax, data=x[name], amp=amp, nlev=nlev, **kwargs)

    ## format/label
    format_hov_axs(axs, peak_mon=peak_month)

    return cf

#### Shared args for hovmoller

In [None]:
hov_kwargs = dict(
    x0=hov_comp_early,
    x1=hov_comp_late,
    peak_month=kwargs["peak_month"],
)

#### SST, SSH

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="sst", amp=3.5, **hov_kwargs)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-3.5, 0, 3.5], label="K")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)

plt.show()

#### Thermocline/Ekman feedbacks

In [None]:
fig, axs = plt.subplots(3, 3, figsize=(6, 7), layout="constrained")

for j, n in enumerate(["ddt_sst", "THF_ml", "EKM_ml"]):

    ## plot data
    cf = plot_comp_on_axs(axs[j, :], name=n, amp=2, **hov_kwargs)

    ## label
    cb = fig.colorbar(cf, ax=axs[j, 2], ticks=[-2, 0, 2], label="K/s")

    ## plot ELI
    plot_eli_on_axs(axs[j, :], x0=comp_early, x1=comp_late)

## format
for ax in axs[:-1].flatten():
    ax.set_xticks([])
    ax.set_xlabel(None)
for ax in axs[1:].flatten():
    ax.set_title(None)

plt.show()

#### Precip, relative ELI

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="pr", amp=1e-4, **hov_kwargs)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-1e-4, 0, 1e-4], label="K")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)

plt.show()

#### Total precip, relative SST

In [None]:
## shared args
pr_kwargs = dict(levels=np.arange(0, 21, 3), cmap="cmo.rain", amp=1e-4)

## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf0 = plot_cycle_hov_lagged(
    ax=axs[0],
    data=8.6e4 * hov_comp_early["pr_total"],
    **pr_kwargs,
)

cf1 = plot_cycle_hov_lagged(
    ax=axs[1],
    data=8.6e4 * hov_comp_late["pr_total"],
    **pr_kwargs,
)

cf_diff = plot_cycle_hov_lagged(
    ax=axs[2],
    data=8.6e4 * (hov_comp_late - hov_comp_early)["pr_total"],
    cmap="cmo.balance_r",
    levels=src.utils.make_cb_range(9, 1.8),
    amp=1e-4,
)

## label
cb = fig.colorbar(cf1, ax=axs[1], ticks=[0, 9, 18])
cb = fig.colorbar(cf_diff, ax=axs[2], ticks=[-9, 0, 9], label="mm/day")
format_hov_axs(axs, peak_mon=hov_kwargs["peak_month"])

## plot ELI
plot_sst_rel_on_axs(axs, hov_comp_early, hov_comp_late)

plt.show()

#### $\tau_x$

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="taux", amp=4e-2, **hov_kwargs)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-4e-2, 0, 4e-2], label="Pa")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)

plt.show()

#### Heat flux

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(axs, name="nhf", amp=80, **hov_kwargs)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-80, 0, 80], label=r"$W/m^{2}$")

## plot ELI
plot_eli_on_axs(axs, x0=comp_early, x1=comp_late)

plt.show()

### Plot spatial

updated

In [None]:
def plot_lev_2d(ax, data, lev=0, lw=2, c="magenta", ls="-"):
    """plot relative sst on ax object"""

    cp = ax.contour(
        data.longitude,
        data.latitude,
        data,
        transform=ccrs.PlateCarree(),
        levels=[lev],
        linewidths=lw,
        colors=c,
        linestyles=ls,
    )

    return cp


def plot_sst_rel_on_axs_2d(axs, x0, x1, **kwargs):
    """plot ELI boundary on hovmoller axs"""

    ## plot ELI
    plot_lev_2d(axs[0], x0["sst_rel"], **kwargs)
    plot_lev_2d(axs[1], x1["sst_rel"], ls="--", **kwargs)
    plot_lev_2d(axs[2], x0["sst_rel"], **kwargs)
    plot_lev_2d(axs[2], x1["sst_rel"], ls="--", **kwargs)

    return


def plot_2d(
    ax,
    data,
    amp=None,
    is_filled=True,
    xticks=[190, 240],
    lat_bound=5,
    nlev=10,
    cmap="cmo.balance",
    levels=None,
):
    """plot data on ax object"""

    ## set levels if not specified
    if levels is None:
        levels = src.utils.make_cb_range(amp, amp / nlev)

    ## specify shared kwargs
    shared_kwargs = dict(levels=levels, extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap=cmap)

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## do the plotting
    cp = plot_fn(
        data.longitude,
        data.latitude,
        data.transpose("latitude", "longitude"),
        transform=ccrs.PlateCarree(),
        **kwargs,
        **shared_kwargs,
    )
    return cp

In [None]:
## specify which month to look at
sel = lambda x: x.sel(lag=2)

fig = plt.figure(figsize=(15, 4.375), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## plot relative SST
cp_sst = plot_2d(axs[0, 0], sel(comp_early)["sst_rel"], amp=10)
plot_2d(axs[1, 0], sel(comp_late)["sst_rel"], amp=10)
cp_sst_ = plot_2d(axs[2, 0], sel(comp_late - comp_early)["sst_rel"], amp=2)

## plot precip (total)
pr_kwargs = dict(levels=np.arange(0, 28, 4), cmap="cmo.rain")
cp_pr = plot_2d(axs[0, 1], 8.6e4 * sel(comp_early)["pr_total"], **pr_kwargs)
plot_2d(axs[1, 1], 8.6e4 * sel(comp_late)["pr_total"], **pr_kwargs)

## plot precip (difference)
pr_diff = 8.6e4 * sel(comp_late - comp_early)["pr_total"]
cp_pr_ = plot_2d(axs[2, 1], pr_diff, amp=12, nlev=5, cmap="cmo.balance_r")

## plot zero line
for j in [0, 1]:
    plot_sst_rel_on_axs_2d(axs[:, j], sel(comp_early), sel(comp_late), lw=1.5)

## colorbars
fig.colorbar(cp_sst, ax=axs[:2, 0], ticks=[-10, 0, 10])
fig.colorbar(cp_pr, ax=axs[:2, 1], ticks=[0, 12, 24])
fig.colorbar(cp_sst_, ax=axs[2, 0], ticks=[-2, 0, 2])
fig.colorbar(cp_pr_, ax=axs[2, 1], ticks=[-12, 0, 12])

plt.show()