# Bjerknes feedback changes over time

## imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_dT_sub(Tsub, Hm, delta=25):
    """
    Get temperature difference b/n entrainment zone and mixed layer.
    (positive if entrainment zone is warmer than ML)
    """

    ## find indices in ML and entrainment zone (ez)
    in_ml = Tsub.z_t <= Hm
    in_ez = (Tsub.z_t > Hm) & (Tsub.z_t < (delta + Hm))

    ## get Tbar and Tplus (following Frankignoul et al paper)
    Tbar = Tsub.where(in_ml).mean("z_t")
    Tplus = Tsub.where(in_ez).mean("z_t")

    ## get gradient
    dT = Tplus - Tbar

    return dT


def prep(data):
    """remove sst dependence and compute tendencies"""

    ## remove from h indices
    for h_idx in ["h_w", "h", "h_w_z20"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var=h_idx, T_var="T_34"
        )

    return data


def load_budget_data(t_early, t_late, target_grid):
    """load ground truth heat budget data"""

    def load_var(varname):
        """load variable from prepped folder"""

        ## open data
        data = xr.open_mfdataset(
            sorted(list(pathlib.Path(DATA_FP, "cesm", f"{varname}_temp").glob("*.nc"))),
            concat_dim="member",
            combine="nested",
            parallel=True,
        )

        return data.assign_coords({"member": np.arange(100)})

    ## load data
    budget_data = xr.merge([load_var(v) for v in ["adv", "ddt_T"]])

    ## get difference
    budget_data["diff"] = budget_data["TEND_TEMP"] - budget_data["ADV_3D_TEMP"]

    ## convert z coord from cm to m, and unit from K/s to K/mo
    M_PER_CM = 1e-2
    SEC_PER_MO = 8.64e4 * 30

    ## convert from (i) cm to m and (ii) K/s to K/mo
    budget_data = budget_data.assign_coords({"z_t": budget_data.z_t * M_PER_CM})
    budget_data = budget_data * SEC_PER_MO

    ## fix longitude coordinate
    budget_data = budget_data.assign_coords({"nlon": budget_data.lon.values})
    budget_data = budget_data.drop_vars("lon").rename({"nlon": "longitude"})

    ## trim in time
    budget_data = xr.concat(
        [budget_data.sel(t_early), budget_data.sel(t_late)], dim="time"
    )

    ## interpolate
    budget_data = budget_data.interp_like(target_grid)

    # ## separate forced/anomalies
    forced_bud, anom_bud = src.utils.separate_forced(budget_data)

    return forced_bud, anom_bud


def plot_mlds(axs, bar_early, bar_late, month=None):
    """plot mixed layer depth on pair of axs objects"""

    ## get longitude
    lon = src.utils.merimean(bar_early).longitude

    ## helper function to get month
    if month is None:
        sel = lambda x: x.mean("month")
    else:
        sel = lambda x: x.sel(month=month)

    ## plot
    axs[0].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[1].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")
    axs[2].plot(lon, sel(src.utils.merimean(bar_early["mld"])), c="k")
    axs[2].plot(lon, sel(src.utils.merimean(bar_late["mld"])), c="k", ls="--")

    return


def get_nino3_da(x):
    """get nino3 on dataarray"""

    if "latitude" not in x.dims:
        x = x.expand_dims("latitude")

    return src.utils.get_nino3(x)


def get_nino34_da(x):
    """get nino3.4 on dataarray"""

    if "latitude" not in x.dims:
        x = x.expand_dims("latitude")

    return src.utils.get_nino34(x)

## Load data

### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices(load_z20=True)

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")

## compute relative sst
for n in ["T_3", "T_34", "T_4"]:
    Th[f"{n}_rel"] = Th_total[n] - trop_sst["trop_sst_10"]

### Spatial data

#### Load

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()

## add normalized Th data
anom = xr.merge([anom, Th / Th.std()])

#### compute $T$ tendency

In [None]:
## get sst tendency (and convert from 1/yr to 1/mo)
anom["ddt_sst"] = 1 / 12 * src.utils.get_ddt(anom[["sst"]], is_forward=False)["ddt_sst"]
anom["ddt_T"] = 1 / 12 * src.utils.get_ddt(anom[["T"]], is_forward=False)["ddt_T"]

#### Get data

In [None]:
## specify vars to look at
VARNAMES = ["T_3", "T_34", "h_w", "h_w_z20", "ddt_T", "nhf", "T", "w", "u"]
VARNAMES += [f"{v}_comp" for v in VARNAMES[-4:]]

## Get windowed data
anom = src.utils.get_windowed(anom[VARNAMES], stride=120).compute()
forced = src.utils.get_windowed(forced[VARNAMES[-8:]], stride=120).compute()

#### Regression on $T_{34}$ and $\hat{h}_w$

helper function

In [None]:
def fit_wrapper(data, y_vars, x_vars=["T_34", "h_w_hat"]):
    """fit linear regression model to data"""

    ## get coeffs
    kwargs = dict(y_vars=y_vars, x_vars=x_vars)
    coefs = src.utils.regress_xr_bymonth(data, **kwargs)

    return coefs

Compute

In [None]:
## do regression
fit_kwargs = dict(
    y_vars=["ddt_T", "nhf", "T", "w", "u"],
    x_vars=["T_3", "h_w_z20"],
)

## save filepath
SAVE_FP = pathlib.Path(
    os.environ["SAVE_FP"],
    "bjerknes",
    f"{fit_kwargs["x_vars"][0]}_{fit_kwargs["x_vars"][1]}_coefs_ddt_T.nc",
)

if SAVE_FP.is_file():
    coefs = xr.open_dataset(SAVE_FP)

else:

    ## empty array to hold coefficients
    coefs = []
    for y in tqdm.tqdm(anom.year):
        coefs.append(fit_wrapper(anom.sel(year=y), **fit_kwargs))

    ## put in xr.Dataarray
    coefs = xr.concat(coefs, dim=anom.year)

    coefs.to_netcdf(SAVE_FP)

In [None]:
## get climatology
bar = src.utils.reconstruct_clim(forced)

## get feedbacks
feedbacks = src.utils.get_feedbacks(bar=bar, prime=coefs)

## merge with other results
coefs = xr.merge([coefs, feedbacks])

## specify mixed layer kwargs
ml_kwargs = dict(H0=70, Hm=None)

## Integrate over mixed layer
coefs = src.utils.get_ml_avg_ds(coefs, **ml_kwargs)


## get NHF in units of K/mo
sec_per_mo = 8.64e4 * 30
rho = 1.02e3
Cp = 4.2e3
coefs["Q"] = coefs["nhf"] * sec_per_mo / (rho * Cp * ml_kwargs["H0"])

convert heat flux to deg C

In [None]:
m_early = coefs.sel(year=1870).isel(j=0)
m_late = coefs.sel(year=2080).isel(j=0)

## Plot feedback hovmollers 
E.g., thermocline ($\overline{w}~\frac{\partial T'}{\partial z}$) and Ekman feedback ($w'~\frac{\partial \overline{T}}{\partial z}$)

### Plot mixed layer integral

In [None]:
## specify plot amplitude
amp = 1.5
amp_diff = 0.75

for n in ["THF_ml", "EKM_ml", "ZAF_ml", "DD_ml", "ADV_ml", "ddt_T_ml", "Q"]:

    print(f"\n{n}")
    fig, axs = plt.subplots(1, 3, figsize=(7, 2.5), layout="constrained")

    ## plot data
    cp0 = src.utils.make_cycle_hov(axs[0], data=m_early[n], amp=amp)
    cp1 = src.utils.make_cycle_hov(axs[1], data=m_late[n], amp=amp)
    cp2 = src.utils.make_cycle_hov(axs[2], data=(m_late - m_early)[n], amp=amp_diff)

    ## make it look nicer
    cb01 = fig.colorbar(
        cp0,
        ax=axs[1],
        ticks=[-amp, 0, amp],
    )
    cb02 = fig.colorbar(
        cp2,
        ax=axs[2],
        ticks=[-amp_diff, 0, amp_diff],
        label=r"$K~\left(\text{month}\right)^{-1}$",
    )
    src.utils.format_hov_axs(axs)
    for ax in axs:
        ax.axhline(5, ls="--", c="k", lw=1)

    plt.show()

In [None]:
def plot_hovs(varname, amp0, amp1, month=1):
    """plot data on hovmoller"""

    ## get data for plot
    plot_data = coefs[varname].isel(j=0)

    ## make sure latitude is a dimension
    if "latitude" not in plot_data.dims:
        plot_data = plot_data.expand_dims("latitude")

    ## get data for plots
    x0 = plot_data.sel(month=month, latitude=slice(-5, 5)).mean("latitude")
    x0 = x0.transpose("year", ...)
    x1 = src.utils.get_nino34(plot_data).transpose("year", ...)

    fig, axs = plt.subplots(1, 2, figsize=(6, 3.5), layout="constrained")

    ## longitude on x-axis
    axs[0].contourf(
        x0.longitude,
        x0.year,
        x0,
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(amp0, amp0 / 10),
        extend="both",
    )

    ## month on x-axis
    axs[1].contourf(
        plot_data.month,
        plot_data.year,
        x1,
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(amp1, amp1 / 10),
        extend="both",
    )

    axs[0].set_xlim([140, 280])
    axs[0].axvline(210, ls="--", c="k", lw=0.8)
    axs[0].axvline(270, ls="--", c="k", lw=0.8)
    axs[0].set_xticks([160, 210, 270])
    axs[0].set_yticks([1870, 1980, 2090])
    axs[1].set_xticks([2, 7, 12])
    axs[1].set_yticks([])

    return fig, axs

In [None]:
## specify month
MONTH = 4


for VARNAME in ["ddt_T_ml", "THF_ml", "EKM_ml", "ZAF_ml", "DD_ml", "Q"]:

    print(f"\n\n{VARNAME}")
    fig, axs = plot_hovs(varname=VARNAME, amp0=1.0, amp1=0.25, month=MONTH)
    plt.show()

In [None]:
## specify month
MONTH = 5

## func to select data
sel = lambda x: (x - x.isel(year=0)).sel(j="T_3").sel(month=MONTH)
# sel = lambda x : (x-x.isel(year=0)).isel(j=0).mean("month")


fig, axs = plt.subplots(1, 2, figsize=(7, 3))


## positive feedbacks
for f in ["EKM", "THF", "ZAF", "ADV", "ddt_T"]:
    axs[0].plot(coefs.year, sel(get_nino3_da(coefs[f"{f}_ml"])), label=f)
# axs[0].plot(
#     coefs.year, sel(get_nino3_da(coefs[f"ddt_T_ml"])), label=r"$\frac{dT}{dt}$"
# )

## negative feedbacks
for f in ["DD_ml", "Q"]:
    axs[1].plot(coefs.year, sel(get_nino3_da(coefs[f"{f}"])), label=f)

## legends
axs[0].legend(prop=dict(size=8))
axs[1].legend(prop=dict(size=8))

for ax in axs:
    ax.axhline(0, ls="--", c="k", lw=0.8)


plt.show()

In [None]:
coefs