# Bjerknes feedback changes over time

## imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_dT_sub(Tsub, Hm, delta=25):
    """
    Get temperature difference b/n entrainment zone and mixed layer.
    (positive if entrainment zone is warmer than ML)
    """

    ## find indices in ML and entrainment zone (ez)
    in_ml = Tsub.z_t <= Hm
    in_ez = (Tsub.z_t > Hm) & (Tsub.z_t < (delta + Hm))

    ## get Tbar and Tplus (following Frankignoul et al paper)
    Tbar = Tsub.where(in_ml).mean("z_t")
    Tplus = Tsub.where(in_ez).mean("z_t")

    ## get gradient
    dT = Tplus - Tbar

    return dT


def prep(data):
    """remove sst dependence and compute tendencies"""

    ## remove from h indices
    for h_idx in ["h_w", "h", "h_w_z20"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var=h_idx, T_var="T_34"
        )

    return data


def load_budget_data(t_early, t_late, target_grid):
    """load ground truth heat budget data"""

    def load_var(varname):
        """load variable from prepped folder"""

        ## open data
        data = xr.open_mfdataset(
            sorted(list(pathlib.Path(DATA_FP, "cesm", f"{varname}_temp").glob("*.nc"))),
            concat_dim="member",
            combine="nested",
            parallel=True,
        )

        return data.assign_coords({"member": np.arange(100)})

    ## load data
    budget_data = xr.merge([load_var(v) for v in ["adv", "ddt_T"]])

    ## get difference
    budget_data["diff"] = budget_data["TEND_TEMP"] - budget_data["ADV_3D_TEMP"]

    ## convert z coord from cm to m, and unit from K/s to K/mo
    M_PER_CM = 1e-2
    SEC_PER_MO = 8.64e4 * 30

    ## convert from (i) cm to m and (ii) K/s to K/mo
    budget_data = budget_data.assign_coords({"z_t": budget_data.z_t * M_PER_CM})
    budget_data = budget_data * SEC_PER_MO

    ## fix longitude coordinate
    budget_data = budget_data.assign_coords({"nlon": budget_data.lon.values})
    budget_data = budget_data.drop_vars("lon").rename({"nlon": "longitude"})

    ## trim in time
    budget_data = xr.concat(
        [budget_data.sel(t_early), budget_data.sel(t_late)], dim="time"
    )

    ## interpolate
    budget_data = budget_data.interp_like(target_grid)

    # ## separate forced/anomalies
    forced_bud, anom_bud = src.utils.separate_forced(budget_data)

    return forced_bud, anom_bud


def plot_mlds(axs, bar_early, bar_late, month=None):
    """plot mixed layer depth on pair of axs objects"""

    ## get longitude
    lon = src.utils.merimean(bar_early).longitude

    ## helper function to get month
    if month is None:
        sel = lambda x: x.mean("month")
    else:
        sel = lambda x: x.sel(month=month)

    ## plot
    axs[0].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[1].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")
    axs[2].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[2].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")

    return

## Load data

#### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices(load_z20=True)

### Spatial data

#### Load

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()

## add T,h information
for n in ["T_3", "T_34", "T_4", "h", "h_w", "h_w_z20", "h_z20"]:
    anom[n] = Th[n] / Th[n].std()

#### compute some indices

In [None]:
## get sst tendency (and convert from 1/yr to 1/mo)
anom["ddt_sst"] = 1 / 12 * src.utils.get_ddt(anom[["sst"]], is_forward=True)["ddt_sst"]

#### early/late split

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
# t_late = dict(time=slice("1981", "2010"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early)
anom_late = anom.sel(t_late)

## preprocess
anom_early = prep(anom_early).compute()
anom_late = prep(anom_late).compute()

## Compute/plot

In [None]:
def get_slope_bymonth(x, **kwargs):
    """get slope for each month separately"""
    return x.groupby("time.month").map(src.utils.regress_proj, **kwargs)

### SST - SST

In [None]:
## shared kwargs
kwargs = dict(x_var="sst", y_var="sst", fn_x=src.utils.get_nino3)

## function to get slope
# get_slope = lambda x: x.groupby("time.month").map(src.utils.regress_proj, **kwargs)

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, **kwargs)
m_late = get_slope_bymonth(anom_late, **kwargs)

Hovmoller

In [None]:
## shared args
kwargs = dict(amp=1.5, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early, **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late, **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * (m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$W~m^{-2}~K^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

Spatial

In [None]:
## select month
sel = lambda x: x.sel(month=6)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=2, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early, **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late, **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], (m_late - m_early), **dict(contour_kwargs, amp=1.5)
)

## colorbar
lab = r"$Pa~K^{-1}$"
cb0 = fig.colorbar(cp0, ax=axs[:2], ticks=[-3e-2, 0, 3e-2], label=lab)
cb2 = fig.colorbar(cp2, ax=axs[-1], ticks=[-1.5e-2, 0, 1.5e-2], label=lab)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_nino34_box(ax, **box_kwargs)

plt.show()

Scatter

In [None]:
## specify func
grad = lambda x: src.utils.get_nino3(x) - src.utils.get_nino4(x)

## kwargs for plotting
kwargs = dict(months=5, scale=1, x_var="T_3", y_var="sst", fn_y=src.utils.get_nino34)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $W~m^{-2}~K^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $W~m^{-2}~K^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### SST - NHF

In [None]:
## shared kwargs
kwargs = dict(x_var="sst", y_var="nhf", fn_x=src.utils.get_nino34)

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, **kwargs)
m_late = get_slope_bymonth(anom_late, **kwargs)

In [None]:
## shared args
kwargs = dict(amp=40, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early, **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late, **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * (m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$W~m^{-2}~K^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

In [None]:
## kwargs for plotting
kwargs = dict(months=4, scale=1, x_var="T_34", y_var="nhf", fn_y=src.utils.get_nino34)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $W~m^{-2}~K^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $W~m^{-2}~K^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### SST-$\tau_x$

In [None]:
## shared kwargs
kwargs = dict(x_var="sst", y_var="taux", fn_x=src.utils.get_nino34)

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, **kwargs)
m_late = get_slope_bymonth(anom_late, **kwargs)

In [None]:
## shared args
kwargs = dict(amp=0.02, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early, **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late, **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * (m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    for t in [5, 7]:
        ax.axhline(t, ls="--", c="k", lw=1)

plt.show()

##### Spatial plot

In [None]:
## select month
sel = lambda x: x.sel(month=6)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=3e-2, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early, **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late, **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], (m_late - m_early), **dict(contour_kwargs, amp=1.5e-2)
)

## colorbar
lab = r"$Pa~K^{-1}$"
cb0 = fig.colorbar(cp0, ax=axs[:2], ticks=[-3e-2, 0, 3e-2], label=lab)
cb2 = fig.colorbar(cp2, ax=axs[-1], ticks=[-1.5e-2, 0, 1.5e-2], label=lab)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_nino34_box(ax, **box_kwargs)

plt.show()

In [None]:
sel = lambda x: x.sel(month=6, latitude=slice(-5, 5)).mean("latitude")

fig, ax = plt.subplots(figsize=(3, 2.5))
ax.plot(m_early.longitude, sel(m_late - m_early))
ax.axhline(0, ls="--", c="k", lw=0.8)
ax.set_xlim([140, 280])
plt.show()

In [None]:
## kwargs for plotting
kwargs = dict(
    months=6, scale=1e3, x_var="T_34", y_var="taux", fn_y=src.utils.get_nino34
)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $mPa~K^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $mPa~K^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### $\tau_x-T_{sub}$

In [None]:
## shared kwargs
kwargs = dict(x_var="taux", y_var="T")

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, fn_x=src.utils.get_nino4, **kwargs)
m_late = get_slope_bymonth(anom_late, fn_x=src.utils.get_nino34, **kwargs)

In [None]:
## specify Hm
Hm = 70
sel = lambda x: x.sel(z_t=Hm, method="nearest")

## shared args
kwargs = dict(amp=3e2, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=sel(m_early), **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=sel(m_late), **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * sel(m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    for t in [5, 7]:
        ax.axhline(t, ls="--", c="k", lw=1)

plt.show()

In [None]:
sel_T = (
    lambda x: x.sel(longitude=slice(210, 280))
    .mean("longitude")
    .sel(z_t=Hm, method="nearest")
)

## kwargs for plotting
kwargs = dict(months=6, scale=1e-2, x_var="taux", y_var="T", fn_y=sel_T)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, fn_x=src.utils.get_nino4, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, fn_x=src.utils.get_nino34, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $K~hPa^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $K~hPa^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### $\tau_x$ - $Z_{20}$

In [None]:
## shared kwargs
kwargs = dict(x_var="taux", y_var="z20")

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, fn_x=src.utils.get_nino4, **kwargs)
m_late = get_slope_bymonth(anom_late, fn_x=src.utils.get_nino34, **kwargs)

##### Hovmoller

In [None]:
## shared args
kwargs = dict(amp=2e3, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=m_early, **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=m_late, **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * (m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    for t in [5, 7]:
        ax.axhline(t, ls="--", c="k", lw=1)

plt.show()

Scatter

In [None]:
sel_e = lambda x: x.sel(
    latitude=slice(-5, 5),
    longitude=slice(210, 280),
).mean(["latitude", "longitude"])
sel_w = lambda x: x.sel(
    latitude=slice(-5, 5),
    longitude=slice(140, 210),
).mean(["latitude", "longitude"])

grad = lambda x: sel_e(x) - sel_w(x)
# grad = lambda x : src.utils.get_nino3(x) - src.utils.get_nino4(x)

## kwargs for plotting
kwargs = dict(months=6, scale=1e-3, x_var="taux", y_var="z20", fn_y=grad)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, fn_x=src.utils.get_nino4, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, fn_x=src.utils.get_nino34, **kwargs)

## label
axs[0].set_title(f"{m0:.2f}" + r" $cm~mPa^{-1}$")
axs[1].set_title(f"{m1:.2f}" + r" $cm~mPa^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)

### $Z_{20} - T_{sub}$

In [None]:
## shared kwargs
kwargs = dict(x_var="z20", y_var="T")

## then, reconstruct regression coefficient
m_early = get_slope_bymonth(anom_early, fn_x=src.utils.get_nino34, **kwargs)
m_late = get_slope_bymonth(anom_late, fn_x=src.utils.get_nino34, **kwargs)

In [None]:
## specify Hm
Hm = 50
sel = lambda x: x.sel(z_t=Hm, method="nearest")

## shared args
kwargs = dict(amp=0.3, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=sel(m_early), **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=sel(m_late), **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * sel(m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)
src.utils.format_hov_axs(axs)
for ax in axs:
    for t in [5, 7]:
        ax.axhline(t, ls="--", c="k", lw=1)

plt.show()

In [None]:
sel_lon = lambda x: x.sel(longitude=slice(210, 280)).mean("longitude")
sel_z20 = lambda x: sel_lon(x).sel(latitude=slice(-2, 2)).mean("latitude")
sel_T = lambda x: sel_lon(x).sel(z_t=Hm, method="nearest")

# ## kwargs for plotting
# kwargs = dict(months=6, scale=1e-3, x_var="z20", y_var="z20", fn_y=src.utils.get_nino34)

## kwargs for plotting
kwargs = dict(months=6, scale=1e1, x_var="z20", y_var="T", fn_x=sel_z20, fn_y=sel_T)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.1f}" + r" $K~cm^{-1}$")
axs[1].set_title(f"{m1:.1f}" + r" $K~cm^{-1}$")
axs[1].set_yticks([])

src.utils.set_lims(axs)