# ENSO composite over time

## imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_dT_sub(Tsub, Hm, delta=25):
    """
    Get temperature difference b/n entrainment zone and mixed layer.
    (positive if entrainment zone is warmer than ML)
    """

    ## find indices in ML and entrainment zone (ez)
    in_ml = Tsub.z_t <= Hm
    in_ez = (Tsub.z_t > Hm) & (Tsub.z_t < (delta + Hm))

    ## get Tbar and Tplus (following Frankignoul et al paper)
    Tbar = Tsub.where(in_ml).mean("z_t")
    Tplus = Tsub.where(in_ez).mean("z_t")

    ## get gradient
    dT = Tplus - Tbar

    return dT


def prep(data):
    """remove sst dependence and compute tendencies"""

    ## remove from h indices
    for h_idx in ["h_w", "h", "h_w_z20"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var=h_idx, T_var="T_34"
        )

    return data


def load_budget_data(t_early, t_late, target_grid):
    """load ground truth heat budget data"""

    def load_var(varname):
        """load variable from prepped folder"""

        ## open data
        data = xr.open_mfdataset(
            sorted(list(pathlib.Path(DATA_FP, "cesm", f"{varname}_temp").glob("*.nc"))),
            concat_dim="member",
            combine="nested",
            parallel=True,
        )

        return data.assign_coords({"member": np.arange(100)})

    ## load data
    budget_data = xr.merge([load_var(v) for v in ["adv", "ddt_T"]])

    ## get difference
    # budget_data["diff"] = budget_data["TEND_TEMP"] - budget_data["ADV_3D_TEMP"]

    ## convert z coord from cm to m, and unit from K/s to K/mo
    M_PER_CM = 1e-2
    SEC_PER_MO = 8.64e4 * 30

    ## convert from (i) cm to m and (ii) K/s to K/mo
    budget_data = budget_data.assign_coords({"z_t": budget_data.z_t * M_PER_CM})
    budget_data = budget_data * SEC_PER_MO

    ## fix longitude coordinate
    budget_data = budget_data.assign_coords({"nlon": budget_data.lon.values})
    budget_data = budget_data.drop_vars("lon").rename({"nlon": "longitude"})

    ## trim in time
    budget_data = xr.concat(
        [budget_data.sel(t_early), budget_data.sel(t_late)], dim="time"
    )

    ## interpolate
    budget_data = budget_data.interp_like(target_grid)

    # ## separate forced/anomalies
    forced_bud, anom_bud = src.utils.separate_forced(budget_data)

    return forced_bud, anom_bud


def plot_mlds(axs, bar_early, bar_late, month=None):
    """plot mixed layer depth on pair of axs objects"""

    ## get longitude
    lon = src.utils.merimean(bar_early).longitude

    ## helper function to get month
    if month is None:
        sel = lambda x: x.mean("month")
    else:
        sel = lambda x: x.sel(month=month)

    ## plot
    axs[0].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[1].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")
    axs[2].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[2].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")

    return


def get_nino3_da(x):
    """get nino3 on dataarray"""

    if "latitude" not in x.dims:
        x = x.expand_dims("latitude")

    return src.utils.get_nino3(x)


def get_nino34_da(x):
    """get nino3.4 on dataarray"""

    if "latitude" not in x.dims:
        x = x.expand_dims("latitude")

    return src.utils.get_nino34(x)


def get_composite(idx, data, peak_month, q=0.95, is_warm=True, event_type=None):
    """
    Get hovmoller composite based on specified:
    - data: used to compute index/make composite
    - peak_month: month to center composite on
    - q: quantile threshold for composite
    """

    ## handle warm/cold case
    if is_warm:
        kwargs = dict(q=q, check_cutoff=lambda x, cut: x > cut)
    else:
        kwargs = dict(q=1 - q, check_cutoff=lambda x, cut: x < cut)

    ## kwargs for composite
    kwargs = dict(
        kwargs, peak_month=peak_month, idx=idx, data=data, event_type=event_type
    )

    ## composite of projected data
    comp_proj = src.utils.make_composite(**kwargs)

    return comp_proj


def get_spatial_composite(data, **composite_kwargs):
    """
    Get spatial composite
    """

    ## pull out components
    components, scores = src.utils.split_components(data)

    ## get projected composite
    comp_proj = get_composite(data=scores, **composite_kwargs)

    ## reconstruct spatial fields
    comp = reconstruct_helper(comp_proj, components, func=lambda x: x).drop_vars("mode")

    # ## reconstruct relative SST
    # comp["sst_rel"] = comp["sst_total"] - comp["trop_sst_05"]

    return comp


def reconstruct_helper(composite, components, func):
    """reconstruction helper function for composite"""

    ## copy to hold reconstructed results
    composite_recon = copy.deepcopy(composite)

    ## reconstruct anomalies
    for c in list(components):
        composite_recon[c] = src.utils.reconstruct_fn(
            components=components[c],
            scores=composite[c],
            fn=func,
        )

    ## check for "total" fields
    for c in list(composite):
        if "_total" in c:
            n = c[:-6]
            if n in list(components):
                composite_recon[c] = src.utils.reconstruct_fn(
                    components=components[n],
                    scores=composite[c],
                    fn=func,
                )

    return composite_recon


def get_spatial_clim(forced, lags, peak_month, fn=lambda x: x):
    """get climatologies of spatial variables"""

    ## reconstruct monthly climatology for period
    clim = src.utils.reconstruct_clim(data=forced, fn=fn)

    ## convert to lag coordinates
    months = 1 + np.mod(lags + peak_month - 1, 12)
    clim_comp = xr.concat(
        [clim.sel(month=m).drop_vars("month") for m in months],
        dim=lags,
    )

    return clim_comp


def get_T_ml_tendency(T, mld, delta=5, H0=None):
    """compute mixed-layer temperature tendency"""

    ## integrate over mixed layer
    T_ml = src.utils.get_ml_avg(T, Hm=mld, delta=delta, H0=H0)

    ## compute tendency
    return T_ml.differentiate("lag")


def get_spatial_composite_wrapper(
    data,
    forced_scores,
    peak_month,
    H0=None,
    **composite_kwargs,
):
    """
    Get spatial composite
    """

    ## get spatial composite of anomalies
    composite = get_spatial_composite(
        data=data,
        peak_month=peak_month,
        **composite_kwargs,
    )

    ## get background state
    comp_clim = get_spatial_clim(
        forced=forced_scores,
        lags=composite.lag,
        peak_month=peak_month,
    )

    ## get thermocline depth
    H_clim = src.utils.get_H_int(comp_clim["T"], thresh=0.04)
    H_full = src.utils.get_H_int(comp_clim["T"] + composite["T"], thresh=0.04)
    comp_clim["H"] = H_clim
    composite["H"] = H_full - H_clim

    ## compute ocean feedbacks
    feedbacks = src.utils.get_feedbacks(bar=comp_clim, prime=composite)
    composite = xr.merge([composite, feedbacks])

    ## get differences
    # composite["diff"] = composite["TEND_TEMP"] - composite["ADV_3D_TEMP"]

    ## add SST tendency
    # composite["ddt_sst"] = composite["sst"].differentiate("lag")
    # composite["ddt_ssh"] = composite["ssh"].differentiate("lag")
    # composite["ddt_z20"] = composite["z20"].differentiate("lag")
    # composite["ddt_ssh_hat"] = composite["ssh_hat"].differentiate("lag")

    ## get NHF in units of K/mo
    sec_per_mo = 8.64e4 * 30
    rho = 1.02e3
    Cp = 4.2e3

    ## get H
    if H0 is None:
        H = comp_clim["mld"]
    else:
        H = H0

    composite["Q"] = composite["nhf"] * sec_per_mo / (rho * Cp * H)

    return composite, comp_clim

## Load data

### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices(load_z20=True)

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")

## compute relative sst
for n in ["T_3", "T_34", "T_4"]:
    Th[f"{n}_rel"] = Th_total[n] - trop_sst["trop_sst_10"]

### Spatial data

#### Load

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()

## add normalized Th data
anom = xr.merge([anom, Th / Th.std()])

#### compute $T$ tendency

In [None]:
## get sst tendency (and convert from 1/yr to 1/mo)
anom["ddt_sst"] = 1 / 12 * src.utils.get_ddt(anom[["sst"]], is_forward=False)["ddt_sst"]
anom["ddt_T"] = 1 / 12 * src.utils.get_ddt(anom[["T"]], is_forward=False)["ddt_T"]

#### Get data

In [None]:
## specify vars to look at
VARNAMES = [
    "T_3",
    "T_34",
    "h_w",
    "h_w_z20",
    "ddt_T",
    "ssh",
    "sst",
    "nhf",
    "T",
    "w",
    "u",
]
VARNAMES += [f"{v}_comp" for v in VARNAMES[-6:]]

## Get windowed data
# anom = src.utils.get_windowed(anom[VARNAMES], stride=120).compute()
# forced = src.utils.get_windowed(forced[VARNAMES[-12:]], stride=120).compute()
anom = src.utils.get_windowed(anom, stride=120).compute()
forced = src.utils.get_windowed(forced, stride=120).compute()

#### Compute composite

In [None]:
## specify what variable to use
VARNAME = "T_34"

## specify shared args
kwargs = dict(
    peak_month=12,
    q=0.95,
    event_type=1,
    H0=70,
)

## save filepath
SAVE_DIR = pathlib.Path(os.environ["SAVE_FP"], "composites", f"{VARNAME}_{kwargs['q']}")


if (SAVE_DIR / "cold.nc").is_file():
    try:
        comps_cold = xr.open_dataset(SAVE_DIR / "cold.nc")
        comps_warm = xr.open_dataset(SAVE_DIR / "warm.nc")
        clim = xr.open_dataset(SAVE_DIR / "clim.nc")
    except:
        print("error...")

else:

    ## empty array to hold coefficients
    comps_cold = []
    comps_warm = []
    clims = []

    for y in tqdm.tqdm(anom.year):

        ## get args for composite
        comp_kwargs = dict(
            idx=anom[VARNAME].sel(year=y),
            forced_scores=forced.sel(year=y),
            data=anom.sel(year=y),
        )

        ## warm
        comp_warm, clim = get_spatial_composite_wrapper(
            is_warm=True,
            **comp_kwargs,
            **kwargs,
        )

        ## cold
        comp_cold, _ = get_spatial_composite_wrapper(
            is_warm=False,
            **comp_kwargs,
            **kwargs,
        )

        ## append to list
        clims.append(clim)
        comps_warm.append(comp_warm)
        comps_cold.append(comp_cold)

    ## convert to xarray
    comps_warm = xr.concat(comps_warm, dim=anom.year)
    comps_cold = xr.concat(comps_cold, dim=anom.year)
    clims = xr.concat(clims, dim=anom.year)

    ## save to fild
    comps_cold.to_netcdf(SAVE_DIR / "cold.nc")
    comps_warm.to_netcdf(SAVE_DIR / "warm.nc")
    clims.to_netcdf(SAVE_DIR / "clim.nc")

Mixed layer averages

In [None]:
ml_kwargs = dict(H0=70, Hm=None)
comps_cold = src.utils.get_ml_avg_ds(comps_cold, **ml_kwargs)
comps_warm = src.utils.get_ml_avg_ds(comps_warm, **ml_kwargs)

## get time tendency
for comp in [comps_cold, comps_warm]:
    comp["ddt_T_ml"] = comp["T_ml"].differentiate("lag")

### Plot

In [None]:
def plot_cycle_hov_lagged(
    ax,
    data,
    amp,
    is_filled=True,
    xticks=[190, 240],
    lat_bound=5,
    nlev=5,
    cmap="cmo.balance",
    levels=None,
):
    """plot data on ax object"""

    ## set levels if not specified
    if levels is None:
        levels = src.utils.make_cb_range(amp, amp / nlev)

    ## specify shared kwargs
    shared_kwargs = dict(levels=levels, extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap=cmap)

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## average over latitudes (if necessary)
    if "latitude" in data.coords:
        plot_data = src.utils.merimean(data, lat_bound=lat_bound)
    else:
        plot_data = data

    ## do the plotting
    cp = plot_fn(
        plot_data.longitude,
        plot_data.lag,
        plot_data.transpose("lag", "longitude"),
        **kwargs,
        **shared_kwargs,
    )

    ## format ax object
    kwargs = dict(c="w", ls="--", lw=0.8, alpha=0.5)
    ax.set_xlim([145, 280])
    ax.set_xlabel("Lon")
    ax.set_xticks(xticks)
    for tick in xticks:
        ax.axvline(tick, **kwargs)
    ax.axhline(0, **kwargs)

    return cp


def format_hov_axs(axs, peak_mon):
    """format hov axs"""

    ## label
    font_kwargs = dict(size=10)
    axs[0].set_title("Early", **font_kwargs)
    axs[1].set_title("Late", **font_kwargs)
    axs[2].set_title("Difference (scaled)", **font_kwargs)
    for ax in axs:
        ax.set_yticks([])

    src.utils.label_hov_yaxis(axs[0], peak_mon=peak_mon)

    return axs


def plot_eli_on_axs(axs, x0, x1, name="eli_05"):
    """plot ELI boundary on hovmoller axs"""

    ## plot ELI
    axs[0].plot(x0[name], x0.lag, c="magenta")
    axs[1].plot(x1[name], x1.lag, c="magenta", ls="--")
    axs[2].plot(x0[name], x0.lag, c="magenta", ls="-")
    axs[2].plot(x1[name], x1.lag, c="magenta", ls="--")

    return


def plot_level(ax, comp, level, ls="-", c="magenta"):
    """plot single level on hovmoller"""
    ax.contour(
        comp.longitude,
        comp.lag,
        comp.transpose("lag", ...),
        levels=[level],
        colors=c,
        linestyles=ls,
    )
    return


def plot_sst_rel_on_axs(axs, x0, x1, lev=0):
    """plot ELI boundary on hovmoller axs"""

    ## plot ELI
    plot_level(axs[0], x0["sst_rel"], level=lev)
    plot_level(axs[1], x1["sst_rel"], level=lev, ls="--")
    plot_level(axs[2], x0["sst_rel"], level=lev)
    plot_level(axs[2], x1["sst_rel"], level=lev, ls="--")

    return


def plot_comp_on_axs(
    axs,
    x0,
    x1,
    name,
    amp,
    peak_month,
    amp_diff=None,
    nlev=5,
    **kwargs,
):
    """plot composite on axs objects"""

    ## handle amp_diff=None
    if amp_diff is None:
        amp_diff = amp / 2

    ## loop thru: x0, x1, (x1-x0)
    for ax, x in zip(axs, [x0, x1, (amp / amp_diff) * (x1 - x0)]):

        ## Plot data and suppress y-axis plot
        cf = plot_cycle_hov_lagged(ax=ax, data=x[name], amp=amp, nlev=nlev, **kwargs)

    ## format/label
    format_hov_axs(axs, peak_mon=peak_month)

    return cf

### show plots

In [None]:
## shared args
hov_kwargs = dict(
    peak_month=kwargs["peak_month"],
)
merimean = lambda x: x.sel(latitude=slice(-5, 5)).mean("latitude")

#### SST

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(
    axs,
    name="sst",
    amp=3.5,
    x0=comps_warm.sel(year=1870),
    x1=comps_warm.sel(year=2070),
    nlev=10,
    **hov_kwargs,
)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-2.5, 0, 2.5], label="K")

plt.show()

### Asymmetry

In [None]:
YEAR = 1870

#### SST

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(
    axs,
    name="sst",
    amp=3.5,
    x1=comps_warm.sel(year=YEAR),
    x0=-comps_cold.sel(year=YEAR),
    nlev=10,
    **hov_kwargs,
)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-30, 30], label="K")

plt.show()

#### SSH

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(
    axs,
    name="ssh",
    amp=25,
    x1=comps_warm.sel(year=YEAR),
    x0=-comps_cold.sel(year=YEAR),
    nlev=10,
    **hov_kwargs,
)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-30, 30], label="K")

plt.show()

#### $\tau_x$

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(
    axs,
    name="taux",
    amp=5e-2,
    x1=comps_warm.sel(year=YEAR),
    x0=-comps_cold.sel(year=YEAR),
    nlev=10,
    **hov_kwargs,
)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-30, 30], label="K")

plt.show()

#### $Z_{20}$

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(
    axs,
    name="z20",
    amp=50,
    x1=comps_warm.sel(year=YEAR),
    x0=-comps_cold.sel(year=YEAR),
    nlev=10,
    **hov_kwargs,
)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-30, 30], label="K")

plt.show()

#### $Q$

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(
    axs,
    name="Q",
    amp=1,
    x1=comps_warm.sel(year=YEAR),
    x0=-comps_cold.sel(year=YEAR),
    nlev=10,
    **hov_kwargs,
)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-2.5, 0, 2.5], label="K")

plt.show()

#### budget

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot data
cf = plot_comp_on_axs(
    axs,
    name="THF_ml",
    amp=2.5,
    x0=comps_cold.sel(year=1870),
    x1=comps_cold.sel(year=2070),
    nlev=10,
    **hov_kwargs,
)

## label
cb = fig.colorbar(cf, ax=axs[2], ticks=[-2.5, 0, 2.5], label="K")

plt.show()

## Hovmollers

In [None]:
def plot_hovs(comp, varname, amp0, amp1, lag=1, fn=src.utils.get_nino3):
    """plot data on hovmoller"""

    ## get data for plot
    plot_data = comp[varname]

    ## make sure latitude is a dimension
    if "latitude" not in plot_data.dims:
        plot_data = plot_data.expand_dims("latitude")

    ## get data for plots
    x0 = plot_data.sel(lag=lag, latitude=slice(-5, 5)).mean("latitude")
    x0 = x0.transpose("year", ...)
    x1 = fn(plot_data).transpose("year", ...)

    fig, axs = plt.subplots(1, 2, figsize=(6, 3), layout="constrained")

    ## longitude on x-axis
    axs[0].contourf(
        x0.longitude,
        x0.year,
        x0,
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(amp0, amp0 / 10),
        extend="both",
    )

    ## lag on x-axis
    axs[1].contourf(
        plot_data.lag,
        plot_data.year,
        x1,
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(amp1, amp1 / 10),
        extend="both",
    )

    axs[0].set_xlim([140, 280])
    axs[0].axvline(210, ls="--", c="k", lw=0.8)
    axs[0].axvline(270, ls="--", c="k", lw=0.8)
    axs[1].axvline(0, ls="--", c="k", lw=0.8)
    axs[0].set_xticks([160, 210, 270])
    axs[0].set_yticks([1870, 1980, 2090])
    axs[1].set_xticks([-12, -6, 0, 6, 12])
    axs[1].set_yticks([])

    return fig, axs

In [None]:
## specify month
LAG = 1

## specify composite type
COMP = comps_warm
# COMP = -comps_cold

## specify function type
FN = src.utils.get_nino3

## get shared args
kwargs = dict(amp0=1, amp1=1, lag=LAG, fn=FN, comp=COMP)


for VARNAME in ["ddt_T_ml", "THF_ml", "EKM_ml", "ZAF_ml", "DD_ml", "Q"]:

    print(f"\n\n{VARNAME}")
    fig, axs = plot_hovs(varname=VARNAME, **kwargs)
    plt.show()

In [None]:
## specify month
LAG = 4

## specify comp
COMP = comps_cold

## func to select data
sel = lambda x: (x - x.isel(year=0)).sel(lag=LAG)


fig, axs = plt.subplots(1, 2, figsize=(7, 3))


## positive feedbacks
for f in ["EKM", "THF", "ZAF", "ADV", "ddt_T"]:
    axs[0].plot(COMP.year, sel(get_nino3_da(COMP[f"{f}_ml"])), label=f)
# axs[0].plot(
#     coefs.year, sel(get_nino3_da(coefs[f"ddt_T_ml"])), label=r"$\frac{dT}{dt}$"
# )

## negative feedbacks
for f in ["DD_ml", "Q"]:
    axs[1].plot(COMP.year, sel(get_nino3_da(COMP[f"{f}"])), label=f)

## legends
axs[0].legend(prop=dict(size=8))
axs[1].legend(prop=dict(size=8))

for ax in axs:
    ax.axhline(0, ls="--", c="k", lw=0.8)


plt.show()