# RO param. covariance

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import cartopy.util
import copy
import time

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Functions

In [None]:
def format_subsurf_ax(ax):
    """set up ax for plotting subsurface data"""

    ax.set_ylim(ax.get_ylim()[::-1])
    ax.set_xlim([140, 280])
    ax.set_yticks([])
    ax.set_xlabel("Longitude")
    ax.set_yticks([300, 150, 0])
    ax.set_ylabel("Depth (m)")

    return

## Load data

### T,h, and fits

In [None]:
## fits
save_fp = pathlib.Path(SAVE_FP, "fits_cesm", "T3_hw_bymember.nc")
# save_fp = pathlib.Path(SAVE_FP, "fits_cesm", "T3_hwhat_bymember.nc")
fits = xr.open_dataset(save_fp)

## Th data
Th = src.utils.load_cesm_indices()

## omit first year (bc of NaN in h,hw vars)
Th = Th.sel(time=slice("1851", None))

## standardize (for convenience)
Th /= Th.std()

## get windowed data (used to estimate change in parameters over time)
Th_rolling = src.utils.get_windowed(Th, window_size=480, stride=120)

## extract param values

In [None]:
## specify model
MODEL = src.XRO.XRO(ncycle=12, ac_order=3, is_forward=True)

## extract parameters
params = src.utils.get_params(fits=fits, model=MODEL)
params = params.rename({"cycle": "month"})

## get early/late
params_early = params.isel(year=0)
params_late = params.isel(year=-1)

## get change from initial period
delta_params = params_late - params_early

### thermocline depth, etc.

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()

## specify which variables to look at
VARNAMES = ["T", "sst", "mld"]
total = []
for v in VARNAMES:
    total.append(forced[v] + anom[v])
    total.append(forced[f"{v}_comp"])
total = xr.merge(total)

In [None]:
## split into early/late periods (to match RO)
t_early = dict(time=slice("1851", "1890"))
t_late = dict(time=slice("2061", "2100"))

## split surface data
total_early = total.sel(t_early).compute()
total_late = total.sel(t_late).compute()

## reconstruct climatologies
clim_early = src.utils.reconstruct_clim(total_early)
clim_late = src.utils.reconstruct_clim(total_late)

Compute mean state stuff

In [None]:
get_dTdx = lambda x: src.utils.get_nino4(x["sst"]) - src.utils.get_nino3(x["sst"])
lon_avg = lambda x: x.sel(longitude=slice(120, 210)).mean("longitude")
get_H = lambda x: src.utils.get_H_int(x["T"], thresh=0.08)

## compute thermocline depth
for clim in [clim_early, clim_late]:
    clim["dTdx"] = get_dTdx(clim)
    clim["H"] = get_H(clim)
    clim["H_hw"] = lon_avg(get_H(clim))
    clim["H_inv"] = 1 / clim["H"]
    clim["H_hw_inv"] = 1 / clim["H_hw"]
    clim["mld_eq"] = clim["mld"].mean("latitude")
    clim["mld_eq_inv"] = 1 / clim["mld_eq"]

## get change in climatology
delta_clim = clim_late - clim_early

## Plot stats

### RO parameter covariance

#### Plotting funcs

In [None]:
def scatter_params(ax, p0, p1, params, **scatter_kwargs):
    """scatter params on ax object"""

    ## helper func to select data
    sel = lambda n: params[n].mean("month")

    ## scatter data
    ax.scatter(sel(p0), sel(p1), **scatter_kwargs)

    ## title
    r = xr.corr(sel(p0), sel(p1)).values.item()
    ax.set_title(f"$r$({p0}, {p1})$=${r:.2f}")

    ## label
    ax.set_xlabel(p0)
    ax.set_ylabel(p1)

    return ax


def compare_scatter_params(axs, params, **scatter_kwargs):
    """comparison plot of parameters"""

    ## shared args
    kwargs = dict(params=params, **scatter_kwargs)

    ## scatter the parameters
    scatter_params(axs[0], "R", "F1", **kwargs)
    scatter_params(axs[1], "epsilon", "F2", **kwargs)
    scatter_params(axs[2], "F1", "F2", **kwargs)
    scatter_params(axs[3], "R", "epsilon", **kwargs)
    scatter_params(axs[4], "BJ_ac", "wyrtki", **kwargs)

    return

#### make plots

In [None]:
print(f"Early:")
fig, axs = plt.subplots(1, 5, figsize=(11, 2), layout="constrained")
compare_scatter_params(axs, params_early, s=15)
plt.show()

print(f"\n\nLate:")
fig, axs = plt.subplots(1, 5, figsize=(11, 2), layout="constrained")
compare_scatter_params(axs, params_late, s=15)
plt.show()

print(f"\n\nChange:")
fig, axs = plt.subplots(1, 5, figsize=(11, 2), layout="constrained")
compare_scatter_params(axs, delta_params, s=15)
ax_kwargs = dict(ls="--", c="k", lw=0.8)
for ax in axs:
    ax.axhline(0, **ax_kwargs)
    ax.axvline(0, **ax_kwargs)

plt.show()

### Mean state and RO param covariance

In [None]:
## param to plot
p = "H_hw_inv"

## how to reduce over months
sel = lambda x: x.mean("month")
# sel = lambda x : x.sel(month=slice(5,8)).mean("month")

fig, axs = plt.subplots(3, 2, figsize=(6, 7.5), layout="constrained")

axs[0, 0].scatter(sel(clim_early[p]), sel(params_late["BJ_ac"]))
axs[0, 1].scatter(sel(clim_early[p]), sel(params_early["wyrtki"]))

axs[1, 0].scatter(sel(clim_late[p]), sel(delta_params["BJ_ac"]))
axs[1, 1].scatter(sel(clim_late[p]), sel(delta_params["wyrtki"]))

axs[2, 0].scatter(sel(delta_clim[p]), sel(delta_params["BJ_ac"]))
axs[2, 1].scatter(sel(delta_clim[p]), sel(delta_params["wyrtki"]))

axs[0, 0].set_title("BJ")
axs[0, 1].set_title("Wyrtki")
axs[-1, 0].set_xlabel(p)
axs[-1, 1].set_xlabel(p)

axs[0, 1].set_ylabel("early")
axs[1, 1].set_ylabel("late")
axs[2, 1].set_ylabel("Change")

for ax in axs[:, 1]:
    ax.yaxis.set_label_position("right")

plt.show()

### Plot mean state correlation with param change

In [None]:
# sel = lambda x : x.sel(month=5)
sel = lambda x: x.mean("month")

## compute
corr_BJ = xr.corr(
    # delta_clim["T"].mean("month"),
    sel(delta_clim["T"]),
    delta_params["BJ_ac"].mean("month"),
    dim="member",
)

corr_wyrtki = xr.corr(
    sel(delta_clim["T"]), delta_params["wyrtki"].mean("month"), dim="member"
)

## plot
fig, axs = plt.subplots(1, 2, figsize=(8, 3), layout="constrained")

## kwargs for plotting
plot_kwargs = dict(levels=src.utils.make_cb_range(0.6, 0.06), cmap="cmo.balance")

for ax, corr in zip(axs, [corr_BJ, corr_wyrtki]):

    cp = ax.contourf(
        delta_clim.longitude,
        delta_clim.z_t,
        corr,
        **plot_kwargs,
    )

    ax.plot(
        clim_early.longitude, sel(clim_early["mld_eq"]).mean("member"), c="k", ls="--"
    )
    ax.plot(clim_early.longitude, sel(clim_early["H"]).mean("member"), c="k")

for ax in axs:
    format_subsurf_ax(ax)
axs[1].set_yticks([])
axs[1].set_ylabel(None)
fig.colorbar(cp, ax=axs[1], label="corr", ticks=[-0.6, 0, 0.6])
plt.show()

In [None]:
# sel = lambda x : x.sel(month=6)
sel = lambda x: x.mean("month")

## compute
corr_BJ = xr.corr(
    # delta_clim["T"].mean("month"),
    sel(delta_clim["sst"]),
    delta_params["BJ_ac"].mean("month"),
    dim="member",
)

corr_wyrtki = xr.corr(
    sel(delta_clim["sst"]), delta_params["wyrtki"].mean("month"), dim="member"
)

fig = plt.figure(figsize=(6, 10 / 3), layout="constrained")
axs = src.utils.subplots_with_proj(
    fig, nrows=2, ncols=1, format_func=src.utils.plot_setup_pac
)

## kwargs for plotting
plot_kwargs = dict(
    levels=src.utils.make_cb_range(0.6, 0.06),
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
)

for ax, corr in zip(axs.flatten(), [corr_BJ, corr_wyrtki]):

    cp = ax.contourf(
        delta_clim.longitude,
        delta_clim.latitude,
        corr,
        **plot_kwargs,
    )


plt.show()

In [None]:
# sel = lambda x : x.sel(month=4)
sel = lambda x: x.mean("month")

## compute
corr_BJ = xr.corr(
    # delta_clim["T"].mean("month"),
    sel(delta_clim["H_inv"]),
    delta_params["BJ_ac"].mean("month"),
    dim="member",
)

corr_wyrtki = xr.corr(
    sel(delta_clim["H_inv"]), delta_params["wyrtki"].mean("month"), dim="member"
)

fig, axs = plt.subplots(2, 1, figsize=(3, 4.5), layout="constrained")

axs[0].plot(delta_clim.longitude, corr_BJ)
axs[1].plot(delta_clim.longitude, corr_wyrtki)

plt.show()