# Sverdrup balance
How well does it hold by season? More efficient in future?

## Imports

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import pandas as pd

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def make_scatter(ax, data, x_var, y_var, scale=1, months=None):
    """scatter plot data on axis"""

    ## helper function to stack time/member dims
    stack = lambda x: x.stack(sample=["member", "time"])
    prep = lambda x: stack(src.utils.sel_month(x, months=months))

    ## evaluate data
    x = prep(data[x_var])
    y = prep(data[y_var]) * scale

    ## compute slope for best fit line
    slope = src.utils.regress_core(X=x, Y=y, dim="sample")

    ## convert to numpy
    slope = slope.values.item()

    ## plot data
    ax.scatter(x, y, s=0.5)

    ## plot best fit
    xtest = np.linspace(x.values.min(), x.values.max())
    ax.plot(xtest, slope * xtest, c="k", lw=1)

    ## plot some guidelines
    ax.axhline(0, ls="--", lw=0.8, c="k")
    ax.axvline(0, ls="--", lw=0.8, c="k")

    return slope

## Load data

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()

### Compute indices

In [None]:
sel_e = lambda x: x.sel(
    latitude=slice(-5, 5),
    longitude=slice(210, 280),
).mean(["latitude", "longitude"])
sel_w = lambda x: x.sel(
    latitude=slice(-5, 5),
    longitude=slice(120, 210),
).mean(["latitude", "longitude"])

## specify funcs and variables
varnames = ["ssh", "ssh", "taux", "sst", "z20", "z20"]
fns = [sel_e, sel_w, src.utils.get_nino4, src.utils.get_nino34, sel_e, sel_w]
newnames = ["h_e", "h_w", "taux4", "T_34", "h_e_z20", "h_w_z20"]

## compute indices
idxs = []
for v, fn, n in zip(varnames, fns, newnames):
    idxs.append(
        src.utils.reconstruct_wrapper(anom[[v, f"{v}_comp"]], fn=fn).rename({v: n})
    )
idxs = xr.merge(idxs)

## get ssh grad
idxs["dh"] = idxs["h_e"] - idxs["h_w"]
idxs["dh_z20"] = idxs["h_e_z20"] - idxs["h_w_z20"]

### Split by period

In [None]:
t_early = dict(time=slice("1851", "1880"))
t_late = dict(time=slice("2071", "2100"))

idxs_early = idxs.sel(t_early)
idxs_late = idxs.sel(t_late)

### Thermocline in each period

In [None]:
## get climatology
T_forced = forced[["T", "T_comp"]]
T_early = src.utils.reconstruct_clim(data=T_forced.sel(t_early))["T"]
T_late = src.utils.reconstruct_clim(data=T_forced.sel(t_late))["T"]

## compute thermocline depth (early & late)
sel = lambda x: x.sel(longitude=slice(140, 280)).mean("longitude")
H_early = sel(src.utils.get_H_int(T_early, thresh=0.08))
H_late = sel(src.utils.get_H_int(T_late, thresh=0.08))

## get fractional change
H = xr.concat([H_early, H_late], dim=pd.Index([0, 1], name="year"))
delta_H = src.utils.frac_change(H).isel(year=-1).squeeze(drop=True)

## Analysis

### Scatter plot

In [None]:
## shared args
kwargs = dict(
    scale=1e-2,
    x_var="taux4",
    y_var="dh_z20",
)

## set up plot
fig, axs = plt.subplots(2, 2, figsize=(5, 5), layout="constrained")

## plot data and best fit
for j, month in enumerate([12, 5]):
    for i, idxs_ in enumerate([idxs_early, idxs_late]):

        ## plot data
        m = make_scatter(axs[j, i], idxs_, months=month, **kwargs)

        ## plot best fit
        zz = np.linspace(*axs[j, i].get_xlim())
        axs[j, i].plot(zz, m * zz, c="k", ls="-", alpha=0.5)

        ## label
        axs[j, i].set_title(f"{m:.1f}" + r" $m~\text{Pa}^{-1}$")

## formatting
src.utils.set_lims(axs)
for ax in axs[0, :]:
    ax.set_xticks([])
for ax in axs[:, 1]:
    ax.set_yticks([])

plt.show()

### Line plot by season

In [None]:
## shared args
kwargs = dict(
    scale=1e-2,
    x_var="taux4",
    y_var="dh_z20",
    # y_var="h_e",
)

## empty array to hold results
m = xr.zeros_like(H)

## loop thru
fig, ax = plt.subplots()
for i, month in enumerate(H.month):
    for j, idxs_ in enumerate([idxs_early, idxs_late]):
        m.values[j, i] = make_scatter(ax, idxs_, months=month, **kwargs)
plt.close()

## get fractional change
delta_m = src.utils.frac_change(m, inv=False).isel(year=-1)

In [None]:
fig, ax = plt.subplots(figsize=(3, 2.5))

## plot data
ax.plot(m.month, m.isel(year=0), label="early")
ax.plot(m.month, m.isel(year=-1), label="late")
ax.plot(m.month, m.isel(year=0) * (delta_H + 1), ls="--", label="predicted")

## legend
ax.legend(prop=dict(size=8))

plt.show()