# Look at climate change in MPI

## Imports

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import matplotlib as mpl
import numpy as np
import scipy.stats
import seaborn as sns
import xarray as xr
import warnings
import tqdm
import pathlib
import cmocean
import os
import calendar

# Import custom modules
import src.utils
from src.XRO import XRO, xcorr

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Functions

In [None]:
def plot_hov(ax, data, amp, label=None):
    """Plot hovmoller of longitude vs. year"""

    # kwargs = dict(levels=src.utils.make_cb_range(3, 0.3), cmap="cmo.balance", extend="both")
    plot_data = ax.contourf(
        data.longitude,
        data.year,
        data.T,
        cmap="cmo.balance",
        extend="both",
        levels=src.utils.make_cb_range(amp, amp / 10),
    )
    cb = fig.colorbar(
        plot_data, orientation="horizontal", ticks=[-amp, 0, amp], label=label
    )

    ## label
    kwargs = dict(ls="--", c="w", lw=0.8)
    for ax in axs:
        ax.set_xlabel("Longitude")
        ax.set_xticks([190, 240])
        ax.set_yticks([])
        ax.axvline(190, **kwargs)
        ax.axvline(240, **kwargs)
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position("top")

    return


def plot_hov2(ax, data, amp, label=None):
    """Plot hovmoller of longitude vs. year"""

    # kwargs = dict(levels=src.utils.make_cb_range(3, 0.3), cmap="cmo.balance", extend="both")
    plot_data = ax.contourf(
        data.month,
        data.year,
        data.T,
        cmap="cmo.balance",
        extend="max",
        levels=src.utils.make_cb_range(amp, amp / 10),
    )
    cb = fig.colorbar(
        plot_data, orientation="horizontal", ticks=[-amp, 0, amp], label=label
    )

    ## label
    kwargs = dict(ls="--", c="w", lw=0.8)
    for ax in axs:
        ax.set_xlabel("Month")
        ax.set_xticks([1, 12])
        ax.set_yticks([])
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position("top")

    return


def get_rolling_var(data, n=10):
    """
    Get variance, computing over time and ensemble member. To increase
    sample size for variance estimate, compute over time window of 2n+1
    years, centered at given year.
    """

    return src.utils.get_rolling_fn_bymonth(data, fn=np.var, n=n)


def get_empirical_pdf(x, bin_edges=None):
    """
    Estimate the "empirical" probability distribution function for the data x.
    In this case the result is a normalized histogram,
    Normalized means that integrating over the histogram yields 1.
    Returns the PDF (normalized histogram) and edges of the histogram bins
    """

    ## compute histogram
    if bin_edges is None:
        hist, bin_edges = np.histogram(x)

    else:
        hist, _ = np.histogram(x, bins=bin_edges)

    ## normalize to a probability distribution (PDF)
    bin_width = bin_edges[1:] - bin_edges[:-1]
    pdf = hist / (hist * bin_width).sum()

    return pdf, bin_edges

## Change in $T$, $h$

### Load data
And compute variance/skewness

In [None]:
## MPI data
mpi_load_fp = pathlib.Path(DATA_FP, "mpi_Th", "Th.nc")
Th = xr.open_dataset(mpi_load_fp)

## get rolling variance, by month
Th_var = get_rolling_var(Th, n=15)
Th_var_bymonth = src.utils.unstack_month_and_year(Th_var)

## get rolling skew, by month
Th_skew = src.utils.get_rolling_fn_bymonth(Th, fn=scipy.stats.skew, n=15)
Th_skew_bymonth = src.utils.unstack_month_and_year(Th_skew)

## Get % increase in variance
baseline = Th_var_bymonth.isel(year=slice(None, 30)).mean("year")
Th_var_bymonth_pct = 100 * (Th_var_bymonth - baseline) / baseline

### Variance

Hövmöller

In [None]:
## setup plot
fig, axs = plt.subplots(1, 2, figsize=(3.6, 4), layout="constrained")

## plot T data
plot_hov2(axs[0], Th_var_bymonth_pct["T_34"].T, amp=100, label="% change (SST)")
plot_hov2(axs[1], Th_var_bymonth_pct["h"].T, amp=100, label="% change (SSH)")

## label
xticks = np.linspace(1870, 2082, 5)
axs[0].set_yticks(xticks)
axs[0].set_ylabel("Year")
axs[1].set_ylim(axs[0].get_ylim())
for ax in axs:
    ax.axvline(8, ls="--", c="w", alpha=1, lw=0.5)
    ax.set_xticks([1, 8, 12])

plt.show()

Compare August and December variance over time

In [None]:
fig, ax = plt.subplots(figsize=(2.5, 2))

## plot data
ax.plot(
    Th_var_bymonth.year, Th_var_bymonth["T_34"].sel(month=8), label="Aug", c="k", ls="-"
)
ax.plot(
    Th_var_bymonth.year,
    Th_var_bymonth["T_34"].sel(month=11),
    label="Dec",
    c="gray",
    ls="--",
)

## label and style
ax.set_title(r"$\sigma^2\left(\text{Niño 3.4}\right)$")
ax.set_ylim([0, None])
ax.set_xticks(xticks[[0, -1]])
ax.legend(prop=dict(size=8))
ax.set_ylabel(r"$^{\circ}\text{C}^2$")

plt.show()

### Skewness

In [None]:
## setup plot
fig, axs = plt.subplots(1, 2, figsize=(3.6, 4), layout="constrained")

## plot T data
plot_hov2(axs[0], Th_skew_bymonth["T_34"].T, amp=0.6, label="SST")
plot_hov2(axs[1], Th_skew_bymonth["h"].T, amp=0.6, label="SSH")

## label
axs[0].set_yticks(np.linspace(1870, 2082, 5))
axs[0].set_ylabel("Year")
axs[1].set_ylim(axs[0].get_ylim())
plt.show()

### Seasonality

## Change in spatial patterns

In [None]:
## specify sliding window size for climatology
n_years = 3

## Load EOF data
eofs_fp = pathlib.Path(DATA_FP, "mpi", "eofs300")
eofs_sst = src.utils.load_eofs(eofs_fp / "ts.nc")
eofs_ssh = src.utils.load_eofs(eofs_fp / "ssh.nc")
eofs_taux = src.utils.load_eofs(eofs_fp / "taux.nc")

## for convenience, put components and scores into datasets
components = xr.merge(
    [
        eofs_sst.components().rename("sst"),
        eofs_ssh.components().rename("ssh"),
        eofs_taux.components().rename("taux"),
    ]
)
scores = xr.merge(
    [
        eofs_sst.scores().rename("sst"),
        eofs_ssh.scores().rename("ssh"),
        eofs_taux.scores().rename("taux"),
    ]
)

## get forced/anomalous component
forced, anom = src.utils.separate_forced(scores, n=n_years)

### Reconstruct equatorial strip

In [None]:
## get rolling avg
forced_rolling = src.utils.get_rolling_avg(forced, n=1)

## get NDJ season
month = forced_rolling.time.dt.month
forced_ndj = forced_rolling.isel(time=(month == 12))

## reconstruct equatorial average
eq_avg = lambda x: x.sel(latitude=slice(-2, 2), longitude=slice(125, 279)).mean(
    "latitude"
)
kwargs = dict(scores=forced_ndj, components=components, fn=eq_avg)
forced_ndj_recon = src.utils.reconstruct_fn(**kwargs)

## replace time axis with year for easier plotting
forced_ndj_recon["time"] = forced_ndj_recon["time"].dt.year
forced_ndj_recon = forced_ndj_recon.rename({"time": "year"})

## subtract leading 30-year mean
clim = forced_ndj_recon.isel(year=slice(None, 30)).mean("year")
forced_anom = forced_ndj_recon - clim

### Hövmöller plot

In [None]:
## setup plot
fig, axs = plt.subplots(1, 4, figsize=(7.2, 4), layout="constrained")

for j, (varname, total_amp, diff_amp) in enumerate(
    zip(["sst", "ssh"], [3, 8], [0.5, 2])
):

    ## get total plot and difference
    total = forced_anom[varname]
    diff = total - total.mean("longitude")

    ## plot data
    plot_hov(axs[j], total, amp=total_amp, label=f"{varname} (total)")
    plot_hov(axs[j + 2], diff, amp=diff_amp, label=f"{varname} (diff)")

axs[0].set_yticks(np.arange(1860, 2100, 55))
axs[0].set_ylabel("Year")
plt.show()

### Line plots

Function to do formatting. Not sure about units for SSH: data attribute says "m", but data itself suggests "cm"...

In [None]:
def format_line_plots(axs):
    """format subplots"""

    for ax in axs[0, :]:
        ax.set_xticks([])
    # for ax in axs[:, 1]:
    #     ax.yaxis.tick_right()
    axs[1, 0].set_xticks([1860, 1970, 2080])
    axs[0, 2].set_yticks([-0.5, 0.5])
    axs[1, 2].set_yticks([-4, 4])
    axs[0, 1].set_yticks([-3, 3])
    axs[1, 1].set_yticks([-24, 24])
    axs[0, 0].set_yticks([25.5, 28.5])
    axs[1, 0].set_yticks([91, 84])
    axs[0, 0].set_ylabel(r"SST ($^{\circ}$C)")
    axs[1, 0].set_ylabel(r"SSH (cm?)")
    axs[0, 0].set_title("Zonal mean")
    axs[0, 1].set_title("Zonal anomaly")
    axs[0, 2].set_title("Change (late minus early)")
    axs[1, 0].set_xlabel("Year")
    for ax in axs[1, 1:]:
        ax.set_xticks([140, 210, 280])
        ax.set_xlabel("Longitude")

    for ax in axs[:, 2]:
        kwargs = dict(ls="--", lw=0.5, c="k")
        ax.axhline(0, **kwargs)
        ax.axvspan(160, 210, color="b", alpha=0.15, label="4")
        ax.axvspan(190, 240, color="r", alpha=0.15, label="3.4")

    ## scale axes
    for j in [0, 1]:
        axs[j, 2].set_ylim(1 / 6 * np.array(axs[j, 1].get_ylim()))

    ## legends
    axs[0, 1].legend(prop=dict(size=6))
    axs[1, 2].legend(prop=dict(size=6), loc="upper right")

    return

Make plot

In [None]:
## get zonal mean, tilt, and grad
zonal_mean = forced_ndj_recon.mean("longitude")
tilt = forced_ndj_recon - zonal_mean
grad = tilt.differentiate(coord="longitude")

## get plot style for beginning/end periods
start_kwargs = dict(color="b", ls="-", label="First 30 yrs")
end_kwargs = dict(color="r", ls="--", label="Last 30 yrs")
t_idxs = [slice(None, 30), slice(-30, None)]

## make plot
fig, axs = plt.subplots(2, 3, figsize=(6, 3), layout="constrained")

## plot SST and SSH on separate rows
for j, varname in enumerate(["sst", "ssh"]):

    ## plot zonal mean
    axs[j, 0].plot(zonal_mean.year, zonal_mean[varname])

    ## tilt at beginning and end of period
    tilts = [tilt[varname].isel(year=t_).mean("year") for t_ in t_idxs]

    ## plot diff from zonal mean
    for tilt_, plot_kwargs in zip(tilts, [start_kwargs, end_kwargs]):

        ## tilt
        axs[j, 1].plot(tilt.longitude, tilt_, **plot_kwargs)

    ## plot change
    axs[j, 2].plot(diff.longitude, tilts[1] - tilts[0], c="k")


## label/format plot
format_line_plots(axs)

plt.show()

### Change in variance

Should we compute/plot this (computationally expensive)

In [None]:
COMPUTE_VAR = False

#### Equatorial region

In [None]:
def get_var(data, year_center, n=15, fn=None):
    """
    Reconstruct variance for given time period.
    Estimate variance based on window centered on `year_center'.
    Window encompasses all samples within 'n' years of year_center.
    """

    ## get indices of samples in window
    in_window = np.abs(data.time.dt.year - year_center) <= n

    ## get variance of samples in window
    kwargs = dict(components=components, scores=data.isel(time=in_window), fn=fn)
    data_var = src.utils.reconstruct_var(**kwargs)

    return data_var


def get_var_bymonth(data, year_center, n=15, fn=None):
    """Get variance by month"""

    ## function to apply to each month
    kwargs = dict(year_center=year_center, n=n, fn=fn)
    get_var_ = lambda x: get_var(x, **kwargs)
    return data.groupby("time.month").map(get_var_)


def get_var_for_periods(data, periods, n=15, fn=None, by_month=True):
    """get variance for specified periods"""

    ## get arguments for variance reconstruction
    kwargs = dict(data=data, n=n, fn=fn)

    ## get variance recon func
    get_var_fn = get_var_bymonth if by_month else get_var
    var_by_period = [get_var_fn(year_center=y, **kwargs) for y in tqdm.tqdm(periods)]

    ## get dimension to represent period_centers
    period_dim = pd.Index(periods, name="period")

    ## put in array
    return xr.concat(var_by_period, dim=period_dim)


if COMPUTE_VAR:

    ## get variance by period
    period_centers = np.array([1868, 1939, 2010, 2082])

    ## Get equatorial strip variance
    kwargs = dict(data=anom, periods=period_centers, n=15, fn=eq_avg)
    var_by_period_eq = get_var_for_periods(**kwargs)

    # ## Get variance over broader range of latitudes (for all months
    kwargs = dict(kwargs, fn=None, periods=period_centers[[0, -1]], by_month=False)
    var_by_period = get_var_for_periods(**kwargs)
    # var_by_period = xr.concat([
    #     get_var(anom, year_center=p) for p in period_centers[[0,-1]]
    # ], concat_dim

Set plot style for plots

In [None]:
def label_axs(axs):
    """label axs for variance plot"""
    axs[0].legend(prop=dict(size=6))
    axs[0].set_title("SST")
    axs[1].set_title("SSH")
    axs[0].set_ylabel(r"$^{\circ}$C$^2$")
    axs[1].set_ylabel(r"cm$^2$")
    axs[0].set_yticks([0, 1, 2])
    axs[1].set_yticks([0, 35, 70])
    axs[1].yaxis.tick_right()
    axs[1].yaxis.set_label_position("right")
    for ax in axs:
        ax.set_xlabel("Longitude")

    return


def add_xticks(ax, ticks):
    """add tick labels in specified positions"""
    ax.set_xticks(ticks)
    for tick in ticks:
        ax.axvline(tick, ls="--", c="k", lw=0.5)
    return

In [None]:
if COMPUTE_VAR:

    ## select month for plot
    month = 8

    ## colorbar
    colors = sns.color_palette("mako")

    ## set up plot
    fig, axs = plt.subplots(1, 2, figsize=(5, 2), layout="constrained")

    for j, (p, c) in enumerate(zip(var_by_period_eq.period, colors)):

        ## Get subset of data for period
        plot_data = var_by_period_eq.sel(period=p, month=month)

        ## plot
        lon = var_by_period_eq.longitude
        axs[0].plot(lon, plot_data["sst"], label=f"{period_centers[j]}", c=c)
        axs[1].plot(lon, plot_data["ssh"], c=c)

    ## format/label axes
    label_axs(axs)
    add_xticks(axs[0], [190, 240])
    add_xticks(axs[1], [160, 210])

    axs[0].set_ylim([0, 2.5])

    plt.show()

#### Spatial (Hovmoller)

In [None]:
if COMPUTE_VAR:

    ## get plot data
    baseline = var_by_period_eq["sst"].isel(period=0)
    future = var_by_period_eq["sst"].isel(period=-1)
    change = future - baseline

    ## shared args for plotting
    plot_kwargs = dict(cmap="cmo.amp", extend="max")

    ## Set up plot
    fig, axs = plt.subplots(3, 1, figsize=(4, 6), layout="constrained")

    ## make hövmöllers
    cp0 = src.utils.plot_cycle_hov(
        axs[0], baseline, levels=np.arange(0, 3.3, 0.3), **plot_kwargs
    )
    cp1 = src.utils.plot_cycle_hov(
        axs[1], future, levels=np.arange(0, 3.3, 0.3), **plot_kwargs
    )
    cp2 = src.utils.plot_cycle_hov(
        axs[2],
        change,
        levels=src.utils.make_cb_range(1.5, 0.15),
        cmap="cmo.balance",
        extend="both",
    )

    ## label
    axs[0].set_title("Baseline")
    axs[1].set_title("Future")
    axs[2].set_title("Difference")
    axs[-1].set_xlabel("Longitude")
    axs[-1].set_xticks([140, 190, 240])

    ## add colorbars
    kwargs = dict(ticks=[0, 1, 2, 3], label=r"$^{\circ}\text{C}^2$")
    cb0 = fig.colorbar(cp0, ax=axs[0], **kwargs)
    cb1 = fig.colorbar(cp1, ax=axs[1], **kwargs)
    cb2 = fig.colorbar(cp2, ax=axs[2], **dict(kwargs, ticks=[-1.5, 0, 1.5]))

    plt.show()

#### Spatial

Compute

In [None]:
if COMPUTE_VAR:

    ## set up paneled subplot
    fig = plt.figure(figsize=(6, 5), layout="constrained")
    axs = src.utils.subplots_with_proj(
        fig, nrows=3, ncols=1, format_func=src.utils.plot_setup_pac
    )

    ## plot data
    kwargs = dict(
        var0=var_by_period["sst"].isel(period=0),
        var1=var_by_period["sst"].isel(period=-1),
        amp=2,
        amp_diff=1,
        show_colorbars=True,
        cbar_label=r"$^{\circ}$C$^2$",
    )
    fig, axs = src.utils.make_variance_subplots(fig, axs, **kwargs)
    axs[0, 0].set_title("Early (1853 – 1883)")
    axs[1, 0].set_title("Late (2067 – 2097)")
    axs[2, 0].set_title("Difference")

    plt.show()

### ENSO composite

#### Hovmoller

Function to compute composite

In [None]:
def get_hov_composite(data, peak_month, q, idx_fn):
    """
    Get hovmoller composite based on specified:
    - data: used to compute index/make composite
    - peak_month: month to center composite on
    - q: quantile threshold for composite
    - idx_fn: function to compute index from spatial sst anoms
    """

    ## get data subset

    ## get MPI index
    idx = src.utils.reconstruct_fn(
        components=components["sst"], scores=data["sst"], fn=idx_fn
    )

    ## kwargs for composite
    kwargs = dict(peak_month=peak_month, q=q)

    ## composite of projected data
    comp_proj = src.utils.make_composite(idx=idx, data=data, **kwargs)

    ## meridional mean for MPI
    comp_merimean = src.utils.reconstruct_fn(
        components=components, scores=comp_proj, fn=get_merimean
    ).transpose("lag", ...)

    return comp_merimean


def get_centroid_posn_dec(lag):
    """
    Get centroid position as a function of lag.
    Centroid is central longitude for averaging box.
    Assumes center is at Niño 4 box for December.
    x = x0 + lag * (alpha_minus * H(-lag) + alpha_plus * H(lag))
    """

    ## heaviside func
    H = lambda x: np.heaviside(x, 1)

    ## define constants
    x0 = 195
    alpha_minus = -55 / 5  # units: degrees E / mon
    alpha_plus = 55 / 7

    ## get x
    x = x0 + lag * (alpha_plus * H(lag) + alpha_minus * H(-lag))

    ## cap eastward extent to Niño 3
    exceedence = H(x - 250) * (x - 250)
    x -= exceedence

    return x


def get_centroid_posn_jul(lag):
    """
    Get centroid position as a function of lag.
    Centroid is central longitude for averaging box.
    Assumes center is at Niño 4 box for December.
    x = x0 + lag * (alpha_minus * H(-lag) + alpha_plus * H(lag))
    """

    ## heaviside func
    H = lambda x: np.heaviside(x, 1)

    ## define constants
    x0 = 230
    alpha_minus = 55 / 7  # units: degrees E / mon
    alpha_plus = -55 / 5

    ## get x
    x = x0 + lag * (alpha_plus * H(lag) + alpha_minus * H(-lag))

    ## cap westward extent to Niño 4
    exceedence = H(195 - x) * (195 - x)
    x += exceedence

    return x


def get_centroid_mask(centroids, longitude):
    """get mask for averaging"""

    ## get distance from centroid
    delta_x = longitude - centroids

    ## find longitudes close to centroid
    mask = np.abs(delta_x) < 25

    return mask.where(mask)


def get_merimean(x):
    """fn to get meridional mean"""
    lat = dict(latitude=slice(-5, 5))
    return x.sel(lat).mean("latitude")

#### New version

In [None]:
def get_composite(idx, data, peak_month, time_idx, q=0.95, is_warm=True):
    """
    Get hovmoller composite based on specified:
    - data: used to compute index/make composite
    - peak_month: month to center composite on
    - q: quantile threshold for composite
    """

    ## get data subset
    data_ = data.sel(time_idx)
    idx_ = idx.sel(time_idx)

    ## handle warm/cold case
    if is_warm:
        kwargs = dict(q=q, check_cutoff=lambda x, cut: x > cut)
    else:
        kwargs = dict(q=1 - q, check_cutoff=lambda x, cut: x < cut)

    ## kwargs for composite
    kwargs = dict(kwargs, peak_month=peak_month, idx=idx_, data=data_)

    ## composite of projected data
    comp_proj = src.utils.make_composite(**kwargs)

    return comp_proj


def get_spatial_composite(components, **composite_kwargs):
    """
    Get spatial composite
    """

    ## get projected composite
    comp_proj = get_composite(**composite_kwargs)

    ## reconstruct spatial fields
    comp = reconstruct_helper(comp_proj, components, func=lambda x: x).drop_vars("mode")

    ## reconstruct relative SST
    if "trop_sst" in comp:
        comp["sst_rel"] = comp["sst_total"] - comp["trop_sst_05"]

    return comp


def reconstruct_helper(composite, components, func):
    """reconstruction helper function for composite"""

    ## copy to hold reconstructed results
    composite_recon = copy.deepcopy(composite)

    ## reconstruct anomalies
    for c in list(components):
        composite_recon[c] = src.utils.reconstruct_fn(
            components=components[c],
            scores=composite[c],
            fn=func,
        )

    ## check for "total" fields
    for c in list(composite):
        if "_total" in c:
            n = c[:-6]
            composite_recon[c] = src.utils.reconstruct_fn(
                components=components[n],
                scores=composite[c],
                fn=func,
            )

    return composite_recon

In [None]:
## create data array for computing composite
comp_data = xr.merge(
    [
        anom,  # anomalies
        scores[["sst"]].rename({"sst": "sst_total"}),
        # eli["eli_05"],
        # trop_sst["trop_sst_05"],
    ]
)

In [None]:
import copy

## specify shared args
kwargs = dict(
    peak_month=6,
    q=0.95,
    idx=Th["T_3"],
    is_warm=True,
    data=comp_data,
    components=components,
)

## specify early/late times
t_idx_early = dict(time=slice("1853", "1882"))
t_idx_late = dict(time=slice("2068", "2097"))


## compute
comp_early = get_spatial_composite(time_idx=t_idx_early, **kwargs)
comp_late = get_spatial_composite(time_idx=t_idx_late, **kwargs)

## get centroid positions for composites
center_dec = get_centroid_posn_dec(comp_early.lag)
center_jul = get_centroid_posn_jul(comp_early.lag)

New plotting func

In [None]:
## set up plot
fig, axs = plt.subplots(1, 3, figsize=(7, 3), layout="constrained")

## plot  composite for each period
for ax, comp in zip(axs[:2], [comp_early, comp_late]):

    cf = ax.contourf(
        merimean.longitude,
        merimean.lag,
        get_merimean(comp["taux"]).transpose("lag", ...),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(30, 6),
        extend="both",
    )


## plot difference
cf_diff = axs[2].contourf(
    merimean.longitude,
    merimean.lag,
    get_merimean(comp_late - comp_early)["taux"].transpose("lag", ...),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(20, 4),
    extend="both",
)

for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])

    ## label x axis
    ax.set_xlabel("Longitude")
    ax.set_xticks([190, 240])
    ax.set_yticks([])
    ax.set_xlim([140, 280])
    ax.axvline(190, c="w", ls="--", lw=0.8)
    ax.axvline(240, c="w", ls="--", lw=0.8)
    ax.axhline(6, c="w", ls="--", lw=0.8, alpha=0.5)

## label
font_kwargs = dict(size=10)
axs[0].set_title("Early (1853-1883)", **font_kwargs)
axs[1].set_title("Late (2067-2097)", **font_kwargs)
axs[2].set_title("Difference", **font_kwargs)
cb = fig.colorbar(cf, ax=axs[1], ticks=[30, 0, -30])
cb_diff = fig.colorbar(cf_diff, ax=axs[2], ticks=[30, 0, -30], label="stress")
src.utils.label_hov_yaxis(axs[0], peak_mon=kwargs["peak_month"])


plt.show()

Original plotting func

In [None]:
## specify amplitudes for plots
scales = np.array([1, 1, 0.5])

## set up plot
fig, axs = plt.subplots(1, 3, figsize=(6, 3), layout="constrained")

for ax, comp, scale in zip(
    axs, [comp_early, comp_late, comp_late - comp_early], scales
):
    cf, _ = src.utils.plot_hov(
        ax=ax, x=get_merimean(comp).transpose("lag", ...), beta=scale
    )

    ## label x axis
    ax.set_xlabel("Longitude")
    ax.set_xticks([190, 240])
    ax.set_yticks([])

## label
axs[0].set_title("Early (1853-1883)")
axs[1].set_title("Late (2067-2097)")
axs[2].set_title("Difference")
src.utils.label_hov_yaxis(axs[0], peak_mon=kwargs["peak_month"])


plt.show()

#### Spaghetti

##### Funcs

In [None]:
def make_spaghetti_helper(data, peak_month, q, idx_fn):
    """
    Get hovmoller composite based on specified:
    - data: used to compute index/make composite
    - peak_month: month to center composite on
    - q: quantile threshold for composite
    - idx_fn: function to compute index from spatial sst anoms
    """

    ## get index for composite
    idx = src.utils.reconstruct_fn(
        components=components["sst"], scores=data["sst"], fn=idx_fn
    )

    ## kwargs for composite
    kwargs = dict(peak_month=peak_month, q=q)

    ## composite of projected data
    comp_proj = src.utils.make_composite_helper(idx=idx, data=data, **kwargs)

    return comp_proj


def make_spaghetti_by_lag(spag_proj, fn_by_lag):
    """
    Get hovmoller composite based on specified:
    - spag_proj: projected spaghetti data
    - fns: functions to apply to each lag
    """

    ## empty list to hold result
    spag = []

    ## loop thru lags
    for lag in spag_proj.lag:

        spag_ = src.utils.reconstruct_fn(
            components=components["sst"],
            scores=spag_proj.sel(lag=lag),
            fn=lambda x: fn_by_lag(x, lag=lag),
        )

        spag.append(spag_)

    ## concatenate result
    return xr.concat(spag, dim=spag_proj.lag)


def make_spaghetti(data, peak_month, q, idx_fn, eval_fn, is_sliding=False):
    """Make spaghetti plot"""

    ## get index for composite
    idx = src.utils.reconstruct_fn(
        components=components["sst"], scores=data["sst"], fn=idx_fn
    )

    ## composite of projected data
    kwargs = dict(data=data["sst"], idx=idx, peak_month=peak_month, q=q)
    samples = src.utils.make_composite_helper(**kwargs)

    ## evaluate
    if is_sliding:
        return make_spaghetti_by_lag(samples, fn_by_lag=eval_fn)

    else:
        kwargs = dict(components=components["sst"], scores=samples, fn=eval_fn)
        return src.utils.reconstruct_fn(**kwargs)

##### Compute spaghettis

In [None]:
## specify peak month
peak_mon = 6

## should we use sliding index?
is_sliding = True

if peak_mon == 12:
    centroids = center_dec
    idx_fn = src.utils.get_nino34
else:
    centroids = center_jul
    idx_fn = src.utils.get_nino34

## get mask
mask = get_centroid_mask(centroids=centroids, longitude=comp_early.longitude)
sliding_idx = lambda x, lag: (mask.sel(lag=lag) * get_merimean(x)).mean("longitude")

## get evaluation function
if is_sliding:
    eval_fn = sliding_idx
else:
    eval_fn = src.utils.get_nino3

# ## specify shared args
kwargs = dict(
    peak_month=peak_mon,
    q=0.95,
    idx_fn=idx_fn,
    is_sliding=is_sliding,
    eval_fn=eval_fn,
)

## get corresponding month name
mon_name = calendar.month_name[kwargs["peak_month"]][:3]

## get early/late spagetti
spag_early = make_spaghetti(data=anom.isel(time=slice(None, 360)), **kwargs)
spag_late = make_spaghetti(data=anom.isel(time=slice(-360, None)), **kwargs)
# spag_late = make_spaghetti(data=anom.isel(time=slice(-720, -360)), **kwargs)

##### Plot spaghettis

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(7, 3))

for ax, spag, ls in zip(axs, [spag_early, spag_late], ["--", "-"]):

    ## plot individual samples
    for s in spag.sample:
        ax.plot(spag.lag, spag.sel(sample=s), lw=0.8, alpha=0.5)

    ## plot mean and bounds
    ax.plot(spag.lag, spag.mean("sample"), c="k", lw=2.5, ls=ls)
    # for q in [.1, .9]:
    #     ax.plot(spag.lag, spag.quantile(q=q, dim="sample"), c="k", lw=1.5, ls=ls)

## superimpose early mean on later mean
axs[1].plot(spag_early.lag, spag_early.mean("sample"), c="k", lw=2.5, ls="--")

## format plot
kwargs_ = dict(c="k", lw=0.8)
for ax in axs:

    ax.axhline(0, **kwargs_)
    ax.axvline(7, ls="--", **kwargs_)
    ax.axvline(-4, ls="--", **kwargs_)
    ax.set_xticks([-4, 0, 7], labels=[r"Aug$_{-}$", r"Dec", r"Jul$_{+}$"])
    ax.set_ylim([-2.5, 4])
    ax.set_yticks([-2, 0, 2])
    ax.axhline(3.5, c="r", ls="--", lw=1)

plt.show()

In [None]:
## subset for months
# Th_early = Th.isel(time=slice(None,360))
Th_early = Th.isel(time=slice(-360, None))
Th_dec = Th_early["T_34"].isel(time=Th_early.time.dt.month == 12)
Th_jul = Th_early["T_34"].isel(time=Th_early.time.dt.month == 8)

In [None]:
## compute correlations
corr_jul_leads, _ = scipy.stats.pearsonr(
    Th_jul.values.flatten(),
    Th_dec.values.flatten(),
)

corr_dec_leads, _ = scipy.stats.pearsonr(
    Th_jul.isel(time=slice(1, None)).values.flatten(),
    Th_dec.isel(time=slice(None, -1)).values.flatten(),
)

fig, axs = plt.subplots(1, 2, figsize=(8, 4))

## Scatter data
axs[0].scatter(
    Th_jul.values.flatten(),
    Th_dec.values.flatten(),
    s=10,
)

axs[1].scatter(
    Th_jul.isel(time=slice(1, None)).values.flatten(),
    Th_dec.isel(time=slice(None, -1)).values.flatten(),
    s=10,
)


## format/label
axs[0].set_title(f"Corr: {corr_jul_leads:.2f}")
axs[1].set_title(f"Corr: {corr_dec_leads:.2f}")
axs[0].plot(np.linspace(-3, 3), np.linspace(-3, 3), c="k", lw=1)
axs[0].set_xlabel("Jul(-1)")
axs[0].set_ylabel("Dec(0)")
axs[1].set_xlabel("Jul(+1)")
axs[0].set_yticks([-3, 0, 3])
axs[1].set_yticks([])

for ax in axs:
    ax.plot(np.linspace(-3, 3), np.linspace(-3, 3), c="k", lw=1)
    ax.set_xticks([-3, 0, 3])
    ax.set_aspect("equal")

##### Histograms

In [None]:
## compute skewness
skew_dec = scipy.stats.skew(Th_dec.values.flatten())
skew_jul = scipy.stats.skew(Th_jul.values.flatten())

edges = np.arange(-3.333, 3.666, 1 / 3)
pdf_dec, _ = get_empirical_pdf(Th_dec, bin_edges=edges)
pdf_jul, _ = get_empirical_pdf(Th_jul, bin_edges=edges)

fig, ax = plt.subplots(figsize=(4, 3))

ax.stairs(pdf_dec, edges, label=f"Dec (skew = {skew_dec:.2f})")
ax.stairs(pdf_jul, edges, label=f"Jul (skew = {skew_jul:.2f})")
ax.legend(prop=dict(size=8))
plt.show()

#### Simpler spaghetti

In [None]:
def make_spaghetti(idx_name, data, peak_month, q):
    """Make spaghetti plot"""

    ## get index
    idx = data[idx_name]

    ## composite of projected data
    kwargs = dict(data=data, idx=idx, peak_month=peak_month, q=q)
    samples = src.utils.make_composite_helper(**kwargs)

    return samples

In [None]:
## specify peak month
peak_mon = 12

## specify index
idx_name = "T_34"

# ## specify shared args
kwargs = dict(
    peak_month=peak_mon,
    q=0.95,
    idx_name=idx_name,
)

## get corresponding month name
mon_name = calendar.month_name[kwargs["peak_month"]][:3]

## normalize data (makes plotting easier...)
Th_norm = Th / Th.std()

## get early/late spagetti
spag_early = make_spaghetti(data=Th_norm.isel(time=slice(None, 360)), **kwargs)
spag_late = make_spaghetti(data=Th_norm.isel(time=slice(-360, None)), **kwargs)
# spag_late = make_spaghetti(data=Th_norm.isel(time=slice(-720, -360)), **kwargs)

In [None]:
## specify plot variable
plot_var = "T_34"

fig, axs = plt.subplots(2, 2, figsize=(4, 4), layout="constrained")

for j, plot_var in enumerate(["T_34", "T_3"]):

    for i, (spag, ls) in enumerate(zip([spag_early, spag_late], ["--", "-"])):

        ## plot individual samples
        for s in spag.sample:
            axs[i, j].plot(
                spag.lag, spag[plot_var].sel(sample=s), lw=0.5, alpha=0.5, c="gray"
            )

        ## plot mean and bounds
        axs[i, j].plot(spag.lag, spag[plot_var].mean("sample"), c="k", lw=2.5, ls=ls)

    ## superimpose early mean on later mean
    axs[1, j].plot(
        spag_early.lag, spag_early[plot_var].mean("sample"), c="k", lw=2.5, ls="--"
    )

## format plot
kwargs_ = dict(c="k", lw=0.8)
for ax in axs.flatten():

    ax.axhline(0, **kwargs_)
    ax.axvline(7, ls="--", **kwargs_)
    ax.axvline(-4, ls="--", **kwargs_)
    ax.set_xticks([-4, 0, 7], labels=[r"Aug$_{-}$", r"Dec", r"Jul$_{+}$"])
    ax.set_ylim([-3, 5])

## add titles/labels
axs[0, 0].set_title("Niño 3.4")
axs[0, 1].set_title("Niño 3")
axs[0, 1].set_ylabel("1850-1880")
axs[1, 1].set_ylabel("2070-2100")

for ax in axs[:, 0]:
    ax.set_yticks([-3, 0, 3])

for ax in axs[:, 1]:
    ax.set_yticks([])
    ax.yaxis.set_label_position("right")

for ax in axs[0, :]:
    ax.set_xticks([])

plt.show()

#### Scatter plots
Idea: push forward (or pull back) trajectories with maxima in June/Aug

In [None]:
## specify time index
t_idx = dict(time=slice(None, 360))
# t_idx = dict(time=slice(-360,None))

## get composites for each month
data_kwargs = dict(data=Th_norm.isel(t_idx), q=0.95)
Th_jun = make_spaghetti(peak_month=6, idx_name="T_3", **data_kwargs)
Th_dec = make_spaghetti(peak_month=12, idx_name="T_34", **data_kwargs)
Th_aug = make_spaghetti(peak_month=8, idx_name="T_34", **data_kwargs)

In [None]:
## these are lags for plotting:
## jun, dec, aug lags: [-6, 0, 4]

## get colors (useful for plotting)
colors = sns.color_palette()

## set up plot
fig, ax = plt.subplots(figsize=(3.5, 3.5), layout="constrained")

## helper func for plotting
Th_scat = lambda Th, **kwargs: ax.scatter(Th["T_34"], Th["h"], **kwargs)

## Plot for june
kwargs = dict(color=colors[0], label="Jun (pull back)")
Th_scat(Th_jun.sel(lag=-6), **kwargs)

## plot for Dec
kwargs = dict(color="none", edgecolor="k", label="Dec")
Th_scat(Th_dec.sel(lag=0), **kwargs)

## plot for Aug
kwargs = dict(color=colors[1], marker="x", label="Aug (push forward)")
Th_scat(Th_aug.sel(lag=4), **kwargs)

## set plot format
ax.set_xlim([-1, 4])
ax.set_ylim([-3, 4])
line_kwargs = dict(c="k", lw=0.8)
ax.axhline(0, **line_kwargs)
ax.axvline(0, **line_kwargs)
# ax.set_aspect("equal")
ax.legend(prop=dict(size=7), loc="upper left")
ax.set_xlabel(r"$T$")
ax.set_ylabel(r"$h_m$")

plt.show()

In [None]:
## Fit RO to data
RO = src.XRO.XRO(ncycle=12, ac_order=3, is_forward=True)
fit = RO.fit_matrix(data_kwargs["data"][["T_34", "h"]], ac_mask_idx=[(1, 1)])

## Get operator in december
L_dec = fit.Lac.isel(cycle=-1)
w, U, V = scipy.linalg.eig(L_dec.values, left=True, right=True)

## evaluate eigenfunctions on data
proj = lambda Th: U[:, :1].T @ Th[["T_34", "h"]].to_dataarray().values
proj_jun = proj(Th_jun.sel(lag=-6))
proj_dec = proj(Th_dec.sel(lag=0))
proj_aug = proj(Th_aug.sel(lag=4))

In [None]:
## get colors (useful for plotting)
colors = sns.color_palette()

## set up plot
fig, ax = plt.subplots(figsize=(3, 3), layout="constrained")

## helper func for plotting
proj_scat = lambda proj_, **kwargs: ax.scatter(proj_.real, proj_.imag, **kwargs)

## Plot for june
kwargs = dict(color=colors[0], label="Jun (pull back)")
proj_scat(proj_jun, **kwargs)

## plot for Dec
kwargs = dict(color="none", edgecolor="k", label="Dec")
proj_scat(proj_dec, **kwargs)

## plot for Aug
kwargs = dict(color=colors[1], marker="x", label="Aug (push forward)")
proj_scat(proj_aug, **kwargs)

## set plot format
line_kwargs = dict(c="k", lw=0.8)
ax.axhline(0, **line_kwargs)
ax.axvline(0, **line_kwargs)
ax.set_aspect("equal")
ax.legend(prop=dict(size=8))

plt.show()

## change in Bjerknes coupling over time

In [None]:
def get_alpha(data, dim):
    """compute mu: linear dependence of nhf on sst"""

    return src.utils.regress_core(Y=data["nhf"], X=data["sst"], dim=dim)


def get_mu(data, dim):
    """compute mu: linear dependence of taux on sst"""

    return src.utils.regress_core(Y=data["taux"], X=data["sst"], dim=dim)


def get_beta(data, dim):
    """compute mu: linear dependence of ssh on taux"""

    return src.utils.regress_core(Y=data["ssh"], X=data["taux"], dim=dim)


def get_xi(data, dim):
    """compute mu: linear dependence of sst on ssh"""

    return src.utils.regress_core(Y=data["sst"], X=data["ssh"], dim=dim)


def get_params(data, dim="time"):
    """compute all parameters"""
    params = xr.merge(
        [
            get_alpha(data, dim=dim).rename("alpha"),
            get_mu(data, dim=dim).rename("mu"),
            get_beta(data, dim=dim).rename("beta"),
            get_xi(data, dim=dim).rename("xi"),
        ]
    )

    params["coupling"] = params["mu"] * params["beta"] * params["xi"]

    return params


def get_rolling_params(data, n=10, reduce_ensemble=True):
    """get rolling parameters"""

    ## get rolling data
    idx_rolling = data.rolling({"time": 2 * n + 1}, center=True)

    ## expand rolling object along "window" dimension
    idx_rolling = idx_rolling.construct("window")

    ## stack member/window if desired
    if reduce_ensemble:
        idx_rolling = idx_rolling.stack(sample=["member", "window"])

    else:
        idx_rolling = idx_rolling.rename({"window": "sample"})

    return get_params(idx_rolling, dim="sample")


def get_rolling_params_bymonth(data, **kwargs):
    """get rolling parameters for each month separately..."""

    return data.groupby("time.month").map(get_rolling_params, **kwargs)


def get_fractional_change(data, dim="year"):
    clim = data.isel({dim: slice(None, 30)}).mean(dim)

    return (data - clim) / clim

In [None]:
## path to EOF data
eofs_fp = pathlib.Path(DATA_FP, "mpi", "eofs300")

## variables to load (and how to rename them)
names = ["ts", "ssh", "taux", "nhf"]
newnames = ["sst", "ssh", "taux", "nhf"]

## load the EOFs
load_var = lambda x: src.utils.load_eofs(pathlib.Path(eofs_fp, f"{x}.nc"))
eofs = {y: load_var(x) for (y, x) in zip(newnames, names)}

## for convenience, put spatial patterns / components in single dataset
components = xr.merge([eofs_.components().rename(y) for (y, eofs_) in eofs.items()])
pc_data = xr.merge([eofs_.scores().rename(y) for (y, eofs_) in eofs.items()])

## convert unit of taux to something reasonable (check with Yann about this...)
pc_data["taux"].values *= 1e-3

## get forced signal and anomalies
forced = pc_data.mean("member")
anom = pc_data - forced

In [None]:
## compute indices
kwargs = dict(scores=anom, components=components)
nino3 = src.utils.reconstruct_fn(fn=src.utils.get_nino3, **kwargs)
nino4 = src.utils.reconstruct_fn(fn=src.utils.get_nino4, **kwargs)

## subset for relevant indices
idxs = xr.merge([nino3[["sst", "nhf", "ssh"]], nino4["taux"]])

Get value of params over time

In [None]:
## get params over time
params = get_rolling_params_bymonth(idxs, n=16)

## unstack month and year to separate dims
params = src.utils.unstack_month_and_year(params)

## subtract off climatology
params_baseline = params.isel(year=slice(None, 30)).mean("year")
delta_params = params - params_baseline

## get change in coupling
coupling_pct_change = delta_params["coupling"] / params_baseline["coupling"]

Hovmoller

In [None]:
## setup plot
fig, axs = plt.subplots(1, 5, figsize=(9, 4), layout="constrained")

## plot T data
plot_hov2(
    axs[0],
    delta_params["alpha"].T,
    amp=6,
    label=f"SST - Net heat flux\n" + r"($W ~m^{-2}~K^{-1}$)",
)
plot_hov2(axs[1], delta_params["mu"].T, amp=0.006, label=r"SST-$\tau_x$" + f"\n(Pa/K)")
plot_hov2(
    axs[2],
    delta_params["beta"].T,
    amp=150,
    label=r"$\tau_x - \text{SSH}$" + f"\n(cm/Pa)",
)
plot_hov2(axs[3], delta_params["xi"].T, amp=7e-2, label=f"SSH-SST\n(K/cm)")
plot_hov2(axs[4], delta_params["coupling"].T, amp=0.5, label=f"coupling\n(unitless)")

for ax in axs:
    ax.axvline(7, c="k", lw=1, ls="--")

## label
axs[0].set_yticks(np.linspace(1870, 2082, 5))
axs[0].set_ylabel("Year")
axs[1].set_ylim(axs[0].get_ylim())
plt.show()

Plot annual mean

In [None]:
colors = sns.color_palette()

fig, ax = plt.subplots(figsize=(3, 2.5))
ax.plot(params.year, params["mu"].mean("month"), label=r"$\mu$", c=colors[0])

ax2 = ax.twinx()
ax2.plot(params.year, params["alpha"].mean("month"), label=r"$\alpha$", c=colors[1])

## clean up
ax.set_ylim([0, 0.01])
ax2.set_ylim([0, -10])
ax.set_yticks([0, 0.005, 0.01])
ax2.set_yticks([0, -5, -10])


plt.show()

## Look at change in spatial patterns...

### add patterns to data

In [None]:
for v in list(components):
    if f"{v}_comp" not in list(anom):
        anom[f"{v}_comp"] = components[v]

## add T and h indices to data
names = ["nino3", "nino34", "nino4"]
fns = [src.utils.get_nino3, src.utils.get_nino34, src.utils.get_nino4]
for n, fn in zip(names, fns):
    if n not in list(anom):
        anom[n] = src.utils.reconstruct_fn(
            scores=anom["sst"], components=anom["sst_comp"], fn=fn
        )

names = ["h_w", "h"]
fns = [src.utils.get_RO_hw, src.utils.get_RO_h]
for n, fn in zip(names, fns):
    if n not in list(anom):
        anom[n] = src.utils.reconstruct_fn(
            scores=anom["ssh"], components=anom["ssh_comp"], fn=fn
        )

Partition into early/late

In [None]:
anom_early = anom.isel(time=slice(None, 360))
anom_late = anom.isel(time=slice(-360, None))

Remove linear dependence on SST and get tendencies

In [None]:
def prep(data):
    """remove sst dependence and compute tendencies"""

    ## remove SST dependence from SSH field
    data["ssh_hat"] = src.utils.remove_sst_dependence_v2(
        data, h_var="ssh", T_var="nino34"
    )
    data["ssh_hat_comp"] = data["ssh_comp"]

    ## remove from h indices
    for h_idx in ["h_w", "h"]:
        data[f"{h_idx}_hat"] = src.utils.remove_sst_dependence_v2(
            data, h_var=h_idx, T_var="nino34"
        )

    ## compute tendencies
    data = src.utils.get_ddt(data)

    return data

In [None]:
anom_early = prep(anom_early)
anom_late = prep(anom_late)

#### Compute RO coefficients

In [None]:
def get_RO_coefs(data, h_var):
    """Function to get RO coefficients"""

    ## should we use 'hat' var?
    if "hat" in h_var:
        ssh_var = f"ssh_hat"
    else:
        ssh_var = "ssh"

    ## get coefs
    coefs_dTdt = get_RO_coefs_helper(data=data, h_var=h_var, y_var="ddt_sst")
    coefs_dhdt = get_RO_coefs_helper(data=data, h_var=h_var, y_var=f"ddt_{ssh_var}")

    ## rename and merge
    coefs = xr.merge(
        [
            coefs_dTdt.rename({"nino34": "R", h_var: "F1"}),
            coefs_dhdt.rename({"nino34": "F2", h_var: "eps"}),
        ]
    )

    return coefs


def get_RO_coefs_helper(data, h_var, y_var):
    """Function to get RO coefficients for one row of ROM"""

    ## shared args
    kwargs = dict(x_vars=["nino34", h_var], y_var=y_var)

    ## compute coefficients
    coefs = src.utils.multi_regress_bymonth(data, **kwargs)
    # coefs = data.groupby("time.month").map(src.utils.multi_regress, **kwargs)

    return coefs

In [None]:
## specify h variable
h_var = "h_w_hat"

## compute
m_early = get_RO_coefs(anom_early, h_var)
m_late = get_RO_coefs(anom_late, h_var)

##### Plot results

In [None]:
def plot_spatial_comp(early, late, amp, amp_diff, label, month):
    """plot comparison between early and late periods"""

    ## select month
    if type(month) is int:
        sel_mon = lambda x: x.sel(month=month)
    else:
        sel_mon = month

    ## set up plot
    fig = plt.figure(figsize=(7, 3.9), layout="constrained")
    format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
    axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

    contour_kwargs = dict(amp=amp, sel=sel_mon)
    cp0 = src.utils.make_contour_plot(axs[0, 0], early, **contour_kwargs)
    cp1 = src.utils.make_contour_plot(axs[1, 0], late, **contour_kwargs)
    cp2 = src.utils.make_contour_plot(
        axs[2, 0], late - early, **dict(contour_kwargs, amp=amp_diff)
    )

    ## colorbar
    lab = r"$K ~\text{yr}^{-1}~ \left(T_{34}\right)^{-1}$"
    cb0 = fig.colorbar(cp0, ax=axs[:-1], ticks=[-amp, 0, amp], label=label)
    cb1 = fig.colorbar(cp2, ax=axs[-1], ticks=[-amp_diff, 0, amp_diff], label=label)

    ## Niño 3 box
    box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
    for ax in axs.flatten():
        src.utils.plot_nino34_box(ax, **box_kwargs)

    return fig, axs

    merimean = lambda x: x.sel(latitude=slice(-5, 5)).mean("latitude")


def plot_hov_comp(early, late, amp, amp_diff, label):
    """hovmoller comparison plot"""

    ## func to compute meridional mean
    merimean = lambda x: x.sel(latitude=slice(-5, 5)).mean("latitude")

    ## make hövmöllers
    fig, axs = plt.subplots(3, 1, figsize=(3.5, 5), layout="constrained")

    ## kwargs
    kwargs = dict(
        cmap="cmo.balance", levels=src.utils.make_cb_range(amp, amp / 10), extend="both"
    )
    cb_kwargs = dict(ticks=[-amp, 0, amp], label=label)

    ## plot early
    cp0 = src.utils.plot_cycle_hov(axs[0], merimean(early), **kwargs)
    cb0 = fig.colorbar(cp0, ax=axs[0], **cb_kwargs)

    # ## plot late
    cp1 = src.utils.plot_cycle_hov(axs[1], merimean(late), **kwargs)
    cb1 = fig.colorbar(cp1, ax=axs[1], **cb_kwargs)

    ## plot difference
    cp2 = src.utils.plot_cycle_hov(
        axs[2],
        merimean(late - early),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(amp_diff, amp_diff / 10),
        extend="both",
    )

    cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-amp_diff, 0, amp_diff], label=label)

    ## label
    axs[0].set_title("Early")
    axs[1].set_title("Late")
    axs[2].set_title("Difference")
    axs[-1].set_xlabel("Longitude")
    axs[-1].set_xticks([140, 190, 240])

    return fig, axs

#### $R$ 
$\frac{d T}{dt}$ vs $T_{34}$

In [None]:
R_kwargs = dict(
    early=m_early["R"],
    late=m_late["R"],
    amp=5,
    amp_diff=5,
    label=r"$K ~\text{yr}^{-1}~ \left(T_{34}\right)^{-1}$",
)

## spatial plot
fig, axs = plot_spatial_comp(
    month=7,
    **R_kwargs,
)

## hovmoller
fig, axs = plot_hov_comp(
    **R_kwargs,
)

#### $F_2$ 
$\frac{d \hat{h}}{dt}$ vs $T_{34}$

In [None]:
F2_kwargs = dict(
    early=m_early["F2"],
    late=m_late["F2"],
    amp=20,
    amp_diff=10,
    label=r"$m ~\text{yr}^{-1}~ \left(T_{34}\right)^{-1}$",
)

## spatial plot
fig, axs = plot_spatial_comp(
    month=5,
    **F2_kwargs,
)

for ax in axs.flatten():
    src.utils.plot_box(ax, lons=[130, 210], lats=[-5, 5], c="white", ls="--", lw=1)
plt.show()

## hovmoller
fig, axs = plot_hov_comp(
    **F2_kwargs,
)

#### $F_1$ 
$\frac{d T}{dt}$ vs $h_w$

In [None]:
F1_kwargs = dict(
    early=m_early["F1"],
    late=m_late["F1"],
    amp=1,
    amp_diff=0.5,
    label=r"$k ~\text{yr}^{-1}~ \left(h_w\right)^{-1}$",
)

## spatial plot
fig, axs = plot_spatial_comp(
    month=5,
    **F1_kwargs,
)

## hovmoller
fig, axs = plot_hov_comp(
    **F1_kwargs,
)

#### $\varepsilon$ 
$\frac{d h_w}{dt}$ vs $h_w$

In [None]:
F1_kwargs = dict(
    early=m_early["eps"],
    late=m_late["eps"],
    amp=10,
    amp_diff=5,
    label=r"$m ~\text{yr}^{-1}~ \left(h_w\right)^{-1}$",
)

## spatial plot
fig, axs = plot_spatial_comp(
    month=6,
    **F1_kwargs,
)
for ax in axs.flatten():
    src.utils.plot_box(ax, lons=[130, 210], lats=[-5, 5], c="white", ls="--", lw=1)
plt.show()


## hovmoller
fig, axs = plot_hov_comp(
    **F1_kwargs,
)

### $\tau_x$–SST

##### compute coefs

In [None]:
## shared args
kwargs = dict(x_vars=["nino3"], y_var="taux")

## compute coefficients
m_early = src.utils.multi_regress_bymonth(anom_early, **kwargs)["nino3"]
m_late = src.utils.multi_regress_bymonth(anom_late, **kwargs)["nino3"]

##### Spatial plot

In [None]:
kwargs = dict(
    early=m_early,
    late=m_late,
    amp=1.5e-2,
    amp_diff=0.75e-2,
    label=r"$Pa~K^{-1}$",
)

## spatial plot
fig, axs = plot_spatial_comp(
    month=lambda x: x.mean("month"),
    **kwargs,
)

fig, axs = plt.subplots(3, 1, figsize=(3, 6), layout="constrained")

## plot data
cp0 = src.utils.plot_cycle_hov_v3(axs[0], data=m_early, amp=1.5e-2)
cp1 = src.utils.plot_cycle_hov_v3(axs[1], data=m_late, amp=1.5e-2)
cp2 = src.utils.plot_cycle_hov_v3(axs[2], data=m_late - m_early, amp=0.75e-2)

## colorbar
cb = fig.colorbar(
    cp2,
    ax=axs[2],
    ticks=[-0.02, 0, 0.02],
    label=r"$Pa ~K^{-1}$",
    orientation="horizontal",
)
axs = src.utils.format_hov_v3(axs)

## label
axs[0].set_ylabel("Early")
axs[1].set_ylabel("Late")
axs[2].set_ylabel("Change (x2)")

plt.show()

### Heatflux

In [None]:
## shared args
kwargs = dict(x_vars=["nino3"], y_var="nhf")

## compute coefficients
m_early = src.utils.multi_regress_bymonth(anom_early, **kwargs)["nino3"]
m_late = src.utils.multi_regress_bymonth(anom_late, **kwargs)["nino3"]

##### Plot

In [None]:
kwargs = dict(
    early=m_early,
    late=m_late,
    amp=20,
    amp_diff=10,
    label=r"$W~m^{-2}~K^{-1}$",
)

## spatial plot
fig, axs = plot_spatial_comp(
    month=lambda x: x.mean("month"),
    **kwargs,
)

##### Hovmoller
fig, axs = plt.subplots(3, 1, figsize=(3, 6), layout="constrained")

## plot data
cp0 = src.utils.plot_cycle_hov_v3(axs[0], data=m_early, amp=25)
cp1 = src.utils.plot_cycle_hov_v3(axs[1], data=m_late, amp=25)
cp2 = src.utils.plot_cycle_hov_v3(axs[2], data=m_late - m_early, amp=12.5)

## colorbar
cb = fig.colorbar(
    cp2,
    ax=axs[2],
    ticks=[-20, 0, 20],
    label=r"$W~m^{-2} ~K^{-1}$",
    orientation="horizontal",
)
axs = src.utils.format_hov_v3(axs)

## label
axs[0].set_ylabel("Early")
axs[1].set_ylabel("Late")
axs[2].set_ylabel("Change (x2)")

plt.show()

### SSH - $\tau_x$

In [None]:
## shared kwargs
kwargs = dict(x_var="taux", y_var="ssh", fn_x=src.utils.get_nino4)

## function to get slope
get_slope = lambda x: x.groupby("time.month").map(src.utils.regress_proj, **kwargs)

## then, reconstruct regression coefficient
m_early = get_slope(anom_early)
m_late = get_slope(anom_late)

##### Spatial plot

In [None]:
sel_mon = lambda x: x.mean("month")
# sel_mon = lambda x : x.sel(month=7)

## set up plot
fig = plt.figure(figsize=(7, 3.9), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=1, format_func=format_func)

contour_kwargs = dict(amp=4e2, sel=lambda x: x.mean("month"))
cp0 = src.utils.make_contour_plot(axs[0, 0], m_early, **contour_kwargs)
cp1 = src.utils.make_contour_plot(axs[1, 0], m_late, **contour_kwargs)
cp2 = src.utils.make_contour_plot(
    axs[2, 0], m_late - m_early, **dict(contour_kwargs, amp=2e2)
)

# ## colorbar
cb_kwargs = dict(label=r"$cm~Pa^{-1}$")
cb0 = fig.colorbar(cp0, ax=axs[:2], ticks=[-400, 0, 400], **cb_kwargs)
cb2 = fig.colorbar(cp2, ax=axs[-1], ticks=[-200, 0, 200], **cb_kwargs)

## Niño 4 box
box_kwargs = dict(c="k", linewidth=0.9, alpha=0.5)
for ax in axs.flatten():
    src.utils.plot_nino4_box(ax, **box_kwargs)

plt.show()

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(3, 6), layout="constrained")

## plot data
cp0 = src.utils.plot_cycle_hov_v3(axs[0], data=m_early, amp=400)
cp1 = src.utils.plot_cycle_hov_v3(axs[1], data=m_late, amp=400)

## plot bias
cp2 = src.utils.plot_cycle_hov_v3(axs[2], data=m_late - m_early, amp=200)

cb = fig.colorbar(
    cp2,
    ax=axs[2],
    ticks=[-400, 0, 400],
    label=r"$cm~Pa^{-1}$",
    orientation="horizontal",
)
axs = src.utils.format_hov_v3(axs)
axs[0].set_ylabel("Early")
axs[1].set_ylabel("Late")
axs[2].set_ylabel("Change (x2)")

plt.show()

### SST-SSH

In [None]:
## shared kwargs
kwargs = dict(x_var="ssh", y_var="sst")

## function to get slope
get_slope = lambda x: x.groupby("time.month").map(src.utils.regress_proj, **kwargs)

## then, reconstruct regression coefficient
m_early = get_slope(anom_early)
m_late = get_slope(anom_late)

In [None]:
kwargs = dict(
    early=m_early,
    late=m_late,
    amp=0.3,
    amp_diff=0.15,
    label=r"$K~cm^{-1}$",
)

## spatial plot
fig, axs = plot_spatial_comp(
    month=lambda x: x.mean("month"),
    **kwargs,
)

## hovmoller
fig, axs = plot_hov_comp(
    **kwargs,
)

## Change in zonal gradient

### Compute zonal gradient

In [None]:
def get_zonal_grad(x):
    """
    Function to compute zonal gradient. Ref: Fig 7 in Maher et al, 2023.
    """

    ## outlines for boxes
    ebox_coords = dict(longitude=slice(210, 270), latitude=slice(-5, 5))
    wbox_coords = dict(longitude=slice(120, 180), latitude=slice(-5, 5))

    ## compute box averages
    ebox_avg = src.utils.spatial_avg(x.sel(ebox_coords))
    wbox_avg = src.utils.spatial_avg(x.sel(wbox_coords))

    return ebox_avg - wbox_avg

In [None]:
## compute zonal gradient
zonal_grad_forced = src.utils.reconstruct_fn(components, forced, fn=get_zonal_grad)

## sep. time into year and month
zonal_grad_forced_bymonth = src.utils.unstack_month_and_year(zonal_grad_forced)

## get change from initial climatology
clim = zonal_grad_forced_bymonth.isel(year=slice(None, 30)).mean("year")
zonal_grad_change = zonal_grad_forced_bymonth - clim

In [None]:
## setup plot
fig, ax = plt.subplots(figsize=(2, 4), layout="constrained")

## plot data
kwargs = dict(levels=src.utils.make_cb_range(1, 0.1), cmap="cmo.balance")
plot_data = ax.contourf(
    zonal_grad_change.month, zonal_grad_change.year, zonal_grad_change["sst"], **kwargs
)

## label
ax.set_ylabel("Year")
ax.set_xlabel("Month")
ax.set_xticks([1, 6, 12])
ax.set_title(r"$\Delta \left(\partial_x T\right)$")
ax.set_ylim([1950, None])

plt.show()

## Look at mean state-dependence

### Compare Bjerknes growth rate to zonal gradient

#### Load $T,h$ data

In [None]:
## MPI data
mpi_load_fp = pathlib.Path(DATA_FP, "mpi_Th", "Th.nc")
Th = xr.open_dataset(mpi_load_fp)

#### Fit RO to $T,h$ data

In [None]:
## get subset of data
Th_sub = Th.sel(time=slice("1979", "2024"))

## initialize model
model = XRO(ncycle=12, ac_order=3, is_forward=True)

## fit to individual ensemble members
kwargs = dict(model=model, T_var="T_3", h_var="h_w", verbose=True)
_, fits = src.utils.get_RO_ensemble(Th_sub, **kwargs)

## extract parameters
params = model.get_RO_parameters(fits)

#### Look at intra-ensemble spread
$\partial_x T$ vs. BJ index  
Compare to Maher et al (2023)

In [None]:
## get subset of data to look at
anom_ = anom.sel(time=slice("1979", "2024"))

## compute zonal gradient at every time step
zonal_grad = src.utils.reconstruct_fn(components, anom_, fn=get_zonal_grad)

## get monthly avg
zonal_grad_by_month = zonal_grad.groupby("time.month").mean()

fig, ax = plt.subplots(figsize=(2, 2))
ax.scatter(zonal_grad_by_month["sst"].mean("month"), params["R"].mean("cycle"), s=10)
ax.set_xlabel("Zonal SST gradient")
ax.set_ylabel("Bjerknes feedback")
plt.show()