# Spaghetti
Make spaghetti plots for ENSO events

## Imports

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import pandas as pd

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_spaghetti(idx, data, peak_month, event_type=None, q=0.95, is_warm=True):
    """
    Get hovmoller composite based on specified:
    - data: used to compute index/make composite
    - peak_month: month to center composite on
    - q: quantile threshold for composite
    """

    ## handle warm/cold case
    if is_warm:
        kwargs = dict(q=q, check_cutoff=lambda x, cut: x > cut)
    else:
        kwargs = dict(q=1 - q, check_cutoff=lambda x, cut: x < cut)

    ## kwargs for composite
    kwargs = dict(
        kwargs,
        avg=False,
        peak_month=peak_month,
        idx=idx,
        data=data,
        event_type=event_type,
    )

    ## composite of projected data
    spag = src.utils.make_composite(**kwargs)

    return spag

## Load data

### T, h

In [None]:
## fits

## Th data
Th = src.utils.load_cesm_indices(load_z20=True)

## omit first year (bc of NaN in h,hw vars)
Th = Th.sel(time=slice("1851", None))

## standardize (for convenience)
Th /= Th.std()

## get windowed data (used to estimate change in parameters over time)
Th_rolling = src.utils.get_windowed(Th, window_size=480, stride=120)

## get h_e
Th_rolling["h_e"] = Th_rolling["h"] - Th_rolling["h_w"]

## remove T dependence on h
Th_rolling["h_w_hat"] = src.utils.remove_sst_dependence_v2(
    Th_rolling,
    T_var="T_3",
    h_var="h_w",
    dims=["time", "member"],
)

Th_rolling["h_w_z20_hat"] = src.utils.remove_sst_dependence_v2(
    Th_rolling,
    T_var="T_34",
    h_var="h_w_z20",
    dims=["time", "member"],
)

### Fits

In [None]:
## specify variables
T_VAR = "T_34"
H_VAR = "h_w_z20_hat"

## get name of file
if T_VAR == "T_34":
    if H_VAR == "h":
        fname = "T34_h.nc"
    elif H_VAR == "h_w":
        fname = "T34_hw_bymember.nc"
    elif H_VAR == "h_w_z20":
        fname = "T34_hw_z20"
    elif H_VAR == "h_w_z20_hat":
        fname = "T34_hw_z20_hat_nsr"

else:
    if H_VAR == "h_w":
        fname = "T3_hw_bymember.nc"
    elif H_VAR == "h_w_hat":
        fname = "T3_hwhat_bymember.nc"
    elif H_VAR == "h_w_z20":
        fname = "T3_hw_z20"

## get filepath
save_fp = pathlib.Path(SAVE_FP, "fits_cesm", fname)
fits = xr.open_dataset(save_fp)

## compute params
model = src.XRO.XRO(ncycle=12, ac_order=3, is_forward=True)
params = model.get_RO_parameters(fits)

## ensemble avg if necessary
if "member" in params.dims:
    params = params.mean("member")

## convert to composite lags
months = 1 + np.mod(np.arange(12, 37) - 1, 12)
params = xr.concat(
    [params.sel(cycle=m) for m in months],
    dim=pd.Index(np.arange(-12, 13), name="lag"),
)

### Early and late

In [None]:
## get early and late
Th_early = Th_rolling.isel(year=0)
Th_late = Th_rolling.isel(year=-1)
params_early = params.isel(year=0)
params_late = params.isel(year=-1)

## compute spaghetti/composite

In [None]:
## specify composite specs
comp_kwargs = dict(
    is_warm=False,
    event_type=1,
    peak_month=12,
    q=0.95,
)

## specify variable to use for composite
VARNAME = "T_34"

## do the compute
spag_early = get_spaghetti(idx=Th_early[VARNAME], data=Th_early, **comp_kwargs)
spag_late = get_spaghetti(idx=Th_late[VARNAME], data=Th_late, **comp_kwargs)

## get composites
comp_early = spag_early.mean("sample")
comp_late = spag_late.mean("sample")

## plots

### Mean

#### T, h

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6, 2.5), layout="constrained")

for ax, v in zip(axs, [T_VAR, H_VAR]):
    for comp, l in zip([comp_early, comp_late], ["early", "late"]):
        ax.plot(comp.lag, comp[v], label=l)

    ## formatting
    f_kwargs = dict(ls="--", c="k", lw=0.8)
    ax.axhline(0, **f_kwargs)
    ax.axvline(0, **f_kwargs)
    ax.set_xlabel("Lag (months)")
    ax.axvline(6, c="gray", lw=0.5)
    ax.set_xticks([-12, 0, 6, 12])

## formattings
axs[0].set_title(T_VAR)
axs[1].set_title(H_VAR)
axs[1].set_ylabel("m")
axs[0].set_ylabel("K")
axs[1].legend(prop=dict(size=8))
axs[1].yaxis.set_label_position("right")
axs[1].yaxis.tick_right()

plt.show()

#### $\frac{dh}{dt}$ terms

##### early v. late

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6, 2.5), layout="constrained")

for (
    ax,
    v,
    n,
) in zip(axs, ["F2", "epsilon"], [T_VAR, H_VAR]):
    for comp, params_, l in zip(
        [comp_early, comp_late],
        [params_early, params_late],
        ["early", "late"],
    ):

        ax.plot(comp.lag, -comp[n] * params_[v], label=l)

    ## formatting
    f_kwargs = dict(ls="--", c="k", lw=0.8)
    ax.axhline(0, **f_kwargs)
    ax.axvline(0, **f_kwargs)
    ax.set_xlabel("Lag (months)")
    ax.axvline(6, c="gray", lw=0.5)
    ax.set_xticks([-12, 0, 6, 12])

## formattings
axs[0].set_title(r"Recharge ($-F_2 \cdot T$)")
axs[1].set_title(r"Damping ($-\varepsilon \cdot h$)")
# axs[1].set_ylabel("")/)
axs[1].legend(prop=dict(size=8))
axs[1].yaxis.set_label_position("right")
axs[1].yaxis.tick_right()
src.utils.set_lims(axs)

plt.show()

##### recharge vs damping

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6, 2.5), layout="constrained")

for ax, comp, params_, c in zip(
    axs,
    [comp_early, comp_late],
    [params_early, params_late],
    sns.color_palette()[:2],
):

    for v, n, ls, l in zip(
        ["F2", "epsilon"], [T_VAR, H_VAR], ["-", "--"], ["discharge", "damping"]
    ):

        ax.plot(comp.lag, -comp[n] * params_[v], ls=ls, c=c, label=l)

    ## formatting
    f_kwargs = dict(ls="--", c="k", lw=0.8)
    ax.axhline(0, **f_kwargs)
    ax.axvline(0, **f_kwargs)
    ax.set_xlabel("Lag (months)")
    ax.axvline(6, c="gray", lw=0.5)
    ax.set_xticks([-12, 0, 6, 12])

## formattings
axs[0].set_title(r"Early")
axs[1].set_title(r"Late")
axs[1].yaxis.set_label_position("right")
axs[1].yaxis.tick_right()
axs[0].legend(prop=dict(size=8))
src.utils.set_lims(axs)

plt.show()

#### $\frac{dT}{dt}$ terms

##### Early vs. Late

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6, 2.5), layout="constrained")

for (
    ax,
    v,
    n,
) in zip(axs, ["R", "F1"], [T_VAR, H_VAR]):
    for comp, params_, l in zip(
        [comp_early, comp_late],
        [params_early, params_late],
        ["early", "late"],
    ):

        ax.plot(comp.lag, comp[n] * params_[v], label=l)

    ## formatting
    f_kwargs = dict(ls="--", c="k", lw=0.8)
    ax.axhline(0, **f_kwargs)
    ax.axvline(0, **f_kwargs)
    ax.set_xlabel("Lag (months)")
    ax.axvline(6, c="gray", lw=0.5)
    ax.set_xticks([-12, 0, 6, 12])

## formattings
axs[0].set_title(r"Bjerknes ($R \cdot T$)")
axs[1].set_title(r"Coupling ($F_1 \cdot h$)")
# axs[1].set_ylabel("")/)
axs[1].legend(prop=dict(size=8))
axs[1].yaxis.set_label_position("right")
axs[1].yaxis.tick_right()
src.utils.set_lims(axs)

plt.show()

##### growth vs. coupling

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6, 2.5), layout="constrained")

for ax, comp, params_, c in zip(
    axs,
    [comp_early, comp_late],
    [params_early, params_late],
    sns.color_palette()[:2],
):

    for v, n, ls, l in zip(
        ["R", "F1"], [T_VAR, H_VAR], ["-", "--"], ["Bjerknes", "coupling"]
    ):

        ax.plot(comp.lag, comp[n] * params_[v], ls=ls, c=c, label=l)

    ## formatting
    f_kwargs = dict(ls="--", c="k", lw=0.8)
    ax.axhline(0, **f_kwargs)
    ax.axvline(0, **f_kwargs)
    ax.set_xlabel("Lag (months)")
    ax.axvline(6, c="gray", lw=0.5)
    ax.set_xticks([-12, 0, 6, 12])

## formattings
axs[0].set_title(r"Early")
axs[1].set_title(r"Late")
axs[1].yaxis.set_label_position("right")
axs[1].yaxis.tick_right()
axs[0].legend(prop=dict(size=8))
src.utils.set_lims(axs)

plt.show()

### spaghetti

In [None]:
## specify variable to plot
PLOT_VAR = "T_4"
YLABEL = "K"

## plot style for spaghetti
spag_kwargs = dict(c="gray", lw=0.5, alpha=0.5)
comp_kwargs = dict(c="k", lw=2)

fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")

for ax, spag, ls in zip(axs, [spag_early, spag_late], ["-", "--"]):

    ## plot data
    ax.plot(comp.lag, spag[PLOT_VAR], label=l, **spag_kwargs)
    ax.plot(comp.lag, spag[PLOT_VAR].mean("sample"), ls=ls, **comp_kwargs)

    ## formatting
    f_kwargs = dict(ls="--", c="k", lw=0.8)
    ax.axhline(0, **f_kwargs)
    ax.axvline(0, **f_kwargs)
    ax.set_xticks([-12, 0, 12])
    ax.set_xlabel("Lag (months)")

## plot early composite in late
axs[1].plot(spag_early.lag, spag_early[PLOT_VAR].mean("sample"), ls="-", **comp_kwargs)

## formattings
axs[0].set_title(f"{PLOT_VAR} (Early)")
axs[1].set_title(f"{PLOT_VAR} (Late)")
axs[0].set_ylabel(YLABEL)
axs[0].set_ylabel("K")
axs[1].set_yticks([])
src.utils.set_lims(axs)

plt.show()

### Statistics

In [None]:
def get_pdf(x, edges, lag=12):

    ## compute pdf
    pdf, _ = src.utils.get_empirical_pdf(x.sel(lag=lag), edges=edges)

    ## get center points
    bin_centers = 0.5 * (edges[1:] + edges[:-1])

    ## return in DataArray format
    return xr.DataArray(pdf, coords=dict(centers=bin_centers))


def plot_pdf_on_ax(ax, name, pdf0, pdf1, edges):
    """plot PDFs for given name on given ax object"""

    ## plot data
    ax.stairs(pdf0[name], EDGES, lw=1.5, label="early")
    ax.stairs(pdf1[name], EDGES, fill=True, alpha=0.3, label="late")

    ## format
    ax.axvline(0, ls="-", c="k")
    ax.set_title(name)
    ax.set_yticks([])
    ax.set_xticks([-2, 0, 2])
    ax.set_ylim([0, 0.9])

    return

#### Compute PDFs

In [None]:
## specify lag
LAG = 12

## specify edges for pdf
EDGES = np.arange(-3.75, 4.25, 0.5)
# EDGES = np.arange(-4.75, 5.25, 0.5)

## compute
pdf_early_post = spag_early.apply(get_pdf, edges=EDGES, lag=LAG)
pdf_late_post = spag_late.apply(get_pdf, edges=EDGES, lag=LAG)

pdf_early_pre = spag_early.apply(get_pdf, edges=EDGES, lag=-LAG)
pdf_late_pre = spag_late.apply(get_pdf, edges=EDGES, lag=-LAG)

#### Plot

In [None]:
## specify variables to plot
# PLOT_VARS = ["h_w", "h", "h_e"]
PLOT_VARS = ["T_4", "T_34", "T_3"]

## set up plot
fig, axs = plt.subplots(2, 3, figsize=(8, 4), layout="constrained")

pdf_kwargs_pre = dict(pdf0=pdf_early_pre, pdf1=pdf_late_pre, edges=EDGES)
pdf_kwargs_post = dict(pdf0=pdf_early_post, pdf1=pdf_late_post, edges=EDGES)

## loop thru before/after
for j, pdf_kwargs in enumerate([pdf_kwargs_pre, pdf_kwargs_post]):

    ## loop thru variables
    for i, name in enumerate(PLOT_VARS):

        ## plot data
        plot_pdf_on_ax(axs[j, i], name=name, **pdf_kwargs)

## legend / formatting
axs[0, 0].legend(prop=dict(size=8))
for ax in axs[0, :]:
    ax.set_xticks([])
for ax in axs[1, :]:
    ax.set_title(None)


plt.show()