# Mean-state changes
Look at how mean state changes over time

In [None]:
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import src.XRO
import copy
import scipy.stats
import warnings
import calendar
import pandas as pd
import cartopy.util

# import gsw

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_ensemble_fn(data, fn):
    """apply python function to data using xr.apply_ufunc"""

    ## stack data
    data_stack = data.stack(sample=["time", "member"])

    ## kwargs for apply_ufunc
    kwargs = dict(
        input_core_dims=[["sample", "year"]],
        output_core_dims=[["year"]],
    )

    return xr.apply_ufunc(fn, data_stack, **kwargs)


def get_ensemble_fn_bymonth(data, fn):
    """apply python function to data by month using xr.apply_ufunc"""
    return data.groupby("time.month").map(get_ensemble_fn, fn=fn)


def get_nino3(x, z_t=70, lat_bound=5):
    """get Niño 3 average, with checks in place"""

    if "z_t" in x.dims:
        x = x.sel(z_t=z_t, method="nearest")

    if "latitude" in x.dims:
        x = x.sel(latitude=slice(-lat_bound, lat_bound)).mean("latitude")

    return x.sel(longitude=slice(210, 270)).mean("longitude")


def get_nino3_helper(x, lat_bounds):
    """get Niño 3 average, with checks in place"""

    ## get index to select data
    idx = dict(latitude=slice(*lat_bounds), longitude=slice(210, 270))

    return x.sel(idx).mean(["longitude", "latitude"])


def get_nino3_north(x):
    return get_nino3_helper(x, lat_bounds=(0, 10))


def get_nino3_south(x):
    return get_nino3_helper(x, lat_bounds=(-10, 0))


def save(fig, fname, dpi=300):
    """save figure to file"""

    ## get save directory
    save_dir = pathlib.Path(os.environ["SAVE_FP"], "ch3-outline")

    ## get fname
    fname = save_dir / f"{fname}.pdf"

    fig.savefig(fname, dpi=dpi)

## Load data

In [None]:
## load spatial data
forced, anom_ = src.utils.load_consolidated()

## get subset of data for total
VARNAMES = ["T", "w", "pr", "tauy"]
total = anom_[VARNAMES] + forced[VARNAMES]
total = xr.merge([forced[[f"{v}_comp" for v in VARNAMES]], total])

## get nino3 averages
total_n3 = src.utils.reconstruct_wrapper(total, fn=get_nino3)

## rename to avoid conflicts
total_n3 = total_n3.rename({v: f"{v}_n3" for v in VARNAMES})

## get precip North and south of equator
pr_N = src.utils.reconstruct_wrapper(
    total[["pr", "pr_comp", "tauy", "tauy_comp"]],
    fn=get_nino3_north,
).rename({"pr": "pr_n", "tauy": "tauy_n"})

pr_S = src.utils.reconstruct_wrapper(
    total[["pr", "pr_comp", "tauy", "tauy_comp"]],
    fn=get_nino3_south,
).rename({"pr": "pr_s", "tauy": "tauy_s"})

## merge
total_n3 = xr.merge([total_n3, pr_N, pr_S])

In [None]:
## get windowed data
forced = src.utils.get_windowed(forced).groupby("time.month").mean()

## reconstruct data
forced = src.utils.reconstruct_wrapper(forced)

# ## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")
Th_total = xr.merge([Th_total, trop_sst])

## custom h data
h_mg_forced, h_mg_anom = src.utils.load_h_data(max_grad=True)

In [None]:
## compute dTdx
Th_total["dTdx"] = Th_total["T_4"] - Th_total["T_3"]

## load ELI
eli = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/eli.nc"))
Th_total = xr.merge([Th_total, eli, total_n3])


## windowed version
Th_total = src.utils.get_windowed(Th_total)

## relative SST
for v in ["T_3", "T_34", "T_4"]:
    Th_total[f"{v}_rel"] = Th_total[v] - Th_total["trop_sst_15"]

## ensemble mean
Th_total_forced = Th_total.groupby("time.month").mean(["time", "member"])

## merge data
forced = xr.merge([forced, Th_total_forced])

In [None]:
## compute quantiles
quant = Th_total.groupby("time.season").quantile(
    q=[0.05, 0.25, 0.5, 0.75, 0.95], dim=["time", "member"]
)
quant = quant.rename({"quantile": "q"})

## compute difference
delta_q = quant - quant.isel(year=0)

## get spread
spread = quant.sel(q=0.95) - quant.sel(q=0.05)
delta_spread = spread - spread.isel(year=0)

#### Compute variance

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## get windowed
Th = src.utils.get_windowed(Th)
for v in ["05", "15", "20", "25", "30"]:
    Th[f"eli_{v}"] = Th_total[f"eli_{v}"] - Th_total[f"eli_{v}"].mean("member")

## get variance by month
Th_sigma = Th.groupby("time.month").std(["time", "member"])

## fractional change
Th_delta_sigma = (Th_sigma - Th_sigma.isel(year=0)) / Th_sigma.isel(year=0)

#### Compute skewness

In [None]:
Th_skew = get_ensemble_fn_bymonth(Th, fn=scipy.stats.skew)
Th_delta_skew = Th_skew - Th_skew.isel(year=0)

#### Compute quantiles

In [None]:
## compute quantiles
quant = Th_total.groupby("time.season").quantile(
    q=[0.02, 0.05, 0.15, 0.25, 0.5, 0.75, 0.85, 0.95, 0.98], dim=["time", "member"]
)
quant = quant.rename({"quantile": "q"})

## compute difference
delta_q = quant - quant.isel(year=0)

## get spread
spread = quant.sel(q=0.95) - quant.sel(q=0.05)
delta_spread = spread - spread.isel(year=0)

## Analysis

### Spatial plots

In [None]:
forced0 = forced.sel(year=1870)
forced1 = forced.sel(year=1980)
forced2 = forced.sel(year=2085)

#### SST

In [None]:
v0 = "sst"
v1 = "trop_sst_15"

import cartopy.crs as ccrs

sel_ = lambda x: x.sel(month=[12, 1, 2]).mean("month")

fig = plt.figure(figsize=(10, 4.375), layout="constrained")
format_func = lambda ax: src.utils.plot_setup_pac(ax, max_lat=30)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## specify kwargs
plot_kwargs = dict(
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(10, 1),
    extend="both",
    transform=ccrs.PlateCarree(),
)

for i, f, f_ in zip(
    np.arange(3), [forced0, forced1, forced2], [forced0, forced0, forced1]
):

    ## plot relative SST
    axs[i, 0].contourf(
        f.longitude,
        f.latitude,
        sel_(f[v0] - f[v1]),
        **plot_kwargs,
    )

    ## difference
    axs[i, 1].contourf(
        f.longitude,
        f.latitude,
        5 * (sel_(f[v0] - f[v1]) - sel_(f_[v0] - f_[v1])),
        **plot_kwargs,
    )


plt.show()

#### precip

In [None]:
v0 = "pr"

import cartopy.crs as ccrs

sel_ = lambda x: x.sel(month=[12, 1, 2]).mean("month")

fig = plt.figure(figsize=(10, 4.375), layout="constrained")
format_func = lambda ax: src.utils.plot_setup_pac(ax, max_lat=30)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## specify kwargs
plot_kwargs = dict(
    extend="both",
    transform=ccrs.PlateCarree(),
)

for i, f, f_ in zip(
    np.arange(3), [forced0, forced1, forced2], [forced0, forced0, forced1]
):

    ## plot relative SST
    axs[i, 0].contourf(
        f.longitude,
        f.latitude,
        8.6e4 * sel_(f[v0]),
        cmap="cmo.rain",
        levels=np.arange(0, 16, 2),
        **plot_kwargs,
    )

    ## difference
    axs[i, 1].contourf(
        f.longitude,
        f.latitude,
        8.6e4 * (sel_(f[v0] - f_[v0])),
        cmap="cmo.balance_r",
        levels=src.utils.make_cb_range(4, 0.4),
        **plot_kwargs,
    )


plt.show()

#### wind

In [None]:
# ## variable to plot
# v0 = "taux"
# SC=8e1

## variable to plot
v0 = "tauy"
SC = -16e1

import cartopy.crs as ccrs

sel_ = lambda x: x.sel(month=[3]).mean("month")

fig = plt.figure(figsize=(10, 4.375), layout="constrained")
format_func = lambda ax: src.utils.plot_setup_pac(ax, max_lat=30)
axs = src.utils.subplots_with_proj(fig, nrows=3, ncols=2, format_func=format_func)

## specify kwargs
plot_kwargs = dict(
    extend="both",
    transform=ccrs.PlateCarree(),
)

for i, f, f_ in zip(
    np.arange(3), [forced0, forced1, forced2], [forced0, forced0, forced1]
):

    ## plot relative SST
    axs[i, 0].contourf(
        f.longitude,
        f.latitude,
        SC * sel_(f[v0]),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(10, 1),
        **plot_kwargs,
    )

    ## difference
    axs[i, 1].contourf(
        f.longitude,
        f.latitude,
        SC * (sel_(f[v0] - f_[v0])),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(2, 0.2),
        **plot_kwargs,
    )

    axs[i, 1].axhline(0, ls="--", c="w")


plt.show()

### Precip / $\tau_x$ on the equator

Absolute

In [None]:
## specify month and lat bound for averaging
MONTH = 5
LAT_BOUND = 5

## func to select data
merimean = lambda x, lat_bound=5: x.sel(latitude=slice(-lat_bound, lat_bound)).mean(
    "latitude"
)
sel = lambda x: merimean(x, lat_bound=LAT_BOUND).sel(month=MONTH).transpose("year", ...)

## get longitude for convenience
LON = forced.longitude

## compute center of mass
forced_pac = forced.sel(longitude=slice(140, 280))
com = sel(forced_pac * LON).integrate("longitude") / sel(forced_pac).integrate(
    "longitude"
)

## get argmin
taux_argmin = forced_pac.longitude.isel(
    longitude=sel(forced_pac["taux"]).argmin("longitude")
)

fig, axs = plt.subplots(1, 2, figsize=(4, 4), layout="constrained")

## precip
axs[0].contourf(
    forced.longitude,
    forced.year,
    8.6e4 * sel(forced["pr"]),
    cmap="cmo.rain",
    levels=np.arange(0, 13.5, 1.5),
)

## wind
axs[1].contourf(
    forced.longitude,
    forced.year,
    sel(forced["taux"]),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.075, 0.0075),
)

## plot longitude of max grad
axs[1].plot(
    taux_argmin,
    forced.year,
    c="w",
    ls="--",
    alpha=1,
    lw=1,
)

## plot longitude of max grad
axs[0].plot(
    com["pr"],
    forced.year,
    c="w",
    ls="--",
    alpha=1,
    lw=1,
)


## formatting
for ax in axs:
    ax.set_xlim([140, 280])
for ax in axs[1:]:
    ax.set_yticks([])

plt.show()

Change

In [None]:
## specify month and lat bound for averaging
MONTH = 5
LAT_BOUND = 5

## func to select data
merimean = lambda x, lat_bound=5: x.sel(latitude=slice(-lat_bound, lat_bound)).mean(
    "latitude"
)
diff = lambda x: (x - x.isel(year=0)) / x.isel(year=0)
sel = (
    lambda x: merimean(diff(x), lat_bound=LAT_BOUND)
    .sel(month=MONTH)
    .transpose("year", ...)
)

## get longitude for convenience
LON = forced.longitude

## compute center of mass
forced_pac = forced.sel(longitude=slice(140, 280))
com = sel(forced_pac * LON).integrate("longitude") / sel(forced_pac).integrate(
    "longitude"
)

## get argmin
taux_argmin = forced_pac.longitude.isel(
    longitude=sel(forced_pac["taux"]).argmin("longitude")
)

fig, axs = plt.subplots(1, 2, figsize=(4, 4), layout="constrained")

## shared args
kwargs = dict(
    levels=src.utils.make_cb_range(1, 0.1),
    extend="both",
)

## precip
axs[0].contourf(
    forced.longitude,
    forced.year,
    sel(forced["pr"]),
    cmap="cmo.balance_r",
    **kwargs,
)

## wind
axs[1].contourf(
    forced.longitude,
    forced.year,
    sel(forced["taux"]),
    cmap="cmo.balance",
    **kwargs,
)


## formatting
for ax in axs:
    ax.set_xlim([140, 280])
for ax in axs[1:]:
    ax.set_yticks([])

plt.show()

In [None]:
## specify month and lat bound for averaging
MONTH = 2

## func to select data
zonamean = lambda x: x.sel(longitude=slice(210, 270)).mean("longitude")
sel = lambda x: zonamean(x).sel(month=[2, 3, 4]).mean("month").transpose("year", ...)


fig, axs = plt.subplots(1, 2, figsize=(4, 4), layout="constrained")

## precip
axs[0].contourf(
    forced.latitude,
    forced.year,
    8.6e4 * sel(forced["pr"]),
    cmap="cmo.rain",
    levels=np.arange(0, 10, 1),
    extend="both",
)

## wind
axs[1].contourf(
    forced.latitude,
    forced.year,
    -sel(forced["tauy"]) + sel(forced["tauy"]).isel(year=0),
    cmap="cmo.balance",
    # levels=src.utils.make_cb_range(5e-2, 5e-3),
    levels=src.utils.make_cb_range(1e-2, 1e-3),
    extend="both",
)

## formatting
for ax in axs:
    ax.axvline(0, c="w", ls="--", lw=0.8)
    ax.set_xlim([-15, 15])
for ax in axs[1:]:
    ax.set_yticks([])

plt.show()

rain

In [None]:
y = 8.6e4 * forced["pr"].sel(longitude=slice(210, 270)).mean("longitude")
# y = 8.6e4 * forced["pr"].sel(longitude=slice(150, 220)).mean("longitude")
# y = 8.6e4 * forced["pr"].sel(longitude=slice(190, 240)).mean("longitude")
# y = 3e2 * forced["taux"].sel(longitude=slice(190,240)).mean("longitude")

Y0 = 1870
Y1 = 2080

fig, axs = plt.subplots(3, 1, figsize=(4, 8))
axs[0].contourf(
    y.month,
    y.latitude,
    y.sel(year=Y0).transpose("latitude", ...),
    cmap="cmo.rain",
    levels=np.arange(0, 15, 1),
    extend="both",
)

axs[1].contourf(
    y.month,
    y.latitude,
    y.sel(year=Y1).transpose("latitude", ...),
    cmap="cmo.rain",
    levels=np.arange(0, 15, 1),
    extend="both",
)

axs[2].contourf(
    y.month,
    y.latitude,
    (y.sel(year=Y1) - y.sel(year=Y0)).transpose("latitude", ...),
    cmap="cmo.balance_r",
    levels=src.utils.make_cb_range(5, 0.5),
    extend="both",
)

for ax in axs:
    ax.set_ylim([-20, 20])
    for t in [-5, 5]:
        ax.axhline(t, ls="--", c="gray", lw=0.8)

$\tau_x$

In [None]:
LON_RANGE = slice(190, 240)
y = 3e2 * forced["taux"].sel(longitude=LON_RANGE).mean("longitude")

Y0 = 1870
Y1 = 2080

fig, axs = plt.subplots(3, 1, figsize=(4, 8))
for ax, year in zip(axs[:2], [Y0, Y1]):
    ax.contourf(
        y.month,
        y.latitude,
        y.sel(year=year).transpose("latitude", ...),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(40, 4),
        extend="both",
    )

axs[2].contourf(
    y.month,
    y.latitude,
    (y.sel(year=Y1) - y.sel(year=Y0)).transpose("latitude", ...),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(10, 1),
    extend="both",
)

for ax in axs:
    ax.set_ylim([-20, 20])
    for t in [-5, 5]:
        ax.axhline(t, ls="--", c="gray", lw=0.8)

$\tau_y$

In [None]:
y = -2e2 * forced["tauy"].sel(longitude=slice(210, 270)).mean("longitude")
# y = 2e2*forced["taux"].sel(longitude=slice(210,270)).mean("longitude")


fig, axs = plt.subplots(2, 1, figsize=(5, 6))
axs[0].contourf(
    y.month,
    y.latitude,
    y.sel(year=1870).transpose("latitude", ...),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(15, 1.5),
    extend="both",
)

axs[1].contourf(
    y.month,
    y.latitude,
    (y.sel(year=1975) - y.sel(year=1870)).transpose("latitude", ...),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(1.5, 0.15),
    extend="both",
)

for ax in axs:
    ax.set_ylim([-20, 20])

In [None]:
y = src.utils.sel_month(Th_total, [2]).stack(s=["member", "time"])

fig, axs = plt.subplots(1, 2, figsize=(6.5, 3))
axs[0].scatter(
    # y["T_3"].sel(year=1870),
    -y["tauy_s"].sel(year=1870),
    (y["pr_n"] - y["pr_n3"]).sel(year=1870),
    s=5,
)

# axs[0].scatter(
#     # y["T_3"].sel(year=1970),
#     -y["tauy_s"].sel(year=1970),
#     (y["pr_n"]-y["pr_n3"]).sel(year=1970),
#     s=5,
# )
# src.utils.set_lims(axs)
plt.show()

### Stratification by season over time

In [None]:
def get_dTdz_sub(Tsub, mld, delta=25):
    """get velocity at base of mixed layer"""

    ## get temperature difference
    dT = src.utils.get_dT_sub(Tsub=Tsub, Hm=mld, delta=delta)

    ## get gradient
    dTdz = -dT / mld

    return dTdz

In [None]:
## compute strat
dTdz = get_dTdz_sub(Tsub=forced.T, mld=50, delta=20)

## average over Niño 3
dTdz_n3 = dTdz.sel(longitude=slice(210, 270)).mean("longitude")

## get fractional change
get_frac_change = lambda x: (x - x.isel(year=0)) / x.isel(year=0)


## make plot
fig, ax = plt.subplots(figsize=(2, 4), layout="constrained")

## precip
ax.contourf(
    dTdz_n3.month,
    dTdz_n3.year,
    # (dTdz_n3 - dTdz_n3.isel(year=0)).transpose("year", ...),
    get_frac_change(dTdz_n3).transpose("year", ...),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(1.5, 0.15),
    extend="both",
)

## formatting
ax.axvline(5, ls="--", lw=0.8, c="k")

plt.show()

$w$ by season

In [None]:
## compute strat
dTdz = get_dTdz_sub(Tsub=forced.T, mld=50, delta=20)

## average over Niño 3
dTdz_n3 = dTdz.sel(longitude=slice(210, 270)).mean("longitude")

## make plot
fig, ax = plt.subplots(figsize=(2, 4), layout="constrained")

## precip
ax.contourf(
    forced.month,
    forced.year,
    get_frac_change(forced["w_n3"]).transpose("year", ...),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.5, 0.05),
    extend="both",
)

## formatting
ax.axvline(5, ls="--", lw=0.8, c="k")

plt.show()

$w\frac{dT}{dz}$ by season

### Timeseries

In [None]:
## specify reference coordinate and inflection pt
ref = forced["trop_sst_15"]
# ref = forced["T_3"] - forced["trop_sst_15"]
ref_year = 1980

## specify plot style for reference line
ref_kwargs = dict(ls="--", c="gray", lw=0.8)

#### Tropical SST over time

In [None]:
sel = lambda x: x.mean("month")
# sel = lambda x : x.sel(month=12)

fig, ax = plt.subplots(figsize=(4, 3))
# ax.plot(forced.year, sel(forced["trop_sst_15"]), label="tropical SST", c="k")
ax.plot(forced.year, sel(ref), label="tropical SST", c="k")
ax.plot(forced.year, sel(forced["T_4"]), label="Niño 4", c="r")
ax.plot(forced.year, sel(forced["T_3"]), label="Niño 3", c="b")
ax.axvline(ref_year, **ref_kwargs)
ax.legend()
plt.show()

#### Mean state quantities, plotted against tropical SST

In [None]:
sel = lambda x: x.sel(month=[1]).mean("month")
# sel = lambda x : x.mean("month")

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(forced.year, sel(forced["T_4"] - forced["T_3"]))
# ax.plot(sel(ref), sel(forced["T_4"] - forced["T_3"]))
# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.axvline(ref_year, **ref_kwargs)
ax.set_title(r"$\frac{dT}{dx}$")
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(forced.year, 8.6e4 * sel(forced["pr_n"] - forced["pr_s"]))
# ax.plot(sel(ref), 8.6e4 * sel(forced["pr_n"] - forced["pr_s"]))
# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"$\frac{dp}{dy}$")
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), -2e2 * sel(get_nino3_south(forced["tauy"])))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"$\tau_y$")
plt.show()


fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), 8.6e4 * sel(forced["pr_n3"]))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"precip")
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(forced.year, sel(forced["w_n3"]))
# ax.plot(sel(ref), sel(forced["w_n3"]))
# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.axvline(ref_year, **ref_kwargs)
ax.set_title(r"$w$")
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(forced.year, sel(dTdz_n3))
# ax.plot(sel(ref), sel(dTdz_n3))
# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.axvline(ref_year, **ref_kwargs)
ax.set_title(r"$\frac{dT}{dz}$")
plt.show()


fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), sel(forced["T_3"] - forced["trop_sst_15"]))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño 3 relative SST")
# ax.set_ylim([0,None])
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(forced.year, sel(forced["trop_sst_15"] - forced["trop_sst_05"]))
# ax.plot(sel(ref), sel(forced["trop_sst_15"] - forced["trop_sst_05"]))
# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"trop sst (15 minus 5)")
ax.axvline(ref_year, **ref_kwargs)
# ax.set_ylim([0,None])

plt.show()

#### Variance over time

In [None]:
## specify which variable to plot
# VARNAME = "eli_30"
VARNAME = "T_3"

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate([1, 4, 7, 10]):
    sel_ = lambda x: x.sel(month=m_)
    ax.plot(
        sel(ref),
        sel_(Th_delta_sigma[VARNAME]),
        c=CMAP(CMAP_NORM(i)),
        label=f"Month {m_}",
    )


ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño variance")
# ax.set_ylim([0,None])
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate([1, 4, 7, 10]):
    sel_ = lambda x: x.sel(month=m_)
    ax.plot(
        Th_delta_sigma.year,
        sel_(Th_delta_sigma[VARNAME]),
        c=CMAP(CMAP_NORM(i)),
        label=f"Month {m_}",
    )


# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño variance")
ax.axvline(ref_year, **ref_kwargs)
ax.axvline(2030, **ref_kwargs)
# ax.set_ylim([0,None])
ax.legend()
plt.show()

PDFs

In [None]:
## specify varname
# VARNAME="eli_30"
VARNAME = "eli_30"

## get data
x = src.utils.sel_month(Th_total[VARNAME], [12]).stack(s=["member", "time"])


## compute PDFs
edges = np.arange(120, 285, 2)
pdf0, _ = src.utils.get_empirical_pdf(x.sel(year=1870), edges=edges)
pdf1, _ = src.utils.get_empirical_pdf(x.sel(year=2085), edges=edges)

fig, ax = plt.subplots(figsize=(4, 3))
ax.stairs(pdf0, edges)
ax.stairs(pdf1, edges)
ax.set_xlim([180, 215])
plt.show()

In [None]:
## specify varname
# VARNAME="eli_30"
VARNAME = "pr_n3"

## get data
x = src.utils.sel_month(Th_total[VARNAME], 12).stack(s=["member", "time"])


## compute PDFs
edges = np.arange(0, 30, 0.5)
pdf0, _ = src.utils.get_empirical_pdf(8.6e4 * x.sel(year=1870), edges=edges)
pdf1, _ = src.utils.get_empirical_pdf(8.6e4 * x.sel(year=2085), edges=edges)

fig, ax = plt.subplots(figsize=(4, 3))
ax.stairs(pdf0, edges)
ax.stairs(pdf1, edges)
ax.set_xlim([0, 15])
plt.show()

#### Skewness over time

In [None]:
## specify which variable to plot
VARNAME = "eli_30"

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate([1, 4, 7, 10]):
    sel_ = lambda x: x.sel(month=m_)
    ax.plot(
        # sel(ref),
        Th_delta_skew.year,
        sel_(Th_delta_skew[VARNAME]),
        c=CMAP(CMAP_NORM(i)),
        label=f"Month {m_}",
    )


# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.axhline(0, **ref_kwargs)
ax.set_title(r"Niño skewness")

ax.legend()

# ax.set_ylim([0,None])
plt.show()

#### Quantiles over time

In [None]:
## specify which variable to plot
VARNAME = "pr_n3"
SEASON = "JJA"
sel_ = lambda x: x[VARNAME].sel(season=SEASON)
# diff_ = lambda x : x-x.isel(year=0)
# prep = lambda x : diff_(sel_(x))
prep = lambda x: sel_(x)

## specify quantile to plot
Q = 0.95

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))

## warm
ax.plot(
    # sel(ref),
    quant.year,
    prep(quant).sel(q=Q),
    c="r",
)

## median
ax.plot(
    # sel(ref),
    quant.year,
    prep(quant).sel(
        q=0.5,
    ),
    c="k",
)

## cold
ax.plot(
    # sel(ref),
    quant.year,
    prep(quant).sel(q=1 - Q, method="nearest"),
    c="b",
)

# ax.axhline(0, **ref_kwargs)
ax.set_title(r"Niño quantiles")

# ax.legend()

# ax.set_ylim([0,None])
plt.show()

#### Quantile spread over time

In [None]:
## specify which variable to plot
VARNAME = "pr_n3"
SEASON = "DJF"

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate(quant.season):
    # sel_ = lambda x: x.sel(season=m_, q=.95) - x.sel(season=m_, q=.05)
    sel_ = lambda x: x.sel(season=m_)
    ax.plot(
        # sel(ref),
        quant.year,
        sel_(delta_spread[VARNAME]),
        c=CMAP(CMAP_NORM(i)),
        label=f"{m_.values.item()}",
    )


# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
# ax.axhline(0, **ref_kwargs)
ax.set_title(r"Niño quantiles")

ax.legend()

# ax.set_ylim([0,None])
plt.show()

##### lower or upper width by season

In [None]:
## specify which variable to plot
VARNAME = "pr_n3"

## compute upper and lower width
width_lower = quant.sel(q=0.5) - quant.sel(q=0.05)
width_upper = quant.sel(q=0.95) - quant.sel(q=0.5)

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate(quant.season):
    ax.plot(
        # sel(ref),
        quant.year,
        width_lower[VARNAME].sel(season=m_),
        c=CMAP(CMAP_NORM(i)),
        label=f"{m_.values.item()}",
    )


ax.set_title(r"Niño quantiles")

ax.legend()

# ax.set_ylim([0,None])
plt.show()

##### warm vs. cold

Compute PDF widths

In [None]:
## specify which variable to plot
VARNAME = "T_34"
# VARNAME = "pr_n3"
# VARNAME = "eli_30"
SEASON = "DJF"

## specify width bound
wb = 0.05

## get func to select data
sel_ = lambda x: x[VARNAME].sel(season=SEASON)

## compute upper and lower width
width_lower = quant.sel(q=0.5) - quant.sel(q=wb)
width_upper = quant.sel(q=1 - wb) - quant.sel(q=0.5)
width_mean = 0.5 * (width_lower + width_upper)

Compute PDFs

In [None]:
## get data and subtract median
x = src.utils.sel_month(Th[VARNAME].resample({"time": "QS-DEC"}).mean(), 12)
x = x.stack(s=["member", "time"])
x = x - x.median("s")

## compute PDFs
edges = np.linspace(-4.5, 4.5, 20)
pdf0, _ = src.utils.get_empirical_pdf(x.sel(year=1870).values, edges=edges)
pdf1, _ = src.utils.get_empirical_pdf(x.sel(year=2010).values, edges=edges)
pdf2, _ = src.utils.get_empirical_pdf(x.sel(year=2080).values, edges=edges)

Plot

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6.5, 3), layout="constrained")

## plot PDF
axs[0].stairs(pdf0, edges, fill=True, alpha=0.1, color="k", label=1870)
axs[0].stairs(pdf1, edges, lw=2, label=2010)
axs[0].stairs(pdf2, edges, lw=2, label=2080)
# axs[0].set_xlim([-4.5, 4.3])
axs[0].axvline(ls="--", c="gray", lw=0.8)
axs[0].legend()
axs[0].set_xlabel(r"${\circ}$C")
axs[0].set_ylabel("Density")
axs[0].set_yticks([])


## plot width of distribution
axs[1].plot(width_lower.year, sel_(width_lower), c="blue", label="Cold")
axs[1].plot(width_lower.year, sel_(width_upper), c="r", label="Warm")
axs[1].plot(width_lower.year, sel_(width_mean), c="k", label="Avg.")


# ax.set_title(r"Distribution width (95% minus median)")
axs[1].set_ylabel(r"$^{\circ}$C")
axs[1].set_xlabel(r"Year")
axs[1].set_yticks([2, 2.5, 3])
axs[1].set_xticks([1870, 2010, 2080])
axs[1].axvline(2010, ls="--", lw=0.8, c="gray")
axs[1].legend()
axs[1].yaxis.tick_right()
axs[1].yaxis.set_label_position("right")

## save
# save(fig, "pdf-over-time", dpi=300)

plt.show()

#### Precip

Plot argmax vs. argmax (joint PDF/scatter)

In [None]:
sel = lambda x: x.sel(latitude=slice(-20, 20), longitude=slice(210, 270)).mean(
    "longitude"
)

pr_cross = src.utils.get_windowed(total[["pr", "pr_comp"]], stride=120)
pr_cross = pr_cross.sel(year=[1870, 1980, 2030, 2080])

pr_cross = src.utils.reconstruct_wrapper(pr_cross, fn=sel)
pr_cross = pr_cross["pr"]

In [None]:
## subset data
x0 = pr_cross.sel(year=1870)
x1 = pr_cross.sel(year=1980)

## funcs to get quantiles
get_quantiles = lambda x: x.quantile(
    q=[0.05, 0.25, 0.5, 0.75, 0.95], dim=["member", "time"]
)
get_quantiles_bymonth = lambda x: x.groupby("time.month").map(get_quantiles)

## compute
q0 = get_quantiles_bymonth(x0)
q1 = get_quantiles_bymonth(x1)

In [None]:
import xeofs as xe

eofs = xe.single.EOF(use_coslat=True, n_modes=10)
eofs.fit(src.utils.sel_month(x1, 1), dim=["member", "time"])

In [None]:
x0_proj = (eofs.components() * x0).sum("latitude")
x1_proj = (eofs.components() * x1).sum("latitude")

# plt.plot(x0_proj.std(["member","time"]))
# plt.plot(x1_proj.std(["member","time"]))

In [None]:
fig, ax = plt.subplots(figsize=(3, 2.5))
ax.plot(x0.latitude, eofs.components().isel(mode=0))
ax.plot(x0.latitude, eofs.components().isel(mode=2))
ax.axvline(0, **ref_kwargs)
plt.show()

In [None]:
plt.scatter(
    eofs.scores().isel(mode=0),
    eofs.scores().isel(mode=2),
)

compute eofs!!

In [None]:
fig, ax = plt.subplots(figsize=(3, 4))

# ax.plot(q0.latitude, q0.sel(quantile=.75, month=1))
ax.plot(q0.latitude, q0.sel(quantile=0.95, month=1))

# ax.plot(q0.latitude, q1.sel(quantile=.75, month=1))
ax.plot(q0.latitude, q1.sel(quantile=0.95, month=1))

ax.axvline(0, **ref_kwargs)
ax.set_xlim([-10, 10])

In [None]:
## specify month
MONTH = 2

## specify which metric to use
USE_COM = False

## get center of mass, and restrict to month/lat range
sel = lambda x: src.utils.sel_month(x.sel(latitude=slice(-20, 20)), MONTH)

x0_ = sel(pr_cross.sel(year=1870))
x1_ = sel(pr_cross.sel(year=1980))

stack = lambda x: x.stack(sample=["time", "member"])

## compute
get_com = lambda x: (x * x.latitude).sum("latitude") / x.sum("latitude")
com0 = get_com(x0_)
com1 = get_com(x1_)

#### lat of max precip
LAT = x0_.latitude
max_lat0 = LAT.isel(latitude=x0_.argmax("latitude"))
max_lat1 = LAT.isel(latitude=x1_.argmax("latitude"))


## make pdfs
if USE_COM:

    edges = np.arange(-10, 11, 1)
    pdf0, _ = src.utils.get_empirical_pdf(stack(com0), edges=edges)
    pdf1, _ = src.utils.get_empirical_pdf(stack(com1), edges=edges)

else:

    edges = np.append(LAT.values - 1.25, LAT.values[-1] + 1.25)
    pdf0, _ = src.utils.get_empirical_pdf(stack(max_lat0), edges=edges)
    pdf1, _ = src.utils.get_empirical_pdf(stack(max_lat1), edges=edges)

fig, ax = plt.subplots(figsize=(2.5, 2.5))
ax.stairs(pdf0, edges, fill=True, alpha=0.2)
ax.stairs(pdf1, edges, lw=2)
ax.set_xlim([-15, 15])
plt.show()

In [None]:
idx = dict(time=x1_.time, year=x1_.year)

fig, axs = plt.subplots(1, 2, figsize=(5.5, 2.5), layout="constrained")

axs[0].scatter(
    stack(Th_total["eli_05"].sel(time=x0_.time, year=x0_.year)),
    stack(com0),
    # stack(max_lat0),
    s=1,
    alpha=0.5,
)

axs[1].scatter(
    stack(Th_total["eli_05"].sel(time=x1_.time, year=x1_.year)),
    stack(com1),
    # stack(max_lat1),
    s=1,
    alpha=0.5,
)
axs[1].set_yticks([])

src.utils.set_lims(axs)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(5.5, 2.5), layout="constrained")

axs[0].scatter(
    Th_total["eli_05"].sel(time=x1_.time, year=x0_.year),
    Th_total["T_3"].sel(time=x1_.time, year=x0_.year),
    s=3,
    alpha=0.5,
)

axs[1].scatter(
    Th_total["eli_05"].sel(time=x1_.time, year=x1_.year),
    Th_total["T_3"].sel(time=x1_.time, year=x1_.year),
    s=3,
    alpha=0.5,
)
axs[1].set_yticks([])

src.utils.set_lims(axs)

Try with eofs!!

## Figs

In [None]:
## colorbar
sns.set_palette("colorblind")

## specify plot style for reference line
ref_kwargs = dict(ls="--", c="gray", lw=0.8)

Func to save

### Mean state

#### Timeseries

##### Preprocessing funcs

In [None]:
def lon_avg(x, lon_range):
    return x.sel(longitude=slice(*lon_range)).mean("longitude")


def get_H_metrics(x, lons_e, lons_w):
    """compute metrics for thermocline"""

    metrics = xr.merge(
        [
            lon_avg(x, lons_w).rename("Hw"),
            lon_avg(x, lons_e).rename("He"),
            lon_avg(x, (lons_w[0], lons_e[1])).rename("Hbar"),
        ]
    )
    metrics["dHdx"] = metrics["He"] - metrics["Hw"]

    return metrics

##### Compute

In [None]:
## specify longitude ranges
LONS_W = (140, 210)
LONS_E = (210, 270)

## get maximum gradient of thermocline data
H_mg = src.utils.load_h_data(max_grad=True)[0]
H_mg = src.utils.get_windowed(H_mg).groupby("time.month").mean()

## get metrics
kwargs = dict(lons_w=LONS_W, lons_e=LONS_E)
H_mg_stats = get_H_metrics(H_mg, **kwargs)
D20_stats = get_H_metrics(
    forced["z20"].sel(latitude=slice(-2, 2)).mean("latitude"), **kwargs
)

First, plot fractional change in mean-state quantities

In [None]:
def format_ax(ax):
    """add custom ticks"""
    ax.set_xticks([1870, 2030, 2085])
    ax.axvline(2030, **ref_kwargs)

    return

In [None]:
sel_ = lambda x: x.mean("month")
sel = lambda x: (sel_(x) - sel_(x.isel(year=0))) / sel_(x.isel(year=0))

fig, ax = plt.subplots(figsize=(3.5, 3), layout="constrained")
ax.plot(forced.year, sel(dTdz_n3), label=r"$\Delta_z T$")
ax.plot(forced.year, sel(forced["T_4"] - forced["T_3"]), label=r"$\Delta_x T$")
ax.plot(forced.year, sel(H_mg_stats["Hbar"]), label=r"$h$")
ax.plot(forced.year, sel(forced["w_n3"]), label=r"$w$")


## format/labelling
ax.set_yticks([-0.3, 0, 0.6])
ax.axhline(0, **ref_kwargs)
format_ax(ax)
ax.legend()
ax.set_title(r"Mean-state quantities over time")
ax.set_ylabel("Fractional change")

## save
# save(fig, "mean-state_timeseries")

plt.show()

Next, plot change in Niño indices

In [None]:
sel = lambda x: x.mean("month")
# sel = lambda x : x.sel(month=12)

fig, ax = plt.subplots(figsize=(3.5, 3), layout="constrained")
ax.plot(forced.year, sel(forced["trop_sst_15"]), label=r"Tropical avg.", c="k")
ax.plot(forced.year, sel(forced["T_4"]), label=r"$T_w$", c="r")
ax.plot(forced.year, sel(forced["T_3"]), label=r"$T_e$", c="b")

## formatting
ax.set_yticks([26, 29, 32])
ax.set_ylabel(r"$^{\circ}C$")
format_ax(ax)
ax.legend()

# save(fig, fname="nino-indices_overtime")

plt.show()

#### Spatial pattern

In [None]:
## get relative SST
forced["sst_rel"] = forced["sst"] - forced["trop_sst_15"]

## get difference
x0 = forced.sel(year=1870)
x1 = forced.sel(year=2085)
dx = x1 - x0

sel = lambda x: x.mean("month")
# sel = lambda x: x.sel(month=1)

fig = plt.figure(figsize=(10, 6), layout="constrained")
format_func = lambda ax: src.utils.plot_setup_pac(ax, max_lat=30, lon_range=(120, 285))
axs = src.utils.subplots_with_proj(fig, nrows=2, ncols=1, format_func=format_func)

## shared args
lonlat = (dx.longitude, dx.latitude)
kwargs = dict(extend="both", transform=ccrs.PlateCarree())

## plot relative SST
cp_sst = axs[0, 0].contourf(
    *lonlat,
    sel(dx["sst_rel"]),
    levels=src.utils.make_cb_range(2, 0.2),
    cmap="cmo.balance",
    **kwargs,
)

## relative SST baseline
axs[0, 0].contour(
    *lonlat,
    sel(x0["sst_rel"]),
    levels=src.utils.make_cb_range(10, 2),
    colors="k",
    linewidths=0.75,
    **kwargs,
)
axs[0, 0].contour(
    *lonlat,
    sel(x0["sst_rel"]),
    levels=[0],
    colors="gray",
    linewidths=1.25,
    **kwargs,
)

## plot precip
cp_pr = axs[1, 0].contourf(
    *lonlat,
    8.6e4 * sel(dx["pr"]),
    levels=src.utils.make_cb_range(5, 0.5),
    cmap="cmo.balance_r",
    **kwargs,
)

## baseline
axs[1, 0].contour(
    *lonlat,
    8.6e4 * sel(x0["pr"]),
    levels=[6],
    colors="k",
    linewidths=0.75,
    **kwargs,
)

## colorbars
fig.colorbar(cp_sst, label=r"$^{\circ}$C", ticks=[-2, 0, 2])
fig.colorbar(cp_pr, label="mm / day", ticks=[-5, 0, 5])

## boxes
for ax in axs.flatten():
    src.utils.plot_nino4_box(ax, c="magenta", lw=0.8)
    src.utils.plot_nino3_box(ax, c="magenta", lw=0.8)

## save
# save(fig, fname="mean-state-pattern_change")

plt.show()

#### Subsurface

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 4), layout="constrained")

## temperature
cp_T = axs[0].contourf(
    dx["T"].longitude,
    dx["T"].z_t,
    sel(dx["T"]),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(5, 0.5),
    extend="both",
)

axs[0].contour(
    x0["T"].longitude,
    x0["T"].z_t,
    sel(x0["T"]),
    colors="k",
    levels=np.arange(12, 34, 4),
    linewidths=0.5,
    extend="both",
)

## vertical velocity
cp_w = axs[1].contourf(
    dx["w"].longitude,
    dx["w"].z_t,
    sel(dx["w"]),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(20, 2),
    extend="both",
)

axs[1].contour(
    x0["w"].longitude,
    x0["w"].z_t,
    sel(x0["w"]),
    colors="k",
    levels=np.arange(0, 100, 10),
    linewidths=0.5,
    extend="both",
)

## plot 20 degree isotherm on both panels
for ax in axs:
    ax.contour(
        x0.longitude,
        x0.z_t,
        sel(x0["T"]),
        levels=[20],
        colors="magenta",
        linestyles="--",
    )

## formatting
src.utils.format_subsurf_axs(axs)
for ax in axs:
    ax.set_xlim([140, 280])
    ax.set_xticks([140, 210, 280])
    ax.xaxis.set_label_position("top")
    ax.xaxis.tick_top()

## colorbars
fig.colorbar(cp_T, ticks=[-5, 0, 5], label=r"$^{\circ}$C", orientation="horizontal")
fig.colorbar(cp_w, ticks=[-20, 0, 20], label=r"m / month", orientation="horizontal")

## save
# save(fig, fname="subsurface-change")

plt.show()

#### warm pool (scratch)

In [None]:
merimean = lambda x: x.sel(latitude=slice(-5, 5)).mean("latitude")
sel = lambda x: merimean(x["taux"]).sel(month=[4]).mean("month")
# prep = lambda x : sel(x)-sel(x).max()
prep = lambda x: sel(x)

fig, axs = plt.subplots(1, 2, figsize=(5.5, 2.5), layout="constrained")

for year in [1870, 2080]:
    axs[0].plot(
        forced.longitude,
        prep(forced.sel(year=year)),
    )

for ax in axs:
    ax.set_xlim([140, 280])
    ax.set_ylim([-0.06, 0.01])

Get equatorial precip

In [None]:
pr_windowed = src.utils.get_windowed(total[["pr", "pr_comp"]]).sel(year=[1870, 2080])

pr_eq = src.utils.reconstruct_wrapper(pr_windowed, fn=merimean)

In [None]:
sel = lambda x: src.utils.sel_month(x, months=[2, 3, 4]).groupby("time.year").mean()

In [None]:
pr_eq_mam0 = (
    src.utils.sel_month(pr_eq.isel(year=0), months=[2, 3, 4])
    .groupby("time.year")
    .mean()
    .stack(sample=["year", "member"])["pr"]
)
pr_eq_mam1 = (
    src.utils.sel_month(pr_eq.isel(year=1), months=[2, 3, 4])
    .groupby("time.year")
    .mean()
    .stack(sample=["year", "member"])["pr"]
)

In [None]:
pr_eq_mam0

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6.5, 3))

for s in pr_eq_mam0.sample:
    axs[0].plot(pr_eq.longitude, pr_eq_mam0.sel(sample=s), c="gray", lw=0.5, alpha=0.2)

axs[0].set_xlim([150, 270])

In [None]:
lon_idx_max = pr_eq_mam0.sel(longitude=slice(150, 270)).argmax("longitude")
lon = pr_eq_mam0.sel(longitude=slice(150, 270)).longitude.isel(longitude=lon_idx_max)

lon_idx_max1 = pr_eq_mam1.sel(longitude=slice(150, 270)).argmax("longitude")
lon1 = pr_eq_mam1.sel(longitude=slice(150, 270)).longitude.isel(longitude=lon_idx_max1)

In [None]:
plt.hist(lon)
plt.hist(lon1, alpha=0.5)

In [None]:
pr_eq_mam = sel(pr_eq).stack(sample=["year", "member"])

### Heat flux

In [None]:
# ## func to select data
# sel = lambda x : src.utils.sel_month(x, 12)

# fig, axs = plt.subplots(1,2,figsize=(5.5,2.5), layout="constrained")

# for ax, y_ in zip(axs, [1870,2080]):
#     ax.scatter(
#         sel(Th["T_3"].sel(year=y_)),
#         sel(Th["T_3"].sel(year=y_)),

### Variance
**to-do: plot bounds**

Compute Niño 1+2 indices

#### Time series

In [None]:
def get_nino12(x):
    return x.sel(latitude=slice(0, 10), longitude=slice(270, 280)).mean(
        ["latitude", "longitude"]
    )


## get nino12
nino12 = src.utils.reconstruct_wrapper(
    anom_[["sst", "sst_comp"]],
    fn=get_nino12,
)["sst"]

## window it, and add to data
Th["T_12"] = src.utils.get_windowed(nino12)

In [None]:
## get variance by month
Th_sigma = Th.groupby("time.season").std(["time", "member"])

## fractional change
Th_delta_sigma = (Th_sigma - Th_sigma.isel(year=0)) / Th_sigma.isel(year=0)

In [None]:
## specify which variables to plot
VARNAMES = ["T_4", "T_34", "T_3", "T_12"]

## specify labels
LABELS = ["Niño 4", "Niño 3.4", "Niño 3", "Niño 1+2"]

## specify months to plot
SEASONS = ["DJF", "MAM", "SON"]

## get colors
COLORS = [sns.color_palette()[i] for i in [0, 2, 1]]

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=2)

fig, axs = plt.subplots(1, 4, figsize=(8.5, 2), layout="constrained")

for ax, v, l in zip(axs, VARNAMES, LABELS):

    ## plot data
    for season, c in zip(SEASONS, COLORS):
        sel_ = lambda x: x.sel(season=season)
        ax.plot(
            Th_delta_sigma.year,
            Th_delta_sigma[v].sel(season=season),
            c=c,
            label=season,
            lw=2,
        )

    ## add reference lines / label
    ax.axvline(2030, **ref_kwargs)
    ax.axhline(0, **ref_kwargs)
    ax.set_yticks([-0.3, 0, 0.3])
    ax.set_xticks([2030])
    ax.set_title(l)


## label
axs[0].set_ylabel("Fractional change")
for ax in axs[1:]:
    ax.set_yticks([])

# ax.set_ylim([0,None])
axs[-1].legend(loc=(1, 0.25))
src.utils.set_lims(axs)
plt.show()

#### Hovmollers (by longitude)

In [None]:
def get_sigma_bygroup(data, grouper, varname="sst"):
    """get variance after grouping data"""

    var_bygroup = data.groupby(grouper).map(
        src.utils.reconstruct_var_wrapper,
    )

    return np.sqrt(var_bygroup)


def get_sigma_byseason(data, varname="sst"):
    """Get variance by month"""

    return get_sigma_bygroup(data, grouper="time.season", varname=varname)


def get_sigma_bymonth(data, varname="sst"):
    """Get variance by month"""

    return get_sigma_bygroup(data, grouper="time.month", varname=varname)


def skew_helper(data):
    """get skew for data"""

    ## specify kwargs
    kwargs = dict(input_core_dims=[["sample"]], kwargs=dict(axis=-1))

    return xr.apply_ufunc(
        scipy.stats.skew, data.stack(sample=["time", "member"]), **kwargs
    )


def get_skew_byseason(data):
    """compute skewness for equatorial strip by month"""

    ## get data on equator
    data_eq = src.utils.reconstruct_wrapper(
        data=data,
        fn=src.utils.merimean,
    )
    ## compute skewness
    return data_eq.groupby("time.season").map(skew_helper)


def get_skew_bymonth(data):
    """compute skewness for equatorial strip by month"""

    ## get data on equator
    data_eq = src.utils.reconstruct_wrapper(
        data=data,
        fn=src.utils.merimean,
    )
    ## compute skewness
    return data_eq.groupby("time.month").map(skew_helper)


def contour_cyclic(ax, data, **kwargs):
    """plot data on hovmoller with cyclic dependence on month"""

    ## add cyclic point
    data_cyclic, month = cartopy.util.add_cyclic_point(
        data.transpose("month", ...), data.month, axis=0
    )

    ## plot data
    cp = ax.contour(data.longitude, month, data_cyclic, **kwargs)

    return cp

##### Get windowed anomalies

In [None]:
anom_windowed = src.utils.get_windowed(anom_[["sst", "sst_comp"]])

##### Compute variance

By season

In [None]:
## years to compute for
YEARS = anom_windowed.year

## empty list to hold result
sigma_spatial_byseason = []

for year in tqdm.tqdm(YEARS):
    sigma_spatial_byseason.append(get_sigma_byseason(anom_windowed.sel(year=year)))

## convert back to XR
sigma_spatial_byseason = xr.concat(
    sigma_spatial_byseason, dim=pd.Index(YEARS, name="year")
)

## Equatorial avg
sigma_spatial_eq = sigma_spatial_byseason.sel(latitude=slice(-5, 5)).mean("latitude")
delta_sigma_spatial_eq = sigma_spatial_eq - sigma_spatial_eq.isel(year=0)

By month

In [None]:
# years to compute for
YEARS = [1870, 2030, 2080]

## empty list to hold result
sigma_spatial_bymonth = []

for year in tqdm.tqdm(YEARS):
    sigma_spatial_bymonth.append(get_sigma_bymonth(anom_windowed.sel(year=year)))

## convert back to XR
sigma_spatial_bymonth = xr.concat(
    sigma_spatial_bymonth, dim=pd.Index(YEARS, name="year")
)

## equatorial avg
sigma_spatial_bymonth_eq = sigma_spatial_bymonth.sel(latitude=slice(-5, 5)).mean(
    "latitude"
)

#### Compute skewness

By season

In [None]:
## years to compute for
YEARS = anom_windowed.year

## empty list to hold result
skew_spatial_byseason = []

for year in tqdm.tqdm(YEARS):
    skew_spatial_byseason.append(get_skew_byseason(anom_windowed.sel(year=year)))

## convert back to XR
skew_spatial_byseason = xr.concat(
    skew_spatial_byseason, dim=pd.Index(YEARS, name="year")
)

## change in skewness
delta_skew = skew_spatial_byseason - skew_spatial_byseason.isel(year=0)

By month

In [None]:
# years to compute for
YEARS = [1870, 2030, 2080]

## empty list to hold result
skew_spatial_bymonth = []

for year in tqdm.tqdm(YEARS):
    skew_spatial_bymonth.append(get_skew_bymonth(anom_windowed.sel(year=year)))

## convert back to XR
skew_spatial_bymonth = xr.concat(skew_spatial_bymonth, dim=pd.Index(YEARS, name="year"))

##### merge data

In [None]:
stats_bymonth = xr.merge(
    [
        sigma_spatial_bymonth_eq.rename({"sst": "sigma"}),
        skew_spatial_bymonth.rename({"sst": "skew"}),
    ]
)
stats_byseason = xr.merge(
    [
        sigma_spatial_eq.rename({"sst": "sigma"}),
        skew_spatial_byseason.rename({"sst": "skew"}),
    ]
)

#### get differences

In [None]:
delta_stats_byseason = stats_byseason - stats_byseason.isel(year=0)
deltanorm_stats_byseason = delta_stats_byseason / stats_byseason.isel(year=0)
delta_stats_bymonth = stats_bymonth - stats_bymonth.isel(year=0)

#### Plot

Variance over time

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(8, 3), layout="constrained")

for ax, season in zip(axs, ["DJF", "MAM", "JJA", "SON"]):
    p = ax.contourf(
        deltanorm_stats_byseason.longitude,
        deltanorm_stats_byseason.year,
        delta_stats_byseason["sigma"].sel(season=season),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(0.4, 0.04),
    )
    ax.set_xlim([140, 280])
    ax.set_xticks([150, 210, 270])
    ax.set_title(season)
    ax.axhline(2030, **ref_kwargs)
    ax.axvline(210, **ref_kwargs)
    ax.set_xlabel("Lon.")

    ## plot climatological
    ax_cyc = ax.twinx()
    for y_, c_, ls in zip([1870, 2030, 2080], ["gray", "gray", "k"], ["-", "--", "-"]):
        ax_cyc.plot(
            stats_byseason.longitude,
            stats_byseason["sigma"].sel(season=season, year=y_),
            c=c_,
            ls=ls,
            label=y_,
        )
    ax_cyc.set_ylim([0.3, 4])

    ## labels to variance
    if season == "DJF":
        ax_cyc.yaxis.set_label_position("left")
        ax_cyc.yaxis.tick_left()
        ax_cyc.set_yticks([0.3, 1.5])
        ax_cyc.text(
            x=-0.65, y=0.14, s=r"$^{\circ}$C", color="gray", transform=ax_cyc.transAxes
        )
        ax.text(x=-0.7, y=0.85, s=r"Year", color="k", transform=ax_cyc.transAxes)
        ax_cyc.tick_params(axis="y", colors="gray")
    else:
        ax_cyc.set_yticks([])


## label/format
fig.colorbar(p, ax=axs[-1], ticks=[-0.4, 0, 0.4], label="Frac. change")
axs[0].set_yticks([2030, 2080])
for ax in axs[1:]:
    ax.set_yticks([])

# save(fig, "sigma-over-time_hov")
plt.show()

Variance by season (v1)

Variance over time (v2)

In [None]:
# cp0 = src.utils.make_cycle_hov(axs[0], data=m_early, **kwargs)

## get plot data
# baseline = sigma_spatial_bymonth_eq["sst"].sel(year=1870)
# future = sigma_spatial_bymonth_eq["sst"].sel(year=2080)
# change = future - baseline

baseline = stats_bymonth.sel(year=1870)
future = stats_bymonth.sel(year=2080)
change = delta_stats_bymonth.sel(year=2080)

## for skew
baseline_sk = skew_spatial_bymonth["sst"].sel(year=1870)
future_sk = skew_spatial_bymonth["sst"].sel(year=2080)
change_sk = future_sk - baseline_sk

## specify levels
LEV = np.arange(0.4, 2.2, 0.2)
LEV_DIFF = src.utils.make_cb_range(0.5, 0.05)

## shared args for plotting
plot_kwargs = dict(cmap="cmo.amp", extend="min", xticks=[210, 270])

## Set up plot
fig, axs = plt.subplots(1, 3, figsize=(9, 3), layout="constrained")

## make hövmöllers
# cp0 = src.utils.make_cycle_hov(axs[0], data=baseline["sigma"])
cp0 = src.utils.plot_cycle_hov(axs[0], baseline["sigma"], levels=LEV, **plot_kwargs)
cp1 = src.utils.plot_cycle_hov(axs[1], future["sigma"], levels=LEV, **plot_kwargs)
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    change["sigma"],
    levels=LEV_DIFF,
    cmap="cmo.balance",
    extend="both",
    xticks=[210, 270],
)

## superimpose skewness
for ax, x, ls in zip(axs[:-1], [baseline, future], ["--", "-"]):
    for ax_ in [ax, axs[-1]]:
        contour_cyclic(
            ax=ax_,
            data=x["skew"],
            levels=[-0.5],
            colors="k",
            zorder=10,
            linewidths=0.8,
            linestyles=ls,
            alpha=0.8,
        )

## label
axs[0].set_title("1870")
axs[1].set_title("2080")
axs[2].set_title("Difference")
axs[1].set_xlabel("Longitude")
# axs[-1].set_xticks([140, 190, 240])

for ax in axs:
    ax.set_xlim([140, 280])
    ax.set_xticks([140, 210, 270])

for ax in axs[1:]:
    ax.set_yticks([])
    ax.set_ylabel(None)

## add colorbars
kwargs = dict(ticks=[0.4, 1.2, 2], label=r"$^{\circ}\text{C}$")
# cb0 = fig.colorbar(cp0, ax=axs[0], **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **kwargs)
cb2 = fig.colorbar(cp2, ax=axs[2], **dict(kwargs, ticks=[-0.5, 0, 0.5]))

## save
# save(fig, fname="sigma-cycle-change")

plt.show()

Plot skewness over time

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(8, 3), layout="constrained")

for ax, season in zip(axs, ["DJF", "MAM", "JJA", "SON"]):
    p = ax.contourf(
        delta_skew.longitude,
        delta_skew.year,
        delta_skew["sst"].sel(season=season),
        cmap="cmo.balance",
        levels=src.utils.make_cb_range(0.75, 0.075),
        extend="both",
    )
    ax.set_xlim([140, 280])
    ax.set_xticks([150, 270])
    ax.set_title(season)
    ax.axhline(2030, **ref_kwargs)
    ax.axvline(210, **ref_kwargs)
    ax.set_xlabel("Lon.")

    ## plot climatological
    ax_cyc = ax.twinx()
    for y_, c_, ls in zip([1870, 2030, 2080], ["gray", "gray", "k"], ["-", "--", "-"]):
        ax_cyc.plot(
            skew_spatial_byseason.longitude,
            skew_spatial_byseason["sst"].sel(season=season, year=y_),
            c=c_,
            ls=ls,
            label=y_,
        )
    ax_cyc.set_ylim([-1.6, 4])
    ax_cyc.axhline(0, c="k", lw=0.8, alpha=0.5)

    ## labels to variance
    if season == "DJF":
        ax_cyc.yaxis.set_label_position("left")
        ax_cyc.yaxis.tick_left()
        ax_cyc.set_yticks([-1.5, 0.0])
        ax_cyc.text(x=-0.7, y=0.14, s=r"Skew", color="gray", transform=ax_cyc.transAxes)
        ax.text(x=-0.7, y=0.85, s=r"Year", color="k", transform=ax_cyc.transAxes)
        ax_cyc.tick_params(axis="y", colors="gray")
    else:
        ax_cyc.set_yticks([])


## label/format
fig.colorbar(p, ax=axs[-1], ticks=[-1, 0, 1], label=r"Skewness change")
axs[0].set_yticks([2030, 2080])
for ax in axs[1:]:
    ax.set_yticks([])


# save(fig, fname="skew-over-time_hov")
plt.show()

## Scratch

Compare different metrics for $H$

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))

for y, c in zip([0, -1], ["b", "r"]):
    ax.plot(H_mg.longitude, H_mg.mean("month").isel(year=y), c=c, ls="-")
    ax.plot(
        forced.longitude,
        forced["z20"]
        .sel(latitude=slice(-2, 2))
        .mean(["latitude", "month"])
        .isel(year=y),
        c=c,
        ls="--",
    )

ax.set_ylim(ax.get_ylim()[::-1])
ax.set_xlim([140, 280])
plt.show()

Relative SST

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), sel(forced["T_3"] - forced["trop_sst_15"]))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño 3 relative SST")
# ax.set_ylim([0,None])
plt.show()