# Scratch
Code scraps which might be useful later...

## Imports

In [None]:
import importlib.util
import subprocess
import sys
import warnings
import datetime
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr
from climpred import HindcastEnsemble
from dateutil.relativedelta import *
from matplotlib.ticker import AutoMinorLocator
import warnings
import tqdm
import pathlib
import cmocean

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
import matplotlib as mpl

mpl.rcParams["figure.dpi"] = 100

# Import necessary modules after ensuring files are available
from src.XRO import XRO, xcorr
from src.XRO_utils import (
    SkewAccessor,
    plot_above_below_shading,
    plot_fill_between,
    pmtm,
)

print("All required libraries are installed and imported successfully!")

## Observed ENSO properties related to RO

Two indices are used to describe the oscillatory behaviour of ENSO. 
 - **$T$**: Sea surface temperature (SST) anomalies averaged over the Niño3.4 region 170°–120° W, 5° S–5° N
 - **$h$** Thermocline depth anomalies averaged over the equatorial Pacific 120° E–80° W, 5° S–5° N, that is, the WWV index (with a constant factor of the area it covers).


In [None]:
# load observed state vectors of XRO: which include ENSO, WWV, and other modes SST indices
# the order of variables is important, with first two must be ENSO SST and WWV;
obs_file = "../data/XRO_indices_oras5.nc"

obs_ds = xr.open_dataset(obs_file).sel(time=slice("1979-01", "2024-12"))
obs_RO_ds = obs_ds[["Nino34", "WWV"]]
print(obs_RO_ds)

### 1 ENSO time series

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))

plot_above_below_shading(
    obs_RO_ds["Nino34"],
    xtime=obs_RO_ds.time,
    c="black",
    lw=0.5,
    above=0.5,
    above_c="orangered",
    below=-0.5,
    below_c="deepskyblue",
    ax=ax,
)
# obs_RO_ds['Nino34'].plot(ax=ax, c='black', )
ax.set_ylim([-3, 3])
ax.set_ylabel("Nino34 SSTA [degC]")

axR = ax.twinx()
obs_RO_ds["WWV"].plot(ax=axR, c="blue", lw=1.5)
axR.set_ylim([-30, 30])
axR.axhline(0, ls="--", c="gray", lw=0.5)

ax.set_xlim([datetime.datetime(1979, 1, 1), datetime.datetime(2026, 1, 1)])
ax.set_title("ORAS5 Nino3.4 and WWV indices")

## Look at parameter variation over time

#### Fit model to 20-yr rolling windows

In [None]:
## Get number of timesteps in data
n = len(obs_RO_ds.time)

## specify number of samples to use in each window
window_size = 240  # units: months

## empty list to hold results and dates
params_by_year = []
start_dates = []

## loop through rolling windows
with warnings.catch_warnings(action="ignore"):
    for i in np.arange(0, n - 12, 12):

        ## make sure there's enough samples for robust estimate
        if (n - i) > (0.7 * window_size):

            ## get subset of data for fitting model
            data_subset = obs_RO_ds.isel(time=slice(i, i + window_size))

            ## Get start date for subset
            start_dates.append(data_subset.time.isel(time=0))

            ## initialize model
            model = XRO(ncycle=12, ac_order=1, is_forward=True)

            ## fit matrix and get parameters
            fit_LRO = model.fit_matrix(data_subset, maskNT=[], maskNH=[])
            params_by_year.append(model.get_RO_parameters(fit_LRO))

## convert from list to xarray
start_dates = xr.concat(start_dates, dim="time")
params_by_year = xr.concat(params_by_year, dim=start_dates)

#### Plot change over time in annual-mean $R$ and seasonal cycle

In [None]:
## plot annual mean growth rate
fig, axs = plt.subplots(1, 2, figsize=(9, 3.5))
axs[0].plot(params_by_year.time, params_by_year["BJ_ac"].mean("cycle"))
axs[0].set_ylabel(r"$R$")

## plot mean annual cycle
axs[1].plot(params_by_year["BJ_ac"].mean("time"))
axs[1].axhline(0, c="k", ls="--")

plt.show()

#### Partial corr vs linear regression

In [None]:
import numpy as np
import scipy.stats
import copy
corr = lambda x, y : scipy.stats.pearsonr(x,y)[0]
rng = np.random.default_rng()

Consider finding $R$ and $F$ to approximate:
\begin{align}
    y &= RT + Fh
\end{align}
Standardize data such that:
\begin{align}
    1 &= \left<y,y\right> = \left<T,T\right> =\left<h,h\right>=\\
    r &\equiv \left<T,h\right>
\end{align}
Define $\hat{h}$ such that:
\begin{align}
    0 &= \left<\hat{h}, T\right>\\
    \implies \hat{h} &= h-rT
\end{align}
Then we have:
\begin{align}
    R &= \left<y,T\right> - \frac{r}{1-r^2}\left<y,\hat{h}\right>\\
    F &= \frac{1}{1-r^2}\left<y,\hat{h}\right>
\end{align}

Next, consider partial correlation of $h$ with $y$. Define component of $y$ independent of $T$ as:
\begin{align}
    \hat{y}\equiv y-\left<y,T\right>T
\end{align}
Then partial correlation for $h$ is defined as:
\begin{align}
    r_{p,h} &= \frac{\left<\hat{y}, \hat{h}\right>}{\sigma_{\hat{y}}\sigma_{\hat{h}}}
\end{align}

Note for the **numerator**, we have:
\begin{align}
    \left<\hat{y},\hat{h}\right> &= \left<y-\left<y,T\right>T, \hat{h}\right>\\
    &= \left<y,\hat{h}\right> - \left<y,T\right>\left<T,\hat{h}\right>\\
    &= \left<y,\hat{h}\right>,
\end{align}
because $\left<T,\hat{h}\right>=0$ by definition. For the **denominator**, we have:
\begin{align}
    \sigma_{\hat{y}}^2 &= \sigma_y^2 - 2\left<y,T\right>\left<y,T\right> + \left<y,T\right>^2\sigma_T^2\\
    &= 1-\left<y,T\right>^2\\
    \sigma_{\hat{h}}^2 &= \sigma_h^2 - 2r\left<T,h\right> + r^2\sigma_T^2\\
    &= 1 - r^2
\end{align}

We can write this as:
\begin{align}
    r_{p,h} &= \frac{\left<y,\hat{h}\right>}{\sqrt{\left(1-\left<y,T\right>^2\right)\left(1-r^2\right)}}\\
    &= F\left[ \frac{1-r^2}{1-\left<y,T\right>^2} \right]^{1/2}
\end{align}

Similarly, we have:
\begin{align}
    r_{p,T} &= R\left[\frac{1-r^2}{1-\left<y,h\right>^2}\right]^{1/2}
\end{align}

Similarly, we have:
\begin{align}
    r_{p,T} &= R\left[\frac{1-r^2}{1-\left<y,h\right>^2}\right]^{1/2}
\end{align}

#### summarizing:

\begin{align}
    R &= \left<y,T\right> - \frac{r}{1-r^2}\left<y,\hat{h}\right>\\
    F &= \frac{1}{1-r^2}\left<y,\hat{h}\right>\\
    r_{p,h} &= F\left[ \frac{1-r^2}{1-\left<y,T\right>^2} \right]^{1/2}\\
    r_{p,T} &= R\left[\frac{1-r^2}{1-\left<y,h\right>^2}\right]^{1/2}
\end{align}

In [None]:
## random data
n = 100
y = rng.normal(size=(n))
X = rng.normal(size=(2,n))

## standardize
norm = lambda Z : (Z-Z.mean(-1,keepdims=True)) / Z.std(-1, keepdims=True)
y = norm(y)
X = norm(X)

## true solution
RF = y[None,:] @ np.linalg.pinv(X)

## split into T and h (for convenience)
T = copy.deepcopy(X[0])
h = copy.deepcopy(X[1])

# ## "analytic" solution
cov = lambda x0, x1 : np.mean(x0*x1)
r = cov(T,h)
h_hat = h-r*T
R = cov(y,T) - r/(1-r**2) * cov(y, h_hat)
F = 1/(1-r**2) * cov(y, h_hat)

## compute partial correlations
y_hat = y - cov(y,T)*T
rph = corr(y_hat, h_hat)
rph2 = F * np.sqrt((1-r**2) / (1-cov(y,T)**2))

y_tilde = y - cov(y,h)*h
T_tilde = T - r*h
rpT = corr(y_tilde, T_tilde)
rpT2 = R * np.sqrt((1-r**2) / (1-cov(y,h)**2))

## check stuff worked
print(np.allclose(RF.squeeze(), np.array([R,F])))
print(np.allclose(rph, rph2))
print(np.allclose(rpT, rpT2))
print(np.allclose(cov(y, T), R + r*F))
print(np.allclose(cov(y, h), F + r*R))