# Mean-state changes
Look at how mean state changes over time

In [None]:
import datetime
import matplotlib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import src.XRO
import copy
import scipy.stats
import warnings
import calendar

# import gsw

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Load data

In [None]:
## load spatial data
forced, _ = src.utils.load_consolidated()

## get windowed data
forced = src.utils.get_windowed(forced).groupby("time.month").mean()

## reconstruct data
forced = src.utils.reconstruct_wrapper(forced)

## slice longitude
forced = forced.sel(longitude=slice(120, 280))

# ## load tropical SST avg
trop_sst = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/trop_sst.nc"))

## Load T,h (total)
Th_total = xr.open_dataset(DATA_FP / "cesm" / "Th.nc")
Th_total = xr.merge([Th_total, trop_sst])

## compute dTdx
Th_total["dTdx"] = Th_total["T_4"] - Th_total["T_3"]

## load ELI
eli = xr.open_dataset(pathlib.Path(DATA_FP, "cesm/eli.nc"))
Th_total = xr.merge([Th_total, eli])

## ensemble mean
Th_total = Th_total.mean("member")

## get windowed
Th_total = src.utils.get_windowed(Th_total).groupby("time.month").mean()

## merge with spatial data
forced = xr.merge([forced, Th_total])

#### Compute variance

In [None]:
## open data
Th = src.utils.load_cesm_indices()

## get windowed
Th = src.utils.get_windowed(Th)

## get variance by month
Th_sigma = Th.groupby("time.month").std(["time", "member"])

## fractional change
Th_delta_sigma = (Th_sigma - Th_sigma.isel(year=0)) / Th_sigma.isel(year=0)

## Analysis

### Precip / $\tau_x$ on the equator

In [None]:
## specify month and lat bound for averaging
MONTH = 5
LAT_BOUND = 5

## func to select data
merimean = lambda x, lat_bound=5: x.sel(latitude=slice(-lat_bound, lat_bound)).mean(
    "latitude"
)
sel = lambda x: merimean(x, lat_bound=LAT_BOUND).sel(month=MONTH).transpose("year", ...)

## get longitude for convenience
LON = forced.longitude

## compute center of mass
com = sel(forced * LON).integrate("longitude") / sel(forced).integrate("longitude")

## get argmin
taux_argmin = forced.longitude.isel(longitude=sel(forced["taux"]).argmin("longitude"))

fig, axs = plt.subplots(1, 2, figsize=(4, 4), layout="constrained")

## precip
axs[0].contourf(
    forced.longitude,
    forced.year,
    8.6e4 * sel(forced["pr"]),
    cmap="cmo.rain",
    levels=np.arange(0, 13.5, 1.5),
)

## wind
axs[1].contourf(
    forced.longitude,
    forced.year,
    sel(forced["taux"]),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.075, 0.0075),
)

## plot longitude of max grad
axs[1].plot(
    taux_argmin,
    forced.year,
    c="w",
    ls="--",
    alpha=1,
    lw=1,
)


## formatting
for ax in axs:
    ax.set_xlim([140, 280])
for ax in axs[1:]:
    ax.set_yticks([])

plt.show()

### Stratification by season over time

In [None]:
def get_dTdz_sub(Tsub, mld, delta=25):
    """get velocity at base of mixed layer"""

    ## get temperature difference
    dT = src.utils.get_dT_sub(Tsub=Tsub, Hm=mld, delta=delta)

    ## get gradient
    dTdz = -dT / mld

    return dTdz

In [None]:
## compute strat
dTdz = get_dTdz_sub(Tsub=forced.T, mld=50, delta=20)

## average over Niño 3
dTdz_n3 = dTdz.sel(longitude=slice(210, 270)).mean("longitude")

## get fractional change
get_frac_change = lambda x: (x - x.isel(year=0)) / x.isel(year=0)


## make plot
fig, ax = plt.subplots(figsize=(2, 4), layout="constrained")

## precip
ax.contourf(
    dTdz_n3.month,
    dTdz_n3.year,
    (dTdz_n3 - dTdz_n3.isel(year=0)).transpose("year", ...),
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(0.04, 0.004),
    extend="both",
)

## formatting
ax.axvline(5, ls="--", lw=0.8, c="k")

plt.show()

In [None]:
## specify reference coordinate and inflection pt
ref = forced["trop_sst_15"]
ref_year = 1980

## specify plot style for reference line
ref_kwargs = dict(ls="--", c="gray", lw=0.8)

### Timeseries

#### Tropical SST over time

In [None]:
sel = lambda x: x.mean("month")
# sel = lambda x : x.sel(month=12)

fig, ax = plt.subplots(figsize=(4, 3))
# ax.plot(forced.year, sel(forced["trop_sst_15"]), label="tropical SST", c="k")
ax.plot(forced.year, sel(ref), label="tropical SST", c="k")
ax.plot(forced.year, sel(forced["T_4"]), label="Niño 4", c="r")
ax.plot(forced.year, sel(forced["T_3"]), label="Niño 3", c="b")
ax.axvline(ref_year, **ref_kwargs)
ax.legend()
plt.show()

#### Mean state quantities, plotted against tropical SST

In [None]:
sel = lambda x: x.sel(month=5)
# sel = lambda x : x.mean("month")

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), sel(forced["T_4"] - forced["T_3"]))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"$\frac{dT}{dx}$")
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), sel(dTdz_n3))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"$\frac{dT}{dz}$")
plt.show()


fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), sel(forced["T_3"] - forced["trop_sst_15"]))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño 3 relative SST")
# ax.set_ylim([0,None])
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(sel(ref), sel(forced["trop_sst_15"] - forced["trop_sst_05"]))
ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño 3 relative SST")
# ax.set_ylim([0,None])
plt.show()

In [None]:
## specify which variable to plot
VARNAME = "T_34"

## make colormap
CMAP = cmocean.cm.amp
CMAP_NORM = plt.Normalize(vmin=-1, vmax=3)

#### plot variance over time
fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate([1, 4, 7, 10]):
    sel_ = lambda x: x.sel(month=m_)
    ax.plot(
        sel(ref),
        sel_(Th_delta_sigma[VARNAME]),
        c=CMAP(CMAP_NORM(i)),
        label=f"Month {m_}",
    )


ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño variance")
# ax.set_ylim([0,None])
plt.show()

fig, ax = plt.subplots(figsize=(4, 3))
for i, m_ in enumerate([1, 4, 7, 10]):
    sel_ = lambda x: x.sel(month=m_)
    ax.plot(
        Th_delta_sigma.year,
        sel_(Th_delta_sigma[VARNAME]),
        c=CMAP(CMAP_NORM(i)),
        label=f"Month {m_}",
    )


# ax.axvline(sel(ref.sel(year=ref_year)).values.item(), **ref_kwargs)
ax.set_title(r"Niño variance")
# ax.set_ylim([0,None])
ax.legend()
plt.show()