# Diagnose changes-over time in $\tau_x$ predictability

## Imports

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import cartopy.util
import copy

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def merimean(x):
    return x.sel(longitude=slice(140, 285), latitude=slice(-5, 5)).mean("latitude")


def plot_cycle_hov(ax, data, amp, is_filled=True, xticks=[190, 240]):
    """plot data on ax object"""

    ## specify shared kwargs
    shared_kwargs = dict(levels=src.utils.make_cb_range(amp, amp / 5), extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap="cmo.balance")

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## average over latitudes (if necessary)
    if "latitude" in data.coords:
        plot_data = merimean(data)
    else:
        plot_data = data

    ## do the plotting
    cp = plot_fn(
        plot_data.longitude,
        plot_data.month,
        plot_data,
        **kwargs,
        **shared_kwargs,
    )

    ## format ax object
    kwargs = dict(c="w", ls="--", lw=1)
    ax.set_xlim([145, 280])
    ax.set_xlabel("Lon")
    ax.set_xticks(xticks)
    for tick in xticks:
        ax.axvline(tick, **kwargs)

    return cp

## Load data

In [None]:
## T, h
Th = src.utils.load_cesm_indices(load_z20=True)

## spatial
_, anom = src.utils.load_consolidated()
anom = anom[["taux", "taux_comp"]]

## combine data
anom = xr.merge([anom, Th]).sel(time=slice("1851", None))

## get "windowed" data
anom = src.utils.get_windowed(anom, stride=120).compute()

## Analysis

### Compute skill

#### Funcs

In [None]:
def regress_helper(data, x_vars=["nino34"]):
    """fit linear regression model to data"""

    ## get coefficients
    coefs = src.utils.regress_xr_proj(data, x_vars=x_vars, y_vars=["taux"])

    ## rename variables for mappting
    coefs = coefs["taux"].to_dataset("j")

    return coefs


def get_recon_skill(data, x_vars=["nino34"]):
    """get reconstruction error"""

    ## get coefficients
    m = regress_helper(data, x_vars=x_vars)

    ## get reconstruction
    recon = (data[x_vars] * m).to_dataarray(dim="k").sum("k")

    ## reconstruct correlation
    cov = src.utils.reconstruct_cov_da(
        V_x=recon,
        V_y=data["taux"],
        U_x=data["taux_comp"],
        U_y=data["taux_comp"],
    )
    var = src.utils.reconstruct_var(
        scores=data["taux"],
        components=data["taux_comp"],
    )

    return cov / var

#### Compute

In [None]:
## function to get correlation
get_skill = lambda x: x.groupby("time.month").map(get_recon_skill)

## then, reconstruct regression coefficient
r_early = get_skill(anom.isel(year=0))
r_late = get_skill(anom.isel(year=-1))

### Plot skill

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=r_early, amp=0.8)
plot_cycle_hov(axs[0], data=r_late, amp=0.8, is_filled=False)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=r_late - r_early, amp=0.4)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"Correlation")
axs[1].set_title("Change")

plt.show()