# Bjerknes feedback changes over time

## imports

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import cartopy.util
import copy

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def plot_hov(ax, data, amp, label=None):
    """Plot hovmoller of longitude vs. year"""

    # kwargs = dict(levels=src.utils.make_cb_range(3, 0.3), cmap="cmo.balance", extend="both")
    plot_data = ax.contourf(
        data.longitude,
        data.year,
        data.T,
        cmap="cmo.balance",
        extend="both",
        levels=src.utils.make_cb_range(amp, amp / 10),
    )
    cb = fig.colorbar(
        plot_data, orientation="horizontal", ticks=[-amp, 0, amp], label=label
    )

    ## label
    kwargs = dict(ls="--", c="w", lw=0.8)
    for ax in axs:
        ax.set_xlabel("Longitude")
        ax.set_xticks([190, 240])
        ax.set_yticks([])
        ax.axvline(190, **kwargs)
        ax.axvline(240, **kwargs)
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position("top")

    return


def plot_hov2(ax, data, amp, label=None):
    """Plot hovmoller of longitude vs. year"""

    # kwargs = dict(levels=src.utils.make_cb_range(3, 0.3), cmap="cmo.balance", extend="both")
    plot_data = ax.contourf(
        data.month,
        data.year,
        data.T,
        cmap="cmo.balance",
        extend="max",
        levels=src.utils.make_cb_range(amp, amp / 10),
    )
    cb = fig.colorbar(
        plot_data,
        orientation="horizontal",
        ticks=[-amp, 0, amp],
        label=label,
        # plot_data, orientation="horizontal", ticks=[], label=None
    )

    ## label
    kwargs = dict(ls="--", c="w", lw=0.8)
    for ax in axs:
        # ax.set_xlabel("Month")
        # ax.set_xticks([1, 12])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position("top")

    return


def get_rolling_var(data, n=10):
    """
    Get variance, computing over time and ensemble member. To increase
    sample size for variance estimate, compute over time window of 2n+1
    years, centered at given year.
    """

    return src.utils.get_rolling_fn_bymonth(data, fn=np.var, n=n)


def get_ml_avg(data, Hm, delta=5, H0=None):
    """func to average data from surface to Hm + delta"""

    ## interpolate MLD onto data grid
    Hm_ = Hm.rename({"longitude": "lon"}).interp({"lon": data.lon})

    ## tweak integration bounds
    if H0 is None:
        Hm_ = Hm_ + delta

    else:
        Hm_ = H0 * xr.ones_like(Hm_)

    ## average over everything above the mixed layer
    return data.where(data.z_t <= Hm_).mean("z_t")


def get_ml_avg_wrapper(data, Hm, delta=5, H0=None):
    """wrapper function to format data for plotting"""

    ## first, compute mixed layer average
    ml_avg = get_ml_avg(data=data, Hm=Hm, delta=delta, H0=H0)

    ## rename coord and tranpose
    return ml_avg.rename({"lon": "longitude"}).transpose("month", ...)


def plot_mld_bounds(ax, clim, m):
    """Plot MLD climatology and ± bounds"""

    ## clim
    ax.plot(clim.longitude, clim, c="k")

    ## El Niño
    ax.plot(clim.longitude, clim + m, c="r")

    ## La Niña
    ax.plot(clim.longitude, clim - m, c="b")

    return


def get_wT(w, T):
    """function to get vertical flux (handles diff. w/T grids)"""

    ## rename w grid
    w_ = copy.deepcopy(w).rename({"z_w_top": "z_t"})
    w_ = w_.assign_coords({"z_t": T.z_t})

    return w_ * T


def get_wdTdz(w, T):
    """function to get vertical flux (handles diff. w/T grids)"""

    ## rename w grid
    w_ = copy.deepcopy(w).rename({"z_w_top": "z_t"})
    w_ = w_.assign_coords({"z_t": T.z_t})

    ## get dTdz (convert from 1/cm to 1/m)
    dTdz = T.differentiate("z_t")

    return w_ * dTdz


def get_udTdx(u, T):
    """zonal advection"""

    ## get grid spacing
    dlon_deg = T.lon.values[1] - T.lon.values[0]
    lat_deg = 0.0

    ## get grid spacing
    dx_m = get_dx(lat_deg=lat_deg, dlon_deg=dlon_deg)

    ## differentiate
    u_dfdx_ = u * T.differentiate("lon") * 1 / dx_m

    return u_dfdx_


def get_u_adv(u, T):
    """zonal advection"""

    ## get grid spacing
    dlon_deg = T.lon.values[1] - T.lon.values[0]
    lat_deg = 0.0

    ## get grid spacing
    dx_m = get_dx(lat_deg=lat_deg, dlon_deg=dlon_deg)

    ## differentiate and convert units to K/yr
    u_dfdx_ = u * T.differentiate("lon") * 1 / dx_m

    return -u_dfdx_


def recon_clim(data, components, varname="sst"):
    """reconstruct climatology for data"""

    ## get climatolgoy in PC space
    monthly_clim = data.groupby("time.month").mean()

    ## function to compute equatorial mean
    equatorial_mean = lambda x: x.sel(latitude=slice(-2, 2)).mean("latitude")

    ## reconstruct
    recon = src.utils.reconstruct_fn(
        components[varname], monthly_clim[varname], fn=equatorial_mean
    )

    ## fill zero values with NaN
    recon.values[recon.values == 0] = np.nan

    return recon


def get_monthly_eli(t_bnds):

    ## get eli for period
    eli_ = eli_forced.isel(time=slice(*t_bnds)).groupby("time.month").mean()

    return eli_


def get_monthly_eli_std(t_bnds):

    ## get eli for period
    eli_ = (
        eli_anom.isel(time=slice(*t_bnds)).groupby("time.month").std(["time", "member"])
    )

    return eli_


def plot_cyclic(ax, data, sigma=None, **kwargs):
    """plot data on hovmoller with cyclic dependence on month"""

    ## add cyclic point
    data_cyclic, dim_cyclic = cartopy.util.add_cyclic_point(data, data.month, axis=0)

    ## plot data
    ax.plot(data_cyclic, dim_cyclic, **kwargs)

    ## plot bounds if they exist
    if sigma is not None:
        sigma_cyclic, _ = cartopy.util.add_cyclic_point(sigma, data.month, axis=0)

        ## plot data
        ax.plot(data_cyclic + sigma_cyclic, dim_cyclic, **kwargs, lw=0.8)
        ax.plot(data_cyclic - sigma_cyclic, dim_cyclic, **kwargs, lw=0.8)

    return


def plot_cyclic_quantiles(ax, data, quantiles=[0.5, 0.15, 0.85], **kwargs):
    """plot data on hovmoller with cyclic dependence on month"""

    ## compute quantiles
    q = data.groupby("time.month").quantile(q=quantiles, dim=["time", "member"])
    # q = q.rename({"quantile":"q"})

    ## convert to numpy
    month = q.month.values
    q = q.transpose("quantile", "month").values

    ## add cyclic point
    q_cyclic, dim_cyclic = cartopy.util.add_cyclic_point(q, month, axis=1)

    ## plot median
    ax.plot(q_cyclic[0], dim_cyclic, **kwargs)

    ## plot other quantiles
    if len(quantiles) > 1:
        for j in range(1, len(quantiles)):
            ax.plot(q_cyclic[j], dim_cyclic, lw=0.8, **kwargs)

    return


def format_subsurf_axs(axs):
    """add labels/formatting to 3-panel axs"""

    ## loop thru axs
    for ax in axs:
        ax.set_ylim(ax.get_ylim()[::-1])
        ax.set_xlim([None, 281])
        ax.set_yticks([])
        ax.set_xlabel("Longitude")
    axs[0].set_yticks([300, 150, 0])
    axs[0].set_ylabel("Depth (m)")

    return


def format_hov_axs(axs):
    """put hovmoller axs in standardized format"""

    ## set fontsize
    font_kwargs = dict(size=8)
    axs[0].set_ylabel("Month", **font_kwargs)
    axs[0].set_title("Early", **font_kwargs)
    axs[1].set_title("Late", **font_kwargs)
    axs[2].set_title("Difference (x2)", **font_kwargs)

    axs[1].set_yticks([])
    axs[2].set_yticks([])
    axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])

    for ax in axs:
        # ax.set_xlim([190, None])
        ax.set_xticks([190, 240])
        ax.axvline(240, ls="--", c="w", lw=1)
        ax.axvline(190, ls="--", c="w", lw=1)

    return


def get_w_int(w):
    """get vertical velocity integrated over top 200 m"""
    return w.sel(z_w_top=slice(None, 200)).mean("z_w_top")


def get_dTdz(Tsub):
    """get vertical velocity integrated over top 200 m"""
    T_surf = Tsub.sel(z_t=0, method="nearest").squeeze(drop=True)
    T_subsurf = Tsub.sel(z_t=200, method="nearest").squeeze(drop=True)

    return T_surf - T_subsurf


def get_diags(data):
    """get diagnostics"""
    diags = xr.merge(
        [get_dTdz(data["T"]).rename("dTdz"), get_w_int(data["w"]).rename("w_int")]
    )
    return diags.rename({"lon": "longitude"})


def get_dT_sub(Tsub, mld, delta=25):
    """get temperature difference b/n mixed layer and entrainment zone"""

    ## interpolate mld to match w
    mld_interp = mld.interp({"longitude": Tsub.lon.values}).rename({"longitude": "lon"})

    ## subset for non-NaN coords
    valid_lon_idx = ~np.isnan(mld_interp).all("month")
    mld_interp = mld_interp.isel(lon=valid_lon_idx)
    Tsub = Tsub.isel(lon=valid_lon_idx)

    ## find indices in ML and entrainment zone (ez)
    in_ml = Tsub.z_t <= mld_interp
    in_ez = (Tsub.z_t > mld_interp) & (Tsub.z_t < (delta + mld_interp))

    ## get Tbar and Tplus (following Frankignoul et al paper)
    Tbar = Tsub.where(in_ml).mean("z_t")
    Tplus = Tsub.where(in_ez).mean("z_t")

    ## get gradient
    dT = Tbar - Tplus

    return dT.rename({"lon": "longitude"})


def get_dTdz_sub(Tsub, mld, delta=25):
    """get velocity at base of mixed layer"""

    ## get temperature difference
    dT = get_dT_sub(Tsub=Tsub, mld=mld, delta=delta)

    ## interpolate mld to match w
    mld_interp = mld.interp({"longitude": Tsub.lon.values}).rename({"longitude": "lon"})

    ## subset for non-NaN coords
    valid_lon_idx = ~np.isnan(mld_interp).all("month")
    mld_interp = mld_interp.isel(lon=valid_lon_idx)

    ## get gradient
    dTdz = dT / mld_interp.rename({"lon": "longitude"})

    return dTdz


def get_nino34(data):
    return data.sel(lon=slice(190, 240)).mean("lon")


def get_w_int_idx(data):
    """get nino3.4 w-int"""
    return get_nino34(get_w_int(data))


def get_dTdz_idx(data):
    """get nino3.4 w-int"""
    return get_nino34(get_dTdz(data))


def eq_avg(x):
    return x.sel(latitude=slice(-5, 5), longitude=slice(125, 279)).mean("latitude")


def avg_mon_range(data, m0, m1):
    """average data each year over specified month range"""

    ## find indices for month range
    month = data.time.dt.month
    is_season = (month >= m0) & (month <= m1)

    ## get avg avg
    data_season = data.isel(time=is_season).groupby("time.year").mean()

    return data_season.rename({"year": "time"})


def get_mam(data):
    """subset for MAM months"""

    return avg_mon_range(data, m0=3, m1=5)


def set_ylims(axs):
    lims = np.stack([ax.get_ylim() for ax in axs.flatten()], axis=0)

    lb = lims[:, 0].min()
    ub = lims[:, 1].max()

    for ax in axs:
        ax.set_ylim([lb, ub])

    return


def set_xlims(axs):
    lims = np.stack([ax.get_xlim() for ax in axs.flatten()], axis=0)

    lb = lims[:, 0].min()
    ub = lims[:, 1].max()

    for ax in axs:
        ax.set_xlim([lb, ub])

    return


def get_dy(dlat_deg):
    """get spacing between latitudes in meters"""

    ## convert from degrees to radians
    dlat_rad = dlat / 180.0 * np.pi

    ## multiply by radius of earth
    R = 6.378e8  # earth radius (centimeters)
    dlat_meters = R * dlat_rad

    return dlat_meters


def get_dx(lat_deg, dlon_deg):
    """get spacing between longitudes in meters"""

    ## convert from degrees to radians
    dlon_rad = dlon_deg / 180.0 * np.pi
    lat_rad = lat_deg / 180 * np.pi

    ## multiply by radius of earth
    R = 6.378e6  # earth radius (meters)
    dlon_meters = R * np.cos(lat_rad) * dlon_rad

    return dlon_meters


def get_dydx(data):
    """get dy and dx for given data"""

    ## empty array to hold result
    grid = xr.Dataset(
        coords=dict(
            latitude=data["latitude"].values,
            longitude=data["longitude"].values,
        ),
    )

    grid["dlat"] = grid["latitude"].values[1] - grid["latitude"].values[0]
    grid["dlon"] = grid["longitude"].values[1] - grid["longitude"].values[0]

    grid["dlat_rad"] = grid["dlat"] / 180.0 * np.pi
    grid["dlon_rad"] = grid["dlon"] / 180.0 * np.pi
    R = 6.378e8  # earth radius (centimeters)

    ## height of gridcell doesn't depend on longitude
    grid["dy"] = R * grid["dlat_rad"]  # unit: meters
    grid["dy"] = grid["dy"] * xr.ones_like(grid["latitude"])

    ## Compute width of gridcell
    grid["lat_rad"] = grid["latitude"] / 180 * np.pi  # latitude in radians
    grid["dx"] = R * np.cos(grid["lat_rad"]) * grid["dlon_rad"]

    return grid[["dy", "dx"]]


def u_dfdx(u, f):
    """zonal advection"""

    ## get grid spacing
    dx_cm = get_dydx(f)["dx"]
    sec_per_year = 86400 * 365
    factor = sec_per_year / dx_cm

    u_dfdx_ = u * f.differentiate("longitude") * factor

    return u_dfdx_


def v_dfdy(v, f):
    """meridional advection"""

    ## get grid spacing
    dy_cm = get_dydx(f)["dy"]
    sec_per_year = 86400 * 365
    factor = sec_per_year / dy_cm

    v_dfdy_ = v * f.differentiate("latitude") * factor

    return v_dfdy_


def get_adv(uv, T):
    """
    Compute T tendency from horizontal advection.
    Equal to:
        (u,v) dot grad(-T)
    """

    ## compute grad T
    u_dTdx = u_dfdx(u=uv["uvel"], f=T)
    v_dTdy = v_dfdy(v=uv["vvel"], f=T)

    ## get

    return -(u_dTdx + v_dTdy)


def merimean(x, lat_bound=5):
    """get meridional mean"""

    ## get bounds for latitude averaging
    coords = dict(
        longitude=slice(140, 285),
        latitude=slice(-lat_bound, lat_bound),
    )

    return x.sel(coords).mean("latitude")


def plot_cycle_hov(ax, data, amp, is_filled=True, xticks=[190, 240], lat_bound=5):
    """plot data on ax object"""

    ## specify shared kwargs
    shared_kwargs = dict(levels=src.utils.make_cb_range(amp, amp / 5), extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap="cmo.balance")

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## average over latitudes (if necessary)
    if "latitude" in data.coords:
        plot_data = merimean(data, lat_bound=lat_bound)
    else:
        plot_data = data

    ## do the plotting
    cp = plot_fn(
        plot_data.longitude,
        plot_data.month,
        plot_data,
        **kwargs,
        **shared_kwargs,
    )

    ## format ax object
    kwargs = dict(c="w", ls="--", lw=1)
    ax.set_xlim([145, 280])
    ax.set_xlabel("Lon")
    ax.set_xticks(xticks)
    for tick in xticks:
        ax.axvline(tick, **kwargs)

    return cp


def prep(data):
    """remove sst dependence and compute tendencies"""

    ## remove from h indices
    for h_idx in ["h_w", "h"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var=h_idx, T_var="T_34"
        )

    return data


def regress(data, y_var, x_vars):
    """multiple linear regression"""

    ## Get covariates and targets
    X = data[x_vars].to_dataarray(dim="i")
    Y = data[y_var]

    ## compute covariance matrices
    YXt = xr.cov(Y, X, dim=["member", "time"])
    XXt = xr.cov(X, X.rename({"i": "j"}), dim=["member", "time"])

    ## invert XX^T
    XXt_inv = xr.zeros_like(XXt)
    XXt_inv.values = np.linalg.inv(XXt.values)

    ## get least-squares fit, YX^T @ (XX^T)^{-1}
    m = (YXt * XXt_inv).sum("i")

    return m.to_dataset(dim="j")


def regress_bymonth(data, y_var, x_vars):
    """do multiple linear regression for each month separately"""
    return data.groupby("time.month").map(regress, y_var=y_var, x_vars=x_vars)

## initialize cluster

In [None]:
# from dask.distributed import LocalCluster, Client

# cluster = LocalCluster(n_workers=4)
# client = Client(cluster)
# client

## Load data

### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")

## compute relative sst
for n in ["T_3", "T_34", "T_4"]:
    Th[f"{n}_rel"] = Th_total[n] - trop_sst["trop_sst_10"]

In [None]:
## load ELI data
eli = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/eli.nc"))

## get forced/anomalous component
eli_forced, eli_anom = src.utils.separate_forced(eli)

### Spatial data

#### MMLEA data

In [None]:
## path to EOF data
eofs_fp = pathlib.Path(DATA_FP, "cesm")

## variables to load (and how to rename them)
names = [
    "tos",
    "mlotst",
    "tauu",
    "nhf",
]
newnames = ["sst", "mld", "taux", "nhf"]

# ## load the EOFs
load_var = lambda x: src.utils.load_eofs(pathlib.Path(eofs_fp, f"eofs_{x}.nc"))
eofs = {y: load_var(x) for (y, x) in zip(newnames, names)}

## for convenience, put spatial patterns / components in single dataset
components = xr.merge([eofs_.components().rename(y) for (y, eofs_) in eofs.items()])

# reset member dimension so they all match (NHF labeled differently...)
member_coord = dict(member=np.arange(100))
get_scores = lambda x, n: x.scores().assign_coords(member_coord).rename(n)
scores = xr.merge([get_scores(eofs_, n) for (n, eofs_) in eofs.items()])

## convert from stress on atm to stress on ocn
scores["taux"].values *= -1

## convert MLD from cm to m
scores["mld"] = scores["mld"] / 100

# ## get forced/anomalous component
forced, anom = src.utils.separate_forced(scores)

## add T,h information
for n in ["T_3", "T_34", "T_4", "h", "h_w"]:
    anom[n] = Th[n]

#### Subsurface

In [None]:
def fix_lon_coord(data_sub):
    """fix longitude coordinate on subsurface data"""

    data_sub = data_sub.assign_coords({"nlon": data_sub.lon.isel(z_t=0).values})
    data_sub = data_sub.drop_vars("lon").rename({"nlon": "lon"})

    return data_sub


def convert_cm_to_m_helper(data, z_coord_name):
    """convert z-coord from cm to m"""
    return data.assign_coords({z_coord_name: data[z_coord_name].values / 100})


def convert_cm_to_m(data):
    """convert all z-coords from cm to m"""

    ## convert both z-coordinates
    for z_coord in ["z_t", "z_w_top"]:
        data = convert_cm_to_m_helper(data, z_coord_name=z_coord)

    return data

In [None]:
## path to EOF data
eofs_fp = pathlib.Path(DATA_FP, "cesm")

## variables to load (and how to rename them)
names = [
    "temp",
    "wvel",
    "uvel_sub",
]
newnames = ["T", "w", "u"]

## load the EOFs
load_var = lambda x: src.utils.load_eofs(pathlib.Path(eofs_fp, f"eofs_{x}.nc"))
eofs_sub = {y: load_var(x) for (y, x) in zip(newnames, names)}

## for convenience, put spatial patterns / components in single dataset
components_sub = xr.merge(
    [eofs_.components().rename(y) for (y, eofs_) in eofs_sub.items()]
)

## fix longitude coord
components_sub = fix_lon_coord(components_sub)

# reset member dimension so they all match (NHF labeled differently...)
member_coord = dict(member_id=np.arange(100))
get_scores = lambda x, n: x.scores().assign_coords(member_coord).rename(n)
scores_sub = xr.merge([get_scores(eofs_, n) for (n, eofs_) in eofs_sub.items()])

## convert z coords from cm to m
components_sub = convert_cm_to_m(components_sub)

## convert u and w from cm/s to m/month

# conversion factors
m_per_cm = 1 / 100
s_per_day = 86400
s_per_month = s_per_day * 30

# do conversion
scores_sub["w"] = scores_sub["w"] * m_per_cm * s_per_month
scores_sub["u"] = scores_sub["u"] * m_per_cm * s_per_month

## get forced/anomalous component
forced_sub, anom_sub = src.utils.separate_forced(
    scores_sub.rename({"member_id": "member"})
)

## add anomalies to original dataarray
for n in list(anom_sub):
    anom[n] = anom_sub[n]
    components[n] = components_sub[n]

#### Budget data

In [None]:
## path to cesm data
CESM_FP = DATA_FP / "cesm"


def load_var(varname):
    """load variable from prepped folder"""

    ## open data
    data = xr.open_mfdataset(
        sorted(list(pathlib.Path(DATA_FP, "cesm", f"{varname}_temp").glob("*.nc"))),
        concat_dim="member",
        combine="nested",
        parallel=True,
    )

    return data.assign_coords({"member": np.arange(100)})


## load data
budget_data = xr.merge([load_var(v) for v in ["adv", "ddt_T"]])

## get difference
budget_data["diff"] = budget_data["TEND_TEMP"] - budget_data["ADV_3D_TEMP"]

## convert z coord from cm to m, and unit from K/s to K/mo
M_PER_CM = 1e-2
SEC_PER_MO = 8.64e4 * 30

## convert from (i) cm to m and (ii) K/s to K/mo
budget_data = budget_data.assign_coords({"z_t": budget_data.z_t * M_PER_CM})
budget_data = budget_data * SEC_PER_MO

## fix longitude coordinate
budget_data = budget_data.assign_coords({"nlon": budget_data.lon.values})
budget_data = budget_data.drop_vars("lon").rename({"nlon": "lon"})

## trim in time and load to memory
t_idx = np.concatenate([np.arange(0, 480), np.arange(3012 - 480, 3012)])
budget_data = budget_data.isel(time=t_idx).compute()

## separate forced/anomalies
forced_bud, anom_bud = src.utils.separate_forced(budget_data)

## add anomalies to data
for n in list(anom_bud):
    anom[n] = anom_bud[n]

## Bjerknes coupling

#### prep data

Add EOF info to surface data

Add eof data to subsurface

In [None]:
for v in list(components):
    if f"{v}_comp" not in list(anom):
        anom[f"{v}_comp"] = components[v]

Split into early/late, and compute tendencies

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early)
anom_late = anom.sel(t_late)

## remove ssh dependence
anom_early = prep(anom_early)
anom_late = prep(anom_late)

#### Mixed layer info

Funcs

In [None]:
def plot_mld_bounds(ax, clim, m):
    """Plot MLD cli|matology and ± bounds"""

    ## clim
    ax.plot(clim.longitude, clim, c="k")

    ## El Niño
    ax.plot(clim.longitude, clim + m, c="r")

    ## La Niña
    ax.plot(clim.longitude, clim - m, c="b")

    return


def get_eq_mld(data):
    """Get equatorial mixed layer depth"""

    data_ = data.sel(latitude=slice(-1.5, 1.5)).mean("latitude")

    return data_.sel(longitude=slice(140, 280))


def recon_eq_mld(time_dict):
    """reconstruct equatorial MLD"""
    return src.utils.reconstruct_fn(
        scores=forced["mld"].sel(time_dict).groupby("time.month").mean(),
        components=components["mld"],
        fn=get_eq_mld,
    )


## function to plot MLDs
def plot_mlds(axs, sel):
    axs[0].plot(mld_early.longitude, sel(mld_early), c="k")
    axs[1].plot(mld_late.longitude, sel(mld_late), c="k", ls="--")
    axs[2].plot(mld_early.longitude, sel(mld_early), c="k")
    axs[2].plot(mld_late.longitude, sel(mld_late), c="k", ls="--")

    return

Get mixed layer for early/late periods

In [None]:
mld_early = recon_eq_mld(t_early)
mld_late = recon_eq_mld(t_late)

#### Get Tsub

In [None]:
def get_Tsub_early(T):
    """
    get subsurface temperature for early period
    """
    return get_dT_sub(T, mld=mld_early, delta=20)


def get_Tsub_recon(data):
    """get subsurface reconstruction"""

    Tsub = src.utils.reconstruct_fn(
        scores=data["T"],
        fn=get_Tsub_early,
        components=data["T_comp"],
    )

    return Tsub

### Compute BJ coefficients

In [None]:
def get_clim_sub(t_dict):
    """Get climatology for given period"""

    ## get climatology by month
    clim_proj = forced_sub.sel(t_dict).groupby("time.month").mean()

    clim = src.utils.reconstruct_fn(
        components=components_sub,
        scores=clim_proj,
        fn=lambda x: x,
    )

    return clim


clim_sub_early = get_clim_sub(t_early)
clim_sub_late = get_clim_sub(t_late)

In [None]:
def fit(y_var, data, x_vars=["T_34", "h_w_hat"]):
    """fit linear regression model to data"""

    ## infer T variable
    T_var = x_vars[0]

    ## get coeffs
    kwargs = dict(y_var=y_var, x_vars=x_vars)
    coefs = src.utils.multi_regress_bymonth(data, **kwargs)

    return coefs[T_var]


def fit_early(**kwargs):
    """convenience func to fit to early data"""
    return fit(data=anom_early, **kwargs)


def fit_late(**kwargs):
    """convenience func to fit to early data"""
    return fit(data=anom_late, **kwargs)

Compute things which don't depend on MLD

In [None]:
## empty datasets to hold results
m_early = xr.Dataset()
m_late = xr.Dataset()


## compute regression coefficients for given variables
for n in tqdm.tqdm(["sst", "nhf", "taux", "T", "w", "u"]):
    m_early[n] = fit_early(y_var=n)
    m_late[n] = fit_late(y_var=n)


## taux-Tsub
kwargs = dict(x_var="taux", y_var="T")
get_slope = lambda x, fn_x: x.groupby("time.month").map(
    src.utils.regress_proj, fn_x=fn_x, **kwargs
)
m_early["taux_T"] = get_slope(anom_early, fn_x=src.utils.get_nino4)
m_late["taux_T"] = get_slope(anom_late, fn_x=src.utils.get_nino4)

## Thermocline feedback
m_early["THF"] = get_wdTdz(w=clim_sub_early["w"], T=m_early["T"])
m_late["THF"] = get_wdTdz(w=clim_sub_late["w"], T=m_late["T"])

## Ekman feedback
m_early["EKM"] = get_wdTdz(T=clim_sub_early["T"], w=m_early["w"])
m_late["EKM"] = get_wdTdz(T=clim_sub_late["T"], w=m_late["w"])

## zonal advective feedback and dynamical damping
m_early["ZAF"] = get_u_adv(T=clim_sub_early["T"], u=m_early["u"])
m_early["DD"] = get_u_adv(T=m_early["T"], u=clim_sub_early["u"])
m_late["ZAF"] = get_u_adv(T=clim_sub_late["T"], u=m_late["u"])
m_late["DD"] = get_u_adv(T=m_late["T"], u=clim_sub_late["u"])

## Decompose changes in Ekman feedback
delta_w = m_late["w"] - m_early["w"]
delta_T = clim_sub_late["T"] - clim_sub_early["T"]
m_late["delta_EKM_mean"] = get_wdTdz(T=delta_T, w=m_early["w"])
m_late["delta_EKM_anom"] = get_wdTdz(T=clim_sub_early["T"], w=delta_w)
m_late["delta_EKM_nl"] = get_wdTdz(T=delta_T, w=delta_w)

## Decompose changes in Thermocline feedback
delta_w = clim_sub_late["w"] - clim_sub_early["w"]
delta_T = m_late["T"] - m_early["T"]
m_late["delta_THF_mean"] = get_wdTdz(T=m_early["T"], w=delta_w)
m_late["delta_THF_anom"] = get_wdTdz(T=delta_T, w=clim_sub_early["w"])
m_late["delta_THF_nl"] = get_wdTdz(T=delta_T, w=delta_w)

## sum up
for m in [m_early, m_late]:
    m["ADV"] = m["THF"] + m["EKM"] + m["ZAF"] + m["DD"]

## get ground truth tendencies
for n in ["TEND_TEMP", "ADV_3D_TEMP"]:

    kwargs = dict(x_vars=["T_34", "h_w_hat"], y_var=n)
    m_early[n] = regress_bymonth(anom_early, **kwargs)["T_34"]
    m_late[n] = regress_bymonth(anom_late, **kwargs)["T_34"]

## get difference
m_early["diff"] = m_early["TEND_TEMP"] - m_early["ADV_3D_TEMP"]
m_late["diff"] = m_late["TEND_TEMP"] - m_late["ADV_3D_TEMP"]

Compute things which *do* depend on MLD

In [None]:
## should we use fixed depth MLD?
USE_FIXED_MLD = True

if USE_FIXED_MLD:
    H0 = 70
    Hm_early = H0 * xr.ones_like(mld_early)
    Hm_late = H0 * xr.ones_like(mld_late)
    delta = 0

else:
    # Hm_early = mld_early
    # Hm_late = mld_late
    Hm_early = 70 * xr.ones_like(mld_early)
    Hm_late = 50 * xr.ones_like(mld_late)
    delta = 15

## height of entrainment zone (from base of ML)
ez_height = 20

## other subsurface vars
update_lon = lambda x: x.interp({"longitude": m_early.longitude})

## T_ml - T_sub (regr on Niño 3.4 and taux)
fit_Tsub = lambda **kwargs: update_lon(-get_dT_sub(**kwargs, delta=ez_height))
m_early["dT_n34"] = fit_Tsub(Tsub=m_early["T"], mld=Hm_early)
m_late["dT_n34"] = fit_Tsub(Tsub=m_late["T"], mld=Hm_late)
m_early["dT_taux"] = fit_Tsub(Tsub=m_early["taux_T"], mld=Hm_early)
m_late["dT_taux"] = fit_Tsub(Tsub=m_late["taux_T"], mld=Hm_late)

## Integrate over mixed layer
for v in list(m_early):
    if "z_t" in m_early[v].coords:
        m_early[f"{v}_ml"] = update_lon(
            get_ml_avg_wrapper(m_early[v], Hm=Hm_early, delta=delta)
        )
for v in list(m_late):
    if "z_t" in m_late[v].coords:
        m_late[f"{v}_ml"] = update_lon(
            get_ml_avg_wrapper(m_late[v], Hm=Hm_late, delta=delta)
        )

## get d/dt(SST)
m_early["ddt_sst"] = update_lon(
    m_early["TEND_TEMP"].rename({"lon": "longitude"}).isel(z_t=0)
)
m_late["ddt_sst"] = update_lon(
    m_late["TEND_TEMP"].rename({"lon": "longitude"}).isel(z_t=0)
)

### Feedback hovmollers 
E.g., thermocline ($\overline{w}~\frac{\partial T'}{\partial z}$) and Ekman feedback ($w'~\frac{\partial \overline{T}}{\partial z}$)

##### Plot mixed layer integral

In [None]:
## specify plot amplitude
amp = 1

for n in ["THF_ml", "EKM_ml", "ADV_ml", "ADV_3D_TEMP_ml", "ddt_sst", "diff_ml"]:
    # for n in ["THF_ml", "EKM_ml", "ZAF_ml", "DD_ml", "ADV_ml", "ADV_3D_TEMP_ml", "ddt_sst"]:

    print(f"\n{n}")
    fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

    ## plot data
    cp0 = plot_cycle_hov(axs[0], data=m_early[n], amp=amp)
    cp1 = plot_cycle_hov(axs[1], data=m_late[n], amp=amp)
    cp2 = plot_cycle_hov(axs[2], data=(m_late - m_early)[n], amp=amp * 0.75)

    ## make it look nicer
    cb = fig.colorbar(
        cp0,
        ax=axs[2],
        ticks=[-amp, 0, amp],
        label=r"$K~\left(\text{month}\right)^{-1}$",
    )
    format_hov_axs(axs)
    for ax in axs:
        ax.axhline(7, ls="--", c="k", lw=1)

    plt.show()

##### Decompose changes in Ekman

In [None]:
## specify colorbar amp
amp = 0.5

## specify plot data and titles
plot_data = [
    (m_late - m_early)["EKM_ml"],
    m_late["delta_EKM_mean_ml"],
    m_late["delta_EKM_anom_ml"],
    m_late["delta_EKM_nl_ml"],
]
titles = [
    r"$\Delta\left(w~\frac{\partial \overline{T}}{\partial z}\right)$",
    r"$w_0'~\Delta\left(\frac{\partial \overline{T}}{\partial z}\right)$",
    r"$\Delta\left(w'\right)~\frac{\partial \overline{T}_0}{\partial z}$",
    r"$\Delta\left(w'\right)~\Delta\left(\frac{\partial \overline{T}}{\partial z}\right)$",
]

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

format_hov_axs(axs)

## plot data
for ax, p, t in zip(axs, plot_data, titles):
    cp = plot_cycle_hov(ax, data=p, amp=amp)
    ax.set_title(t)
    ax.axhline(7, ls="--", c="k", lw=1)

for ax in axs[3:]:
    ax.set_yticks([])

## make it look nicer
cb = fig.colorbar(
    cp,
    ax=axs[-1],
    ticks=[-amp, 0, amp],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)

plt.show()

##### Decompose changes in thermocline feedback

In [None]:
## specify colorbar amp
amp = 0.5

## specify plot data and titles
plot_data = [
    (m_late - m_early)["THF_ml"],
    m_late["delta_THF_mean_ml"],
    m_late["delta_THF_anom_ml"],
    m_late["delta_THF_nl_ml"],
]
titles = [
    r"$\Delta\left(\overline{w}~\frac{\partial T'}{\partial z}\right)$",
    r"$\Delta\left(\overline{w}\right)~\frac{\partial T_0}{\partial z}$",
    r"$\overline{w}_0~\Delta\left(\frac{\partial T'}{\partial z}\right)$",
    r"$\Delta\left(\overline{w}\right)~\Delta\left(\frac{\partial T'}{\partial z}\right)$",
]

fig, axs = plt.subplots(1, 4, figsize=(8, 2.5), layout="constrained")

format_hov_axs(axs)

## plot data
for ax, p, t in zip(axs, plot_data, titles):
    cp = plot_cycle_hov(ax, data=p, amp=amp)
    ax.set_title(t)
    ax.axhline(7, ls="--", c="k", lw=1)

for ax in axs[3:]:
    ax.set_yticks([])

## make it look nicer
cb = fig.colorbar(
    cp,
    ax=axs[-1],
    ticks=[-amp, 0, amp],
    label=r"$K~\left(\text{month}\right)^{-1}$",
)

plt.show()

### Plot BJ couplings

#### Niño 3.4 - SST

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_early["sst"], amp=2, lat_bound=1.5)
plot_cycle_hov(axs[0], data=m_late["sst"], amp=2, is_filled=False, lat_bound=1.5)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=m_late["sst"] - m_early["sst"], amp=1, lat_bound=1.5)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"Niño 3.4-SST coupling")
axs[1].set_title("Change")

for ax in axs:
    ax.axhline(6, c="k", ls="--", lw=1)

plt.show()

### SST - NHF

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_early["nhf"], amp=40, lat_bound=1.5)
plot_cycle_hov(axs[0], data=m_late["nhf"], amp=40, lat_bound=1.5, is_filled=False)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=m_late["nhf"] - m_early["nhf"], amp=20, lat_bound=1.5)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"$NHF$-SST coupling")
axs[1].set_title("Change")

for ax in axs:
    ax.axhline(7, c="k", ls="--", lw=1)

plt.show()

### SST-$\tau_x$

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2.5), layout="constrained")

## plot data
cp0 = plot_cycle_hov(axs[0], data=m_early["taux"], amp=0.015)
plot_cycle_hov(axs[0], data=m_late["taux"], amp=0.015, is_filled=False)

## plot difference
cp2 = plot_cycle_hov(axs[1], data=m_late["taux"] - m_early["taux"], amp=0.0075)

## make it look nicer
axs[1].set_yticks([])
axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])
axs[0].set_ylabel("Month")
axs[0].set_title(r"$\tau_x$-SST coupling")
axs[1].set_title("Change")

for ax in axs:
    ax.axhline(6, c="k", ls="--", lw=1)
    ax.axvline(160, ls="--", c="w")
    ax.axvline(210, ls="--", c="w")

plt.show()

In [None]:
sel = lambda x: x.sel(month=6, latitude=slice(-5, 5)).mean("latitude")

fig, ax = plt.subplots(figsize=(3, 2.5))
ax.plot(m_early["taux"].longitude, sel(m_late["taux"] - m_early["taux"]))
ax.axhline(0, ls="--", c="k", lw=0.8)
ax.set_xlim([140, 280])

### $T_{sub}-$Niño3.4

In [None]:
## helper func to update longitude coord
sel_data = lambda x: x["dT_n34"]

## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(
    cmap="cmo.balance", levels=src.utils.make_cb_range(1e0, 1e-1), extend="both"
)
cb_kwargs = dict(ticks=[-1, 0, 1], label=r"$K~/~K$")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], sel_data(m_early), **kwargs)
cb0 = fig.colorbar(cp0, ax=axs[0], **cb_kwargs)

# ## plot late
cp1 = src.utils.plot_cycle_hov(axs[1], sel_data(m_late), **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **cb_kwargs)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    sel_data(m_late - m_early),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.5e0, 0.5e-1),
    extend="both",
)

cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-0.5, 0, 0.5], label=r"$K~/~K$")

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])
axs[-1].axhline(7, c="k", lw=0.8, ls="--")

plt.show()

### $T_{sub}-\tau_x$

In [None]:
## specify which period/month to plot
# sel = lambda x: x.mean("month")
sel = lambda x: x.sel(month=7)

fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")

for ax, m in zip(axs[:2], [m_early, m_late]):

    ## temperature
    cp = ax.contourf(
        m.lon,
        m.z_t,
        sel(m["taux_T"]),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(200, 20),
        extend="both",
    )

## difference
axs[2].contourf(
    m.lon,
    m.z_t,
    sel(m_late - m_early)["taux_T"],
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(100, 10),
    extend="both",
)
## plot MLD
plot_mlds(axs=axs, sel=sel)

## label
cb = fig.colorbar(cp, ax=axs[2], ticks=[-10, 0, 10], label=r"$K~\text{Pa}^{-1}$")
format_subsurf_axs(axs)
for ax in axs:
    ax.set_ylim([100, 5])
    ax.axvline(190, ls="--", c="w", lw=0.8)
    ax.axvline(240, ls="--", c="w", lw=0.8)

plt.show()

In [None]:
## helper func to update longitude coord
sel_data = lambda x: x["dT_taux"]

## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(
    cmap="cmo.balance", levels=src.utils.make_cb_range(1e2, 1e1), extend="both"
)
cb_kwargs = dict(ticks=[-100, 0, 100], label=r"$K~/~Pa$")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], sel_data(m_early), **kwargs)
cb0 = fig.colorbar(cp0, ax=axs[0], **cb_kwargs)

# ## plot late
cp1 = src.utils.plot_cycle_hov(axs[1], sel_data(m_late), **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **cb_kwargs)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    sel_data(m_late - m_early),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(5e1, 5e0),
    extend="both",
)

cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-50, 0, 50], label=r"$K~/~Pa$")

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

plt.show()

##### Horizontal cross-section

In [None]:
for n in ["THF", "EKM", "ZAF", "ADV", "ADV_3D_TEMP"]:

    print(f"\n\n{n}")

    ## specify which period/month to plot
    sel = lambda x: x.sel(month=6)

    fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")

    for ax, m in zip(axs[:2], [m_early, m_late]):

        ## temperature
        cp = ax.contourf(
            m.lon,
            m.z_t,
            sel(m)[n],
            cmap="cmo.balance",
            levels=src.utils.make_cb_range(2, 0.2),
            extend="both",
        )

    ## difference
    axs[2].contourf(
        m_early.lon,
        m_early.z_t,
        sel(m_late - m_early)[n],
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(1, 0.1),
        extend="both",
    )

    ## plot MLD
    plot_mlds(axs, sel=sel)

    ## set ax limit and plot Niño 3.4 bounds
    cb = fig.colorbar(
        cp, ax=axs[2], ticks=[-2, 0, 2], label=r"$K~\left(\text{month}\right)^{-1}$"
    )
    format_subsurf_axs(axs)
    for ax in axs:
        ax.set_ylim([100, 5])
        ax.axvline(190, ls="--", c="w", lw=0.8)
        ax.axvline(240, ls="--", c="w", lw=0.8)

    plt.show()