# Mean-state changes
Look at how mean state changes over time

In [None]:
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import src.XRO
import copy
import scipy.stats
import warnings
import calendar

# import gsw

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Funcs

In [None]:
def get_ensemble_fn(data, fn):
    """apply python function to data using xr.apply_ufunc"""

    ## stack data
    data_stack = data.stack(sample=["time", "member"])

    ## kwargs for apply_ufunc
    kwargs = dict(
        input_core_dims=[["sample", "year"]],
        output_core_dims=[["year"]],
    )

    return xr.apply_ufunc(fn, data_stack, **kwargs)


def get_ensemble_fn_bymonth(data, fn):
    """apply python function to data by month using xr.apply_ufunc"""
    return data.groupby("time.month").map(get_ensemble_fn, fn=fn)

## Load data

In [None]:
## load spatial data
forced, anom_ = src.utils.load_consolidated()
anom_ = anom_[["T", "T_comp", "w", "w_comp"]]

## get windowed data
forced = src.utils.get_windowed(forced).groupby("time.month").mean()

## reconstruct data
forced = src.utils.reconstruct_wrapper(forced)

## slice longitude
forced = forced.sel(longitude=slice(120, 280))

# ## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")
Th_total = xr.merge([Th_total, trop_sst])

## compute dTdx
Th_total["dTdx"] = Th_total["T_4"] - Th_total["T_3"]

## load ELI
eli = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/eli.nc"))
Th_total = xr.merge([Th_total, eli])

## get w
fn = (
    lambda x: x.sel(z_t=70, method="nearest")
    .sel(longitude=slice(210, 270))
    .mean("longitude")
)
Th_total = xr.merge(
    [
        Th_total,
        src.utils.reconstruct_wrapper(anom_, fn=fn).rename({"T": "T_n3", "w": "w_n3"}),
    ]
)


## windowed version
Th_total = src.utils.get_windowed(Th_total)

## relative SST
for v in ["T_3", "T_34", "T_4"]:
    Th_total[f"{v}_rel"] = Th_total[v] - Th_total["trop_sst_15"]

## ensemble mean
Th_total_forced = Th_total.groupby("time.month").mean(["time", "member"])

## merge with spatial data
forced = xr.merge([forced, Th_total_forced])

## compute w
forced["w_n3"] = (
    forced["w"]
    .sel(longitude=slice(210, 270))
    .mean("longitude")
    .sel(z_t=70, method="nearest")
)

In [None]:
## compute quantiles
quant = Th_total.groupby("time.season").quantile(
    q=[0.05, 0.25, 0.5, 0.75, 0.95], dim=["time", "member"]
)
quant = quant.rename({"quantile": "q"})

## compute difference
delta_q = quant - quant.isel(year=0)

## get spread
spread = quant.sel(q=0.95) - quant.sel(q=0.05)
delta_spread = spread - spread.isel(year=0)

#### Compute variance

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## get windowed
Th = src.utils.get_windowed(Th)
for v in ["05", "15", "20", "25", "30"]:
    Th[f"eli_{v}"] = Th_total[f"eli_{v}"] - Th_total[f"eli_{v}"].mean("member")

## get variance by month
Th_sigma = Th.groupby("time.month").std(["time", "member"])

## fractional change
Th_delta_sigma = (Th_sigma - Th_sigma.isel(year=0)) / Th_sigma.isel(year=0)

#### Compute skewness

In [None]:
Th_skew = get_ensemble_fn_bymonth(Th, fn=scipy.stats.skew)
Th_delta_skew = Th_skew - Th_skew.isel(year=0)

#### Compute quantiles

In [None]:
## compute quantiles
quant = Th_total.groupby("time.season").quantile(
    q=[0.02, 0.05, 0.15, 0.25, 0.5, 0.75, 0.85, 0.95, 0.98], dim=["time", "member"]
)
quant = quant.rename({"quantile": "q"})

## compute difference
delta_q = quant - quant.isel(year=0)

## get spread
spread = quant.sel(q=0.95) - quant.sel(q=0.05)
delta_spread = spread - spread.isel(year=0)

## Analysis

### Spatial plots

### Precip / $\tau_x$ on the equator

In [None]:
## specify month and lat bound for averaging
MONTH = 5
LAT_BOUND = 5

## func to select data
merimean = lambda x, lat_bound=5: x.sel(latitude=slice(-lat_bound, lat_bound)).mean(
    "latitude"
)
sel = lambda x: merimean(x, lat_bound=LAT_BOUND).sel(month=MONTH).transpose("year", ...)

## get longitude for convenience
LON = forced.longitude

## compute center of mass
com = sel(forced * LON).integrate("longitude") / sel(forced).integrate("longitude")

## get argmin
taux_argmin = forced.longitude.isel(longitude=sel(forced["taux"]).argmin("longitude"))

fig, axs = plt.subplots(1, 2, figsize=(4, 4), layout="constrained")

## precip
axs[0].contourf(
    forced.longitude,
    forced.year,
    8.6e4 * sel(forced["pr"]),
    cmap="cmo.rain",
    levels=np.arange(0, 13.5, 1.5),
)

## wind
axs[1].contourf(
    forced.longitude,
    forced.year,
    sel(forced["taux"]),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.075, 0.0075),
)

## plot longitude of max grad
axs[1].plot(
    taux_argmin,
    forced.year,
    c="w",
    ls="--",
    alpha=1,
    lw=1,
)


## formatting
for ax in axs:
    ax.set_xlim([140, 280])
for ax in axs[1:]:
    ax.set_yticks([])

plt.show()

### Stratification by season over time

In [None]:
def get_dTdz_sub(Tsub, mld, delta=25):
    """get velocity at base of mixed layer"""

    ## get temperature difference
    dT = src.utils.get_dT_sub(Tsub=Tsub, Hm=mld, delta=delta)

    ## get gradient
    dTdz = -dT / mld

    return dTdz

In [None]:
## compute strat
dTdz = get_dTdz_sub(Tsub=forced.T, mld=50, delta=20)

## average over Niño 3
dTdz_n3 = dTdz.sel(longitude=slice(210, 270)).mean("longitude")

## get fractional change
get_frac_change = lambda x: (x - x.isel(year=0)) / x.isel(year=0)


## make plot
fig, ax = plt.subplots(figsize=(2, 4), layout="constrained")

## precip
ax.contourf(
    dTdz_n3.month,
    dTdz_n3.year,
    # (dTdz_n3 - dTdz_n3.isel(year=0)).transpose("year", ...),
    get_frac_change(dTdz_n3).transpose("year", ...),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(1.5, 0.15),
    extend="both",
)

## formatting
ax.axvline(5, ls="--", lw=0.8, c="k")

plt.show()

$w$ by season

In [None]:
## compute strat
dTdz = get_dTdz_sub(Tsub=forced.T, mld=50, delta=20)

## average over Niño 3
dTdz_n3 = dTdz.sel(longitude=slice(210, 270)).mean("longitude")

## make plot
fig, ax = plt.subplots(figsize=(2, 4), layout="constrained")

## precip
ax.contourf(
    forced.month,
    forced.year,
    get_frac_change(forced["w_n3"]).transpose("year", ...),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.5, 0.05),
    extend="both",
)

## formatting
ax.axvline(5, ls="--", lw=0.8, c="k")

plt.show()

$w\frac{dT}{dz}$ by season

### Timeseries

In [None]:
## specify reference coordinate and inflection pt
ref = forced["trop_sst_15"]
# ref = forced["T_3"] - forced["trop_sst_15"]
ref_year = 1980

## specify plot style for reference line
ref_kwargs = dict(ls="--", c="gray", lw=0.8)

#### Tropical SST over time

In [None]:
sel = lambda x: x.mean("month")
# sel = lambda x : x.sel(month=12)

fig, ax = plt.subplots(figsize=(4, 3))
# ax.plot(forced.year, sel(forced["trop_sst_15"]), label="tropical SST", c="k")
ax.plot(forced.year, sel(ref), label="tropical SST", c="k")
ax.plot(forced.year, sel(forced["T_4"]), label="Niño 4", c="r")
ax.plot(forced.year, sel(forced["T_3"]), label="Niño 3", c="b")
ax.axvline(ref_year, **ref_kwargs)
ax.legend()
plt.show()

#### Mean state quantities, plotted against tropical SST

In [None]:
sel = lambda x: x.sel(month=8)
# sel = lambda x : x.mean("month")

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), sel(forced["T_4"] - forced["T_3"]))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"$\frac{dT}{dx}$")
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
# ax.plot(forced.year, sel(forced["w_n3"]))
ax.plot(sel(ref), sel(forced["w_n3"]))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"$w$")
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
# ax.plot(forced.year, sel(dTdz_n3))
ax.plot(sel(ref), sel(dTdz_n3))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"$\frac{dT}{dz}$")
plt.show()


fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), sel(forced["T_3"] - forced["trop_sst_15"]))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño 3 relative SST")
# ax.set_ylim([0,None])
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), sel(forced["trop_sst_15"] - forced["trop_sst_05"]))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño 3 relative SST")
# ax.set_ylim([0,None])
plt.show()

#### Variance over time

In [None]:
## specify which variable to plot
VARNAME = "eli_30"

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate([1, 4, 7, 10]):
    sel_ = lambda x: x.sel(month=m_)
    ax.plot(
        sel(ref),
        sel_(Th_delta_sigma[VARNAME]),
        c=CMAP(CMAP_NORM(i)),
        label=f"Month {m_}",
    )


ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño variance")
# ax.set_ylim([0,None])
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate([1, 4, 7, 10]):
    sel_ = lambda x: x.sel(month=m_)
    ax.plot(
        Th_delta_sigma.year,
        sel_(Th_delta_sigma[VARNAME]),
        c=CMAP(CMAP_NORM(i)),
        label=f"Month {m_}",
    )


# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño variance")
# ax.set_ylim([0,None])
ax.legend()
plt.show()

In [None]:
## specify varname
# VARNAME="eli_30"
VARNAME = "eli_30"

## get data
x = src.utils.sel_month(Th_total[VARNAME], 12).stack(s=["member", "time"])


## compute PDFs
edges = np.arange(120, 285, 2)
pdf0, _ = src.utils.get_empirical_pdf(x.sel(year=1870), edges=edges)
pdf1, _ = src.utils.get_empirical_pdf(x.sel(year=2085), edges=edges)

fig, ax = plt.subplots(figsize=(4, 3))
ax.stairs(pdf0, edges)
ax.stairs(pdf1, edges)
ax.set_xlim([180, 215])
plt.show()

#### Skewness over time

In [None]:
## specify which variable to plot
VARNAME = "eli_30"

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate([1, 4, 7, 10]):
    sel_ = lambda x: x.sel(month=m_)
    ax.plot(
        # sel(ref),
        Th_delta_skew.year,
        sel_(Th_delta_skew[VARNAME]),
        c=CMAP(CMAP_NORM(i)),
        label=f"Month {m_}",
    )


# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.axhline(0, **ref_kwargs)
ax.set_title(r"Niño skewness")

ax.legend()

# ax.set_ylim([0,None])
plt.show()

#### Quantiles over time

In [None]:
## specify which variable to plot
VARNAME = "T_3_rel"
SEASON = "DJF"
sel_ = lambda x: x[VARNAME].sel(season=SEASON)
# diff_ = lambda x : x-x.isel(year=0)
# prep = lambda x : diff_(sel_(x))
prep = lambda x: sel_(x)

## specify quantile to plot
Q = 0.85

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))

## warm
ax.plot(
    # sel(ref),
    quant.year,
    prep(quant).sel(q=Q),
    c="r",
)

## median
ax.plot(
    # sel(ref),
    quant.year,
    prep(quant).sel(
        q=0.5,
    ),
    c="k",
)

## cold
ax.plot(
    # sel(ref),
    quant.year,
    prep(quant).sel(q=1 - Q, method="nearest"),
    c="b",
)

# ax.axhline(0, **ref_kwargs)
ax.set_title(r"Niño quantiles")

# ax.legend()

# ax.set_ylim([0,None])
plt.show()

#### Quantile spread over time

In [None]:
## specify which variable to plot
VARNAME = "T_34"
SEASON = "DJF"

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate(quant.season):
    # sel_ = lambda x: x.sel(season=m_, q=.95) - x.sel(season=m_, q=.05)
    sel_ = lambda x: x.sel(season=m_)
    ax.plot(
        # sel(ref),
        quant.year,
        sel_(delta_spread[VARNAME]),
        c=CMAP(CMAP_NORM(i)),
        label=f"{m_.values.item()}",
    )


# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
# ax.axhline(0, **ref_kwargs)
ax.set_title(r"Niño quantiles")

ax.legend()

# ax.set_ylim([0,None])
plt.show()

##### lower or upper width by season

In [None]:
## specify which variable to plot
VARNAME = "T_34"

## compute upper and lower width
width_lower = quant.sel(q=0.5) - quant.sel(q=0.05)
width_upper = quant.sel(q=0.95) - quant.sel(q=0.5)

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate(quant.season):
    ax.plot(
        # sel(ref),
        quant.year,
        width_lower[VARNAME].sel(season=m_),
        c=CMAP(CMAP_NORM(i)),
        label=f"{m_.values.item()}",
    )


ax.set_title(r"Niño quantiles")

ax.legend()

# ax.set_ylim([0,None])
plt.show()

##### warm vs. cold

In [None]:
## specify which variable to plot
VARNAME = "T_3"
# VARNAME = "eli_30"
SEASON = "DJF"

## specify width bound
wb = 0.05


## specify x-variable
X_VAR = quant.year
# X_VAR = sel(ref)

## get func to select data
sel_ = lambda x: x[VARNAME].sel(season=SEASON)

## compute upper and lower width
width_lower = quant.sel(q=0.5) - quant.sel(q=wb)
width_upper = quant.sel(q=1 - wb) - quant.sel(q=0.5)
width_mean = 0.5 * (width_lower + width_upper)

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(X_VAR, sel_(width_lower), c="blue", label="lower")

ax.plot(X_VAR, sel_(width_upper), c="r", label="upper")

ax.plot(X_VAR, sel_(width_mean), c="k")


ax.set_title(r"Distribution width")


plt.show()