# ENSO diversity over time

## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import pathlib
import cmocean
import os
import cartopy.crs as ccrs

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_djf(data):
    """Get djf data"""

    ## subset for DJF
    data_djf = src.utils.sel_month(
        data.resample({"time": "QS-DEC"}).mean(),
        months=12,
    )
    return data_djf.isel(time=slice(1, -1))


def get_trim(data_proj, v="sst"):
    """Get DJF data trimmed to [-5,5] latitude"""

    ## trim to [-5,5]
    data_trim = src.utils.reconstruct_fn(
        scores=data_proj[v],
        components=data_proj[f"{v}_comp"],
        fn=lambda x: x.sel(latitude=slice(-15, 15)),
    )

    ## project back on EOFs
    data_proj_trim = (data_trim * data_proj[f"{v}_comp"]).sum(["longitude", "latitude"])

    return data_proj_trim


def compute_eofs(data, v="sst"):
    """compute eofs from data"""

    ## subset for djf
    data_djf = get_djf(data)

    ## get trimmed version
    data_djf_trim = get_trim(data_djf, v=v)

    ## stack data
    X = data_djf_trim.stack(sample=["member", "time"]).transpose("mode", "sample")

    ## Do SVD
    U, s, Vt = np.linalg.svd(X.values, full_matrices=False)

    ## put results in xr
    eofs = xr.Dataset(
        data_vars=dict(
            U=(("mode", "eof_mode"), U),
            V=(("sample", "eof_mode"), Vt.T),
            s=("eof_mode", s),
        ),
        coords=dict(
            mode=X.mode,
            sample=X.sample,
            eof_mode=np.arange(300),
        ),
    )

    ## get equally-weighted projection
    eofs["Vs"] = eofs["V"] * eofs["s"]

    ## get "full" spatial patterns
    U_full = (data_djf[f"{v}"].stack(sample=["member", "time"]) * eofs["V"]).sum(
        "sample"
    )
    eofs["patterns"] = src.utils.reconstruct_fn(
        scores=U_full,
        components=data[f"{v}_comp"],
        fn=lambda x: x,
    )

    ## normalize spatial patterns
    scale = eofs["patterns"].sel(latitude=slice(-5, 5)).std(["latitude", "longitude"])
    eofs["patterns"] = eofs["patterns"] / scale

    ## get explained variance
    eofs["exp_var"] = eofs["s"] ** 2 / (eofs["s"] ** 2).sum()

    return eofs

## Load data

In [None]:
## Load data
_, anom = src.utils.load_consolidated()

## subset for SST
# data = anom[["sst", "sst_comp"]]
data = anom[["pr", "pr_comp"]]

## open data
Th = src.utils.load_cesm_indices()

### Subset for early/late

In [None]:
## specify early/late
t_early = dict(time=slice("1850", "1881"))
t_late = dict(time=slice("2069", "2100"))

## subset for early/late
data_early = data.sel(t_early).compute()
data_late = data.sel(t_late).compute()

Th_early = Th.sel(t_early)
Th_late = Th.sel(t_late)

## Compute EOFs

In [None]:
eofs_early = compute_eofs(data_early, v="pr")
eofs_late = compute_eofs(data_late, v="pr")

## Plot

### Plot variance

In [None]:
## specify markers/labels
markers = ["o", "x"]
labels = ["early", "late"]

fig, ax = plt.subplots(figsize=(3, 2.5))

for eofs, m, l in zip([eofs_early, eofs_late], markers, labels):

    ## plot data
    ax.scatter(eofs.eof_mode, eofs["exp_var"], marker=m, label=l)
    # ax.plot(eofs.eof_mode, eofs["exp_var"], alpha=.5)

    ## format
    ax.set_xlim([-0.5, 5.5])

plt.show()

In [None]:
proj_early = (eofs_early["U"] * data_early["pr"]).sum("mode")
proj_late = (eofs_early["U"] * data_late["pr"]).sum("mode")

In [None]:
proj_early

### Plot PC space

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 1.5), layout="constrained")

for ax, eofs in zip(axs, [eofs_early, eofs_late]):
    ax.scatter(eofs["Vs"].isel(eof_mode=0), eofs["Vs"].isel(eof_mode=1), s=1)
    ax.set_aspect("equal")
    kwargs = dict(ls="--", lw=0.8, c="k")
    ax.axvline(0, **kwargs)
    ax.axhline(0, **kwargs)

src.utils.set_lims(axs)

plt.show()

### Compare to ($T,h$) space

In [None]:
## func to subset data
sel = get_djf
# sel = lambda x : src.utils.sel_month(x, months=12)

fig, axs = plt.subplots(1, 2, figsize=(5.5, 2.5), layout="constrained")

for ax, Th_ in zip(axs, [Th_early, Th_late]):
    ax.scatter(sel(Th_["T_3"]), sel(Th_["h_w"]), s=1)
    # ax.scatter(sel(Th_["T_3"]), sel(Th_["h"]-Th_["h_w"]), s=1)
    kwargs = dict(ls="--", lw=0.8, c="k")
    ax.axvline(0, **kwargs)
    ax.axhline(0, **kwargs)

src.utils.set_lims(axs)

plt.show()

### Plot spatial pattern

#### plot pair of modes

In [None]:
# PLOT_EOFS = eofs_early
PLOT_EOFS = eofs_late

fig = plt.figure(figsize=(6, 10 / 3), layout="constrained")
axs = src.utils.subplots_with_proj(
    fig, nrows=2, ncols=1, format_func=src.utils.plot_setup_pac
)

for j, i in enumerate([0, 1]):
    cp = axs[j, 0].contourf(
        PLOT_EOFS.longitude,
        PLOT_EOFS.latitude,
        PLOT_EOFS["patterns"].isel(eof_mode=i),
        transform=ccrs.PlateCarree(),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(3, 0.3),
        extend="both",
    )

cb = fig.colorbar(cp, ax=axs[:, 0])

#### corresponding modes for diff. EOFs

In [None]:
## which EOF mode to plot
MODE = 0

fig = plt.figure(figsize=(6, 10 / 3), layout="constrained")
axs = src.utils.subplots_with_proj(
    fig, nrows=2, ncols=1, format_func=src.utils.plot_setup_pac
)

for j, eofs in enumerate([eofs_early, eofs_late]):
    cp = axs[j, 0].contourf(
        eofs.longitude,
        eofs.latitude,
        eofs["patterns"].isel(eof_mode=MODE),
        transform=ccrs.PlateCarree(),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(3, 0.3),
        extend="both",
    )

cb = fig.colorbar(cp, ax=axs[:, 0])