# Thermocline
Compute thermocline depth, look at sensitivity to definition and changes over time

## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time
import pandas as pd

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## Math...

Center of mass
\begin{align}
    \left(\int m(x)~dx\right)^{-1}\int x \cdot m(x)~dx
\end{align}

\begin{align}
    \int_{300}^0 \frac{dT}{dz} ~dz = T(z=0)-T(z=300)
\end{align}

## Load data

In [None]:
## load spatial data
forced, _ = src.utils.load_consolidated()
T = forced[["T", "T_comp"]]
MLD = forced[["mld", "mld_comp"]]

## Compute thermocline depth

In [None]:
## specify time interval
t_early = dict(time=slice("1850", "1879"))
t_late = dict(time=slice("2070", "2099"))

## get climatology
T_early = src.utils.reconstruct_clim(data=T.sel(t_early))["T"]
T_late = src.utils.reconstruct_clim(data=T.sel(t_late))["T"]

## compute thermocline depth (early & late)
H_early = src.utils.get_H(T_early)
H_early_int = src.utils.get_H_int(T_early, thresh=0.04)

H_late = src.utils.get_H(T_late)
H_late_int = src.utils.get_H_int(T_late, thresh=0.04)

## compute MLD (early and late)
MLD_early = src.utils.reconstruct_clim(
    MLD.sel(t_early),
    fn=lambda x: x.sel(latitude=slice(-2, 2)).mean("latitude"),
)["mld"]
MLD_late = src.utils.reconstruct_clim(
    MLD.sel(t_late),
    fn=lambda x: x.sel(latitude=slice(-2, 2)).mean("latitude"),
)["mld"]

### Plot climatology (early vs. late)

In [None]:
## function to select data
sel = lambda x: x.sel(month=[1, 2, 3]).mean("month")
# sel = lambda x: x.mean("month")

fig, axs = plt.subplots(1, 3, figsize=(12, 3), layout="constrained")

for ax, T_, H_, H_int_, mld_ in zip(
    axs,
    [T_early, T_late],
    [H_early, H_late],
    [H_early_int, H_late_int],
    [MLD_early, MLD_late],
):

    ## plot temperature
    cp = ax.contourf(
        T_.longitude,
        T_.z_t,
        sel(T_),
        levels=np.arange(10, 34, 2),
        cmap="cmo.thermal",
        extend="both",
    )

    ## plot estimated thermocline
    ax.plot(
        H_.longitude,
        sel(H_),
        c="w",
        ls="--",
    )
    ax.plot(
        H_int_.longitude,
        sel(H_int_),
        c="w",
        ls=":",
    )

    ## plot MLD
    ax.plot(
        mld_.longitude,
        sel(mld_),
        c="gray",
        ls="--",
    )

## plot difference
cp_diff = axs[-1].contourf(
    T_early.longitude,
    T_early.z_t,
    sel(T_late - T_early),
    levels=src.utils.make_cb_range(5, 0.5),
    cmap="cmo.balance",
    extend="both",
)

for H_int_, mld_, ls in zip(
    [H_early_int, H_late_int], [MLD_early, MLD_late], ["-", "--"]
):
    axs[-1].plot(H_int_.longitude, sel(H_int_), c="k", ls=ls)
    axs[-1].plot(mld_.longitude, sel(mld_), c="gray", ls=ls)

## formatting
for ax in axs:
    src.utils.format_subsurf_axs([ax])
    ax.set_xlim([140, 280])

for ax in axs[1:]:
    ax.set_yticks([])
    ax.set_ylabel(None)

## add colorbar
fig.colorbar(cp_diff, ticks=[-5, 0, 5], label=r"$^{\circ}$C")

plt.show()

### Early vs. late comparison

Look at some profiles

In [None]:
def sel_mj(x):
    """select may, june, and Niño 3.4 region"""
    return x.sel(month=[5, 6], longitude=slice(190, 240)).mean(["month", "longitude"])


def sel_year(x, yr, n=15):
    """select range of 2n+1 years around given year)"""

    ## trim in time
    x_subset = x.sel(time=slice(f"{yr-n}", f"{yr+n}"))

    return src.utils.reconstruct_clim(x_subset).to_dataarray().squeeze()


def sel(x, yr):
    """wrapper function to select mj and yr"""
    return sel_mj(sel_year(x, yr))


## colorbar
colors = sns.color_palette("mako")

## specify years to plot
yrs = np.round(np.linspace(2000, 2085, 6)).astype("int")
yrs = pd.Index(yrs, name="year")

## compute
profs = xr.concat([sel(T, y) for y in yrs], dim=yrs)
dTdz = profs.differentiate("z_t")
H0 = src.utils.get_H(profs)
H1 = src.utils.get_H_int(profs, thresh=0.06)

## set up plot
fig, axs = plt.subplots(
    1, 2, figsize=(5, 2.5), width_ratios=[1.5, 3.5], layout="constrained"
)

for i, y in enumerate((dTdz.year)):
    axs[0].plot(dTdz.sel(year=y), T.z_t, c=colors[i])
    axs[0].scatter(0, H0.sel(year=y), color=colors[i], s=25)
    axs[0].scatter(0.03, H1.sel(year=y), color=colors[i], s=25)

## format
axs[0].set_ylim(axs[0].get_ylim()[::-1])
axs[0].axvline(-0.1)

## line plot
axs[1].plot(profs.year, H0)
axs[1].plot(profs.year, H1)

plt.show()

## Change over time

### Compute rolling $T$

In [None]:
T_rolling = []
years = np.arange(1865, 2090, 5)
for y in tqdm.tqdm(years):
    t_idx = dict(time=slice(f"{y-15}", f"{y+15}"))
    T_rolling.append(src.utils.reconstruct_clim(T.sel(t_idx))["T"])

## convert to xr
T_rolling = xr.concat(T_rolling, dim=pd.Index(years, name="year"))

### Compute $H$

In [None]:
## compute thermocline depth
# H_rolling = src.utils.get_H(T_rolling)
H_rolling = src.utils.get_H_int(T_rolling, thresh=0.04)

## average over Niño 3.4 region
# H_n34 = H_rolling.sel(longitude=slice(190, 240)).mean("longitude")
H_n34 = H_rolling.sel(longitude=slice(160, 210)).mean("longitude")

## Get fractional change
delta_H = src.utils.frac_change(H_n34)

## Get MLD in Niño 3.4 region
MLD_n34 = src.utils.reconstruct_wrapper(MLD, fn=src.utils.get_nino34)

## get 30-year rolling mean (by month)
MLD_n34 = MLD_n34.groupby("time.month").map(
    lambda x: x.rolling({"time": 30}, center=True).mean()
)

## subset to match thermocline
MLD_n34 = src.utils.unstack_month_and_year(MLD_n34).sel(year=H_n34.year)["mld"]
delta_MLD = src.utils.frac_change(MLD_n34)

### Plot

#### Hovmoller

In [None]:
## setup plot
fig, axs = plt.subplots(1, 2, figsize=(4, 4), layout="constrained")

for ax, delta in zip(axs, [delta_H, delta_MLD]):

    ## plot data
    src.utils.plot_hov2(
        fig,
        ax,
        delta.transpose("month", "year"),
        amp=0.4,
        label=r"$\frac{\Delta~H^{-1}}{H_0^{-1}}$",
    )

    ax.axvline(6, c="k", lw=1, ls="--")
    ax.set_ylim([1975, None])
    ax.set_yticks([])

## label
axs[0].set_yticks(np.arange(1975, 2090, 30))
axs[0].set_ylabel("Year")
axs[0].set_title("Thermocline depth")
axs[1].set_title("MLD depth")


plt.show()

#### Line plots

In [None]:
## function to select may/jun
sel_mj = lambda x: x.sel(month=[4]).mean("month")

## set up plot
fig, axs = plt.subplots(1, 2, figsize=(6, 3), layout="constrained")

## plot data
for ax, x in zip(axs, [H_n34, delta_H]):
    p0 = ax.plot(x.year, sel_mj(x), label="May/June")
    p1 = ax.plot(x.year, x.mean("month"), label="All months")

## plot MLD fractional change
axs[1].plot(
    delta_MLD.year, sel_mj(delta_MLD), c=p0[0].get_color(), ls="--", label="MLD"
)
axs[1].plot(delta_MLD.year, delta_MLD.mean("month"), c=p1[0].get_color(), ls="--")


## format axes
axs[0].set_ylim(axs[0].get_ylim()[::-1])
axs[0].legend()
axs[1].legend()
axs[0].set_title("Thermocline depth")
axs[1].set_title("Frac. change")

plt.show()

#### Seasonal hovmoller

##### Absolute difference

In [None]:
## get samples to compare
H0 = H_rolling.isel(year=0)
H1 = H_rolling.isel(year=-1)

## get fractional changle
delta_H_inv = src.utils.frac_change(H_rolling, inv=True).isel(year=-1)

## make hövmöllers
fig, axs = plt.subplots(4, 1, figsize=(3.5, 7), layout="constrained")

## kwargs
kwargs = dict(cmap="cmo.amp", levels=np.arange(0, 220, 20), extend="max")
cb_kwargs = dict(ticks=[0, 100, 200], label=r"m")

## plot early
cp0 = src.utils.plot_cycle_hov(axs[0], H0, **kwargs)
cb0 = fig.colorbar(cp0, ax=axs[0], **cb_kwargs)

## plot late
kwargs["levels"] = kwargs["levels"] + 3
cp1 = src.utils.plot_cycle_hov(axs[1], H1, **kwargs)
cb1 = fig.colorbar(cp1, ax=axs[1], **cb_kwargs)

## plot difference
cp2 = src.utils.plot_cycle_hov(
    axs[2],
    H1 - H0,
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(80, 8),
    extend="both",
)
cb2 = fig.colorbar(cp2, ax=axs[2], ticks=[-80, 0, 80], label="change (m)")

## plot fractional change
cp3 = src.utils.plot_cycle_hov(
    axs[3],
    100 * delta_H_inv,
    cmap="cmo.balance",
    levels=src.utils.make_cb_range(50, 5),
    extend="both",
)
cb3 = fig.colorbar(cp3, ax=axs[3], ticks=[-50, 0, 50], label="% change")

## label
axs[0].set_title("Early")
axs[1].set_title("Late")
axs[2].set_title("Diff.")
axs[3].set_title("% Diff.")
axs[-1].set_xlabel("Longitude")
axs[-1].set_xticks([140, 190, 240])

plt.show()