# ENSO composites

## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Load data

### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")

## compute relative sst
for n in ["T_3", "T_34", "T_4"]:
    Th[f"{n}_rel"] = Th_total[n] - trop_sst["trop_sst_10"]

### Spatial data

In [None]:
## load spatial data
CONS_DIR = pathlib.Path(DATA_FP, "cesm", "consolidated")
forced = xr.open_dataset(CONS_DIR / "forced.nc")
anom = xr.open_dataset(CONS_DIR / "anom.nc")

## add T,h information
for n in ["T_3", "T_34", "T_4", "h", "h_w"]:
    anom[n] = Th[n]

### Preprocess

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early)
anom_late = anom.sel(t_late)

## Composite

### Funcs

In [None]:
def get_composite(idx, data, peak_month, time_idx, q=0.95, is_warm=True):
    """
    Get hovmoller composite based on specified:
    - data: used to compute index/make composite
    - peak_month: month to center composite on
    - q: quantile threshold for composite
    """

    ## handle warm/cold case
    if is_warm:
        kwargs = dict(q=q, check_cutoff=lambda x, cut: x > cut)
    else:
        kwargs = dict(q=1 - q, check_cutoff=lambda x, cut: x < cut)

    ## kwargs for composite
    kwargs = dict(kwargs, peak_month=peak_month, idx=idx, data=data)

    ## composite of projected data
    comp_proj = src.utils.make_composite(**kwargs)

    return comp_proj


def get_spatial_composite(components, **composite_kwargs):
    """
    Get spatial composite
    """

    ## get projected composite
    comp_proj = get_composite(**composite_kwargs)

    ## reconstruct spatial fields
    comp = reconstruct_helper(comp_proj, components, func=lambda x: x).drop_vars("mode")

    ## reconstruct relative SST
    comp["sst_rel"] = comp["sst_total"] - comp["trop_sst_05"]

    return comp


def reconstruct_helper(composite, components, func):
    """reconstruction helper function for composite"""

    ## copy to hold reconstructed results
    composite_recon = copy.deepcopy(composite)

    ## reconstruct anomalies
    for c in list(components):
        composite_recon[c] = src.utils.reconstruct_fn(
            components=components[c],
            scores=composite[c],
            fn=func,
        )

    ## check for "total" fields
    for c in list(composite):
        if "_total" in c:
            n = c[:-6]
            composite_recon[c] = src.utils.reconstruct_fn(
                components=components[n],
                scores=composite[c],
                fn=func,
            )

    return composite_recon


def get_spatial_clim(forced, time_idx, lags, peak_month, components):
    """get climatologies of spatial variables"""

    ## reconstruct monthly climatology for period
    clim = src.utils.reconstruct_fn(
        scores=forced.sel(time_idx).groupby("time.month").mean(),
        components=components,
        fn=lambda x: x,
    )

    ## convert to lag coordinates
    months = 1 + np.mod(lags + peak_month - 1, 12)
    clim_comp = xr.concat(
        [clim.sel(month=m).drop_vars("month") for m in months],
        dim=lags,
    )

    return clim_comp


def add_advection_terms(comp, comp_clim, delta=5, H0=None):
    """add advection terms to composite"""

    ## copy composite
    comp_ = copy.deepcopy(comp)

    ## zonal velocity
    comp_["adv_uprime_Tbar"] = -get_udTdx(u=comp["u"], T=comp_clim["T"])
    comp_["adv_ubar_Tprime"] = -get_udTdx(T=comp["T"], u=comp_clim["u"])

    ## vertical velocity
    comp_["adv_wprime_Tbar"] = get_wdTdz(w=comp["w"], T=comp_clim["T"])
    comp_["adv_wbar_Tprime"] = get_wdTdz(T=comp["T"], w=comp_clim["w"])

    ## integrate over mixed layer
    for v in list(comp_):
        if "adv" in v:
            comp_[f"{v}_ml"] = get_ml_avg(
                comp_[v],
                Hm=comp_clim["mld"],
                delta=delta,
                H0=H0,
            )

    ## add together zonal adv and thermocline feedbacks
    comp_["Th_zaf_ml"] = comp_["adv_wbar_Tprime_ml"] + comp_["adv_uprime_Tbar_ml"]

    return comp_


def get_T_ml_tendency(T, mld, delta=5, H0=None):
    """compute mixed-layer temperature tendency"""

    ## integrate over mixed layer
    T_ml = get_ml_avg(T, Hm=mld, delta=delta, H0=H0)

    ## compute tendency
    return T_ml.differentiate("lag")


def get_spatial_composite_wrapper(
    forced_scores,
    components,
    peak_month,
    time_idx,
    delta=5,
    H0=None,
    **composite_kwargs,
):
    """
    Get spatial composite
    """

    ## shared args
    shared_args = dict(
        time_idx=time_idx,
        peak_month=peak_month,
    )

    ## get spatial composite of anomalies
    composite = get_spatial_composite(
        components=components,
        **shared_args,
        **composite_kwargs,
    )

    ## get background state
    comp_clim = get_spatial_clim(
        forced=forced_scores,
        lags=composite.lag,
        components=components[["u", "w", "T", "mld"]],
        **shared_args,
    )

    ## add advection terms
    composite = add_advection_terms(
        comp=composite, comp_clim=comp_clim, delta=delta, H0=H0
    )

    ## add mixed-layer temperature tendency
    composite["ddt_T"] = get_T_ml_tendency(
        composite["T"],
        mld=comp_clim["mld"],
        delta=delta,
        H0=H0,
    )

    ## add SST tendency
    ddt_sst = composite["sst"].differentiate("lag").rename({"longitude": "lon"})
    composite["ddt_sst"] = ddt_sst.interp({"lon": composite.lon.values})

    ## get NHF in units of K/mo
    sec_per_mo = 8.64e4 * 30
    rho = 1.02e3
    Cp = 4.2e3
    H = comp_clim["mld"]
    # H = 50
    Q = composite["nhf"] * sec_per_mo / (rho * Cp * H)
    Q = Q.rename({"longitude": "lon"})
    composite["Q"] = Q.interp({"lon": composite.lon.values})

    return composite

### Compute