# Diagnose changes-over time in $\tau_x$ predictability

## Imports

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import cartopy.util
import copy

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def merimean(x):
    return x.sel(longitude=slice(130, 280), latitude=slice(-5, 5)).mean("latitude")


def plot_cycle_hov(ax, data, amp, is_filled=True, xticks=[190, 240]):
    """plot data on ax object"""

    ## specify shared kwargs
    shared_kwargs = dict(levels=src.utils.make_cb_range(amp, amp / 5), extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap="cmo.balance")

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## average over latitudes (if necessary)
    if "latitude" in data.coords:
        plot_data = merimean(data)
    else:
        plot_data = data

    ## do the plotting
    cp = plot_fn(
        plot_data.longitude,
        plot_data.month,
        plot_data,
        **kwargs,
        **shared_kwargs,
    )

    ## format ax object
    kwargs = dict(c="w", ls="--", lw=1)
    ax.set_xlim([130, 280])
    ax.set_xlabel("Lon")
    ax.set_xticks(xticks)
    for tick in xticks:
        ax.axvline(tick, **kwargs)

    return cp

## Load data

In [None]:
## T, h
Th = src.utils.load_cesm_indices(load_z20=True)

## spatial
_, anom = src.utils.load_consolidated()
anom = anom[["taux", "taux_comp"]]

## combine data
anom = xr.merge([anom, Th]).sel(time=slice("1851", None))

## get "windowed" data
anom = src.utils.get_windowed(anom, stride=120).compute()

## Analysis

### Compute skill

#### Funcs

In [None]:
def get_recon_error(data):
    """get reconstruction skill"""

    ## reconstruct correlation
    std = src.utils.reconstruct_std(
        scores=data["taux_recon"] - data["taux"],
        components=data["taux_comp"],
    )

    return std


def get_recon_error_bymonth(data, x_vars=["nino34"], max_order=None):
    return eval_metric_bymonth(
        data=data, fn=get_recon_error, x_vars=x_vars, max_order=max_order
    )


def get_recon_skill(data):
    """get reconstruction skill"""

    ## reconstruct correlation
    cov = src.utils.reconstruct_cov_da(
        V_x=data["taux_recon"],
        V_y=data["taux"],
        U_x=data["taux_comp"],
        U_y=data["taux_comp"],
    )
    var = src.utils.reconstruct_var(
        scores=data["taux"],
        components=data["taux_comp"],
    )

    return cov / var


def eval_metric_bymonth(data, fn, x_vars=["nino34"], max_order=None):
    """evaluate function by month"""

    ## first, get recon
    data["taux_recon"] = get_recon(data, x_vars=x_vars, max_order=max_order)

    return data.groupby("time.month").map(fn)


def get_recon_skill_bymonth(data, x_vars=["nino34"], max_order=None):
    """get reconstruction skill by month"""

    return eval_metric_bymonth(
        data=data, fn=get_recon_skill, x_vars=x_vars, max_order=max_order
    )


def regress_helper(data, x_vars=["nino34"]):
    """fit linear regression model to data"""

    ## get coefficients
    coefs = src.utils.regress_xr_proj(data, x_vars=x_vars, y_vars=["taux"])

    ## rename variables for mappting
    coefs = coefs["taux"].to_dataset("j")

    return coefs


def regress_helper_harm(data, max_order=3, x_vars=["nino34"]):
    """do linear regression using harmonics"""

    ## get coefficiennts
    coefs = src.utils.regress_harm_wrapper(
        data,
        y_vars=["taux"],
        x_vars=x_vars,
        max_order=max_order,
    )

    ## rename
    coefs = coefs["taux"].to_dataset("ell")

    return coefs


def get_recon(data, x_vars=["nino34"], max_order=None):
    """Get taux reconstruction."""

    ## get reconstruction coefficients
    if max_order is None:
        ## by-month version
        m = data.groupby("time.month").map(regress_helper, x_vars=x_vars)

    else:
        ## harmonic version
        m = regress_helper_harm(data, x_vars=x_vars, max_order=max_order)

    ## get reconstruction
    recon_ = data[x_vars].groupby("time.month") * m
    recon = recon_.to_dataarray(dim="k").sum("k")

    return recon

#### Compute

In [None]:
## specify predictor variables to use
kwargs = dict(x_vars=["T_3", "T_34", "T_4"], max_order=3)

## set save filepath
save_fp = pathlib.Path(
    os.environ["SAVE_FP"], f"taux_skill_ac_order{kwargs['max_order']}.nc"
)

if save_fp.is_file():

    ## load existing file
    res = xr.open_dataset(save_fp)

else:
    ## empty lists to hold results
    r_vals = []
    sigma_vals = []

    ## loop thru years
    for y in tqdm.tqdm(anom.year):

        ## get data subset
        anom_ = anom.sel(year=y)

        ## get skill
        r_vals.append(get_recon_skill_bymonth(anom_, **kwargs))

        ## get sigma
        sigma_vals.append(get_recon_error_bymonth(anom_, **kwargs))

    ## append into xr.DataArrays; then merge
    sigma_vals = xr.concat(sigma_vals, dim=anom.year)
    r_vals = xr.concat(r_vals, dim=anom.year)
    res = xr.merge([sigma_vals.rename("sigma"), r_vals.rename("r")])

    ## save to file
    res.to_netcdf(save_fp)

### Plot skill

In [None]:
## specify years to plot
y0 = 1871
y1 = 1971

#### Hovmollers

Correlation

In [None]:
## get correlations to plot
r0 = res["r"].sel(year=y0)
r1 = res["r"].sel(year=y1)
DIFF_SCALE = 8

## shared args
kwargs = dict(amp=0.8, lat_bound=5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=r0, **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=r1, **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=DIFF_SCALE * (r1 - r0), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label="Corr.",
)
src.utils.format_hov_axs(axs)

plt.show()

Error

In [None]:
## get correlations to plot
sigma0 = res["sigma"].sel(year=y0)
sigma1 = res["sigma"].sel(year=y1)
DIFF_SCALE = 8

## shared args
kwargs = dict(amp=1.2e-2, lat_bound=5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=sigma0, **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=sigma1, **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=DIFF_SCALE * (sigma0 - sigma1), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label="Pa",
)
src.utils.format_hov_axs(axs)

plt.show()

#### Spatial

In [None]:
## select month
sel = lambda x: x["r"].sel(month=5)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=0.8, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], res.sel(year=y0), **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], res.sel(year=y1), **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], res.sel(year=y1) - res.sel(year=y0), **dict(contour_kwargs, amp=0.2)
)

In [None]:
## select month
sel = lambda x: x["sigma"].sel(month=5)

## set up plot
fig = plt.figure(figsize=(7, 1.5), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=5)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=1.5e-2, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], res.sel(year=y0), **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], res.sel(year=y1), **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], res.sel(year=y0) - res.sel(year=y1), **dict(contour_kwargs, amp=3.75e-3)
)