# ENSO diversity over time

## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import pathlib
import cmocean
import os
import cartopy.crs as ccrs

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_djf(data):
    """Get djf data"""

    ## subset for DJF
    data_djf = src.utils.sel_month(
        data.resample({"time": "QS-DEC"}).mean(),
        months=12,
    )
    return data_djf.isel(time=slice(1, -1))


def get_trim(data_proj):
    """Get DJF data trimmed to [-5,5] latitude"""

    ## trim to [-5,5]
    data_trim = src.utils.reconstruct_fn(
        scores=data_proj["sst"],
        components=data_proj["sst_comp"],
        fn=lambda x: x.sel(latitude=slice(-5, 5)),
    )

    ## project back on EOFs
    data_proj_trim = (data_trim * data_proj["sst_comp"]).sum(["longitude", "latitude"])

    return data_proj_trim


def compute_eofs(data):
    """compute eofs from data"""

    ## subset for djf
    data_djf = get_djf(data)

    ## get trimmed version
    data_djf_trim = get_trim(data_djf)

    ## stack data
    X = data_djf_trim.stack(sample=["member", "time"]).transpose("mode", "sample")

    ## Do SVD
    U, s, Vt = np.linalg.svd(X.values, full_matrices=False)

    ## put results in xr
    eofs = xr.Dataset(
        data_vars=dict(
            U=(("mode", "eof_mode"), U),
            V=(("sample", "eof_mode"), Vt.T),
            s=("eof_mode", s),
        ),
        coords=dict(
            mode=X.mode,
            sample=X.sample,
            eof_mode=np.arange(300),
        ),
    )

    ## get equally-weighted projection
    eofs["Vs"] = eofs["V"] * eofs["s"]

    print(data_djf)
    print(eofs["U"])

    ## get "full" spatial patterns
    U_full = (data_djf["sst"].stack(sample=["member", "time"]) * eofs["V"]).sum(
        "sample"
    )
    eofs["patterns"] = src.utils.reconstruct_fn(
        scores=U_full,
        components=data["sst_comp"],
        fn=lambda x: x,
    )

    ## normalize spatial patterns
    scale = eofs["patterns"].sel(latitude=slice(-5, 5)).std(["latitude", "longitude"])
    eofs["patterns"] = eofs["patterns"] / scale

    ## get explained variance
    eofs["exp_var"] = eofs["s"] ** 2 / (eofs["s"] ** 2).sum()

    return eofs

## Load data

In [None]:
## Load data
_, anom = src.utils.load_consolidated()

## subset for SST
data = anom[["sst", "sst_comp"]]

## specify early/late
t_early = dict(time=slice("1850", "1881"))
t_late = dict(time=slice("2069", "2100"))

## subset for early/late
data_early = data.sel(t_late).compute()

## Compute EOFs

In [None]:
eofs_early = compute_eofs(data_early)

## Plot

### Plot PC space

In [None]:
fig, ax = plt.subplots(figsize=(3, 3))

ax.scatter((eofs_early["V"]).isel(eof_mode=0), -(eofs_early["V"]).isel(eof_mode=1), s=1)
plt.show()

### Plot spatial pattern

In [None]:
fig = plt.figure(figsize=(6, 10 / 3), layout="constrained")
axs = src.utils.subplots_with_proj(
    fig, nrows=2, ncols=1, format_func=src.utils.plot_setup_pac
)

for j, i in enumerate([0, 1]):
    cp = axs[j, 0].contourf(
        eofs_early.longitude,
        eofs_early.latitude,
        eofs_early["patterns"].isel(eof_mode=i),
        transform=ccrs.PlateCarree(),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(3, 0.3),
        extend="both",
    )

cb = fig.colorbar(cp)