# Discharge efficiency change over time

## imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def prep(data, is_forward=True, T_var="T_3"):
    """remove sst dependence and compute tendencies"""

    ## remove SST dependence from SSH field
    if "ssh" in list(data):
        data["ssh_hat"] = src.utils.remove_sst_dependence_v2(
            data,
            h_var="ssh",
            T_var=T_var,
        )
        data["ssh_hat_comp"] = data["ssh_comp"]

    ## remove from h indices
    for h_idx in ["h_w", "h"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data,
            h_var=h_idx,
            T_var=T_var,
        )

    ## get sst tendencies (and convert from 1/yr to 1/mo)
    for v in ["sst", "ssh", "ssh_hat"]:
        data[f"ddt_{v}"] = (
            1 / 12 * src.utils.get_ddt(data[[v]], is_forward=is_forward)[f"ddt_{v}"]
        )

    return data

## Load data

### $T$, $h$

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")

## compute relative sst
for n in ["T_3", "T_34", "T_4"]:
    Th[f"{n}_rel"] = Th_total[n] - trop_sst["trop_sst_10"]

### Spatial data

#### Load

In [None]:
## load spatial data
CONS_DIR = pathlib.Path(DATA_FP, "cesm", "consolidated")
forced = xr.open_dataset(CONS_DIR / "forced.nc")
anom = xr.open_dataset(CONS_DIR / "anom.nc")

## add T,h information
for n in ["T_3", "T_34", "T_4", "h", "h_w"]:
    anom[n] = Th[n]

#### early/late split

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early)
anom_late = anom.sel(t_late)

## preprocess
anom_early = prep(anom_early, is_forward=False)
anom_late = prep(anom_late, is_forward=False)

## load to memory
anom_early.load()
anom_late.load();

#### Regression on $T_{34}$ and $\hat{h}_w$

helper function

In [None]:
def fit_wrapper(data, y_vars, x_vars=["T_34", "h_w_hat"]):
    """fit linear regression model to data"""

    ## get coeffs
    kwargs = dict(y_vars=y_vars, x_vars=x_vars)
    coefs = src.utils.regress_xr_bymonth(data, **kwargs)

    ## get first coefficient
    return coefs

Compute

In [None]:
## specify: should we use h or h_hat?
USE_HAT = True

## set variables for rest of script
if USE_HAT:
    fit_kwargs = dict(x_vars=["T_3", "h_w_hat"], y_vars=["ddt_ssh_hat", "ssh_hat"])
    H_VAR = "ssh_hat"
else:
    fit_kwargs = dict(x_vars=["T_3", "h_w"], y_vars=["ddt_ssh", "ssh"])
    H_VAR = "ssh"

## keep track of time
t0 = time.time()

## do regression
m_early = fit_wrapper(anom_early, **fit_kwargs)
m_late = fit_wrapper(anom_late, **fit_kwargs)

## print out elapsed time
t1 = time.time()
print(f"Elapsed time: {t1-t0:.1f} seconds")

## Plot feedbacks

### $F_2$

##### Hovmoller

In [None]:
## specify which ssh variable to use
LABEL = r"$m~\left(K \cdot \text{month}\right)^{-1}$"

## function to select data
sel = lambda x: x[f"ddt_{H_VAR}"].sel(j="T_3")

## shared args
kwargs = dict(amp=2, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=sel(m_early), **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=sel(m_late), **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=sel(m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=LABEL,
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

##### Spatial plot

In [None]:
## select month
sel = lambda x: x[f"ddt_{H_VAR}"].sel(j="T_3", month=6)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=5, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early, **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late, **contour_kwargs)
cp2 = src.utils.make_contour_plot(axs[2, 0], m_late - m_early, **contour_kwargs)

## colorbar
cb_kwargs = dict(label=LABEL, ticks=[-contour_kwargs["amp"], 0, contour_kwargs["amp"]])
cb0 = fig.colorbar(cp0, ax=axs[:2], **cb_kwargs)
cb2 = fig.colorbar(cp2, ax=axs[-1], **cb_kwargs)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_hw_box(ax, **box_kwargs)

plt.show()

##### Scatter plot

In [None]:
## kwargs for plotting
kwargs = dict(
    months=[6],
    x_var="T_3",
    y_var=f"ddt_{H_VAR}",
    fn_y=src.utils.get_RO_hw,
)

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.5), layout="constrained")
m0 = src.utils.make_scatter2(axs[0], anom_early, **kwargs)
m1 = src.utils.make_scatter2(axs[1], anom_late, **kwargs)

## label
axs[0].set_title(f"{m0:.2f} " + LABEL)
axs[1].set_title(f"{m1:.2f} " + LABEL)
axs[1].set_yticks([])

src.utils.set_lims(axs)

### SSH

#### Hovmoller

In [None]:
## specify which ssh variable to use
LABEL = r"$m~\left(K\right)^{-1}$"

## functin to select data
sel = lambda x: x[H_VAR].sel(j="T_3")

## shared args
kwargs = dict(amp=8, lat_bound=5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=sel(m_early), **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=sel(m_late), **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * sel(m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=LABEL,
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

#### Spatial

In [None]:
## select month
sel = lambda x: x[H_VAR].sel(j="T_3", month=5)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=8, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early, **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late, **contour_kwargs)
cp2 = src.utils.make_contour_plot(axs[2, 0], 2 * (m_late - m_early), **contour_kwargs)

## colorbar
cb_kwargs = dict(label=LABEL, ticks=[-contour_kwargs["amp"], 0, contour_kwargs["amp"]])
cb0 = fig.colorbar(cp0, ax=axs[:2], **cb_kwargs)
cb2 = fig.colorbar(cp2, ax=axs[-1], **cb_kwargs)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_hw_box(ax, **box_kwargs)

plt.show()

### $\varepsilon$

In [None]:
## specify which ssh variable to use
LABEL = r"$\left(\text{month}\right)^{-1}$"

## function to select data
sel = lambda x: x[f"ddt_{H_VAR}"].isel(j=1)

## shared args
kwargs = dict(amp=80, lat_bound=1.5)

fig, axs = plt.subplots(1, 3, figsize=(6, 2.5), layout="constrained")

## plot data
cp0 = src.utils.make_cycle_hov(axs[0], data=sel(m_early), **kwargs)
cp1 = src.utils.make_cycle_hov(axs[1], data=sel(m_late), **kwargs)
cp2 = src.utils.make_cycle_hov(axs[2], data=2 * sel(m_late - m_early), **kwargs)

## make it look nicer
cb = fig.colorbar(
    cp0,
    ax=axs[2],
    ticks=[-kwargs["amp"], 0, kwargs["amp"]],
    label=LABEL,
)
src.utils.format_hov_axs(axs)
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

In [None]:
## select month
sel = lambda x: x[f"ddt_{H_VAR}"].isel(j=1).sel(month=6)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=80, sel=sel)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early, **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late, **contour_kwargs)
cp2 = src.utils.make_contour_plot(axs[2, 0], 2 * (m_late - m_early), **contour_kwargs)

## colorbar
cb_kwargs = dict(label=LABEL, ticks=[-contour_kwargs["amp"], 0, contour_kwargs["amp"]])
cb0 = fig.colorbar(cp0, ax=axs[:2], **cb_kwargs)
cb2 = fig.colorbar(cp2, ax=axs[-1], **cb_kwargs)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_hw_box(ax, **box_kwargs)

plt.show()