# RO param. covariance

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import cartopy.util
import copy

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Load data

In [None]:
## fits
save_fp = pathlib.Path(SAVE_FP, "fits_cesm", "T3_hw_bymember.nc")
fits = xr.open_dataset(save_fp)

## Th data
Th = src.utils.load_cesm_indices()

## omit first year (bc of NaN in h,hw vars)
Th = Th.sel(time=slice("1851", None))

## standardize (for convenience)
Th /= Th.std()

## get windowed data (used to estimate change in parameters over time)
Th_rolling = src.utils.get_windowed(Th, window_size=480, stride=120)

## extract param values

In [None]:
## specify model
MODEL = src.XRO.XRO(ncycle=12, ac_order=3, is_forward=True)

## extract parameters
params = src.utils.get_params(fits=fits, model=MODEL)

## get change from initial period
delta_params = params - params.isel(year=0)

## Plot stats

In [None]:
def scatter_params(ax, p0, p1, params, **scatter_kwargs):
    """scatter params on ax object"""

    ## helper func to select data
    sel = lambda n: params[n].isel(year=-1).mean("cycle")

    ## scatter data
    ax.scatter(sel(p0), sel(p1), **scatter_kwargs)

    ## title
    r = xr.corr(sel(p0), sel(p1)).values.item()
    ax.set_title(f"$r$({p0}, {p1})$=${r:.2f}")

    ## label
    ax.set_xlabel(p0)
    ax.set_ylabel(p1)

    return ax

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(9, 2), layout="constrained")

kwargs = dict(params=delta_params, s=15)
scatter_params(axs[0], "R", "F1", **kwargs)
scatter_params(axs[1], "epsilon", "F2", **kwargs)
scatter_params(axs[2], "F1", "F2", **kwargs)
scatter_params(axs[3], "R", "epsilon", **kwargs)

ax_kwargs = dict(ls="--", c="k", lw=0.8)
for ax in axs:
    ax.axhline(0, **ax_kwargs)
    ax.axvline(0, **ax_kwargs)

plt.show()