# Multiyear

## Imports

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import cartopy.util
import copy

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## random number generator
RNG = np.random.default_rng()

## Funcs

In [None]:
## helper function for padding array
pad = lambda x, n: np.pad(x, pad_width=(0, n), constant_values=False)


def get_multiyr_idx(bool_arr, n):
    ## get indices of "true"
    idx = np.where(bool_arr)[-1]

    return np.array([[i + j for j in range(n)] for i in idx]).flatten()


def get_multiyr_mask(mask_raw, n):
    ## get indices of event starts
    start_idx = np.where(mask_raw)[-1]

    ## get indices included in events
    idx = np.array([[i + j for j in range(n)] for i in start_idx]).flatten()

    ## fill out new mask
    mask = np.zeros_like(mask_raw)

    if len(idx) > 0:
        mask[idx] = True

    return mask


def get_events_per_year(mask, n):
    """Convert from boolean 'is an event happening?' to float
    number of events per year"""
    return mask.astype(float) * 1 / n


def mask_multiyear_events(x):
    """count number of multiyear events up to 'n' years.
    To-do: implement the 'n' thing (up to 'n'-year events)."""

    ## make copy to avoid mutation
    single_mask = copy.deepcopy(x)

    ## get mask of double and triple events
    double_mask = single_mask[:-1] & single_mask[1:]
    triple_mask = double_mask[:-1] & single_mask[2:]
    # quadru_mask = triple_mask[:-1] & single_mask[3:]

    ## pad them so arrays are same length
    double_mask = pad(double_mask, 1)
    triple_mask = pad(triple_mask, 2)

    ## Get indices of events
    double_idx = get_multiyr_idx(double_mask, n=2)
    triple_idx = get_multiyr_idx(triple_mask, n=3)

    ## Get updated masks
    double_mask = get_multiyr_mask(double_mask, n=2)
    triple_mask = get_multiyr_mask(triple_mask, n=3)

    # ## remove duplicates
    double_mask[triple_mask] = False
    single_mask[triple_mask | double_mask] = False

    ## to-do: convert to events per year
    single_mask = get_events_per_year(single_mask, 1)
    double_mask = get_events_per_year(double_mask, 2)
    triple_mask = get_events_per_year(triple_mask, 3)

    ## Get total mask
    mask = np.stack([single_mask, double_mask, triple_mask], axis=0)

    ## filter out events at end of record (to avoid biasing towards single-year)
    # mask[:, -2:] = False

    return mask


def mask_multiyear_events_geng(n34, month_idx, thresh, is_warm):
    """mask wintertime multiyear La Niñas, following Geng et al. (2023)"""

    ## get rolling Niño 3.4 index
    n34_rolling = np.convolve(n34, 1 / 5 * np.ones(5), mode="same")

    ## get winter n34 (centered on december)
    is_winter = month_idx == 11
    n34_winter = n34_rolling[is_winter]

    ## set threshold for ENSO event if not specified
    if thresh is None:
        thresh = n34_winter[:80].std() * 0.5

    ## find El NIño / La Niñas (following Geng, base on ONDJF Niño std dev)
    if is_warm:
        is_event = n34_winter > thresh

    else:
        is_event = n34_winter < -thresh

    ## find multiyear La Niñas
    return mask_multiyear_events(is_event)


def count_multiyear_laninas(n34, month_idx, thresh=None, is_warm=False, window=41):
    """count frequency of multiyear La Niñas, following Geng et al. (2023)"""

    ## find multiyear La Niñas
    kwargs = dict(n34=n34, month_idx=month_idx, thresh=thresh, is_warm=is_warm)
    peak_mask_winter = mask_multiyear_events_geng(**kwargs)

    ## function to count La Niñas in single boolean sequence
    conv_filter = 100 / window * np.ones(window)
    count_single = lambda x: np.convolve(conv_filter, x, mode="valid")

    return np.apply_along_axis(func1d=count_single, axis=1, arr=peak_mask_winter)


def count_multiyear_laninas_ensemble(T, is_warm=False, thresh=0.57, window=41):
    """apply to each ensemble member. To-do: compute threshold using all ensemble memberes"""

    ## array to hold counts
    counts = []

    ## get month index
    month_idx = T.time.dt.month.values - 1

    ## iterate over ensemble axis
    for m in T.member:
        ## get single ensemble member
        kwargs = dict(
            n34=T.sel(member=m).values,
            month_idx=month_idx,
            thresh=thresh,
            is_warm=is_warm,
            window=window,
        )

        ## count in given ensemble member
        counts.append(count_multiyear_laninas(**kwargs))

    return np.stack(counts, axis=2)


def get_rolling_std(data, n=20):
    """
    Get standard deviation, computing over time and ensemble member. To increase
    sample size for variance estimate, compute over time window of 2n+1
    years, centered at given year.
    """

    ## do the computation
    kwargs = dict(fn=np.std, n=n, reduce_ensemble_dim=True)
    data_std = src.utils.get_rolling_fn_bymonth(data, **kwargs)

    ## unstack year and month
    data_std = src.utils.unstack_month_and_year(data_std)

    return data_std


def postprocess_counts(counts):
    """Compute some stats on multiyear counts"""

    ## get number of years in La Niña state
    nyears = counts["n"] * counts

    ## get fractional data
    frac_count = counts.mean("member").sel(n=2) / counts.mean("member").sum("n")
    frac_nyear = nyears.mean("member").sel(n=2) / nyears.mean("member").sum("n")
    frac_event = nyears.sum("n").mean("member") / 100

    return xr.merge(
        [
            counts.rename("counts"),
            nyears.rename("nyears"),
            frac_count.rename("frac_count"),
            frac_nyear.rename("frac_nyear"),
            frac_event.rename("frac_event"),
        ]
    )


def generate_simulations(
    model,
    params,
    X0_ds,
    save_fp,
    nyear=3000,
):
    """generate simulations for given parameter set, and save to file"""

    ## try opening pre-computed
    if save_fp.is_file():
        time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
        sims = xr.open_dataset(save_fp, decode_times=time_coder)

    ## otherwise, run simulations
    else:

        ## empty list to hold simulations
        sims = []

        ## loop thru years in parameter set
        for y in tqdm.tqdm(params.year):

            ## get initial condition
            X0 = X0_ds.sel(member=RNG.choice(X0_ds.member), time=RNG.choice(X0_ds.time))

            ## simulate
            sim = model.simulate(
                fit_ds=params.sel(year=y), nyear=nyear, ncopy=100, X0_ds=X0
            )

            ## append
            sims.append(sim)

        ## save to xarray
        sims = xr.concat(sims, dim=params.year)

        ## save
        sims.to_netcdf(save_fp)

    return sims


def count_RO_multi_over_time(
    sims,
    thresh=0.73,
    is_warm=True,
    varname="T_34",
    nyear=3000,
):
    """Compute stats over time"""

    ## empty list to hold result
    counts = []

    ## loop through years
    for y in tqdm.tqdm(sims.year):

        ## get simulation for given year
        sim = sims.sel(year=y)

        ## count multiyear events
        counts_for_period = count_multiyear_laninas_ensemble(
            T=sim[varname], is_warm=is_warm, thresh=thresh, window=nyear
        )

        ## put in xarray
        counts_for_period = xr.DataArray(
            data=counts_for_period,
            coords=dict(n=[1, 2, 3], year=[y], member=sim.member.values),
            dims=["n", "year", "member"],
        )

        counts.append(counts_for_period)

    ## put back in xarray
    counts = xr.concat(counts, dim=sims.year)

    return counts

## Load data

In [None]:
## open data
Th = src.utils.load_cesm_indices(load_z20=True)

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load scale factor
hbar_scale = xr.open_dataarray(
    pathlib.Path(SAVE_FP, "cesm_Hbar_scale_v2.nc"),
)

## compute
Th["h_w_Hbar-scaled"] = Th["h_w"] * hbar_scale

## Analysis

### CESM diagnostics

In [None]:
## specify varname
VARNAME = "T_34"

## warm or cold
IS_WARM = False

## specify whether to use fixed cutoff to define event
use_fixed_thresh = True

if use_fixed_thresh:
    ## use 1/2 std dev of early period
    Th_early = Th[VARNAME].sel(time=slice("1850", "1910"))
    sigma_early = Th_early.groupby("time.month").std(["time", "member"])
    THRESH = 0.5 * sigma_early.sel(month=12).values.item()
else:
    ## get rolling standard dev
    half_sigma_rolling = 0.5 * get_rolling_std(Th[VARNAME]).sel(month=12).values

    ## get threshold array
    thresh = np.nan * np.zeros(2101 - 1850)
    thresh[20:-20] = half_sigma_rolling
    thresh[:20] = half_sigma_rolling[0]
    thresh[-20:] = half_sigma_rolling[-1]

## count multiyear events
counts = count_multiyear_laninas_ensemble(T=Th[VARNAME], is_warm=IS_WARM, thresh=THRESH)

## put in xarray
counts = xr.DataArray(
    data=counts,
    coords=dict(n=[1, 2, 3], year=np.arange(1870, 2081), member=Th.member.values),
    dims=["n", "year", "member"],
)

## Compute stats
stats_cesm = postprocess_counts(counts)

### RO experiment 

In [None]:
## save filpath
sims_fname = "T3_hw_Hbar_scaled.nc"

## set variables
fp = pathlib.Path(os.environ["SAVE_FP"], "fits_cesm/T3_h_w_Hbar_scaled_ac-order3")
X0_ds = Th[["T_3", "h_w_Hbar-scaled"]].isel(year=0)
sim_save_fp = pathlib.Path(DATA_FP, "RO_stoc", sims_fname)

## load parameters (and perturb them)
params = xr.open_dataset(fp)
pparams_bj_noise = src.utils.get_perturbed_multi(
    params=params,
    idxs=[(0, 0), (1, 1)],
    fix_others=False,
    fix_noise=False,
)
pparams_wyrtki = src.utils.get_perturbed_multi(
    params=params,
    idxs=[(0, 1)],
    fix_others=False,
    fix_noise=False,
)

## specify model
model = src.XRO.XRO(ncycle=12, ac_order=3, is_forward=True)

## generate simulations
sims = generate_simulations(
    model=model, params=params, X0_ds=X0_ds, save_fp=sim_save_fp
)
sims_bj = generate_simulations(
    model=model,
    params=pparams_bj_noise,
    X0_ds=X0_ds,
    save_fp=sim_save_fp.parent / f"{sims_fname[:-3]}_bj.nc",
)
sims_wy = generate_simulations(
    model=model,
    params=pparams_wyrtki,
    X0_ds=X0_ds,
    save_fp=sim_save_fp.parent / f"{sims_fname[:-3]}_wy.nc",
)

Count events

In [None]:
## compute stats
counts_RO = count_RO_multi_over_time(
    sims=sims_wy, is_warm=IS_WARM, thresh=THRESH, varname=list(X0_ds)[0]
)
stats_RO = postprocess_counts(counts_RO)

In [None]:
## compute stats
counts_RO = []
stats_RO = []
for sims_ in [sims, sims_bj, sims_wy]:
    counts_ = count_RO_multi_over_time(
        sims=sims_, is_warm=IS_WARM, thresh=THRESH, varname=list(X0_ds)[0]
    )
    counts_RO.append(counts_)
    stats_RO.append(postprocess_counts(counts_))

### Plot results

In [None]:
## get control experiment
counts_RO_control = counts_RO[0]
stats_RO_control = stats_RO[0]

## colors for plot
colors = sns.color_palette("colorblind")

## set up the plot
fig, ax = plt.subplots(figsize=(4, 3))

## plot each curve
for c, n, label in zip(colors, counts_RO_control.n, ["1 yr", "2 yr", "3 yr"]):
    for j, (counts_, ls) in enumerate(zip([counts, counts_RO_control], ["-", "--"])):

        ## plot args
        plot_kwargs = dict(ls=ls, c=c, label=(label if j == 0 else None))

        ## plot data
        ax.plot(counts_.year, counts_.sel(n=n).mean("member"), **plot_kwargs)

## label
ax.legend(prop={"size": 8})
ax.set_ylabel("# per century")
ax.set_ylim([-2, 25])
ax.set_yticks([0, 10, 20])
ax.axvline(2010, ls="--", c="k", lw=0.8)
ax.set_xticks([1880, 2010, 2080])

plt.show()

In [None]:
## get labels
labels = ["La Niña winters / total winters", "2-year La Niña / total La Niñas"]
varnames = ["frac_event", "frac_count"]

## set up the plot
fig, ax = plt.subplots(figsize=(4, 3))

for c, n, label_ in zip(["k", colors[1]], varnames, labels):
    for j, (stats, ls) in enumerate(zip([stats_cesm, stats_RO_control], ["-", "--"])):

        ## plot args
        plot_kwargs = dict(ls=ls, c=c, label=(label_ if j == 0 else None))

        ## plot data
        ax.plot(stats.year, stats[n], **plot_kwargs)

## label
ax.legend()
ax.set_ylabel("fraction")
ax.set_ylim([0, None])
ax.set_yticks([0, 0.2, 0.4])
ax.axvline(2010, ls="--", c="k", lw=0.8)
ax.set_xticks([1880, 2010, 2080])

plt.show()

### Compare experiments

In [None]:
## get labels
labels = ["La Niña winters / total winters", "2-year La Niña / total La Niñas"]
varnames = ["frac_event", "frac_count"]

## set up the plot
fig, ax = plt.subplots(figsize=(4, 3))

for stats, c, label in zip(
    stats_RO, ["k", *colors[1:3]], ["Control", "fix Bjerknes", "fix Wyrtki"]
):

    ## plot args
    # plot_kwargs = dict(ls=ls, c=c)

    ## plot data
    ax.plot(stats.year, stats["frac_count"], c=c, lw=2, label=label)

## label
ax.legend()
ax.set_ylabel("fraction")
ax.set_ylim([0, None])
ax.set_yticks([0, 0.2, 0.4])
ax.axvline(2010, ls="--", c="k", lw=0.8)
ax.set_xticks([1880, 2010, 2080])
ax.set_title("2-year La Niña / total La Niñas (RO)")

plt.show()