# Bjerknes feedback changes over time

## imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_dT_sub(Tsub, Hm, delta=25):
    """
    Get temperature difference b/n entrainment zone and mixed layer.
    (positive if entrainment zone is warmer than ML)
    """

    ## find indices in ML and entrainment zone (ez)
    in_ml = Tsub.z_t <= Hm
    in_ez = (Tsub.z_t > Hm) & (Tsub.z_t < (delta + Hm))

    ## get Tbar and Tplus (following Frankignoul et al paper)
    Tbar = Tsub.where(in_ml).mean("z_t")
    Tplus = Tsub.where(in_ez).mean("z_t")

    ## get gradient
    dT = Tplus - Tbar

    return dT


def prep(data):
    """remove sst dependence and compute tendencies"""

    ## remove from h indices
    for h_idx in ["h_w", "h"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var=h_idx, T_var="T_34"
        )

    return data


def load_budget_data(t_early, t_late, target_grid):
    """load ground truth heat budget data"""

    def load_var(varname):
        """load variable from prepped folder"""

        ## open data
        data = xr.open_mfdataset(
            sorted(list(pathlib.Path(DATA_FP, "cesm", f"{varname}_temp").glob("*.nc"))),
            concat_dim="member",
            combine="nested",
            parallel=True,
        )

        return data.assign_coords({"member": np.arange(100)})

    ## load data
    budget_data = xr.merge([load_var(v) for v in ["adv", "ddt_T"]])

    ## get difference
    budget_data["diff"] = budget_data["TEND_TEMP"] - budget_data["ADV_3D_TEMP"]

    ## convert z coord from cm to m, and unit from K/s to K/mo
    M_PER_CM = 1e-2
    SEC_PER_MO = 8.64e4 * 30

    ## convert from (i) cm to m and (ii) K/s to K/mo
    budget_data = budget_data.assign_coords({"z_t": budget_data.z_t * M_PER_CM})
    budget_data = budget_data * SEC_PER_MO

    ## fix longitude coordinate
    budget_data = budget_data.assign_coords({"nlon": budget_data.lon.values})
    budget_data = budget_data.drop_vars("lon").rename({"nlon": "longitude"})

    ## trim in time
    budget_data = xr.concat(
        [budget_data.sel(t_early), budget_data.sel(t_late)], dim="time"
    )

    ## interpolate
    budget_data = budget_data.interp_like(target_grid)

    # ## separate forced/anomalies
    forced_bud, anom_bud = src.utils.separate_forced(budget_data)

    return forced_bud, anom_bud


def plot_mlds(axs, bar_early, bar_late, sel):
    """plot mixed layer depth on pair of axs objects"""

    ## get longitude
    lon = src.utils.merimean(bar_early).longitude

    ## plot
    axs[0].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[1].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")
    axs[2].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[2].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")

    return

## Load data

### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")

## compute relative sst
for n in ["T_3", "T_34", "T_4"]:
    Th[f"{n}_rel"] = Th_total[n] - trop_sst["trop_sst_10"]

### Spatial data

#### Load

In [None]:
## load spatial data
CONS_DIR = pathlib.Path(DATA_FP, "cesm", "consolidated")
forced = xr.open_dataset(CONS_DIR / "forced.nc")
anom = xr.open_dataset(CONS_DIR / "anom.nc")

## add T,h information
for n in ["T_3", "T_34", "T_4", "h", "h_w"]:
    anom[n] = Th[n]

#### compute some indices

In [None]:
## add other variables used in BJ computation
fns = [src.utils.get_nino3, src.utils.get_nino34, src.utils.get_nino4]
labels = ["3", "34", "4"]
for fn, label in zip(fns, labels):
    for n in ["nhf", "ssh", "taux"]:
        kwargs = dict(components=anom[f"{n}_comp"], scores=anom[n], fn=fn)
        anom[f"{n}_{label}"] = src.utils.reconstruct_fn(**kwargs)

## get sst tendency (and convert from 1/yr to 1/mo)
anom["ddt_sst"] = 1 / 12 * src.utils.get_ddt(anom[["sst"]])["ddt_sst"]

#### early/late split

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early)
anom_late = anom.sel(t_late)

## preprocess
anom_early = prep(anom_early)
anom_late = prep(anom_late)

## load to memory
anom_early.load()
anom_late.load();

## BJ couplings

In [None]:
def get_alpha(data, dim, r="34"):
    """compute mu: linear dependence of nhf on sst"""

    return src.utils.regress_core(Y=data[f"nhf_{r}"], X=data[f"T_{r}"], dim=dim)


def get_mu(data, dim, r="34"):
    """compute mu: linear dependence of taux on sst"""

    return src.utils.regress_core(Y=data["taux_4"], X=data[f"T_{r}"], dim=dim)


def get_beta(data, dim, r="34"):
    """compute mu: linear dependence of ssh on taux"""

    return src.utils.regress_core(Y=data[f"ssh_{r}"], X=data["taux_4"], dim=dim)


def get_xi(data, dim, r="34"):
    """compute mu: linear dependence of sst on ssh"""

    return src.utils.regress_core(Y=data[f"T_{r}"], X=data[f"ssh_{r}"], dim=dim)


def get_params(data, dim="time", r="34"):
    """compute all parameters"""
    params = xr.merge(
        [
            get_alpha(data, dim=dim, r=r).rename("alpha"),
            get_mu(data, dim=dim, r=r).rename("mu"),
            get_beta(data, dim=dim, r=r).rename("beta"),
            get_xi(data, dim=dim, r=r).rename("xi"),
        ]
    )

    params["coupling"] = params["mu"] * params["beta"] * params["xi"]

    return params


def get_rolling_params(data, n=10, reduce_ensemble=True):
    """get rolling parameters"""

    ## get rolling data
    idx_rolling = data.rolling({"time": 2 * n + 1}, center=True)

    ## expand rolling object along "window" dimension
    idx_rolling = idx_rolling.construct("window")

    ## stack member/window if desired
    if reduce_ensemble:
        idx_rolling = idx_rolling.stack(sample=["member", "window"])

    else:
        idx_rolling = idx_rolling.rename({"window": "sample"})

    return get_params(idx_rolling, dim="sample")


def get_rolling_params_bymonth(data, **kwargs):
    """get rolling parameters for each month separately..."""

    return data.groupby("time.month").map(get_rolling_params, **kwargs)

In [None]:
## get names of 'index' variables (scalars)
idx_vars = [n for n in list(anom) if n[-1].isdigit()]

## get params over time
params = get_rolling_params_bymonth(anom[idx_vars], n=16)

## unstack month and year to separate dims
params = src.utils.unstack_month_and_year(params)

## subtract off climatology
delta_params = params - params.isel(year=slice(None, 30)).mean("year")

In [None]:
## setup plot
fig, axs = plt.subplots(1, 5, figsize=(7, 3), layout="constrained")

## plot T data
src.utils.plot_hov2(
    fig, axs[0], delta_params["alpha"].T, amp=20, label=r"$\Delta~ \alpha$"
)
src.utils.plot_hov2(
    fig, axs[1], delta_params["mu"].T, amp=0.01, label=r"$\tau_x-\text{SST}$"
)
src.utils.plot_hov2(
    fig, axs[2], delta_params["beta"].T, amp=200, label=r"$\text{SSH}-\tau_x$"
)
src.utils.plot_hov2(
    fig, axs[3], delta_params["xi"].T, amp=7e-2, label=r"$\text{SST}-\text{SSH}$"
)
src.utils.plot_hov2(fig, axs[4], delta_params["coupling"].T, amp=0.5, label=r"coupling")

for ax in axs:
    ax.axvline(7, c="k", lw=1, ls="--")

## label
axs[0].set_yticks(np.linspace(1870, 2082, 5))
axs[0].set_ylabel("Year")
axs[1].set_ylim(axs[0].get_ylim())
plt.show()

## Compute BJ coefficients

#### Components of BJ feedback

Function

In [None]:
def get_bj_coupling(data):
    """get couplings needed to compute bjerknes feedback"""

    ## empty dataset to hold results
    couplings = xr.Dataset()

    ## get variables
    y_vars = ["taux", "ssh", "sst", "T"]
    x_vars = ["T_34", "taux_4", "ssh_3", "taux_4"]

    ## loop thru variables
    for y_var, x_var in tqdm.tqdm(zip(y_vars, x_vars)):
        kwargs = dict(data=data, y_vars=[y_var], x_vars=[x_var])
        coefs = src.utils.regress_xr_bymonth(**kwargs)[y_var].squeeze(drop=True)
        couplings[f"{x_var}-{y_var}"] = coefs

    return couplings

Compute

In [None]:
m_early_bj = get_bj_coupling(anom_early)
m_late_bj = get_bj_coupling(anom_late)

#### Regression on $T_{34}$ and $\hat{h}_w$

helper function

In [None]:
def fit_wrapper(data, y_vars, x_vars=["T_34", "h_w_hat"]):
    """fit linear regression model to data"""

    ## get coeffs
    kwargs = dict(y_vars=y_vars, x_vars=x_vars)
    coefs = src.utils.regress_xr_bymonth(data, **kwargs)

    ## get first coefficient
    return coefs.sel(j=x_vars[0])

Compute

In [None]:
## keep track of time
t0 = time.time()

## do regression
fit_kwargs = dict(
    y_vars=["ddt_sst", "sst", "nhf", "taux", "T", "w", "u", "uvel", "vvel"],
    x_vars=["T_34", "h_w_hat"],
)
m_early = fit_wrapper(anom_early, **fit_kwargs)
m_late = fit_wrapper(anom_late, **fit_kwargs)

## print out elapsed time
t1 = time.time()
print(f"Elapsed time: {t1-t0:.1f} seconds")

#### compute feedback terms

##### Subsurface

In [None]:
## get mean-state for each period
sub_vars = ["u", "w", "T", "sst", "uvel", "vvel", "mld"]
sub_vars += [f"{s}_comp" for s in sub_vars]
bar_early = src.utils.reconstruct_clim(forced[sub_vars].sel(t_early))
bar_late = src.utils.reconstruct_clim(forced[sub_vars].sel(t_late))

## get early/late feedbacks
feedbacks_early = src.utils.get_feedbacks(bar=bar_early, prime=m_early)
feedbacks_late = src.utils.get_feedbacks(bar=bar_late, prime=m_late)

## decompose changes
feedback_changes = src.utils.decompose_feedback_changes(
    bar_early=bar_early,
    bar_late=bar_late,
    prime_early=m_early,
    prime_late=m_late,
)

## merge with other data
m_early = xr.merge([m_early, feedbacks_early])
m_late = xr.merge([m_late, feedbacks_late, feedback_changes])

Compute things which *do* depend on MLD

In [None]:
## specify mixed layer kwargs
ml_early_kwargs = dict(H0=70, Hm=None)
ml_late_kwargs = dict(H0=70, Hm=None)

## Integrate over mixed layer
m_early = src.utils.get_ml_avg_ds(m_early, **ml_early_kwargs)
m_late = src.utils.get_ml_avg_ds(m_late, **ml_late_kwargs)

##### Surface (advection)

In [None]:
## compute surface feedbacks
feedbacks_surf_early = src.utils.get_surface_feedbacks(bar_early, m_early)
feedbacks_surf_late = src.utils.get_surface_feedbacks(bar_late, m_late)

## add to array
m_early = xr.merge([m_early, feedbacks_surf_early])
m_late = xr.merge([m_late, feedbacks_surf_late])

#### Compute $T_{sub}$

In [None]:
## height of entrainment zone (from base of ML)
ez_height = 20

## set kwargs
dT_kwargs_early = dict(Hm=ml_early_kwargs["H0"], delta=ez_height)
dT_kwargs_late = dict(Hm=ml_late_kwargs["H0"], delta=ez_height)

## T_ml - T_sub (regr on Niño 3.4 and taux)
m_early["dT_n34"] = get_dT_sub(Tsub=m_early["T"], **dT_kwargs_early)
m_late["dT_n34"] = get_dT_sub(Tsub=m_late["T"], **dT_kwargs_late)
m_early_bj["dT_taux"] = get_dT_sub(Tsub=m_early_bj["taux_4-T"], **dT_kwargs_late)
m_late_bj["dT_taux"] = get_dT_sub(Tsub=m_late_bj["taux_4-T"], **dT_kwargs_late)

## Plot feedback hovmollers 
E.g., thermocline ($\overline{w}~\frac{\partial T'}{\partial z}$) and Ekman feedback ($w'~\frac{\partial \overline{T}}{\partial z}$)

### Plot mixed layer integral

In [None]:
## specify plot amplitude
amp = 1

for n in ["THF_ml", "EKM_ml", "ADV_ml"]:

    print(f"\n{n}")
    fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

    ## plot data
    cp0 = src.utils.make_cycle_hov(axs[0], data=m_early[n], amp=amp)
    cp1 = src.utils.make_cycle_hov(axs[1], data=m_late[n], amp=amp)
    cp2 = src.utils.make_cycle_hov(axs[2], data=(m_late - m_early)[n], amp=amp * 0.75)

    ## make it look nicer
    cb = fig.colorbar(
        cp0,
        ax=axs[2],
        ticks=[-amp, 0, amp],
        label=r"$K~\left(\text{month}\right)^{-1}$",
    )
    src.utils.format_hov_axs(axs)
    for ax in axs:
        ax.axhline(7, ls="--", c="k", lw=1)

    plt.show()

### Decomp. Ekman feedback

In [None]:
## specify colorbar amp
amp = 0.5

## specify plot data and titles
plot_data = [
    (m_late - m_early)["EKM_ml"],
    m_late["delta_EKM_mean_ml"],
    m_late["delta_EKM_anom_ml"],
    m_late["delta_EKM_nl_ml"],
]
titles = [
    r"$\Delta\left(w~\frac{\partial \overline{T}}{\partial z}\right)$",
    r"$w_0'~\Delta\left(\frac{\partial \overline{T}}{\partial z}\right)$",
    r"$\Delta\left(w'\right)~\frac{\partial \overline{T}_0}{\partial z}$",
    r"$\Delta\left(w'\right)~\Delta\left(\frac{\partial \overline{T}}{\partial z}\right)$",
]

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

src.utils.format_hov_axs(axs)

## plot data
for ax, p, t in zip(axs, plot_data, titles):
    cp = src.utils.make_cycle_hov(ax, data=p, amp=amp)
    ax.set_title(t)
    ax.axhline(7, ls="--", c="k", lw=1)

for ax in axs[3:]:
    ax.set_yticks([])

## make it look nicer
cb = fig.colorbar(
    cp,
    ax=axs[-1],
    ticks=[-amp, 0, amp],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)

plt.show()

### Decomp. thermocline feedback

In [None]:
## specify colorbar amp
amp = 0.5

## specify plot data and titles
plot_data = [
    (m_late - m_early)["THF_ml"],
    m_late["delta_THF_mean_ml"],
    m_late["delta_THF_anom_ml"],
    m_late["delta_THF_nl_ml"],
]
titles = [
    r"$\Delta\left(\overline{w}~\frac{\partial T'}{\partial z}\right)$",
    r"$\Delta\left(\overline{w}\right)~\frac{\partial T_0}{\partial z}$",
    r"$\overline{w}_0~\Delta\left(\frac{\partial T'}{\partial z}\right)$",
    r"$\Delta\left(\overline{w}\right)~\Delta\left(\frac{\partial T'}{\partial z}\right)$",
]

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

src.utils.format_hov_axs(axs)

## plot data
for ax, p, t in zip(axs, plot_data, titles):
    cp = src.utils.make_cycle_hov(ax, data=p, amp=amp)
    ax.set_title(t)
    ax.axhline(7, ls="--", c="k", lw=1)

for ax in axs[3:]:
    ax.set_yticks([])

## make it look nicer
cb = fig.colorbar(
    cp,
    ax=axs[-1],
    ticks=[-amp, 0, amp],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)

plt.show()

### Decomp. ZAF feedback

In [None]:
## specify colorbar amp
amp = 0.25

## specify plot data and titles
plot_data = [
    (m_late - m_early)["ZAF_ml"],
    m_late["delta_ZAF_mean_ml"],
    m_late["delta_ZAF_anom_ml"],
    m_late["delta_ZAF_nl_ml"],
]
titles = [
    r"$\Delta\left(u'~\frac{\partial \overline{T}}{\partial x}\right)$",
    r"$u_0'~\Delta\left(\frac{\partial \overline{T}}{\partial x}\right)$",
    r"$\Delta\left(u'\right)~\frac{\partial \overline{T}_0}{\partial x}$",
    r"$\Delta\left(u'\right)~\Delta\left(\frac{\partial \overline{T}}{\partial x}\right)$",
]

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

src.utils.format_hov_axs(axs)

## plot data
for ax, p, t in zip(axs, plot_data, titles):
    cp = src.utils.make_cycle_hov(ax, data=p, amp=amp)
    ax.set_title(t)
    ax.axhline(7, ls="--", c="k", lw=1)

for ax in axs[3:]:
    ax.set_yticks([])

## make it look nicer
cb = fig.colorbar(
    cp,
    ax=axs[-1],
    ticks=[-amp, 0, amp],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)

plt.show()

### Decomp. DD feedback

In [None]:
## specify colorbar amp
amp = 0.25

## specify plot data and titles
plot_data = [
    (m_late - m_early)["DD_ml"],
    m_late["delta_DD_mean_ml"],
    m_late["delta_DD_anom_ml"],
    m_late["delta_DD_nl_ml"],
]
titles = [
    r"$\Delta\left(\overline{u}~\frac{\partial T'}{\partial x}\right)$",
    r"$\Delta\left(\overline{u}\right)~\frac{\partial T_0}{\partial x}$",
    r"$\overline{u}_0~\Delta\left(\frac{\partial T'}{\partial x}\right)$",
    r"$\Delta\left(\overline{u}\right)~\Delta\left(\frac{\partial T'}{\partial x}\right)$",
]

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

src.utils.format_hov_axs(axs)

## plot data
for ax, p, t in zip(axs, plot_data, titles):
    cp = src.utils.make_cycle_hov(ax, data=p, amp=amp)
    ax.set_title(t)
    ax.axhline(7, ls="--", c="k", lw=1)

for ax in axs[3:]:
    ax.set_yticks([])

## make it look nicer
cb = fig.colorbar(
    cp,
    ax=axs[-1],
    ticks=[-amp, 0, amp],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)

plt.show()

### Ground truth budget

In [None]:
## should we compute/plot this? (time-consuming)
PLOT_BUDGET = False

Load data

In [None]:
if PLOT_BUDGET:

    ## load budget data
    budg_forced, budg_anom = load_budget_data(
        t_early=t_early,
        t_late=t_late,
        target_grid=anom[["longitude", "z_t"]],
    )

    ## split into early/late
    budg_early = budg_anom.sel(t_early).compute()
    budg_late = budg_anom.sel(t_late).compute()

    ## add T,h information
    for n in ["T_34", "h_w_hat"]:
        budg_early[n] = anom_early[n]
        budg_late[n] = anom_late[n]

Compute regression coefficients

In [None]:
if PLOT_BUDGET:

    ## get datasets to hold budget coefficients
    m_early_bud = xr.Dataset()
    m_late_bud = xr.Dataset()

    ## get ground truth tendencies
    for n in ["TEND_TEMP", "ADV_3D_TEMP"]:

        ## regression coefficients
        kwargs = dict(x_vars=["T_34", "h_w_hat"], y_var=n)
        m_early_bud[n] = src.utils.regress_bymonth(budg_early, **kwargs)["T_34"]
        m_late_bud[n] = src.utils.regress_bymonth(budg_late, **kwargs)["T_34"]

        ## mixed layer avg
        m_early_bud[f"{n}_ml"] = src.utils.get_ml_avg(m_early_bud[n], **ml_early_kwargs)
        m_late_bud[f"{n}_ml"] = src.utils.get_ml_avg(m_late_bud[n], **ml_late_kwargs)

    ## get d/dt(SST)
    m_early_bud["ddt_sst"] = m_early_bud["TEND_TEMP"].isel(z_t=0)
    m_late_bud["ddt_sst"] = m_late_bud["TEND_TEMP"].isel(z_t=0)

Plot

In [None]:
if PLOT_BUDGET:

    amp = 1

    for n in ["ADV_3D_TEMP_ml", "ddt_sst"]:

        print(f"\n{n}")
        fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

        ## plot data
        cp0 = src.utils.make_cycle_hov(axs[0], data=m_early_bud[n], amp=amp)
        cp1 = src.utils.make_cycle_hov(axs[1], data=m_late_bud[n], amp=amp)
        cp2 = src.utils.make_cycle_hov(
            axs[2], data=(m_late_bud - m_early_bud)[n], amp=amp * 0.75
        )

        ## make it look nicer
        cb = fig.colorbar(
            cp0,
            ax=axs[2],
            ticks=[-amp, 0, amp],
            label=r"$K~\left(\text{month}\right)^{-1}$",
        )
        src.utils.format_hov_axs(axs)
        for ax in axs:
            ax.axhline(7, ls="--", c="k", lw=1)

        plt.show()

## Plot BJ couplings

### $R$

##### Hovmoller

In [None]:
## shared args
kwargs = dict(amp=0.5, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early["ddt_sst"], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late["ddt_sst"], **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=(m_late - m_early)["ddt_sst"], **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

##### Spatial plot

In [None]:
## select month
sel = lambda x: x.sel(month=7)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=0.5, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early["ddt_sst"], **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late["ddt_sst"], **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], (m_late - m_early)["ddt_sst"], **contour_kwargs
)

## colorbar
cb_kwargs = dict(
    label=r"$K ~\text{yr}^{-1}~ \left(T_{34}\right)^{-1}$", ticks=[-0.5, 0, 0.5]
)
cb0 = fig.colorbar(cp0, ax=axs[:2], **cb_kwargs)
cb2 = fig.colorbar(cp2, ax=axs[-1], **cb_kwargs)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_nino34_box(ax, **box_kwargs)

plt.show()

##### Scatter plot

In [None]:
## kwargs for plotting
kwargs = dict(months=7, x_var="T_34", y_var="ddt_sst", fn_y=src.utils.get_nino34)

# kwargs = dict(
#     months=7, x_var="sst", y_var="ddt_sst", fn_x=src.utils.get_nino4, fn_y=src.utils.get_nino4
# )

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.2f}" + r" $mo^{-1}$")
axs[1].set_title(f"{m1:.2f}" + r" $mo^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### Surface feedbacks

In [None]:
## specify which feedback to plot
PLOTVAR = "ADV_"

## shared args
kwargs = dict(amp=0.5, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early[PLOTVAR], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late[PLOTVAR], **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=(m_late - m_early)[PLOTVAR], **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

In [None]:
## select month
sel = lambda x: x.sel(month=7)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=0.5, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early[PLOTVAR], **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late[PLOTVAR], **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], (m_late - m_early)[PLOTVAR], **contour_kwargs
)

## colorbar
cb_kwargs = dict(
    label=r"$K ~\text{yr}^{-1}~ \left(T_{34}\right)^{-1}$", ticks=[-0.5, 0, 0.5]
)
cb0 = fig.colorbar(cp0, ax=axs[:2], **cb_kwargs)
cb2 = fig.colorbar(cp2, ax=axs[-1], **cb_kwargs)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_nino34_box(ax, **box_kwargs)

plt.show()

### Niño 3.4 - SST

In [None]:
## shared args
kwargs = dict(amp=1.5, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early["sst"], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late["sst"], **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * (m_late - m_early)["sst"], **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~K^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

### SST - NHF

In [None]:
## shared args
kwargs = dict(amp=40, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early["nhf"], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late["nhf"], **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * (m_late - m_early)["nhf"], **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$W~m^{-2}~K^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

### SST-$\tau_x$

##### Hovmoller

In [None]:
## shared args
kwargs = dict(amp=0.02, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early["taux"], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late["taux"], **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * (m_late - m_early)["taux"], **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

##### Spatial plot

In [None]:
## select month
sel = lambda x: x.sel(month=6)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=3e-2, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early["taux"], **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late["taux"], **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], (m_late - m_early)["taux"], **dict(contour_kwargs, amp=1.5e-2)
)

## colorbar
lab = r"$Pa~K^{-1}$"
cb0 = fig.colorbar(cp0, ax=axs[:2], ticks=[-3e-2, 0, 3e-2], label=lab)
cb2 = fig.colorbar(cp2, ax=axs[-1], ticks=[-1.5e-2, 0, 1.5e-2], label=lab)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_nino34_box(ax, **box_kwargs)

plt.show()

In [None]:
sel = lambda x: x.sel(month=6, latitude=slice(-5, 5)).mean("latitude")

fig, ax = plt.subplots(figsize=(3, 2.5))
ax.plot(m_early["taux"].longitude, sel(m_late["taux"] - m_early["taux"]))
ax.axhline(0, ls="--", c="k", lw=0.8)
ax.set_xlim([140, 280])

In [None]:
## kwargs for plotting
kwargs = dict(
    months=7, scale=1e3, x_var="T_34", y_var="taux", fn_y=src.utils.get_nino34
)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $mPa~K^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $mPa~K^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### $\tau_x$ - SSH

##### Hovmoller

In [None]:
## shared args
kwargs = dict(amp=800, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early_bj["taux_4-ssh"], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late_bj["taux_4-ssh"], **kwargs)
cp2 = src.utils.make_cycle_hov(
    axs[2], data=2 * (m_late_bj - m_early_bj)["taux_4-ssh"], **kwargs
)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$cm~\left(\text{Pa}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

##### Scatter

In [None]:
## kwargs for plotting
kwargs = dict(
    months=7,
    scale=1e-2,
    x_var="taux_4",
    y_var="ssh_3",
)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $m~\text{Pa}^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $m~\text{Pa}^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### $T_{sub}-$Niño3.4

In [None]:
## shared args
kwargs = dict(amp=2, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early["dT_n34"], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late["dT_n34"], **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * (m_late - m_early)["dT_n34"], **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

### $T_{sub}-\tau_x$

In [None]:
## specify which period/month to plot
sel = lambda x: x.sel(month=7)

fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")

for ax, m in zip(axs[:2], [m_early_bj, m_late_bj]):

    ## temperature
    cp = ax.contourf(
        m.longitude,
        m.z_t,
        sel(m["taux_4-T"]),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(200, 20),
        extend="both",
    )

## difference
axs[2].contourf(
    m.longitude,
    m.z_t,
    sel(m_late_bj - m_early_bj)["taux_4-T"],
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(100, 10),
    extend="both",
)
## plot MLD
# plot_mlds(axs=axs, bar_early=bar_early, bar_late=bar_late, sel=sel)

## label
cb = fig.colorbar(cp, ax=axs[2], ticks=[-200, 0, 200], label=r"$K~\text{Pa}^{-1}$")
src.utils.format_subsurf_axs(axs)
for ax in axs:
    ax.set_ylim([100, 5])
    ax.set_xlim([140, 280])
    ax.axvline(190, ls="--", c="w", lw=0.8)
    ax.axvline(240, ls="--", c="w", lw=0.8)

plt.show()

In [None]:
## shared args
kwargs = dict(amp=150, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early_bj["dT_taux"], **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late_bj["dT_taux"], **kwargs)
cp2 = src.utils.make_cycle_hov(
    axs[2], data=(m_late_bj - m_early_bj)["dT_taux"], **kwargs
)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~Pa^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

### Subsurface terms

#### Budget terms

In [None]:
## specify vars to plot
if PLOT_BUDGET:
    plot_vars = ["THF", "EKM", "ZAF", "ADV", "ADV_3D_TEMP"]
else:
    plot_vars = ["THF", "EKM", "ZAF", "ADV"]


for n in plot_vars:

    print(f"\n\n{n}")

    ## specify which period/month to plot
    sel = lambda x: x.sel(month=8)

    fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")

    for ax, m in zip(axs[:2], [m_early, m_late]):

        ## temperature
        cp = ax.contourf(
            m.longitude,
            m.z_t,
            sel(m)[n],
            cmap="cmo.balance",
            levels=src.utils.make_cb_range(2, 0.2),
            extend="both",
        )

    ## difference
    axs[2].contourf(
        m_early.longitude,
        m_early.z_t,
        sel(m_late - m_early)[n],
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(1, 0.1),
        extend="both",
    )

    ## plot MLD
    plot_mlds(axs, bar_early=bar_early, bar_late=bar_late, sel=sel)

    ## set ax limit and plot Niño 3.4 bounds
    cb = fig.colorbar(
        cp, ax=axs[2], ticks=[-2, 0, 2], label=r"$K~\left(\text{month}\right)^{-1}$"
    )
    src.utils.format_subsurf_axs(axs)
    for ax in axs:
        ax.set_xlim([145, 280])
        ax.set_ylim([100, 5])
        ax.axvline(190, ls="--", c="w", lw=0.8)
        ax.axvline(240, ls="--", c="w", lw=0.8)

    plt.show()

#### Ekman / ZAF

In [None]:
## specify which period/month to plot
sel = lambda x: x.sel(month=7)

## get aspect ratio (for scaling arrows)
dz = 300  # units: m
dx = src.utils.get_dx(lat_deg=0, dlon_deg=150)
aspect = dx / dz

fig, axs = plt.subplots(1, 3, figsize=(12, 3.5), layout="constrained")

for ax, bar, prime in zip(axs[:2], [bar_early, bar_late], [m_early, m_late]):

    ## temperature
    ax.contourf(
        bar.longitude,
        bar.z_t,
        sel(bar["T"]),
        cmap="cmo.thermal",
        levels=np.arange(10, 34, 2),
        extend="both",
    )

    ## u and w
    ax.quiver(
        prime.longitude.values[::4],
        prime.z_t.values[::2],
        sel(prime["u"]).values[::2, ::4],
        sel(prime["w"]).values[::2, ::4] * aspect,
        pivot="middle",
        alpha=0.7,
        scale=1.5e7,
    )


## plot difference
delta_bar = bar_late - bar_early

## temperature
axs[2].contourf(
    delta_bar.longitude,
    delta_bar.z_t,
    sel(delta_bar["T"]),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(5, 0.5),
    extend="both",
)

# ## vertical velocity
# ## u and w
axs[2].quiver(
    m_early.longitude.values[::4],
    m_early.z_t.values[::2],
    sel(m_early["u"]).values[::2, ::4],
    sel(m_early["w"]).values[::2, ::4] * aspect,
    pivot="middle",
    alpha=0.7,
    scale=1.5e7,
)

## label
plot_mlds(axs=axs, bar_early=bar_early, bar_late=bar_late, sel=sel)
src.utils.format_subsurf_axs(axs)
for ax in axs:
    ax.set_ylim([305, 10])
    ax.set_xlim([140, 279])


plt.show()