# Bjerknes feedback changes over time

## imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time
import cartopy.crs as ccrs

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_dT_sub(Tsub, Hm, delta=25):
    """
    Get temperature difference b/n entrainment zone and mixed layer.
    (positive if entrainment zone is warmer than ML)
    """

    ## find indices in ML and entrainment zone (ez)
    in_ml = Tsub.z_t <= Hm
    in_ez = (Tsub.z_t > Hm) & (Tsub.z_t < (delta + Hm))

    ## get Tbar and Tplus (following Frankignoul et al paper)
    Tbar = Tsub.where(in_ml).mean("z_t")
    Tplus = Tsub.where(in_ez).mean("z_t")

    ## get gradient
    dT = Tplus - Tbar

    return dT


def prep(data):
    """remove sst dependence and compute tendencies"""

    ## remove from h indices
    for h_idx in ["h_w", "h", "h_w_z20"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var=h_idx, T_var="T_34"
        )

    return data


def load_budget_data(t_early, t_late, target_grid):
    """load ground truth heat budget data"""

    def load_var(varname):
        """load variable from prepped folder"""

        ## open data
        data = xr.open_mfdataset(
            sorted(list(pathlib.Path(DATA_FP, "cesm", f"{varname}_temp").glob("*.nc"))),
            concat_dim="member",
            combine="nested",
            parallel=True,
        )

        return data.assign_coords({"member": np.arange(100)})

    ## load data
    budget_data = xr.merge([load_var(v) for v in ["adv", "ddt_T"]])

    ## get difference
    budget_data["diff"] = budget_data["TEND_TEMP"] - budget_data["ADV_3D_TEMP"]

    ## convert z coord from cm to m, and unit from K/s to K/mo
    M_PER_CM = 1e-2
    SEC_PER_MO = 8.64e4 * 30

    ## convert from (i) cm to m and (ii) K/s to K/mo
    budget_data = budget_data.assign_coords({"z_t": budget_data.z_t * M_PER_CM})
    budget_data = budget_data * SEC_PER_MO

    ## fix longitude coordinate
    budget_data = budget_data.assign_coords({"nlon": budget_data.lon.values})
    budget_data = budget_data.drop_vars("lon").rename({"nlon": "longitude"})

    ## trim in time
    budget_data = xr.concat(
        [budget_data.sel(t_early), budget_data.sel(t_late)], dim="time"
    )

    ## interpolate
    budget_data = budget_data.interp_like(target_grid)

    # ## separate forced/anomalies
    forced_bud, anom_bud = src.utils.separate_forced(budget_data)

    return forced_bud, anom_bud


def plot_mlds(axs, bar_early, bar_late, month=None):
    """plot mixed layer depth on pair of axs objects"""

    ## get longitude
    lon = src.utils.merimean(bar_early).longitude

    ## helper function to get month
    if month is None:
        sel = lambda x: x.mean("month")
    else:
        sel = lambda x: x.sel(month=month)

    ## plot
    axs[0].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[1].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")
    axs[2].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[2].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")

    return


def get_nino3_da(x):
    """get nino3 on dataarray"""

    if "latitude" not in x.dims:
        x = x.expand_dims("latitude")

    return src.utils.get_nino3(x)


def get_nino34_da(x):
    """get nino3.4 on dataarray"""

    if "latitude" not in x.dims:
        x = x.expand_dims("latitude")

    return src.utils.get_nino34(x)


def get_nino4_da(x):
    """get nino3.4 on dataarray"""

    if "latitude" not in x.dims:
        x = x.expand_dims("latitude")

    return src.utils.get_nino4(x)

## Load data

### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices(load_z20=True)

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")

## compute relative sst
for n in ["T_3", "T_34", "T_4"]:
    Th[f"{n}_rel"] = Th_total[n] - trop_sst["trop_sst_10"]

### Spatial data

#### Load

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()

## add normalized Th data
anom = xr.merge([anom, Th / Th.std()])

In [None]:
mld = src.utils.get_windowed(forced[["mld", "mld_comp"]], stride=120)
mld_clim = src.utils.reconstruct_clim(mld)["mld"]

In [None]:
mld_over_time = (
    mld_clim.mean("month")
    .sel(latitude=slice(-5, 5), longitude=slice(190, 240))
    .mean(["latitude", "longitude"])
)

## add buffer zone to capture mixing
mld_base = mld_over_time + 25

#### compute $T$ tendency

In [None]:
## get sst tendency (and convert from 1/yr to 1/mo)
anom["ddt_sst"] = 1 / 12 * src.utils.get_ddt(anom[["sst"]], is_forward=False)["ddt_sst"]
anom["ddt_T"] = 1 / 12 * src.utils.get_ddt(anom[["T"]], is_forward=False)["ddt_T"]

#### Get data

In [None]:
## specify vars to look at
VARNAMES = ["T_3", "T_34", "h_w", "h_w_z20", "ddt_sst", "ddt_T", "nhf", "T", "w", "u"]
VARNAMES += [f"{v}_comp" for v in ["sst", "nhf", "T", "w", "u"]]

## Get windowed data
anom = src.utils.get_windowed(anom[VARNAMES], stride=120).compute()
forced = src.utils.get_windowed(
    forced[["T", "T_comp", "u", "u_comp", "w", "w_comp"]], stride=120
).compute()

Get mixed layer $T_3$

In [None]:
anom["T_3_ml"] = src.utils.reconstruct_wrapper(
    anom[["T", "T_comp"]],
    fn=lambda x: src.utils.get_ml_avg(get_nino3_da(x), H0=70, Hm=None),
)["T"]

### Regress subsurface data

helper function

In [None]:
def fit_wrapper(data, y_vars, x_vars=["T_34", "h_w_hat"]):
    """fit linear regression model to data"""

    ## get coeffs
    kwargs = dict(y_vars=y_vars, x_vars=x_vars)
    coefs = src.utils.regress_xr_bymonth(data, **kwargs)

    return coefs

Compute

In [None]:
## do regression
fit_kwargs = dict(
    y_vars=["ddt_T", "ddt_sst", "nhf", "T", "w", "u"],
    # x_vars=["T_34", "h_w_z20"],
    # x_vars=["T_3_ml", "h_w_z20"],
    x_vars=["T_3", "h_w_z20"],
)

## save filepath
SAVE_FP = pathlib.Path(
    os.environ["SAVE_FP"],
    "bjerknes",
    f"{fit_kwargs["x_vars"][0]}_{fit_kwargs["x_vars"][1]}_coefs_ddt_sst.nc",
)

if SAVE_FP.is_file():
    coefs = xr.open_dataset(SAVE_FP)

else:

    ## empty array to hold coefficients
    coefs = []
    for y in tqdm.tqdm(anom.year):
        coefs.append(fit_wrapper(anom.sel(year=y), **fit_kwargs))

    ## put in xr.Dataarray
    coefs = xr.concat(coefs, dim=anom.year)

    coefs.to_netcdf(SAVE_FP)

### regress budget data

In [None]:
## save filepath
SAVE_FP = pathlib.Path(
    os.environ["SAVE_FP"],
    "bjerknes",
    f"{fit_kwargs["x_vars"][0]}_{fit_kwargs["x_vars"][1]}_budget.nc",
)

if SAVE_FP.is_file():
    coefs_bud = xr.open_dataset(SAVE_FP)

else:

    # load budget data
    budg_anom = src.utils.load_budget_data()[1]

    ## add normalized Th data
    budg_anom = xr.merge([budg_anom, Th / Th.std()])

    ## get windowed data
    budg_anom = src.utils.get_windowed(budg_anom, stride=120)

    ## empty array to hold coefficients
    coefs_bud = []
    for y in tqdm.tqdm(budg_anom.year):

        ## empty dataset to hold result
        coefs_bud_y = xr.Dataset()

        ## compute coefficients for each variable
        for n in ["TEND_TEMP", "ADV_3D_TEMP"]:

            ## regression coefficients
            kwargs = dict(x_vars=fit_kwargs["x_vars"], y_var=n)
            coefs_ = src.utils.regress_bymonth(budg_anom.sel(year=y), **kwargs)
            coefs_bud_y[n] = coefs_.to_dataarray(dim="j")

        ## append result
        coefs_bud.append(coefs_bud_y)

    ## put in xr.Dataarray
    coefs_bud = xr.concat(coefs_bud, dim=budg_anom.year)

    ## save
    coefs_bud.to_netcdf(SAVE_FP)

## get residual and SST tendency
coefs_bud["diff"] = coefs_bud["TEND_TEMP"] - coefs_bud["ADV_3D_TEMP"]
coefs_bud["ddt_sst_v2"] = coefs_bud["TEND_TEMP"].isel(z_t=0)

### merge data

In [None]:
coefs = xr.merge([coefs, coefs_bud])

### Integrate over ML

In [None]:
## get climatology
bar = src.utils.reconstruct_clim(forced)

## get feedbacks
feedbacks = src.utils.get_feedbacks(bar=bar, prime=coefs)

## merge with other results
coefs = xr.merge([coefs, feedbacks])
coefs_v2 = copy.deepcopy(coefs)

## specify mixed layer kwargs
ml_kwargs = dict(H0=70, Hm=None)

## Integrate over mixed layer
coefs = src.utils.get_ml_avg_ds(coefs, **ml_kwargs)
coefs_v2 = src.utils.get_ml_avg_ds(coefs_v2, H0=mld_base, Hm=None)


## get NHF in units of K/mo
sec_per_mo = 8.64e4 * 30
rho = 1.02e3
Cp = 4.2e3
coefs["Q"] = coefs["nhf"] * sec_per_mo / (rho * Cp * ml_kwargs["H0"])
coefs_v2["Q"] = coefs_v2["nhf"] * sec_per_mo / (rho * Cp * mld_base)

## Plot feedback hovmollers 
E.g., thermocline ($\overline{w}~\frac{\partial T'}{\partial z}$) and Ekman feedback ($w'~\frac{\partial \overline{T}}{\partial z}$)

In [None]:
## why is last year 2090?
m_early = coefs.sel(year=1870).isel(j=0)
m_late = coefs.sel(year=2030).isel(j=0)
# m_early = coefs.sel(year=2030).isel(j=0)
# m_late = coefs.sel(year=2090).isel(j=0)

### Plot mixed layer integral

In [None]:
## specify plot amplitude
amp = 0.75
amp_diff = 0.375

## specify latitude bound for spatial plots
lat_bound = 1.5

for n in [
    "THF_ml",
    "EKM_ml",
    "ZAF_ml",
    "DD_ml",
    "Q",
    "diff_ml",
    "ADV_3D_TEMP_ml",
    "ADV_ml",
    "TEND_TEMP_ml",
    "ddt_sst",
]:
    #     "TEND_TEMP_ml",]:
    # for n in [
    #     "ADV_ml",
    #     "ADV_3D_TEMP_ml",
    #     "ddt_T_ml",
    #     "TEND_TEMP_ml",
    #     "ddt_sst",
    #     "ddt_sst_v2",
    #     "diff_ml",
    # ]:

    print(f"\n{n}")
    fig, axs = plt.subplots(1, 3, figsize=(7, 2.5), layout="constrained")

    ## plot data
    kwargs = dict(lat_bound=lat_bound, xticks=[210, 270])
    cp0 = src.utils.make_cycle_hov(axs[0], data=m_early[n], amp=amp, **kwargs)
    cp1 = src.utils.make_cycle_hov(axs[1], data=m_late[n], amp=amp, **kwargs)
    cp2 = src.utils.make_cycle_hov(
        axs[2], data=(m_late - m_early)[n], amp=amp_diff, **kwargs
    )

    ## make it look nicer
    cb01 = fig.colorbar(
        cp0,
        ax=axs[1],
        ticks=[-amp, 0, amp],
    )
    cb02 = fig.colorbar(
        cp2,
        ax=axs[2],
        ticks=[-amp_diff, 0, amp_diff],
        label=r"$K~\left(\text{month}\right)^{-1}$",
    )
    src.utils.format_hov_axs(axs, ticks=None)
    for ax in axs:
        ax.axhline(5, ls="--", c="k", lw=1)
        ax.axhline(3, ls="--", c="k", lw=1)
        ax.set_xlim([None, 280])

    plt.show()

In [None]:
for v in ["ADV_3D_TEMP_ml", "ZAF_ml", "THF_ml"]:
    print(f"\n\n{v}")

    fig, ax = plt.subplots(figsize=(3, 2.5))
    for m in [m_early, m_late]:
        ax.plot(m.longitude, m[v].sel(month=slice(3, 3)).mean("month"))

    for t in [210, 270]:
        ax.axvline(t, ls="--", c="gray", lw=0.8)
    plt.show()

### Decompose changes

In [None]:
## decompose changes
delta = src.utils.decompose_feedback_changes(
    bar_early=bar.sel(year=m_early.year),
    bar_late=bar.sel(year=m_late.year),
    prime_early=m_early,
    prime_late=m_late,
)

## mixed layer avg
delta = src.utils.get_ml_avg_ds(delta, **ml_kwargs)

#### Thermocline

In [None]:
## specify colorbar amp
amp = 0.25

## specify plot data and titles
plot_data = [
    (m_late - m_early)["THF_ml"],
    delta["delta_THF_mean_ml"],
    delta["delta_THF_anom_ml"],
    delta["delta_THF_nl_ml"],
]
titles = [
    r"$\Delta\left(\overline{w}~\frac{\partial T'}{\partial z}\right)$",
    r"$\Delta\left(\overline{w}\right)~\frac{\partial T_0}{\partial z}$",
    r"$\overline{w}_0~\Delta\left(\frac{\partial T'}{\partial z}\right)$",
    r"$\Delta\left(\overline{w}\right)~\Delta\left(\frac{\partial T'}{\partial z}\right)$",
]

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

src.utils.format_hov_axs(axs, ticks=None)

## plot data
for ax, p, t in zip(axs, plot_data, titles):
    cp = src.utils.make_cycle_hov(ax, data=p, amp=amp, xticks=[210, 270])
    ax.set_title(t)
    ax.axhline(2, ls="--", c="k", lw=1)
    ax.axhline(7, ls="--", c="k", lw=1)

for ax in axs[3:]:
    ax.set_yticks([])

## make it look nicer
cb = fig.colorbar(
    cp,
    ax=axs[-1],
    ticks=[-amp, 0, amp],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)

plt.show()

#### ZAF

In [None]:
## specify colorbar amp
amp = 0.25

## specify plot data and titles
plot_data = [
    (m_late - m_early)["ZAF_ml"],
    delta["delta_ZAF_mean_ml"],
    delta["delta_ZAF_anom_ml"],
    delta["delta_ZAF_nl_ml"],
]
titles = [
    r"$\Delta\left(u'~\frac{\partial \overline{T}}{\partial x}\right)$",
    r"$u_0'~\Delta\left(\frac{\partial \overline{T}}{\partial x}\right)$",
    r"$\Delta\left(u'\right)~\frac{\partial \overline{T}_0}{\partial x}$",
    r"$\Delta\left(u'\right)~\Delta\left(\frac{\partial \overline{T}}{\partial x}\right)$",
]

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

src.utils.format_hov_axs(axs)

## plot data
for ax, p, t in zip(axs, plot_data, titles):
    cp = src.utils.make_cycle_hov(ax, data=p, amp=amp)
    ax.set_title(t)
    ax.axhline(7, ls="--", c="k", lw=1)

for ax in axs[3:]:
    ax.set_yticks([])

## make it look nicer
cb = fig.colorbar(
    cp,
    ax=axs[-1],
    ticks=[-amp, 0, amp],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)

plt.show()

### Plot subsurface changes

#### $u$

In [None]:
# variable to plot
# VARNAME = "u"
# AMP = 7e5

VARNAME = "ZAF"
AMP = 1.5

# VARNAME = "T"
# # AMP = 3
# AMP=.05

## func to select data
# sel = lambda x : x[VARNAME].sel(month=slice(3,3)).mean("month").differentiate("z_t")
sel = lambda x: x[VARNAME].sel(month=slice(3, 4)).mean("month")

fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")


for ax, m in zip(axs[:2], [m_early, m_late]):

    ## temperature
    cp = ax.contourf(
        m.longitude,
        m.z_t,
        sel(m),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(AMP, AMP / 10),
        extend="both",
    )

## difference
axs[2].contourf(
    m_early.longitude,
    m_early.z_t,
    sel(m_late - m_early),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(AMP / 2, AMP / 20),
    extend="both",
)

## plot contours
for y, c, l in zip([m_early.year, m_late.year], ["k", "w"], [20, 20]):
    for ax in axs:
        ax.contour(
            bar.longitude,
            bar.z_t,
            bar["T"].sel(year=y).sel(month=slice(3, 4)).mean("month"),
            colors=c,
            levels=[l],
        )

## set ax limit and plot Niño 3.4 bounds
cb = fig.colorbar(cp, ax=axs[2], ticks=[-AMP, 0, AMP])
src.utils.format_subsurf_axs(axs)

for ax in axs:
    ax.set_xlim([145, 280])
    ax.set_ylim([200, 5])
    ax.set_xticks([210, 270])
    for t in ax.get_xticks():
        ax.axvline(t, ls="--", c="w", lw=0.8)

plt.show()

#### Mean state changes

Zonal velocity

In [None]:
VARNAME = "u"
AMP = 2.5e6

## func to select data
sel = lambda x: x[VARNAME].sel(month=slice(3, 4)).mean("month")

fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")


for ax, y in zip(axs[:2], [m_early.year, m_late.year]):

    ## temperature
    cp = ax.contourf(
        bar.longitude,
        bar.z_t,
        sel(bar.sel(year=y)),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(AMP, AMP / 10),
        extend="both",
    )

## difference
axs[2].contourf(
    m_early.longitude,
    m_early.z_t,
    sel(bar.sel(year=m_late.year) - bar.sel(year=m_early.year)),
    cmap="cmo.balance_r",
    levels=src.utils.make_cb_range(AMP / 2, AMP / 20),
    extend="both",
)

## plot contours
for y, c, l in zip([m_early.year, m_late.year], ["k", "w"], [20, 25]):
    for ax in axs:
        ax.contour(
            bar.longitude,
            bar.z_t,
            bar["T"].sel(year=y).sel(month=slice(3, 4)).mean("month"),
            colors=c,
            levels=[l],
        )

## set ax limit and plot Niño 3.4 bounds
cb = fig.colorbar(cp, ax=axs[2], ticks=[-AMP, 0, AMP])
src.utils.format_subsurf_axs(axs)

for ax in axs:
    ax.set_xlim([145, 280])
    ax.set_ylim([200, 5])
    ax.set_xticks([210, 270])
    for t in ax.get_xticks():
        ax.axvline(t, ls="--", c="w", lw=0.8)

plt.show()

Stratification

In [None]:
VARNAME = "T"

## func to select data
sel = lambda x: x[VARNAME].sel(month=slice(3, 4)).mean("month").differentiate("z_t")

fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")


for ax, y in zip(axs[:2], [m_early.year, m_late.year]):

    ## temperature
    cp = ax.contourf(
        bar.longitude,
        bar.z_t,
        sel(bar.sel(year=y)),
        cmap="cmo.amp_r",
        levels=np.arange(-0.2, 0, 0.02),
        extend="both",
    )

## difference
axs[2].contourf(
    m_early.longitude,
    m_early.z_t,
    sel(bar.sel(year=m_late.year) - bar.sel(year=m_early.year)),
    cmap="cmo.balance_r",
    levels=src.utils.make_cb_range(0.1, 0.01),
    extend="both",
)

## plot contours
for y, c, l in zip([m_early.year, m_late.year], ["k", "w"], [20, 25]):
    for ax in axs:
        ax.contour(
            bar.longitude,
            bar.z_t,
            bar["T"].sel(year=y).sel(month=slice(3, 4)).mean("month"),
            colors=c,
            levels=[l],
        )

## set ax limit and plot Niño 3.4 bounds
cb = fig.colorbar(cp, ax=axs[2], ticks=[-AMP, 0, AMP])
src.utils.format_subsurf_axs(axs)

for ax in axs:
    ax.set_xlim([145, 280])
    ax.set_ylim([200, 5])
    ax.set_xticks([210, 270])
    for t in ax.get_xticks():
        ax.axvline(t, ls="--", c="w", lw=0.8)

plt.show()

## Change over time

In [None]:
def plot_hovs(varname, amp0, amp1, month=1, annual_mean=False, delta=False):
    """plot data on hovmoller"""

    ## get data for plot
    plot_data = coefs[varname].isel(j=0)

    ## make sure latitude is a dimension
    if "latitude" not in plot_data.dims:
        plot_data = plot_data.expand_dims("latitude")

    ## get data for plots
    if annual_mean:
        x0 = plot_data.mean("month")
    else:
        x0 = plot_data.sel(month=month)
    x0 = x0.sel(latitude=slice(-5, 5)).mean("latitude").transpose("year", ...)
    x1 = src.utils.get_nino3(plot_data).transpose("year", ...)

    ## get delta if necessary
    if delta:
        x0 = x0 - x0.isel(year=0)
        x1 = x1 - x1.isel(year=0)

    fig, axs = plt.subplots(1, 2, figsize=(6, 3.5), layout="constrained")

    ## longitude on x-axis
    axs[0].contourf(
        x0.longitude,
        x0.year,
        x0,
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(amp0, amp0 / 10),
        extend="both",
    )

    ## month on x-axis
    axs[1].contourf(
        plot_data.month,
        plot_data.year,
        x1,
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(amp1, amp1 / 10),
        extend="both",
    )

    axs[0].set_xlim([140, 280])
    axs[0].axvline(210, ls="--", c="k", lw=0.8)
    axs[0].axvline(270, ls="--", c="k", lw=0.8)
    axs[0].set_xticks([160, 210, 270])
    axs[0].set_yticks([1870, 1980, 2090])
    axs[1].set_xticks([2, 7, 12])
    axs[1].set_yticks([])

    return fig, axs

In [None]:
# for VARNAME in ["ADV_3D_TEMP_ml", "ddt_sst", "ZAF_ml", "DD_ml", "Q"]:
# for VARNAME in ["ADV_3D_TEMP_ml", "TEND_TEMP_ml", "ddt_sst", "ddt_sst_v2"]:
for VARNAME in ["THF_ml", "EKM_ml", "ZAF_ml", "DD_ml", "diff_ml"]:

    print(f"\n\n{VARNAME}")
    fig, axs = plot_hovs(
        varname=VARNAME,
        amp0=1,
        amp1=1,
        annual_mean=False,
        delta=False,
        month=1,
    )

    for ax in axs:
        ax.axhline(2030, ls="--", c="k", lw=0.8)

    plt.show()

In [None]:
def get_nino3_da_eq(x):
    """get nino3 on dataarray"""

    if "latitude" not in x.dims:
        x = x.expand_dims("latitude")

    return x.sel(longitude=slice(210, 270), latitude=slice(-1.5, 1.5)).mean(
        ["latitude", "longitude"]
    )

In [None]:
## func to select data
sel = lambda x: (x - x.isel(year=0)).isel(j=0).mean("month")
# sel = lambda x: (x - x.isel(year=0)).isel(j=0).sel(month=slice(7,None)).mean("month")

fig, axs = plt.subplots(1, 2, figsize=(7, 3))


## positive feedbacks
# for f in ["EKM", "THF", "ZAF", "ADV", "ddt_T"]:
# for f in ["ADV", "ddt_T"]:
# for f in ["ADV_3D_TEMP_ml", "TEND_TEMP_ml", "diff_ml", "ddt_sst"]:
# for f in ["ddt_sst", "ddt_sst_v2"]:
# axs[0].plot(coefs.year, sel(get_nino3_da_eq(coefs[f"{f}"])), label=f)

axs[0].plot(coefs.year, sel(get_nino3_da_eq(coefs["ddt_sst"])), label="1.5")
axs[0].plot(coefs.year, sel(get_nino3_da(coefs["ddt_sst"])), label="5")
axs[0].plot(coefs.year, sel(get_nino3_da(coefs["ddt_sst_v2"])), label="ddt_sst_v2")

## negative feedbacks
# for f in ["DD_ml", "Q"]:
#     axs[1].plot(coefs.year, sel(get_nino3_da(coefs[f"{f}"])), label=f)

axs[1].plot(coefs.year, sel(get_nino3_da_eq(coefs["Q"])), label="Q_1.5")
axs[1].plot(coefs.year, sel(get_nino3_da(coefs["Q"])), label="Q_5")

## legends
axs[0].legend(prop=dict(size=8))
axs[1].legend(prop=dict(size=8))

for ax in axs:
    ax.axhline(0, ls="--", c="k", lw=0.8)
    ax.axvline(2030, ls="--", c="k", lw=0.8)

src.utils.set_lims(axs)


plt.show()

In [None]:
## get baseline
get_bl = lambda x: x.isel(year=slice(None, 1)).mean("year")

## func to select data
sel_helper = lambda x: x.isel(j=0).mean("month")
sel = lambda x: sel_helper(x - get_bl(x))
sel_ = lambda x: sel(x) / sel_helper(get_bl(x))

fig, axs = plt.subplots(1, 2, figsize=(7, 3))

## plot temp tendency
sum_ = sel(get_nino3_da(coefs["ADV_3D_TEMP_ml"])) + 1.5 * sel(
    get_nino3_da_eq(coefs["Q"])
)
axs[0].plot(coefs.year, sel(get_nino3_da(coefs["TEND_TEMP_ml"])), label="ddt_T")
axs[0].plot(coefs.year, sum_, label="sum")
axs[0].plot(coefs.year, sel(get_nino3_da(coefs["ddt_sst"])), label="ddt_sst")

## plot residual term
axs[1].plot(coefs.year, -sel(get_nino3_da(coefs["diff_ml"])), label="diff")
axs[1].plot(coefs.year, -1.5 * sel(get_nino3_da_eq(coefs["Q"])), label="Q_1.5 (scaled)")
axs[1].plot(coefs.year, sel(get_nino3_da(coefs["ADV_3D_TEMP_ml"])), label="ADV")

## legends
for ax in axs:
    ax.legend(prop=dict(size=8))

for ax in axs:
    ax.axhline(0, ls="--", c="k", lw=0.8)
    ax.axvline(2030, ls="--", c="k", lw=0.8)

# src.utils.set_lims(axs)


plt.show()

In [None]:
## get baseline
get_bl = lambda x: x.isel(year=slice(None, 1)).mean("year")

## func to select data
# sel_helper = lambda x : x.isel(j=0).mean("month")
sel_helper = lambda x: x.isel(j=0).mean("month")
sel = lambda x: sel_helper(x - get_bl(x))
sel_ = lambda x: sel(x) / sel_helper(get_bl(x))

fig, axs = plt.subplots(1, 2, figsize=(7, 3))

## plot temp tendency
axs[0].plot(coefs.year, sel(get_nino3_da(coefs["ADV_3D_TEMP_ml"])), label="adv (true)")
axs[0].plot(coefs.year, sel(get_nino3_da(coefs["ADV_ml"])), label="adv (est)")

sum_ = sel(get_nino3_da(coefs["THF_ml"])) + sel(get_nino3_da(coefs["DD_ml"]))
# sum_ = sel(get_nino3_da(coefs["EKM_ml"])) + sel(get_nino3_da(coefs["ZAF_ml"]))
axs[0].plot(coefs.year, sum_, label="sum")

# axs[0].plot(coefs.year, sel(get_nino3_da(coefs["THF_ml"])), label="THF")
# axs[0].plot(coefs.year, sel(get_nino3_da(coefs["EKM_ml"])), label="EKM")
# axs[0].plot(coefs.year, sel(get_nino3_da(coefs["ZAF_ml"])), label="ZAF")
# axs[0].plot(coefs.year, sel(get_nino3_da(coefs["DD_ml"])), label="DD")
# axs[0].plot(coefs.year, -2*sel(get_nino3_da(coefs["Q"])), label="Q")
# axs[0].plot(coefs.year, sel(get_nino3_da(coefs["ddt_sst"])), label="ddt_sst")

# ## plot residual term
# axs[1].plot(coefs.year, -sel(get_nino3_da(coefs["diff_ml"])), label="diff")
# axs[1].plot(coefs.year, -1.5*sel(get_nino3_da_eq(coefs["Q"])), label="Q_1.5 (scaled)")
# axs[1].plot(coefs.year, sel(get_nino3_da(coefs["ADV_3D_TEMP_ml"])), label="ADV")

## legends
for ax in axs:
    ax.legend(prop=dict(size=8))

for ax in axs:
    ax.axhline(0, ls="--", c="k", lw=0.8)
    ax.axvline(2030, ls="--", c="k", lw=0.8)

# src.utils.set_lims(axs)


plt.show()

### spatial plots

#### $\frac{d}{dt}$SST

In [None]:
## specify plot years
Y0 = 1870
Y1 = 1980
# Y0 = 2030
# Y1 = 2090


## specify which month to look at
# plot_data = coefs["ddt_sst"].isel(j=0).mean("month")
plot_data = coefs["ddt_sst"].isel(j=0).sel(month=2)

fig = plt.figure(figsize=(7.5, 3), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=15)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

kwargs = dict(
    cmap="cmo.balance",
    # levels=src.utils.make_cb_range(3e-1, 3e-2),
    levels=src.utils.make_cb_range(6e-1, 6e-2),
    extend="both",
    transform=ccrs.PlateCarree(),
)

## get lon/lat data
LON = plot_data.longitude
LAT = plot_data.latitude

## plot early/late
axs[0, 0].contourf(LON, LAT, plot_data.sel(year=Y0), **kwargs)
axs[1, 0].contourf(LON, LAT, plot_data.sel(year=Y1), **kwargs)

## plot diff
diff = plot_data.sel(year=Y1) - plot_data.sel(year=Y0)
axs[2, 0].contourf(LON, LAT, 4 * diff, **kwargs)


## plot guidelines
box_kwargs = dict(ls="--", c="k", lw=0.8)
for ax in axs.flatten():
    src.utils.plot_nino3_box(ax, **box_kwargs)

plt.show()

In [None]:
sel = lambda x: x.isel(j=0).mean("month")

fig, ax = plt.subplots(figsize=(3, 2.5))

ax.plot(plot_data.year, src.utils.get_nino3(plot_data))
ax.plot(coefs.year, get_nino3_da(sel(coefs["ddt_sst"])), ls="--")

ax.plot(coefs.year, get_nino3_da_eq(sel(coefs["ddt_sst"])), ls="-")
# ax.plot(coefs.year, get_nino3_da(sel(coefs["ddt_sst_v2"])), ls="--")

ax.axvline(2035)

plt.show()

#### Q

In [None]:
## specify plot years
Y0 = 1870
Y1 = 2030
# Y0 = 2030
# Y1 = 2090


## specify which month to look at
# plot_data = coefs["Q"].isel(j=0).mean("month")
plot_data = coefs["Q"].isel(j=0).sel(month=3)

fig = plt.figure(figsize=(7.5, 3), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=15)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

kwargs = dict(
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(3e-1, 3e-2),
    extend="both",
    transform=ccrs.PlateCarree(),
)

## get lon/lat data
LON = plot_data.longitude
LAT = plot_data.latitude

## plot early/late
axs[0, 0].contourf(LON, LAT, plot_data.sel(year=Y0), **kwargs)
axs[1, 0].contourf(LON, LAT, plot_data.sel(year=Y1), **kwargs)

## plot diff
diff = plot_data.sel(year=Y1) - plot_data.sel(year=Y0)
axs[2, 0].contourf(LON, LAT, 2 * diff, **kwargs)


## plot guidelines
box_kwargs = dict(ls="--", c="k", lw=0.8)
for ax in axs.flatten():
    src.utils.plot_nino3_box(ax, **box_kwargs)

plt.show()

In [None]:
## specify plot years
Y0 = 1870
Y1 = 2030
# Y0 = 2030
# Y1 = 2090

## specify which month to look at
sel_lat = lambda x: x.sel(latitude=slice(-5, 5)).mean("latitude")
sel_lon = lambda x: x.sel(longitude=slice(210, 270)).mean("longitude")
sel = lambda x: x.isel(j=0).mean("month")

sel1 = lambda x: sel(sel_lat(x))
sel2 = lambda x: sel(sel_lon(x))
diff_lat = sel1(coefs[["ddt_sst", "Q"]].sel(year=Y1)) - sel1(
    coefs[["ddt_sst", "Q"]].sel(year=Y0)
)
diff_lon = sel2(coefs[["ddt_sst", "Q"]].sel(year=Y1)) - sel2(
    coefs[["ddt_sst", "Q"]].sel(year=Y0)
)


## get diff ml expanded
Z = coefs["diff_ml"].expand_dims("latitude")
diff_lat_diff = sel1(Z).sel(year=Y1) - sel1(Z).sel(year=Y0)

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))

ax.plot(diff_lat.longitude, diff_lat["ddt_sst"])
ax.plot(diff_lat.longitude, diff_lat["Q"])

ax.axhline(**box_kwargs)
for t in [210, 270]:
    ax.axvline(t, **box_kwargs)

plt.show()

fig, ax = plt.subplots(figsize=(4, 3))

ax.plot(diff_lon["ddt_sst"], diff_lon.latitude)
ax.plot(diff_lon["Q"], diff_lon.latitude)
ax.set_ylim([-15, 15])

ax.axvline(0, **box_kwargs)
for t in [-5, 5]:
    ax.axhline(t, **box_kwargs)

plt.show()

#### How correlated is $Q$ with residual?

Load data

In [None]:
## budget data
budg_anom = src.utils.load_budget_data()[1]
diff = budg_anom["TEND_TEMP"] - budg_anom["ADV_3D_TEMP"]

## NHF
nhf = src.utils.load_consolidated()[1][["nhf", "nhf_comp"]]

Aggregate

In [None]:
## integrate difference over ml
diff_ml = src.utils.get_ml_avg(diff, H0=70, Hm=None)

## funcs to integrate in space
get_lon = lambda x: x.sel(longitude=slice(210, 270)).mean("longitude")
get_lat = lambda x: x.sel(latitude=slice(-2, 2)).mean("latitude")

## do integration
nhf_ = src.utils.reconstruct_wrapper(nhf, fn=lambda x: get_lat(get_lon(x)))["nhf"]
diff_ = get_lon(diff_ml)

Plot

In [None]:
# sel = lambda x : src.utils.sel_month(x,5)
# sel = lambda x : x.sel(time=slice(None,"1890"))
sel = lambda x: x

print(f"Corr: {xr.corr(sel(nhf_), sel(diff_)).values.item():.2f}")

fig, ax = plt.subplots(figsize=(4, 4))
ax.scatter(sel(diff_), sel(nhf_), s=3, alpha=0.5)
ax.set_xlabel("Residual")
ax.set_ylabel("Surface energy flux")

plt.show()