# ENSO diversity over time

## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

## Load data

In [None]:
## Load data
_, anom = src.utils.load_consolidated()

## subset for SST
data = anom[["sst", "sst_comp"]]

## subset for early
t_early = dict(time=slice("1850", "1881"))
data_early = data.sel(t_early).compute()

## subset for DJF
data_djf = src.utils.sel_month(
    data_early["sst"].resample({"time": "QS-DEC"}).mean(),
    months=12,
)
data_djf = data_djf.isel(time=slice(1, -1))

In [None]:
# src.utils.reconstruct_fn(
#     scores=data_djf, components=data["sst_comp"], fn=lambda x : x,
# )

## Compute EOFs

In [None]:
X = data_djf.stack(sample=["member", "time"])
U, s, Vt = np.linalg.svd(X.transpose("mode", "sample").values, full_matrices=False)

## put results in xr
eofs = xr.Dataset(
    data_vars=dict(
        U=(("mode", "eof_mode"), U),
        V=(("sample", "eof_mode"), Vt.T),
        s=("eof_mode", s),
    ),
    coords=dict(
        mode=X.mode,
        sample=X.sample,
        eof_mode=np.arange(300),
    ),
)

## get spatial patterns
eofs["patterns"] = src.utils.reconstruct_fn(
    scores=eofs["U"],
    components=data["sst_comp"],
    fn=lambda x: x,
)

## normalize spatial patterns
scale = eofs["patterns"].std(["latitude", "longitude"])
eofs["patterns"] = eofs["patterns"] / scale

### Plot PC space

In [None]:
fig, ax = plt.subplots(figsize=(3, 3))

ax.scatter(eofs["V"].isel(eof_mode=0), eofs["V"].isel(eof_mode=1), s=1)
plt.show()

### Plot spatial pattern

In [None]:
import cartopy.crs as ccrs

fig = plt.figure(figsize=(6, 10 / 3), layout="constrained")
axs = src.utils.subplots_with_proj(
    fig, nrows=2, ncols=1, format_func=src.utils.plot_setup_pac
)

for j, i in enumerate([0, 2]):
    cp = axs[j, 0].contourf(
        eofs.longitude,
        eofs.latitude,
        eofs["patterns"].isel(eof_mode=i),
        transform=ccrs.PlateCarree(),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(4, 0.4),
        extend="both",
    )

cb = fig.colorbar(cp)