# RO param. covariance

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import cartopy.util
import copy
import time

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Functions

In [None]:
# ## specify which variables to look at
# VARNAMES = ["T"]
# total = []
# for v in VARNAMES:
#     total.append(forced[v] + anom[v])
#     total.append(forced[f"{v}_comp"])
# total = xr.merge(total)

## Load data

### T,h, and fits

In [None]:
## fits
save_fp = pathlib.Path(SAVE_FP, "fits_cesm", "T3_hw_bymember.nc")
# save_fp = pathlib.Path(SAVE_FP, "fits_cesm", "T3_hwhat_bymember.nc")
fits = xr.open_dataset(save_fp)

## Th data
Th = src.utils.load_cesm_indices()

## omit first year (bc of NaN in h,hw vars)
Th = Th.sel(time=slice("1851", None))

## standardize (for convenience)
Th /= Th.std()

## get windowed data (used to estimate change in parameters over time)
Th_rolling = src.utils.get_windowed(Th, window_size=480, stride=120)

### thermocline depth, etc.

In [None]:
## load spatial data
_, anom = src.utils.load_consolidated()

## specify which variables to look at
VARNAMES = ["T"]
total = []
for v in VARNAMES:
    total.append(forced[v] + anom[v])
    total.append(forced[f"{v}_comp"])
total = xr.merge(total)

## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
total_early = total.sel(t_early).compute()
total_late = total.sel(t_late).compute()

Compute climatology for each ensemble member (note: this is slow if doing "bymonth")

In [None]:
t0 = time.time()
# clim_early = src.utils.reconstruct_clim(total_early)
# clim_late = src.utils.reconstruct_clim(total_late)
clim_early = src.utils.reconstruct_wrapper(total_early.mean("time"))
clim_late = src.utils.reconstruct_wrapper(total_late.mean("time"))
t1 = time.time()
print(f"Elapsed time: {t1-t0:.2f} seconds")

In [None]:
## get change in H
H_early = src.utils.get_H_int(clim_early["T"], thresh=0.08)
H_late = src.utils.get_H_int(clim_late["T"], thresh=0.08)
# H_early = src.utils.get_H(clim_early["T"])
# H_late = src.utils.get_H(clim_late["T"])
delta_H = H_late - H_early

## extract param values

In [None]:
## specify model
MODEL = src.XRO.XRO(ncycle=12, ac_order=3, is_forward=True)

## extract parameters
params = src.utils.get_params(fits=fits, model=MODEL)

## get change from initial period
delta_params = params - params.isel(year=0)

## Plot stats

### RO parameter covariance

#### Plotting funcs

In [None]:
def scatter_params(ax, p0, p1, params, **scatter_kwargs):
    """scatter params on ax object"""

    ## helper func to select data
    sel = lambda n: params[n].mean("cycle")

    ## scatter data
    ax.scatter(sel(p0), sel(p1), **scatter_kwargs)

    ## title
    r = xr.corr(sel(p0), sel(p1)).values.item()
    ax.set_title(f"$r$({p0}, {p1})$=${r:.2f}")

    ## label
    ax.set_xlabel(p0)
    ax.set_ylabel(p1)

    return ax


def compare_scatter_params(axs, params, **scatter_kwargs):
    """comparison plot of parameters"""

    ## shared args
    kwargs = dict(params=params, **scatter_kwargs)

    ## scatter the parameters
    scatter_params(axs[0], "R", "F1", **kwargs)
    scatter_params(axs[1], "epsilon", "F2", **kwargs)
    scatter_params(axs[2], "F1", "F2", **kwargs)
    scatter_params(axs[3], "R", "epsilon", **kwargs)
    scatter_params(axs[4], "BJ_ac", "wyrtki", **kwargs)

    return

#### make plots

In [None]:
print(f"Early:")
fig, axs = plt.subplots(1, 5, figsize=(11, 2), layout="constrained")
compare_scatter_params(axs, params.isel(year=0), s=15)
plt.show()

print(f"\n\nLate:")
fig, axs = plt.subplots(1, 5, figsize=(11, 2), layout="constrained")
compare_scatter_params(axs, params.isel(year=-1), s=15)
plt.show()

print(f"\n\nChange:")
fig, axs = plt.subplots(1, 5, figsize=(11, 2), layout="constrained")
compare_scatter_params(axs, delta_params.isel(year=-1), s=15)
ax_kwargs = dict(ls="--", c="k", lw=0.8)
for ax in axs:
    ax.axhline(0, **ax_kwargs)
    ax.axvline(0, **ax_kwargs)

plt.show()

### Mean state and RO param covariance

In [None]:
xr.corr(
    params["F2"].isel(year=0).mean("cycle"),
    H_early.sel(longitude=slice(190, 240)).mean("longitude"),
)

In [None]:
fig, ax = plt.subplots(figsize=(2.5, 2.5))

ax.scatter(
    params["xi_h"].isel(year=0).mean("cycle"),
    H_early.sel(longitude=slice(190, 240)).mean("longitude"),
)

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(2.5, 2.5))

ax.scatter(
    delta_params["xi_T"].isel(year=-1).mean("cycle"),
    delta_H.sel(longitude=slice(190, 240)).mean("longitude"),
)

plt.show()