# Multiyear

## Imports

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import cartopy.util
import copy

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
## helper function for padding array
pad = lambda x, n: np.pad(x, pad_width=(0, n), constant_values=False)


def get_multiyr_idx(bool_arr, n):
    ## get indices of "true"
    idx = np.where(bool_arr)[-1]

    return np.array([[i + j for j in range(n)] for i in idx]).flatten()


def get_multiyr_mask(mask_raw, n):
    ## get indices of event starts
    start_idx = np.where(mask_raw)[-1]

    ## get indices included in events
    idx = np.array([[i + j for j in range(n)] for i in start_idx]).flatten()

    ## fill out new mask
    mask = np.zeros_like(mask_raw)

    if len(idx) > 0:
        mask[idx] = True

    return mask


def get_events_per_year(mask, n):
    """Convert from boolean 'is an event happening?' to float
    number of events per year"""
    return mask.astype(float) * 1 / n


def mask_multiyear_events(x):
    """count number of multiyear events up to 'n' years.
    To-do: implement the 'n' thing (up to 'n'-year events)."""

    ## make copy to avoid mutation
    single_mask = copy.deepcopy(x)

    ## get mask of double and triple events
    double_mask = single_mask[:-1] & single_mask[1:]
    triple_mask = double_mask[:-1] & single_mask[2:]
    # quadru_mask = triple_mask[:-1] & single_mask[3:]

    ## pad them so arrays are same length
    double_mask = pad(double_mask, 1)
    triple_mask = pad(triple_mask, 2)

    ## Get indices of events
    double_idx = get_multiyr_idx(double_mask, n=2)
    triple_idx = get_multiyr_idx(triple_mask, n=3)

    ## Get updated masks
    double_mask = get_multiyr_mask(double_mask, n=2)
    triple_mask = get_multiyr_mask(triple_mask, n=3)

    # ## remove duplicates
    double_mask[triple_mask] = False
    single_mask[triple_mask | double_mask] = False

    ## to-do: convert to events per year
    single_mask = get_events_per_year(single_mask, 1)
    double_mask = get_events_per_year(double_mask, 2)
    triple_mask = get_events_per_year(triple_mask, 3)

    ## Get total mask
    mask = np.stack([single_mask, double_mask, triple_mask], axis=0)

    ## filter out events at end of record (to avoid biasing towards single-year)
    # mask[:, -2:] = False

    return mask


def mask_multiyear_events_geng(n34, month_idx, thresh, is_warm):
    """mask wintertime multiyear La Niñas, following Geng et al. (2023)"""

    ## get rolling Niño 3.4 index
    n34_rolling = np.convolve(n34, 1 / 5 * np.ones(5), mode="same")

    ## get winter n34 (centered on december)
    is_winter = month_idx == 11
    n34_winter = n34_rolling[is_winter]

    ## set threshold for ENSO event if not specified
    if thresh is None:
        thresh = n34_winter[:80].std() * 0.5

    ## find El NIño / La Niñas (following Geng, base on ONDJF Niño std dev)
    if is_warm:
        is_event = n34_winter > thresh

    else:
        is_event = n34_winter < -thresh

    ## find multiyear La Niñas
    return mask_multiyear_events(is_event)


def count_multiyear_laninas(n34, month_idx, thresh=None, is_warm=False):
    """count frequency of multiyear La Niñas, following Geng et al. (2023)"""

    ## find multiyear La Niñas
    kwargs = dict(n34=n34, month_idx=month_idx, thresh=thresh, is_warm=is_warm)
    peak_mask_winter = mask_multiyear_events_geng(**kwargs)

    ## function to count La Niñas in single boolean sequence
    window = 41  # units: years
    conv_filter = 100 / window * np.ones(window)
    count_single = lambda x: np.convolve(conv_filter, x, mode="valid")

    return np.apply_along_axis(func1d=count_single, axis=1, arr=peak_mask_winter)


def count_multiyear_laninas_ensemble(T, is_warm=False, thresh=0.57):
    """apply to each ensemble member. To-do: compute threshold using all ensemble memberes"""

    ## array to hold counts
    counts = []

    ## get month index
    month_idx = T.time.dt.month.values - 1

    ## iterate over ensemble axis
    for m in T.member:
        ## get single ensemble member
        kwargs = dict(
            n34=T.sel(member=m).values,
            month_idx=month_idx,
            thresh=thresh,
            is_warm=is_warm,
        )

        ## count in given ensemble member
        counts.append(count_multiyear_laninas(**kwargs))

    return np.stack(counts, axis=2)

## Load data

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

In [None]:
use_fixed_thresh = True

if use_fixed_thresh:
    thresh = 0.57

else:
    t_ = np.arange(1850, 2101)
    thresh = 0.62 + 0.0018 * (t_ - t_.mean())

## count multiyear events
counts = count_multiyear_laninas_ensemble(T=Th["T_34"], is_warm=False, thresh=thresh)

## time axis for plot
yrs = np.arange(1870, 2081)

## set up the plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot each curve
## don't plot last 2 years (because we can't count triple year events here)
for i, label in zip(np.arange(counts.shape[0]), ["1 yr", "2 yr", "3 yr"]):
    ax.plot(yrs[:-2], counts[i].mean(-1)[:-2], label=label)

## label
ax.legend(prop={"size": 8})
ax.set_ylabel("# per century")
ax.set_ylim([0, 25])
ax.set_yticks([0, 10, 20])

plt.show()