# Thermocline
Compute thermocline depth, look at sensitivity to definition and changes over time

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time
import pandas as pd

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Functions

In [None]:
def get_H(T):
    """compute thermocline depth"""

    ## find index for max (negative gradient)
    min_idx = T.differentiate("z_t").argmin("z_t")

    return T.z_t.isel(z_t=min_idx)


def frac_change(x, inv=True):
    """get fractional change"""

    ## get inverse if desired
    if inv:
        x_ = 1 / x
    else:
        x_ = x

    ## compute initial value and change
    x0_ = x_.isel(year=0)

    return (x_ - x0_) / x0_

## Load data

In [None]:
## load spatial data
CONS_DIR = pathlib.Path(DATA_FP, "cesm", "consolidated")
forced = xr.open_dataset(CONS_DIR / "forced.nc")
T = forced[["T", "T_comp"]]
MLD = forced[["mld", "mld_comp"]]

## Compute thermocline depth

In [None]:
## specify time interval
t_early = dict(time=slice("1850", "1879"))

## get climatology
T_early = src.utils.reconstruct_clim(
    data=T.sel(time=slice("1850", "1881")),
)["T"]

## compute thermocline depth
H_early = get_H(T_early)

### Plot

In [None]:
## function to select data
sel = lambda x: x.sel(month=[3]).mean("month")

fig, ax = plt.subplots(figsize=(5, 3), layout="constrained")

## plot temperature
cp = ax.contourf(
    T_early.longitude,
    T_early.z_t,
    sel(T_early),
    levels=np.arange(10, 32, 2),
    cmap="cmo.thermal",
    extend="both",
)

## plot estimated thermocline
ax.plot(
    H_early.longitude,
    sel(H_early),
    c="w",
    ls="--",
)

## formatting
src.utils.format_subsurf_axs([ax])
ax.set_xlim([140, 280])
fig.colorbar(cp, ticks=[10, 30], label=r"$^{\circ}$C")

plt.show()

## Change over time

### Compute

In [None]:
T_rolling = []
years = np.arange(1865, 2090, 5)
for y in tqdm.tqdm(years):
    t_idx = dict(time=slice(f"{y-15}", f"{y+15}"))
    T_rolling.append(src.utils.reconstruct_clim(T.sel(t_idx))["T"])

## convert to xr
T_rolling = xr.concat(T_rolling, dim=pd.Index(years, name="year"))

## compute thermocline depth
H_rolling = get_H(T_rolling)

## average over Niño 3.4 region
H_n34 = H_rolling.sel(longitude=slice(190, 240)).mean("longitude")
# H_n34 = H_rolling.sel(longitude=slice(140, 280)).mean("longitude")

## Get fractional change
delta_H = frac_change(H_n34)

In [None]:
## Get MLD in Niño 3.4 region
MLD_n34 = src.utils.reconstruct_wrapper(MLD, fn=src.utils.get_nino34)

## get 30-year rolling mean (by month)
MLD_n34 = MLD_n34.groupby("time.month").map(
    lambda x: x.rolling({"time": 30}, center=True).mean()
)

## subset to match thermocline
MLD_n34 = src.utils.unstack_month_and_year(MLD_n34).sel(year=H_n34.year)["mld"]
delta_MLD = frac_change(MLD_n34)

### Plot

#### Hovmoller

In [None]:
14 / 2.5

In [None]:
## setup plot
fig, axs = plt.subplots(1, 2, figsize=(4, 4), layout="constrained")

for ax, delta in zip(axs, [delta_H, delta_MLD]):

    ## plot data
    src.utils.plot_hov2(
        fig,
        ax,
        delta.transpose("month", "year"),
        amp=0.6,
        label=r"$\frac{\Delta~H^{-1}}{H_0^{-1}}$",
    )

    ax.axvline(4, c="k", lw=1, ls="--")
    ax.axvline(7, c="k", lw=1, ls="--")
    ax.set_ylim([1975, None])
    ax.set_yticks([])

## label
axs[0].set_yticks(np.arange(1975, 2090, 30))
axs[0].set_ylabel("Year")
axs[0].set_title("Thermocline depth")
axs[1].set_title("MLD depth")


plt.show()

#### Line plots

In [None]:
## function to select may/jun
sel_mj = lambda x: x.sel(month=[5, 6]).mean("month")

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(6, 3), layout="constrained")

## plot data
for ax, x in zip(axs, [H_n34, delta_H]):
    p0 = ax.plot(x.year, sel_mj(x), label="May/June")
    p1 = ax.plot(x.year, x.mean("month"), label="All months")

## plot MLD fractional change
axs[1].plot(
    delta_MLD.year, sel_mj(delta_MLD), c=p0[0].get_color(), ls="--", label="MLD"
)
axs[1].plot(delta_MLD.year, delta_MLD.mean("month"), c=p1[0].get_color(), ls="--")


## format axes
axs[0].set_ylim(axs[0].get_ylim()[::-1])
axs[0].legend()
axs[1].legend()
axs[0].set_title("Thermocline depth")
axs[1].set_title("Frac. change")

plt.show()