# Look at climate change in CESM

## Imports

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import cartopy.util
import copy

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Functions

In [None]:
def plot_hov(ax, data, amp, label=None):
    """Plot hovmoller of longitude vs. year"""

    # kwargs = dict(levels=src.utils.make_cb_range(3, 0.3), cmap="cmo.balance", extend="both")
    plot_data = ax.contourf(
        data.longitude,
        data.year,
        data.T,
        cmap="cmo.balance",
        extend="both",
        levels=src.utils.make_cb_range(amp, amp / 10),
    )
    cb = fig.colorbar(
        plot_data, orientation="horizontal", ticks=[-amp, 0, amp], label=label
    )

    ## label
    kwargs = dict(ls="--", c="w", lw=0.8)
    for ax in axs:
        ax.set_xlabel("Longitude")
        ax.set_xticks([190, 240])
        ax.set_yticks([])
        ax.axvline(190, **kwargs)
        ax.axvline(240, **kwargs)
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position("top")

    return


def plot_hov2(ax, data, amp, label=None):
    """Plot hovmoller of longitude vs. year"""

    # kwargs = dict(levels=src.utils.make_cb_range(3, 0.3), cmap="cmo.balance", extend="both")
    plot_data = ax.contourf(
        data.month,
        data.year,
        data.T,
        cmap="cmo.balance",
        extend="max",
        levels=src.utils.make_cb_range(amp, amp / 10),
    )
    cb = fig.colorbar(
        plot_data,
        orientation="horizontal",
        ticks=[-amp, 0, amp],
        label=label,
        # plot_data, orientation="horizontal", ticks=[], label=None
    )

    ## label
    kwargs = dict(ls="--", c="w", lw=0.8)
    for ax in axs:
        # ax.set_xlabel("Month")
        # ax.set_xticks([1, 12])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position("top")

    return


def get_rolling_var(data, n=10):
    """
    Get variance, computing over time and ensemble member. To increase
    sample size for variance estimate, compute over time window of 2n+1
    years, centered at given year.
    """

    return src.utils.get_rolling_fn_bymonth(data, fn=np.var, n=n)


def get_ml_avg(data, Hm, delta=5, H0=None):
    """func to average data from surface to Hm + delta"""

    ## interpolate MLD onto data grid
    Hm_ = Hm.rename({"longitude": "lon"}).interp({"lon": data.lon})

    ## tweak integration bounds
    if H0 is None:
        Hm_ = Hm_ + delta

    else:
        Hm_ = H0 * xr.ones_like(Hm_)

    ## average over everything above the mixed layer
    return data.where(data.z_t <= Hm_).mean("z_t")


def get_ml_avg_wrapper(data, Hm, delta=5):
    """wrapper function to format data for plotting"""

    ## first, compute mixed layer average
    ml_avg = get_ml_avg(data=data, Hm=Hm, delta=delta)

    ## rename coord and tranpose
    return ml_avg.rename({"lon": "longitude"}).transpose("month", ...)


def plot_mld_bounds(ax, clim, m):
    """Plot MLD climatology and ± bounds"""

    ## clim
    ax.plot(clim.longitude, clim, c="k")

    ## El Niño
    ax.plot(clim.longitude, clim + m, c="r")

    ## La Niña
    ax.plot(clim.longitude, clim - m, c="b")

    return


def get_wT(w, T):
    """function to get vertical flux (handles diff. w/T grids)"""

    ## rename w grid
    w_ = copy.deepcopy(w).rename({"z_w_top": "z_t"})
    w_ = w_.assign_coords({"z_t": T.z_t})

    return w_ * T


def get_wdTdz(w, T):
    """function to get vertical flux (handles diff. w/T grids)"""

    ## rename w grid
    w_ = copy.deepcopy(w).rename({"z_w_top": "z_t"})
    w_ = w_.assign_coords({"z_t": T.z_t})

    ## get dTdz (convert from 1/cm to 1/m)
    dTdz = T.differentiate("z_t")

    return w_ * dTdz


def get_udTdx(u, T):
    """zonal advection"""

    ## get grid spacing
    dlon_deg = T.lon.values[1] - T.lon.values[0]
    lat_deg = 0.0

    ## get grid spacing
    dx_m = get_dx(lat_deg=lat_deg, dlon_deg=dlon_deg)

    ## differentiate
    u_dfdx_ = u * T.differentiate("lon") * 1 / dx_m

    return u_dfdx_


def get_u_adv(u, T):
    """zonal advection"""

    ## get grid spacing
    dlon_deg = T.lon.values[1] - T.lon.values[0]
    lat_deg = 0.0

    ## get grid spacing
    dx_m = get_dx(lat_deg=lat_deg, dlon_deg=dlon_deg)

    ## differentiate and convert units to K/yr
    mo_per_yr = 12
    u_dfdx_ = u * T.differentiate("lon") * 1 / dx_m * mo_per_yr

    return -u_dfdx_


def recon_clim(data, components, varname="sst"):
    """reconstruct climatology for data"""

    ## get climatolgoy in PC space
    monthly_clim = data.groupby("time.month").mean()

    ## function to compute equatorial mean
    equatorial_mean = lambda x: x.sel(latitude=slice(-2, 2)).mean("latitude")

    ## reconstruct
    recon = src.utils.reconstruct_fn(
        components[varname], monthly_clim[varname], fn=equatorial_mean
    )

    ## fill zero values with NaN
    recon.values[recon.values == 0] = np.nan

    return recon


def get_monthly_eli(t_bnds):

    ## get eli for period
    eli_ = eli_forced.isel(time=slice(*t_bnds)).groupby("time.month").mean()

    return eli_


def get_monthly_eli_std(t_bnds):

    ## get eli for period
    eli_ = (
        eli_anom.isel(time=slice(*t_bnds)).groupby("time.month").std(["time", "member"])
    )

    return eli_


def plot_cyclic(ax, data, sigma=None, **kwargs):
    """plot data on hovmoller with cyclic dependence on month"""

    ## add cyclic point
    data_cyclic, dim_cyclic = cartopy.util.add_cyclic_point(data, data.month, axis=0)

    ## plot data
    ax.plot(data_cyclic, dim_cyclic, **kwargs)

    ## plot bounds if they exist
    if sigma is not None:
        sigma_cyclic, _ = cartopy.util.add_cyclic_point(sigma, data.month, axis=0)

        ## plot data
        ax.plot(data_cyclic + sigma_cyclic, dim_cyclic, **kwargs, lw=0.8)
        ax.plot(data_cyclic - sigma_cyclic, dim_cyclic, **kwargs, lw=0.8)

    return


def plot_cyclic_quantiles(ax, data, quantiles=[0.5, 0.15, 0.85], **kwargs):
    """plot data on hovmoller with cyclic dependence on month"""

    ## compute quantiles
    q = data.groupby("time.month").quantile(q=quantiles, dim=["time", "member"])
    # q = q.rename({"quantile":"q"})

    ## convert to numpy
    month = q.month.values
    q = q.transpose("quantile", "month").values

    ## add cyclic point
    q_cyclic, dim_cyclic = cartopy.util.add_cyclic_point(q, month, axis=1)

    ## plot median
    ax.plot(q_cyclic[0], dim_cyclic, **kwargs)

    ## plot other quantiles
    if len(quantiles) > 1:
        for j in range(1, len(quantiles)):
            ax.plot(q_cyclic[j], dim_cyclic, lw=0.8, **kwargs)

    return


def format_subsurf_axs(axs):
    """add labels/formatting to 3-panel axs"""

    ## loop thru axs
    for ax in axs:
        ax.set_ylim(ax.get_ylim()[::-1])
        ax.set_xlim([140, 281])
        ax.set_yticks([])
        ax.set_xlabel("Longitude")
    axs[0].set_yticks([300, 150, 0])
    axs[0].set_ylabel("Depth (m)")

    return


def format_hov_axs(axs):
    """put hovmoller axs in standardized format"""

    ## set fontsize
    font_kwargs = dict(size=8)
    axs[0].set_ylabel("Month", **font_kwargs)
    axs[0].set_title("Early", **font_kwargs)
    axs[1].set_title("Late", **font_kwargs)
    axs[2].set_title("Difference (x2)", **font_kwargs)

    axs[1].set_yticks([])
    axs[2].set_yticks([])
    axs[0].set_yticks([1, 5, 9, 12], labels=["Jan", "May", "Sep", "Dec"])

    for ax in axs:
        # ax.set_xlim([190, None])
        ax.set_xticks([190, 240])
        ax.axvline(240, ls="--", c="w", lw=1)
        ax.axvline(190, ls="--", c="w", lw=1)

    return


def get_w_int(w):
    """get vertical velocity integrated over top 200 m"""
    return w.sel(z_t=slice(None, 200)).mean("z_t")


def get_dTdz(Tsub):
    """get vertical velocity integrated over top 200 m"""
    T_surf = Tsub.sel(z_t=0, method="nearest").squeeze(drop=True)
    T_subsurf = Tsub.sel(z_t=200, method="nearest").squeeze(drop=True)

    return T_surf - T_subsurf


def get_diags(data):
    """get diagnostics"""
    diags = xr.merge(
        [get_dTdz(data["T"]).rename("dTdz"), get_w_int(data["w"]).rename("w_int")]
    )
    return diags


def get_dTdz_sub(Tsub, mld, delta=25):
    """get velocity at base of mixed layer"""

    ## get temperature difference
    dT = src.utils.get_dT_sub(Tsub=Tsub, Hm=mld, delta=delta)

    ## get gradient
    dTdz = -dT / mld

    return dTdz


def get_nino34(data):
    return data.sel(longitude=slice(190, 240)).mean("longitude")


def get_w_int_idx(data):
    """get nino3.4 w-int"""
    return get_nino34(get_w_int(data))


def get_dTdz_idx(data):
    """get nino3.4 w-int"""
    return get_nino34(get_dTdz(data))


def reconstruct_fn(components, scores, fn):

    ## reconstruct
    recon = (fn(components) * scores).sum("mode")

    return recon


def eq_avg(x):
    return x.sel(latitude=slice(-5, 5), longitude=slice(125, 279)).mean("latitude")


def get_var(data, year_center, n=15, fn=None):
    """
    Reconstruct variance for given time period.
    Estimate variance based on window centered on `year_center'.
    Window encompasses all samples within 'n' years of year_center.
    """

    ## get indices of samples in window
    in_window = np.abs(data.time.dt.year - year_center) <= n

    ## get variance of samples in window
    kwargs = dict(data=data.isel(time=in_window), fn=fn)
    data_var = src.utils.reconstruct_var_wrapper(**kwargs)

    return data_var


def get_var_bymonth(data, year_center, n=15, fn=None):
    """Get variance by month"""

    ## function to apply to each month
    kwargs = dict(year_center=year_center, n=n, fn=fn)
    get_var_ = lambda x: get_var(x, **kwargs)
    return data.groupby("time.month").map(get_var_)


def get_var_for_periods(data, periods, n=15, fn=None, by_month=True):
    """get variance for specified periods"""

    ## get arguments for variance reconstruction
    kwargs = dict(data=data, n=n, fn=fn)

    ## get variance recon func
    get_var_fn = get_var_bymonth if by_month else get_var
    var_by_period = [get_var_fn(year_center=y, **kwargs) for y in tqdm.tqdm(periods)]

    ## get dimension to represent period_centers
    period_dim = pd.Index(periods, name="period")

    ## put in array
    return xr.concat(var_by_period, dim=period_dim)


def avg_mon_range(data, m0, m1):
    """average data each year over specified month range"""

    ## find indices for month range
    month = data.time.dt.month
    is_season = (month >= m0) & (month <= m1)

    ## get avg avg
    data_season = data.isel(time=is_season).groupby("time.year").mean()

    return data_season.rename({"year": "time"})


def get_mam(data):
    """subset for MAM months"""

    return avg_mon_range(data, m0=3, m1=5)


def set_ylims(axs):
    lims = np.stack([ax.get_ylim() for ax in axs.flatten()], axis=0)

    lb = lims[:, 0].min()
    ub = lims[:, 1].max()

    for ax in axs:
        ax.set_ylim([lb, ub])

    return


def set_xlims(axs):
    lims = np.stack([ax.get_xlim() for ax in axs.flatten()], axis=0)

    lb = lims[:, 0].min()
    ub = lims[:, 1].max()

    for ax in axs:
        ax.set_xlim([lb, ub])

    return


def get_dy(dlat_deg):
    """get spacing between latitudes in meters"""

    ## convert from degrees to radians
    dlat_rad = dlat / 180.0 * np.pi

    ## multiply by radius of earth
    R = 6.378e8  # earth radius (centimeters)
    dlat_meters = R * dlat_rad

    return dlat_meters


def get_dx(lat_deg, dlon_deg):
    """get spacing between longitudes in meters"""

    ## convert from degrees to radians
    dlon_rad = dlon_deg / 180.0 * np.pi
    lat_rad = lat_deg / 180 * np.pi

    ## multiply by radius of earth
    R = 6.378e6  # earth radius (meters)
    dlon_meters = R * np.cos(lat_rad) * dlon_rad

    return dlon_meters


def get_dydx(data):
    """get dy and dx for given data"""

    ## empty array to hold result
    grid = xr.Dataset(
        coords=dict(
            latitude=data["latitude"].values,
            longitude=data["longitude"].values,
        ),
    )

    grid["dlat"] = grid["latitude"].values[1] - grid["latitude"].values[0]
    grid["dlon"] = grid["longitude"].values[1] - grid["longitude"].values[0]

    grid["dlat_rad"] = grid["dlat"] / 180.0 * np.pi
    grid["dlon_rad"] = grid["dlon"] / 180.0 * np.pi
    R = 6.378e8  # earth radius (centimeters)

    ## height of gridcell doesn't depend on longitude
    grid["dy"] = R * grid["dlat_rad"]  # unit: meters
    grid["dy"] = grid["dy"] * xr.ones_like(grid["latitude"])

    ## Compute width of gridcell
    grid["lat_rad"] = grid["latitude"] / 180 * np.pi  # latitude in radians
    grid["dx"] = R * np.cos(grid["lat_rad"]) * grid["dlon_rad"]

    return grid[["dy", "dx"]]


def u_dfdx(u, f):
    """zonal advection"""

    ## get grid spacing
    dx_cm = get_dydx(f)["dx"]
    sec_per_year = 86400 * 365
    factor = sec_per_year / dx_cm

    u_dfdx_ = u * f.differentiate("longitude") * factor

    return u_dfdx_


def v_dfdy(v, f):
    """meridional advection"""

    ## get grid spacing
    dy_cm = get_dydx(f)["dy"]
    sec_per_year = 86400 * 365
    factor = sec_per_year / dy_cm

    v_dfdy_ = v * f.differentiate("latitude") * factor

    return v_dfdy_


def get_adv(uv, T):
    """
    Compute T tendency from horizontal advection.
    Equal to:
        (u,v) dot grad(-T)
    """

    ## compute grad T
    u_dTdx = u_dfdx(u=uv["uvel"], f=T)
    v_dTdy = v_dfdy(v=uv["vvel"], f=T)

    ## get

    return -(u_dTdx + v_dTdy)

## Change in $T$, $h$

### Load data
And compute variance/skewness

##### Load ELI

In [None]:
## load ELI data
eli = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/eli.nc"))

## get forced/anomalous component
eli_forced, eli_anom = src.utils.separate_forced(eli)

##### Load $T$, $h$, and climate mode indices

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## rename indices for convenience
Th = Th.rename(
    {
        "north_tropical_atlantic": "natl",
        "atlantic_nino": "nino_atl",
        "tropical_indian_ocean": "iobm",
        "indian_ocean_dipole": "iod",
        "north_pacific_meridional_mode": "npmm",
        "south_pacific_meridional_mode": "spmm",
    }
)

## add ELI data
# Th = xr.merge([Th, eli_anom["eli_15"]])

#### Compute relative SST for $T$ indices

In [None]:
## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")

## compute relative sst
for n in ["T_3", "T_34", "T_4"]:
    Th[f"{n}_rel"] = Th_total[n] - trop_sst["trop_sst_10"]

#### Compute stats

In [None]:
## get rolling variance, by month
Th_var = get_rolling_var(Th, n=15)
Th_var_bymonth = src.utils.unstack_month_and_year(Th_var)

## get rolling skew, by month
Th_skew = src.utils.get_rolling_fn_bymonth(Th, fn=scipy.stats.skew, n=15)
Th_skew_bymonth = src.utils.unstack_month_and_year(Th_skew)

## Get % increase in variance
baseline = Th_var_bymonth.isel(year=slice(None, 30)).mean("year")
Th_var_bymonth_pct = 100 * (Th_var_bymonth - baseline) / baseline

### Variance

Hövmöller

In [None]:
## setup plot
fig, axs = plt.subplots(
    1, len(list(Th)), figsize=(0.9 * len(list(Th)), 2), layout="constrained"
)

## plot T data
for ax, n in zip(axs, list(Th)):
    plot_hov2(ax, Th_var_bymonth_pct[n].T, amp=100)
    ax.set_title(n, fontsize=9)
    ax.axvline(8, ls="--", c="w", lw=1)


plt.show()

Compare August and December variance over time

In [None]:
fig, ax = plt.subplots(figsize=(2.5, 2))

## plot data
ax.plot(
    Th_var_bymonth.year, Th_var_bymonth["T_34"].sel(month=8), label="Aug", c="k", ls="-"
)
ax.plot(
    Th_var_bymonth.year,
    Th_var_bymonth["T_34"].sel(month=11),
    label="Dec",
    c="gray",
    ls="--",
)

## label and style
ax.set_title(r"$\sigma^2\left(\text{Niño 3.4}\right)$")
ax.set_ylim([0, None])
ax.legend(prop=dict(size=8))
ax.set_ylabel(r"$^{\circ}\text{C}^2$")

plt.show()

### Skewness

In [None]:
## setup plot
fig, axs = plt.subplots(1, 5, figsize=(7.2, 4), layout="constrained")

## plot T data
for i, n in enumerate(["T_4", "T_34", "T_3", "h", "h_w"]):
    plot_hov2(axs[i], Th_skew_bymonth[n].T, amp=1.5)
    axs[i].set_title(n, fontsize=9)
    axs[i].axvline(5, ls="--", c="w", lw=1, alpha=0.5)

plt.show()

### ELI

In [None]:
## get mean by month (smoothed)
eli_mean_bymonth = src.utils.unstack_month_and_year(eli_forced)
eli_mean_bymonth = src.utils.get_rolling_avg(eli_mean_bymonth, n=15, dim="year")
delta_eli = eli_mean_bymonth - eli_mean_bymonth.isel(year=0)

## get rolling variance, by month
eli_var = get_rolling_var(eli_anom, n=15)
eli_var_bymonth = src.utils.unstack_month_and_year(eli_var)

## Get % increase in variance
baseline = eli_var_bymonth.isel(year=slice(None, 30)).mean("year")
eli_var_bymonth = 100 * (eli_var_bymonth - baseline) / baseline

## get rolling skew, by month
eli_skew = src.utils.get_rolling_fn_bymonth(eli_anom, fn=scipy.stats.skew, n=15)
eli_skew_bymonth = src.utils.unstack_month_and_year(eli_skew)

In [None]:
## specify which ELI index to plot
eli_idx = "eli_05"

## setup plot
fig, axs = plt.subplots(1, 3, figsize=(3.16, 3), layout="constrained")

## plot mean
plot_hov2(axs[0], delta_eli[eli_idx].T, amp=15)
plot_hov2(axs[1], eli_var_bymonth[eli_idx].T, amp=300)
plot_hov2(axs[2], eli_skew_bymonth[eli_idx].T, amp=3)

## label
kwargs = dict(size=8)
axs[0].set_title(r"$\Delta$ mean", **kwargs)
axs[1].set_title(r"$\Delta$ variance (%)", **kwargs)
axs[2].set_title("skew", **kwargs)

plt.show()

### stats

Scatter ELI vs. Niño indices

In [None]:
## select data
data = xr.merge(
    [Th[["T_3_rel", "T_34_rel", "T_4_rel", "T_4", "T_3", "T_34"]], eli_anom["eli_05"]]
)
data_early = data.isel(time=slice(None, 360))
data_late = data.isel(time=slice(-360, None))

## get corresponding data for total
total = xr.merge([Th_total, eli])
total_early = total.sel(time=data_early.time)
total_late = total.sel(time=data_late.time)

In [None]:
## specify which T-variable to plot
T_var = "T_3_rel"

## specify month
month = 4

## function to select month
sel_mon = lambda x: x.isel(time=x.time.dt.month == month)

## plot kwargs
kwargs = dict(s=2, alpha=0.8)

fig, axs = plt.subplots(2, 1, figsize=(3, 5), layout="constrained")

axs[0].scatter(
    sel_mon(data_early[T_var]),
    sel_mon(data_early["eli_05"]),
    **kwargs,
)

axs[1].scatter(
    sel_mon(data_late[T_var]),
    sel_mon(data_late["eli_05"]),
    **kwargs,
)

## format
axs[0].set_xticks([])
axs[1].set_xlabel(r"$T$")
set_ylims(axs)
set_xlims(axs)
for ax in axs:
    kwargs = dict(c="k", lw=0.8, ls="--")
    ax.axhline(0, **kwargs)
    ax.axvline(0, **kwargs)
    ax.set_ylabel(r"ELI")

plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(5.5, 2.5), layout="constrained")

## kwargs
sc_kwargs = dict(s=2, alpha=0.6)

## scatter all data
axs[0].scatter(
    sel_mon(total_early["eli_05"]),
    sel_mon(total_early["T_4"] - total_early["T_3"]),
    **sc_kwargs,
)

axs[1].scatter(
    sel_mon(total_late["eli_05"]),
    sel_mon(total_late["T_4"] - total_late["T_3"]),
    **sc_kwargs,
)

set_ylims(axs)
set_xlims(axs)

for ax in axs:
    kwargs = dict(c="k", lw=0.8, ls="--")
    ax.axhline(0, **kwargs)
    ax.set_xlabel("ELI")
    ax.set_xticks([190, 240])
    for t in ax.get_xticks():
        ax.scatter(t, 0, c="k", s=100, marker="|", linewidths=3)

axs[0].set_ylabel(r"(Niño 4) – (Niño 3)")
axs[1].set_yticks([])

plt.show()

#### Look at change relative to median

In [None]:
## specify plot variable
plot_var = "T_34"

if "eli" in plot_var:
    edges = np.arange(-64, 68, 4)
    tick_kwargs = dict(ticks=[-50, 0, 50], labels=["-50", "median", "+50"])
    # edges = np.arange(-25,27.5, 2.5)

elif "h" in plot_var:
    edges = np.arange(-10e-2, 11e-2, 1e-2)

else:
    edges = np.arange(-4, 4.25, 0.25)
    tick_kwargs = dict(ticks=[-3, 0, 3], labels=["-3", "median", "+3"])


## get data
y0 = sel_mon(data_early)[plot_var].values.flatten()
y1 = sel_mon(data_late)[plot_var].values.flatten()

## compute pdfs
pdf0, _ = src.utils.get_empirical_pdf(y0, edges=edges)
pdf1, _ = src.utils.get_empirical_pdf(y1, edges=edges)

## compute skewness
s0 = scipy.stats.skew(y0)
s1 = scipy.stats.skew(y1)

## make plot
fig, ax = plt.subplots(figsize=(4, 3), layout="constrained")
ax.stairs(pdf0, edges - np.median(y0), label=f"early: skew = {s0:.2f}")
ax.stairs(pdf1, edges - np.median(y1), label=f"late: skew = {s1:.2f}")
ax.axvline(0, c="k", lw=1)
ax.set_xticks(**tick_kwargs)

ax.legend(prop=dict(size=8))

plt.show()

In [None]:
## get data for histogram
plot_var = "eli_05"

## get data
y0 = sel_mon(total_early)[plot_var].values.flatten()
y1 = sel_mon(total_late)[plot_var].values.flatten()

## compute pdfs
edges = np.arange(140, 280, 5)
pdf0, _ = src.utils.get_empirical_pdf(y0, edges=edges)
pdf1, _ = src.utils.get_empirical_pdf(y1, edges=edges)

## make plot
fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
ax.stairs(pdf0, edges, label=f"early")
ax.stairs(pdf1, edges, label=f"late")
ax.set_xticks(ticks=[190, 240])
ax.set_ylim([-2e-3, None])
ax.set_xlim([None, 255])
ax.set_yticks([])
ax.set_ylabel("Prob. density")
ax.set_xlabel("ELI (longitude)")

for t in ax.get_xticks():
    ax.scatter(t, 0, marker="|", s=100, linewidths=3, c="k")


ax.legend(prop=dict(size=8))

plt.show()

## Change in spatial patterns

In [None]:
## load spatial data
CONS_DIR = pathlib.Path(DATA_FP, "cesm", "consolidated")
forced = xr.open_dataset(CONS_DIR / "forced.nc")
anom = xr.open_dataset(CONS_DIR / "anom.nc")

### Change in mean

In [None]:
def merimean_safe(data):
    """get meridional mean"""
    if "latitude" in data.coords:
        return data.sel(latitude=slice(-2, 2)).mean("latitude")

    else:
        return data

#### Computation

Get climatologies

In [None]:
## split into early/late periods
t_early = dict(time=slice("1851", "1880"))
t_late = dict(time=slice("2071", "2100"))

## split surface data
anom_early = anom.sel(t_early)
anom_late = anom.sel(t_late)

## climatolgoies
clim_early = src.utils.reconstruct_clim(forced.sel(t_early), fn=merimean_safe)
clim_late = src.utils.reconstruct_clim(forced.sel(t_late), fn=merimean_safe)

## get difference
clim_diff = clim_late - clim_early

## compute ELI
eli_early = eli.sel(t_early)
eli_late = eli.sel(t_late)

#### Plot

In [None]:
## specify which ELI index to plot
ELI = "eli_05"

## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.thermal", levels=np.arange(23, 32), extend="both")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], clim_early["sst"], **kwargs)

## plot late
kwargs["levels"] = kwargs["levels"] + 3
cp1 = src.utils.plot_cycle_hov(axs[1], clim_late["sst"], **kwargs)

## plot bias
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    clim_diff["sst"],
    cmap="cmo.amp",
    levels=np.arange(2.6, 5.2, 0.2),
    extend="both",
)

## plot eli
plot_cyclic_quantiles(ax=axs[0], data=eli_early[ELI], c="w")
plot_cyclic_quantiles(ax=axs[1], data=eli_late[ELI], c="k")
plot_cyclic_quantiles(ax=axs[2], data=eli_early[ELI], c="w")
plot_cyclic_quantiles(ax=axs[2], data=eli_late[ELI], c="k")
# plot_cyclic(ax=axs[0], data=eli_early[ELI], sigma=eli_std_early[ELI], c="w")
# plot_cyclic(ax=axs[1], data=eli_late[ELI], sigma=eli_std_late[ELI], c="w", ls="--")
# plot_cyclic(ax=axs[2], data=eli_early[ELI], c="w")
# plot_cyclic(ax=axs[2], data=eli_late[ELI], c="w", ls="--")

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

cb0 = fig.colorbar(cp0, ax=axs[0], ticks=[23, 27, 31], label=r"$^{\circ}C$")
cb1 = fig.colorbar(cp1, ax=axs[1], ticks=[26, 30, 34], label=r"$^{\circ}C$")
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[3, 4, 5], label=r"$^{\circ}C$")

plt.show()

#### Mixed layer depth

In [None]:
# ## compute clims
# kwargs = dict(components=components, varname="mld")
# mld_clim_early = recon_clim(forced.isel(time=slice(None, 360)), **kwargs)
# mld_clim_late = recon_clim(forced.isel(time=slice(-360, None)), **kwargs)

# ## get difference
# mld_clim_diff = mld_clim_late - mld_clim_early

## get fractional chnage
mld_clim_diff_pct1 = 100 * (clim_diff / clim_early)["mld"]
mld_clim_diff_pct2 = 100 * (-clim_diff / (clim_early + clim_diff))["mld"]

In [None]:
## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.amp", levels=np.arange(0, 65, 5), extend="max")
cb_kwargs = dict(ticks=[0, 30, 60], label=r"m")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], clim_early["mld"], **kwargs)
cb0 = fig.colorbar(cp0, ax=axs[0], **cb_kwargs)

## plot late
kwargs["levels"] = kwargs["levels"] + 3
cp1 = src.utils.plot_cycle_hov(axs[1], clim_late["mld"], **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **cb_kwargs)

## plot % pct change
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    mld_clim_diff_pct2,
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(80, 8),
    extend="both",
)
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-80, 0, 80], label="% change")

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

plt.show()

#### zonal velocity

In [None]:
## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.balance", levels=src.utils.make_cb_range(80, 8), extend="both")

cb_kwargs = dict(ticks=[-80, 0, 80], label=r"m")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], clim_early["uvel"], **kwargs)
cb0 = fig.colorbar(cp0, ax=axs[0], **cb_kwargs)

## plot late
kwargs["levels"] = kwargs["levels"] + 3
cp1 = src.utils.plot_cycle_hov(axs[1], clim_late["uvel"], **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **cb_kwargs)

## plot % pct change
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    clim_diff["uvel"],
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(40, 4),
    extend="both",
)
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-40, 0, 40], label="Difference")

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])
for ax in axs:
    ax.axhline(7, ls="--", c="k", lw=1)

plt.show()

#### spatial pattern for precip/relative SST

In [None]:
def recon_sst_pr_clim(data, trop_sst):
    """reconstruct sst and precip climatology,
    and compute relative SST"""

    ## specify vars to reconstruct
    recon_vars = ["sst", "pr", "sst_comp", "pr_comp"]

    ## reconstruct
    clim = src.utils.reconstruct_clim(data[recon_vars])

    ## compute mean tropical SST
    trop_sst_clim = trop_sst.groupby("time.month").mean(["time", "member"])

    ## get relative SST
    clim["sst_rel"] = clim["sst"] - trop_sst_clim

    return clim

In [None]:
## reconstruct
spatial_clim_early = recon_sst_pr_clim(
    forced.sel(t_early),
    trop_sst["trop_sst_05"].sel(t_early),
)
spatial_clim_late = recon_sst_pr_clim(
    forced.sel(t_late),
    trop_sst["trop_sst_05"].sel(t_late),
)

In [None]:
def plot_sst_rel(ax, data):
    """plot relative sst on ax object"""

    cp = ax.contourf(
        data.longitude,
        data.latitude,
        data["sst_rel"],
        cmap="cmo.balance",
        transform=ccrs.PlateCarree(),
        levels=src.utils.make_cb_range(10, 1),
        extend="both",
    )

    return cp


def plot_contour(ax, data, lev, lw, c="k"):
    """plot relative sst on ax object"""

    cp = ax.contour(
        data.longitude,
        data.latitude,
        data["sst_rel"],
        transform=ccrs.PlateCarree(),
        levels=[lev],
        linewidths=lw,
        colors=c,
    )

    return cp


def plot_pr(ax, data):
    """plot precip on ax object"""

    ## convert from kg / m / s to mm/day
    ## 1e-3 (m3 / kg) * 1e3 (mm / m) * 8.6e4 (s / day)
    factor = 8.6e4

    cp = ax.contourf(
        data.longitude,
        data.latitude,
        factor * data["pr"],
        cmap="cmo.rain",
        transform=ccrs.PlateCarree(),
        levels=np.arange(0, 28, 4),
        extend="max",
    )

    return cp

In [None]:
## specify which month to look at
sel = lambda x: x.sel(month=4)

fig = plt.figure(figsize=(12, 3.5), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## plot early
cp0 = plot_sst_rel(axs[0, 0], sel(spatial_clim_early))
cp1 = plot_pr(axs[0, 1], sel(spatial_clim_early))

## plot late
plot_sst_rel(axs[1, 0], sel(spatial_clim_late))
plot_pr(axs[1, 1], sel(spatial_clim_late))

## plot diff
cp0_ = axs[2, 0].contourf(
    spatial_clim_early.longitude,
    spatial_clim_early.latitude,
    sel((spatial_clim_late - spatial_clim_early)["sst_rel"]),
    cmap="cmo.balance",
    transform=ccrs.PlateCarree(),
    levels=src.utils.make_cb_range(2, 0.2),
    extend="both",
)

cp1_ = axs[2, 1].contourf(
    spatial_clim_early.longitude,
    spatial_clim_early.latitude,
    8.6e4 * sel((spatial_clim_late - spatial_clim_early)["pr"]),
    cmap="cmo.balance_r",
    transform=ccrs.PlateCarree(),
    levels=src.utils.make_cb_range(12, 2),
    extend="both",
)

## plot convective bounds
for ax in axs[0, :]:
    plot_contour(ax, sel(spatial_clim_early), lev=0, lw=2)
    plot_contour(ax, sel(spatial_clim_early), lev=-1, lw=1)

for ax in axs[1, :]:
    plot_contour(ax, sel(spatial_clim_late), lev=0, lw=2)
    plot_contour(ax, sel(spatial_clim_late), lev=-1, lw=1)

for ax in axs[2]:
    plot_contour(ax, sel(spatial_clim_early), lev=-1, lw=1, c="gray")
    plot_contour(ax, sel(spatial_clim_early), lev=-1, lw=1)

for ax in axs.flatten():
    ax.axhline(0, c="magenta", lw=1, alpha=0.5)

## colorbars
fig.colorbar(cp0, ax=axs[:2, 0], ticks=[-10, 0, 10])
fig.colorbar(cp1, ax=axs[:2, 1], ticks=[0, 12, 24])
fig.colorbar(cp0_, ax=axs[2, 0], ticks=[-2, 0, 2])
fig.colorbar(cp1_, ax=axs[2, 1], ticks=[-12, 0, 12])

plt.show()

### Subsurface

### Mixed layer depth

In [None]:
# def get_eq_mld(data):
#     """Get equatorial mixed layer depth"""

#     data_ = data.sel(latitude=slice(-1.5, 1.5)).mean("latitude")

#     return data_.sel(longitude=slice(140, 280))


# def recon_eq_mld(t_bnds):
#     """reconstruct equatorial MLD"""
#     return src.utils.reconstruct_fn(
#         scores=forced["mld"].isel(time=slice(*t_bnds)).groupby("time.month").mean(),
#         components=components["mld"],
#         fn=get_eq_mld,
#     )


# eq_mld_early = recon_eq_mld((None, 360))
# eq_mld_late = recon_eq_mld((-360, None))

Plot vertical velocity and temperature grad

In [None]:
## specify which period/month to plot
# sel = lambda x: x.mean("month")
sel = lambda x: x.sel(month=7)

fig, axs = plt.subplots(1, 3, figsize=(8, 2.5), layout="constrained")

for ax, clim in zip(axs[:2], [clim_early, clim_late]):

    ## temperature
    ax.contourf(
        clim["T"].longitude,
        clim["T"].z_t,
        sel(clim["T"]),
        cmap="cmo.thermal",
        levels=np.arange(10, 34, 2),
        extend="both",
    )

    ## vertical velocity
    ax.contour(
        clim["w"].longitude,
        clim["w"].z_t,
        sel(clim["w"]),
        colors="k",
        levels=src.utils.make_cb_range(70, 7),
        extend="both",
        linewidths=1,
    )


## plot difference
diff = clim_late - clim_early

## temperature
axs[2].contourf(
    clim["T"].longitude,
    clim["T"].z_t,
    sel(diff["T"]),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(5, 0.5),
    extend="both",
)

## vertical velocity
axs[2].contour(
    clim["w"].longitude,
    clim["w"].z_t,
    sel(diff["w"]),
    colors="k",
    levels=src.utils.make_cb_range(35, 3.5),
    extend="both",
    linewidths=1,
)

## plot MLD
lon = clim_early.longitude
axs[0].plot(lon, sel(clim_early["mld"]), c="w")
axs[1].plot(lon, sel(clim_late["mld"]), c="w", ls="--")
axs[2].plot(lon, sel(clim_early["mld"]), c="w")
axs[2].plot(lon, sel(clim_late["mld"]), c="w", ls="--")

## label
format_subsurf_axs(axs)


plt.show()

In [None]:
## specify which period/month to plot
sel = lambda x: x.sel(month=12)


## get aspect ratio (for scaling arrows)
dz = 300  # units: m
dx = get_dx(lat_deg=0, dlon_deg=150)
aspect = dx / dz

fig, axs = plt.subplots(1, 3, figsize=(12, 3.5), layout="constrained")

for ax, clim in zip(axs[:2], [clim_early, clim_late]):

    ## temperature
    ax.contourf(
        clim["T"].longitude,
        clim["T"].z_t,
        sel(clim["T"]),
        cmap="cmo.thermal",
        levels=np.arange(10, 34, 2),
        extend="both",
    )

    ## u and w
    ax.quiver(
        clim.longitude.values[::4],
        clim.z_t.values[::2],
        sel(clim.u).values[::2, ::4],
        sel(clim.w).values[::2, ::4] * aspect,
        pivot="middle",
        alpha=0.7,
        scale=3e7,
    )

## temperature
axs[2].contourf(
    clim["T"].longitude,
    clim["T"].z_t,
    sel(diff["T"]),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(5, 0.5),
    extend="both",
)

## vertical velocity
## u and w
axs[2].quiver(
    diff.longitude.values[::4],
    diff.z_t.values[::2],
    sel(diff.u).values[::2, ::4],
    sel(diff.w).values[::2, ::4] * aspect,
    pivot="middle",
    alpha=0.7,
    scale=1.5e7,
)

## plot MLD
lon = clim_early.longitude
axs[0].plot(lon, sel(clim_early["mld"]), c="w")
axs[1].plot(lon, sel(clim_late["mld"]), c="w", ls="--")
axs[2].plot(lon, sel(clim_early["mld"]), c="w")
axs[2].plot(lon, sel(clim_late["mld"]), c="w", ls="--")

## label
format_subsurf_axs(axs)


plt.show()

#### Hovmoller

In [None]:
## get diagnostics
diags_early = get_diags(clim_early)
diags_late = get_diags(clim_late)
diags_diff = diags_late - diags_early

##### Stratification

In [None]:
## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.amp", levels=np.arange(10, 19, 1), extend="both")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], diags_early["dTdz"], **kwargs)

# ## plot late
cp1 = src.utils.plot_cycle_hov(axs[1], diags_late["dTdz"], **kwargs)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    diags_diff["dTdz"],
    cmap="cmo.amp",
    levels=np.arange(0, 6.6, 0.6),
    extend="both",
)

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

cb0 = fig.colorbar(cp0, ax=axs[0], ticks=[10, 14, 18], label=r"$^{\circ}C$")
cb1 = fig.colorbar(cp1, ax=axs[1], ticks=[10, 14, 18], label=r"$^{\circ}C$")
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[0, 3, 6], label=r"$^{\circ}C$")

plt.show()

##### Temperature gradient

In [None]:
## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.balance", levels=src.utils.make_cb_range(40, 4), extend="both")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], diags_early["w_int"], **kwargs)

# ## plot late
cp1 = src.utils.plot_cycle_hov(axs[1], diags_late["w_int"], **kwargs)

# ## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    diags_diff["w_int"],
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(20, 4),
    extend="both",
)

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

cb0 = fig.colorbar(cp0, ax=axs[0], ticks=[-40, 0, 40], label=r"$m/mo$")
cb1 = fig.colorbar(cp1, ax=axs[1], ticks=[-40, 0, 40], label=r"$m/mo$")
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-20, 0, 20], label=r"$m/mo$")

plt.show()

##### vertical temperature flux
\begin{align}
    \overline{w}/\overline{H}
\end{align}

##### Functions to use in computation

In [None]:
def get_w_sub(data):
    """get velocity at base of mixed layer"""

    ## get w along section
    w_sub = []
    for i, lon in enumerate(data.longitude):
        w_ = data["w"].sel(longitude=lon)
        mld_ = data["mld"].sel(longitude=lon)
        w_sub.append(w_.interp(z_t=mld_))

    return xr.concat(w_sub, dim=pd.Index(data.longitude.values, name="longitude"))


def get_v_timescale(data):
    """get inverse timescale for mean upwelling (w/H)"""

    ## get w at base of ML
    w_sub = get_w_sub(data)

    ## set negative values to zero
    w_sub = w_sub.where(w_sub > 0, other=0)

    ## compute
    return w_sub / data["mld"]

##### Compute

In [None]:
## get vertical velocities
w_sub_early = get_w_sub(clim_early)
w_sub_late = get_w_sub(clim_late)

## get (inverse) upwelling timescale
t_inv_early = get_v_timescale(clim_early)
t_inv_late = get_v_timescale(clim_late)

## get pct_change
t_inv_diff = t_inv_late - t_inv_early
t_inv_diff_pct = 100 * t_inv_diff / t_inv_early

##### Plot for single month

In [None]:
## select
sel = lambda x: x.mean("month")

fig, axs = plt.subplots(1, 3, figsize=(5, 1.5), layout="constrained")

## get longitude for plotting
lon = clim_early.longitude

## plot data
axs[0].plot(lon, sel(clim_early["mld"]))
axs[0].plot(lon, sel(clim_late["mld"]))

axs[1].plot(lon, sel(w_sub_early))
axs[1].plot(lon, sel(w_sub_late))

axs[2].plot(lon, sel(t_inv_early))
axs[2].plot(lon, sel(t_inv_late))

axs[0].set_ylim(axs[0].get_ylim()[::-1])

plt.show()

Hovmoller version of plot

In [None]:
## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.amp", levels=np.arange(0, 1.65, 0.15), extend="max")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], t_inv_early, **kwargs)

# ## plot late
cp1 = src.utils.plot_cycle_hov(axs[1], t_inv_late, **kwargs)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    t_inv_diff,
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.75, 0.075),
    extend="both",
)

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

cb0 = fig.colorbar(cp0, ax=axs[0], ticks=[0, 1.5], label=r"month$^{-1}$")
cb1 = fig.colorbar(cp1, ax=axs[1], ticks=[0, 1.5], label=r"month$^{-1}$")
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-0.75, 0.75], label=r"month$^{-1}$")

plt.show()

Look at temperature gradient more carefully

In [None]:
## get vertical velocities
dTdz_early = get_dTdz_sub(Tsub=clim_early["T"], mld=clim_early["mld"])
dTdz_late = get_dTdz_sub(Tsub=clim_late["T"], mld=clim_late["mld"])

In [None]:
## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.amp", levels=np.arange(0, 0.3, 0.03), extend="max")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], dTdz_early, **kwargs)

# ## plot late
cp1 = src.utils.plot_cycle_hov(axs[1], dTdz_late, **kwargs)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    dTdz_late - dTdz_early,
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.15, 0.015),
    extend="both",
)

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

cb0 = fig.colorbar(cp0, ax=axs[0], ticks=[0, 0.3], label=r"$K~m^{-1}$")
cb1 = fig.colorbar(cp1, ax=axs[1], ticks=[0, 0.3], label=r"$K~m^{-1}$")
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-0.15, 0.15], label=r"$K~m^{-1}$")

plt.show()

#### Change over time

Compute indices

In [None]:
## compute indices
w_idx = reconstruct_fn(forced["w_comp"], forced["w"], get_w_int_idx)
dTdz_idx = reconstruct_fn(forced["T_comp"], forced["T"], get_dTdz_idx)

## merge
idxs = xr.merge([w_idx.rename("w"), dTdz_idx.rename("dTdz")])

## unstack month and year, and get percent change
idxs = src.utils.unstack_month_and_year(idxs)
idxs = src.utils.get_rolling_avg(idxs, n=15, dim="year")
delta_idxs = idxs - idxs.isel(year=0)

Hovmoller indices

In [None]:
## setup plot
fig, axs = plt.subplots(1, 2, figsize=(2, 3), layout="constrained")

## plot mean
plot_hov2(axs[0], delta_idxs["dTdz"].T, amp=4)
plot_hov2(axs[1], delta_idxs["w"].T, amp=12)

## label
kwargs = dict(size=12)
axs[0].set_title(r"$\Delta \frac{\partial T}{\partial z}$", **kwargs)
axs[1].set_title(r"$\Delta w$", **kwargs)

plt.show()

### Change in variance

#### Equatorial region

In [None]:
## get variance by period
period_centers = np.array([1868, 2082])

## Get equatorial strip variance
kwargs = dict(data=anom[["sst", "sst_comp"]], periods=period_centers, n=15, fn=eq_avg)
var_by_period_eq = get_var_for_periods(**kwargs)

## get full variance
kwargs = dict(kwargs, fn=None, periods=period_centers[[0, -1]], by_month=True)
var_by_period = get_var_for_periods(**kwargs)

#### Spatial (Hovmoller)

In [None]:
## get plot data
baseline = var_by_period_eq["sst"].isel(period=0)
future = var_by_period_eq["sst"].isel(period=-1)
change = future - baseline

## shared args for plotting
plot_kwargs = dict(cmap="cmo.amp", extend="max")

## Set up plot
fig, axs = plt.subplots(3, 1, figsize=(4, 6), layout="constrained")

## make hövmöllers
cp0 = src.utils.plot_cycle_hov(
    axs[0], baseline, levels=np.arange(0, 3.3, 0.3), **plot_kwargs
)
cp1 = src.utils.plot_cycle_hov(
    axs[1], future, levels=np.arange(0, 3.3, 0.3), **plot_kwargs
)
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    change,
    levels=src.utils.make_cb_range(1.5, 0.15),
    cmap="cmo.balance",
    extend="both",
)

## label
axs[0].set_title("Baseline")
axs[1].set_title("Future")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

## add colorbars
kwargs = dict(ticks=[0, 1, 2, 3], label=r"$^{\circ}\text{C}^2$")
cb0 = fig.colorbar(cp0, ax=axs[0], **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **kwargs)
cb2 = fig.colorbar(cp2, ax=axs[2], **dict(kwargs, ticks=[-1.5, 0, 1.5]))

plt.show()

#### Skew

In [None]:
def skew_helper(data):
    """get skew for data"""

    ## specify kwargs
    kwargs = dict(input_core_dims=[["sample"]], kwargs=dict(axis=-1))

    return xr.apply_ufunc(
        scipy.stats.skew, data.stack(sample=["time", "member"]), **kwargs
    )


def skew_by_month(data):
    """compute skewness for equatorial strip by month"""

    ## get data on equator
    data_eq = src.utils.reconstruct_wrapper(
        data=data,
        fn=lambda x: x.sel(latitude=slice(-2, 2)).mean("latitude"),
    )
    ## compute skewness
    return data_eq.groupby("time.month").map(skew_helper)

Compute

In [None]:
## compute skew
SKEW_VARNAMES = ["sst", "sst_comp", "ssh", "ssh_comp"]
skew_early = skew_by_month(data=anom_early[SKEW_VARNAMES])
skew_late = skew_by_month(data=anom_late[SKEW_VARNAMES])

Plot

In [None]:
## shared args for plotting
plot_kwargs = dict(cmap="cmo.balance", extend="both")

## Set up plot
fig, axs = plt.subplots(3, 2, figsize=(6, 6), layout="constrained")

## make hövmöllers
for j, n in enumerate(["sst", "ssh"]):
    cp0 = src.utils.plot_cycle_hov(
        axs[0, j],
        skew_early[n],
        levels=src.utils.make_cb_range(1.5, 0.15),
        **plot_kwargs,
    )
    cp1 = src.utils.plot_cycle_hov(
        axs[1, j],
        skew_late[n],
        levels=src.utils.make_cb_range(1.5, 0.15),
        **plot_kwargs,
    )
    cp2 = src.utils.plot_cycle_hov(
        axs[2, j],
        (skew_late - skew_early)[n],
        levels=src.utils.make_cb_range(1.5, 0.15),
        **plot_kwargs,
    )

    ## label
    axs[0, j].set_title("Baseline")
    axs[1, j].set_title("Future")
    axs[2, j].set_title("Difference")
    axs[-1, j].set_xlabel("Longitude")
    axs[-1, j].set_xticks([140, 190, 240])

## add colorbars
kwargs = dict(ticks=[-1.5, 0, 1.5])
cb0 = fig.colorbar(cp0, ax=axs[0], label=r"skew", **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], label=r"skew", **kwargs)
cb2 = fig.colorbar(cp2, ax=axs[2], label=r"$\Delta$ skew", **kwargs)

for ax in axs[:, 1]:
    ax.set_yticks([])
    ax.set_ylabel(None)

## super-impose plot variance change
axs[2, 0].contour(
    change.longitude,
    change.month,
    change,
    colors="k",
    levels=src.utils.make_cb_range(3, 0.3),
    linewidths=1,
)

plt.show()

#### Spatial

Compute

In [None]:
## specify month
month = 7

## set up paneled subplot
fig = plt.figure(figsize=(6, 5), layout="constrained")
axs = src.utils.subplots_with_proj(
    fig, nrows=3, ncols=1, format_func=src.utils.plot_setup_pac
)

## plot data
kwargs = dict(
    var0=var_by_period["sst"].isel(period=0).sel(month=month),
    var1=var_by_period["sst"].isel(period=-1).sel(month=month),
    amp=3.5,
    amp_diff=1,
    show_colorbars=True,
    cbar_label=r"$^{\circ}$C$^2$",
)
fig, axs = src.utils.make_variance_subplots(fig, axs, **kwargs)
axs[0, 0].set_title("Early (1853 – 1883)")
axs[1, 0].set_title("Late (2067 – 2097)")
axs[2, 0].set_title("Difference")

for ax in axs.flatten():
    src.utils.plot_nino34_box(ax, c="w")

plt.show()

#### Look at change in thermocline depth over time

##### First: change in MLD over time

In [None]:
## Get Niño 3.4 mixed layer depth
H = src.utils.reconstruct_fn(
    scores=forced["mld"],
    components=components["mld"],
    fn=src.utils.get_nino3,
)

## 15-year rolling mean
mean = lambda x: x.rolling({"time": 15}, center=True).mean()
H = H.groupby("time.month").map(func=mean)

## trim NaN values
H = H.sel(time=slice("1857", "2093"))

## get fractional change
# diff = lambda x : (x - x.isel(time=0)) / x.isel(time=0)
diff = lambda x: (1 / x - 1 / x.isel(time=0)) / (1 / x.isel(time=0))

delta_H = H.groupby("time.month").map(diff)

In [None]:
## func to select month
sel_mon = lambda x, m: x.sel(time=x.time.dt.month == m)
yr = sel_mon(delta_H, 1).time.dt.year

## plot data
fig, ax = plt.subplots(figsize=(3, 2.5))
ax.plot(yr, sel_mon(delta_H, 12), label="Dec")
ax.plot(yr, sel_mon(delta_H, 3), label="Mar")
ax.plot(yr, sel_mon(delta_H, 6), label="Jun")
ax.plot(yr, sel_mon(delta_H, 9), label="Sep")

ax.legend(prop=dict(size=8))
plt.show()

Compute thermocline depth

In [None]:
## func to get mean over Niño 3.4 longitudes
get_n34_lon_mean = lambda x: x.sel(lon=slice(190, 240)).mean("lon")

## compute
Tw = src.utils.reconstruct_fn(
    components=components_sub,
    scores=forced_sub,
    fn=get_n34_lon_mean,
)

# get 15 year mean
Tw = Tw.groupby("time.month").map(func=mean)

# trim NaN values
Tw = Tw.sel(time=slice("1857", "2093")).isel(time=slice(12, None))

## compute vertical gradient
Tw["dTdz"] = Tw["T"].differentiate("z_t")

## get thermocline depth
H1 = Tw.z_t.isel(z_t=Tw["dTdz"].argmin("z_t"))

## get change
delta_H1 = H1.groupby("time.month").map(diff)

In [None]:
fig, ax = plt.subplots(figsize=(3, 2.5))

ax.plot(sel_mon(delta_H1, 12))
ax.plot(sel_mon(delta_H1, 3))
ax.plot(sel_mon(delta_H1, 6))
ax.plot(sel_mon(delta_H1, 9))

plt.show()

In [None]:
## specify month
month = 12

## compute profiles
early = sel_mon(Tw, month).isel(time=slice(None, 30)).mean("time")
late = sel_mon(Tw, month).isel(time=slice(-30, None)).mean("time")

## plot some profiles
fig, axs = plt.subplots(1, 2, figsize=(4, 4))

axs[0].plot(early["T"], Tw.z_t)
axs[0].plot(late["T"], Tw.z_t)

axs[1].plot(early["dTdz"], Tw.z_t)
axs[1].plot(late["dTdz"], Tw.z_t)

for ax in axs:
    ax.set_ylim(ax.get_ylim()[::-1])
plt.show()

In [None]:
## plot some profiles
fig, ax = plt.subplots(figsize=(2, 4))

ax.plot(sel_mon(Tw["T"], 12).isel(time=slice(None, 30)).mean("time"), Tw.z_t)
ax.plot(sel_mon(Tw["T"], 12).isel(time=slice(-30, None)).mean("time"), Tw.z_t)

ax.set_ylim(ax.get_ylim()[::-1])
plt.show()

### Change in Bjerknes coupling

#### Hövmöller

In [None]:
def make_scatter(ax, data, x_var, y_var, fn_x, fn_y, scale=1):
    """scatter plot data on axis"""

    ## evaluate functions
    if "mode" in data.dims:
        fn_x_eval = src.utils.reconstruct_fn(
            scores=data[x_var], components=data[f"{x_var}_comp"], fn=fn_x
        )
        fn_y_eval = src.utils.reconstruct_fn(
            scores=data[y_var], components=data[f"{y_var}_comp"], fn=fn_y
        )

        ## stack member/time dim
        stack = lambda x: x.stack(sample=["member", "time"])
        fn_x_eval = stack(fn_x_eval)
        fn_y_eval = stack(fn_y_eval)
        dim = "sample"

    else:
        fn_x_eval = fn_x(data[x_var])
        fn_y_eval = fn_y(data[y_var])
        dim = "time"

    ## compute slope for best fit line
    slope = src.utils.regress_core(X=fn_x_eval, Y=scale * fn_y_eval, dim=dim)

    ## convert to numpy
    slope = slope.values.item()

    ## plot data
    ax.scatter(fn_x_eval, scale * fn_y_eval, s=0.5)

    ## plot best fit
    xtest = np.linspace(fn_x_eval.values.min(), fn_x_eval.values.max())
    ax.plot(xtest, slope * xtest, c="k", lw=1)

    ## plot some guidelines
    ax.axhline(0, ls="--", lw=0.8, c="k")
    ax.axvline(0, ls="--", lw=0.8, c="k")

    return slope


def get_alpha(data, dim):
    """compute mu: linear dependence of nhf on sst"""

    return src.utils.regress_core(Y=data["nhf"], X=data["sst"], dim=dim)


def get_mu(data, dim):
    """compute mu: linear dependence of taux on sst"""

    return src.utils.regress_core(Y=data["taux"], X=data["sst"], dim=dim)


def get_beta(data, dim):
    """compute mu: linear dependence of ssh on taux"""

    return src.utils.regress_core(Y=data["ssh"], X=data["taux"], dim=dim)


def get_xi(data, dim):
    """compute mu: linear dependence of sst on ssh"""

    return src.utils.regress_core(Y=data["sst"], X=data["ssh"], dim=dim)


def get_params(data, dim="time"):
    """compute all parameters"""
    params = xr.merge(
        [
            get_alpha(data, dim=dim).rename("alpha"),
            get_mu(data, dim=dim).rename("mu"),
            get_beta(data, dim=dim).rename("beta"),
            get_xi(data, dim=dim).rename("xi"),
        ]
    )

    params["coupling"] = params["mu"] * params["beta"] * params["xi"]

    return params


def get_rolling_params(data, n=10, reduce_ensemble=True):
    """get rolling parameters"""

    ## get rolling data
    idx_rolling = data.rolling({"time": 2 * n + 1}, center=True)

    ## expand rolling object along "window" dimension
    idx_rolling = idx_rolling.construct("window")

    ## stack member/window if desired
    if reduce_ensemble:
        idx_rolling = idx_rolling.stack(sample=["member", "window"])

    else:
        idx_rolling = idx_rolling.rename({"window": "sample"})

    return get_params(idx_rolling, dim="sample")


def get_rolling_params_bymonth(data, **kwargs):
    """get rolling parameters for each month separately..."""

    return data.groupby("time.month").map(get_rolling_params, **kwargs)


def get_fractional_change(data, dim="year"):
    clim = data.isel({dim: slice(None, 30)}).mean(dim)

    return (data - clim) / clim

In [None]:
components, _ = src.utils.split_components(anom)

In [None]:
## compute indices
varnames = ["sst", "nhf", "ssh", "taux"]
kwargs = dict(scores=anom[varnames], components=components[varnames])
nino34 = src.utils.reconstruct_fn(fn=src.utils.get_nino34, **kwargs)
nino4 = src.utils.reconstruct_fn(fn=src.utils.get_nino4, **kwargs)

## subset for relevant indices
idxs = xr.merge([nino34[["sst", "nhf", "ssh"]], nino4["taux"]])

In [None]:
## get params over time
params = get_rolling_params_bymonth(idxs, n=16)

## unstack month and year to separate dims
params = src.utils.unstack_month_and_year(params)

## subtract off climatology
delta_params = params - params.isel(year=slice(None, 30)).mean("year")

In [None]:
## setup plot
fig, axs = plt.subplots(1, 5, figsize=(9, 4), layout="constrained")

## plot T data
plot_hov2(axs[0], delta_params["alpha"].T, amp=20, label=r"$\Delta~ \alpha$")
plot_hov2(axs[1], delta_params["mu"].T, amp=0.01, label=r"$\tau_x-\text{SST}$")
plot_hov2(axs[2], delta_params["beta"].T, amp=200, label=r"$\text{SSH}-\tau_x$")
plot_hov2(axs[3], delta_params["xi"].T, amp=7e-2, label=r"$\text{SST}-\text{SSH}$")
plot_hov2(axs[4], delta_params["coupling"].T, amp=0.5, label=r"coupling")

for ax in axs:
    ax.axvline(7, c="k", lw=1, ls="--")

## label
axs[0].set_yticks(np.linspace(1870, 2082, 5))
axs[0].set_ylabel("Year")
axs[1].set_ylim(axs[0].get_ylim())
plt.show()

#### preprocessing

In [None]:
def merimean(x):
    return x.sel(longitude=slice(140, 285), latitude=slice(-5, 5)).mean("latitude")


def plot_cycle_hov(ax, data, amp, is_filled=True, xticks=[190, 240]):
    """plot data on ax object"""

    ## specify shared kwargs
    shared_kwargs = dict(levels=src.utils.make_cb_range(amp, amp / 5), extend="both")

    ## specify kwargs
    if is_filled:
        plot_fn = ax.contourf
        kwargs = dict(cmap="cmo.balance")

    else:
        plot_fn = ax.contour
        kwargs = dict(colors="k", linewidths=0.8)

    ## average over latitudes (if necessary)
    if "latitude" in data.coords:
        plot_data = merimean(data)
    else:
        plot_data = data

    ## do the plotting
    cp = plot_fn(
        plot_data.longitude,
        plot_data.month,
        plot_data,
        **kwargs,
        **shared_kwargs,
    )

    ## format ax object
    kwargs = dict(c="w", ls="--", lw=1)
    ax.set_xlim([145, 280])
    ax.set_xlabel("Lon")
    ax.set_xticks(xticks)
    for tick in xticks:
        ax.axvline(tick, **kwargs)

    return cp

Add EOF info to dataset

In [None]:
for v in list(components):
    if f"{v}_comp" not in list(anom):
        anom[f"{v}_comp"] = components[v]

## add T and h indices to data
names = ["nino3", "nino34", "nino4"]
fns = [src.utils.get_nino3, src.utils.get_nino34, src.utils.get_nino4]
for n, fn in zip(names, fns):
    if n not in list(anom):
        anom[n] = src.utils.reconstruct_fn(
            scores=anom["sst"], components=anom["sst_comp"], fn=fn
        )

names = ["h_w", "h"]
fns = [src.utils.get_RO_hw, src.utils.get_RO_h]
for n, fn in zip(names, fns):
    if n not in list(anom):
        anom[n] = src.utils.reconstruct_fn(
            scores=anom["ssh"], components=anom["ssh_comp"], fn=fn
        )

In [None]:
### To-do: compute Tsub...
# get_dT_sub(Tsub =

Partition into early/late

In [None]:
anom_early = anom.isel(time=slice(12, 372))
anom_late = anom.isel(time=slice(-372, -12))

Remove linear dependence on SST and get tendencies

In [None]:
def prep(data):
    """remove sst dependence and compute tendencies"""

    ## remove SST dependence from SSH field
    if "ssh" in list(data):
        data["ssh_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var="ssh", T_var="nino34"
        )
        data["ssh_hat_comp"] = data["ssh_comp"]

    ## remove from h indices
    for h_idx in ["h_w", "h"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var=h_idx, T_var="nino34"
        )

    ## compute tendencies
    data = src.utils.get_ddt(data)

    return data

In [None]:
anom_early = prep(anom_early)
anom_late = prep(anom_late)

#### $R$ 
$\frac{d T}{dt}$ vs $T_{34}$

In [None]:
## shared args
kwargs = dict(x_vars=["nino34", "h_w_hat"], y_var="ddt_sst")

m_early = src.utils.multi_regress_bymonth(anom_early, **kwargs)["nino34"]
m_late = src.utils.multi_regress_bymonth(anom_late, **kwargs)["nino34"]

In [None]:
## select month
sel_mon = lambda x: x.sel(month=7)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=5, sel=sel_mon)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early, **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late, **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], (m_late - m_early), **dict(contour_kwargs, amp=5)
)

## colorbar
lab = r"$K ~\text{yr}^{-1}~ \left(T_{34}\right)^{-1}$"
cb0 = fig.colorbar(cp0, ax=axs, ticks=[-5, 0, 5], label=lab)

## Niño 3 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_nino4_box(ax, **box_kwargs)

plt.show()

Scatter plot of Niño 3.4 vs. $\frac{d}{dt}\left(\text{Niño 3.4}\right)$ in July

In [None]:
## get data for july
def get_month(data, month):
    month = data.sel(time=anom_late.time.dt.month == month)
    return month.stack(sample=["member", "time"])


jul_early = get_month(anom_early, 7)
jul_late = get_month(anom_late, 7)


## make plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.4))

axs[0].scatter(
    jul_early["nino4"],
    jul_early["ddt_nino4"],
    s=2,
    alpha=0.5,
)

axs[1].scatter(
    jul_late["nino4"],
    jul_late["ddt_nino4"],
    s=2,
    alpha=0.5,
)

for ax in axs:
    ax.set_xlim([-3, 3])
    ax.set_ylim([-25, 15])
    ax.axhline(0, ls="--", c="k", lw=1)
    ax.axvline(0, ls="--", c="k", lw=1)

axs[1].set_yticks([])
axs[0].set_title("Early")
axs[1].set_title("Late")

plt.show()

In [None]:
merimean = lambda x: x.sel(latitude=slice(-5, 5)).mean("latitude")

## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(
    cmap="cmo.balance", levels=src.utils.make_cb_range(7.5, 0.75), extend="both"
)
cb_kwargs = dict(ticks=[-5, 0, 5], label=r"$m~yr^{-1}~K^{-1}$")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], merimean(m_early), **kwargs)
cb0 = fig.colorbar(cp0, ax=axs[0], **cb_kwargs)

# ## plot late
cp1 = src.utils.plot_cycle_hov(axs[1], merimean(m_late), **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **cb_kwargs)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    merimean(m_late - m_early),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(7.5, 0.75),
    extend="both",
)

cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-7.5, 0, 7.5], label=r"$m~yr^{-1}~K^{-1}$")

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])
for ax in axs:
    kwargs = dict(ls="--", c="k", lw=1, alpha=0.5)
    ax.axhline(6, **kwargs)
    ax.axhline(7, **kwargs)

plt.show()

#### $F_2$ 
$\frac{d h}{dt}$ vs $T_{34}$

In [None]:
## shared args
kwargs = dict(x_vars=["nino34", "h_w_hat"], y_var="ddt_ssh_hat")

m_early = src.utils.multi_regress_bymonth(anom_early, **kwargs)["nino34"]
m_late = src.utils.multi_regress_bymonth(anom_late, **kwargs)["nino34"]

Spatial

In [None]:
## select month
sel_mon = lambda x: x.sel(month=slice(5, 5)).mean("month")

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=50, sel=sel_mon)
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early, **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late, **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], m_late - m_early, **dict(contour_kwargs, amp=50)
)

## colorbar
lab = r"$m ~\text{yr}^{-1}~ \left(T_{34}\right)^{-1}$"
cb0 = fig.colorbar(cp0, ax=axs, ticks=[-50, 0, 50], label=lab)

## h_w box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_box(ax, lons=[120, 210], lats=[-5, 5], **box_kwargs)

plt.show()

Hovmoller

In [None]:
merimean = lambda x: x.sel(latitude=slice(-5, 5)).mean("latitude")

## make hövmöllers
fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.balance", levels=src.utils.make_cb_range(40, 4), extend="both")
cb_kwargs = dict(ticks=[-40, 0, 40], label=r"$m~yr^{-1}~K^{-1}$")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], merimean(m_early), **kwargs)
cb0 = fig.colorbar(cp0, ax=axs[0], **cb_kwargs)

# ## plot late
cp1 = src.utils.plot_cycle_hov(axs[1], merimean(m_late), **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **cb_kwargs)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    merimean(m_late - m_early),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(32, 3.2),
    extend="both",
)

# ## contour MLD changes
# axs[2].contour(
#     mld_clim_diff_pct2.longitude,
#     mld_clim_diff_pct2.month,
#     mld_clim_diff_pct2.T,
#     levels=src.utils.make_cb_range(100,25),
#     colors="k",
#     linewidths=.5,
#     alpha=.5,
# )

cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-32, 0, 32], label=r"$m~yr^{-1}~K^{-1}$")

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Difference")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

plt.show()

In [None]:
def scatter_data(ax, data):
    """add scatter data to plot"""

    ## get data for plotting
    x = data["nino34"]
    y = data["ddt_h_w_hat"]

    ## compute regression coefficient
    m = src.utils.regress_core(Y=y, X=x, dim="sample").values.item()

    ## plot data
    ax.scatter(x, y, s=2, alpha=0.5)

    ## plot best fit
    z = np.linspace(x.min().values.item(), x.max().values.item())
    ax.plot(z, m * z, c="k", lw=1, label=f"{m:.1f}" + r" $m~yr^{-1}~K^{-1}$")

    ## formatting
    ax.set_xlim([-3, 3])
    ax.set_ylim([-80, 80])
    ax.axhline(0, ls="--", c="k", lw=1)
    ax.axvline(0, ls="--", c="k", lw=1)
    ax.set_xlabel("Niño 3.4")
    ax.set_yticks([])
    ax.legend(prop=dict(size=8))

    return

In [None]:
## select month to plot
month = 5

## extract data for month
may_early = get_month(anom_early, month)
may_late = get_month(anom_late, month)

## make plot
fig, axs = plt.subplots(1, 2, figsize=(5, 2.4), layout="constrained")

scatter_data(axs[0], may_early)
scatter_data(axs[1], may_late)

## formatting
axs[0].set_yticks([-60, 0, 60])
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[0].set_ylabel("Recharge rate ($m~yr^{-1}$)", size=10)

plt.show()

## Change in zonal gradient

### Compute zonal gradient

In [None]:
def get_zonal_grad(x):
    """
    Function to compute zonal gradient. Ref: Fig 7 in Maher et al, 2023.
    """

    ## outlines for boxes
    ebox_coords = dict(longitude=slice(210, 270), latitude=slice(-5, 5))
    wbox_coords = dict(longitude=slice(120, 180), latitude=slice(-5, 5))

    ## compute box averages
    ebox_avg = src.utils.spatial_avg(x.sel(ebox_coords))
    wbox_avg = src.utils.spatial_avg(x.sel(wbox_coords))

    return ebox_avg - wbox_avg

In [None]:
## compute zonal gradient
zonal_grad_forced = src.utils.reconstruct_fn(components, forced, fn=get_zonal_grad)

## sep. time into year and month
zonal_grad_forced_bymonth = src.utils.unstack_month_and_year(zonal_grad_forced)

## get change from initial climatology
clim = zonal_grad_forced_bymonth.isel(year=slice(None, 30)).mean("year")
zonal_grad_change = zonal_grad_forced_bymonth - clim

In [None]:
## setup plot
fig, ax = plt.subplots(figsize=(2, 4), layout="constrained")

## plot data
kwargs = dict(levels=src.utils.make_cb_range(2, 0.2), cmap="cmo.balance")
plot_data = ax.contourf(
    zonal_grad_change.month, zonal_grad_change.year, zonal_grad_change["sst"], **kwargs
)

## label
ax.set_ylabel("Year")
ax.set_xlabel("Month")
ax.set_xticks([1, 6, 12])
ax.set_title(r"$\Delta \left(\partial_x T\right)$")
ax.set_ylim([1950, None])

plt.show()