# Compute $h$
Compute thermocline depth from subsurface temperature profiles

## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time
import pandas as pd

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## cluster

Note: need lots of memory to load data: ~16 GB / worker

In [None]:
from dask.distributed import LocalCluster, Client

cluster = LocalCluster(n_workers=6)
client = Client(cluster)
client

## Load data

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()
T_scores = forced["T"] + anom["T"]
T = xr.merge([T_scores, forced["T_comp"]])

## chunk the data (for dask
T = T.chunk({"member": 1, "time": 120})

## Compute

### threshold-version

In [None]:
## set threshold for thermocline depth
THRESH = 0.04

## get save filepath
save_fp = pathlib.Path(SAVE_FP, "h_ests", f"h_int_{int(THRESH*1e3)}.nc")

if save_fp.is_file():
    print("File exists!")
    H = xr.open_dataarray(save_fp)

else:

    print("Computing")
    T_vals = src.utils.reconstruct_wrapper(T)
    H = src.utils.get_H_int(T_vals["T"], thresh=THRESH)

    ## save
    H.to_netcdf(save_fp)

### Contour version

#### Func to compute

In [None]:
def find_level(T_recon, level=20):
    """function to find depth of thermocline"""

    ## find index of closest vertical level
    level_idx = np.abs(T_recon - level).fillna(1e20).argmin("z_t")

    ## get corresponding depth
    level_est = T_recon.z_t.isel(z_t=level_idx)

    ## add back NaNs
    level_est = level_est.where(~np.isnan(T_recon.isel(z_t=0)), other=np.nan)

    return level_est


def load_level(T_recon, lev):
    """wrapper function: try loading thermocline depth; otherwise compute and save"""

    ## get save filepath
    save_fp = pathlib.Path(SAVE_FP, "h_ests", f"z_{lev}.nc")

    if save_fp.is_file():

        print("File exists!")
        H = xr.open_dataarray(save_fp)

    else:

        ## compute
        H = find_level(T_recon, level=lev)

        ## save
        H.to_netcdf(save_fp)

    return H

#### Test func

In [None]:
## get data
T_recon = src.utils.reconstruct_wrapper(T.isel(time=300, member=30))["T"].compute()

## find thermocline
z20_est = find_level(T_recon)

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
ax.contourf(
    T_recon.longitude,
    T_recon.z_t,
    T_recon,
    cmap="cmo.thermal",
    levels=np.arange(10, 34, 2),
    extend="both",
)
ax.contour(
    T_recon.longitude,
    T_recon.z_t,
    T_recon,
    colors="w",
    levels=[20],
)
ax.plot(z20_est.longitude, z20_est, ls="--", c="magenta")

ax.set_ylim([200, 5])
ax.set_xlim([120, 280])
plt.show()

#### Compute

Load to memory (this part takes the longest)

In [None]:
T_recon = src.utils.reconstruct_wrapper(T)["T"].compute()

In [None]:
for lev in tqdm.tqdm(np.arange(20, 25)):
    load_level(T_recon, lev=lev)

# z20_est = find_level(T_recon, level=20)
# z21_est = find_level(T_recon, level=21)
# z22_est = find_level(T_recon, level=22)
# z23_est = find_level(T_recon, level=23)

#### Plot results to make sure they look ok

In [None]:
z20 = xr.open_dataarray(pathlib.Path(SAVE_FP, "h_ests", f"z_20.nc"))
z22 = xr.open_dataarray(pathlib.Path(SAVE_FP, "h_ests", f"z_22.nc"))
Z = xr.merge([z20.rename("z20"), z22.rename("z22")])

In [None]:
sel = lambda x: x.isel(time=40, member=60)

fig, ax = plt.subplots(figsize=(4, 3))
ax.contourf(
    T_recon.longitude,
    T_recon.z_t,
    sel(T_recon),
    cmap="cmo.thermal",
    levels=np.arange(10, 34, 2),
    extend="both",
)
ax.contour(
    T_recon.longitude,
    T_recon.z_t,
    sel(T_recon),
    colors="w",
    levels=[20],
)
ax.plot(Z.longitude, sel(Z["z20"]), ls="--", c="magenta")
ax.plot(Z.longitude, sel(Z["z22"]), ls="--", c="magenta")

ax.set_ylim([250, 5])
ax.set_xlim([120, 280])
plt.show()

### max-grad version

#### function to find max gradient

In [None]:
def find_maxgrad(T_recon):
    """function to find depth of thermocline"""

    ## find index of max gradient
    idx = T_recon.differentiate("z_t").fillna(1e20).argmin("z_t")

    ## get thermocline depth
    h = T_recon.z_t.isel(z_t=idx)

    ## fill in NaN values
    h = h.where(~np.isnan(T_recon.isel(z_t=0)), other=np.nan)

    return h

#### Compute

In [None]:
## get save filepath
save_fp = pathlib.Path(SAVE_FP, "h_ests", f"h_max-grad.nc")

if save_fp.is_file():
    print("File exists!")
    h = xr.open_dataarray(save_fp)

else:

    print("Computing")
    T_recon = src.utils.reconstruct_wrapper(T)["T"].compute()
    h = find_maxgrad(T_recon)

    ## save
    h.to_netcdf(save_fp)

In [None]:
h

#### Test it makes sense

In [None]:
idx = dict(time=100, member=6)

## get data
T_ = src.utils.reconstruct_wrapper(T.isel(idx))["T"].compute()

## get h
h_ = h.isel(idx)

fig, ax = plt.subplots(figsize=(4, 3))
ax.contourf(
    T_.longitude,
    T_.z_t,
    T_,
    cmap="cmo.thermal",
    levels=np.arange(10, 34, 2),
    extend="both",
)

## plot 20 deg isotherm
ax.contour(T_.longitude, T_.z_t, T_, colors="w", levels=[20])

## plot max grad thermocline
ax.plot(h_.longitude, h_, c="k")

ax.set_ylim([250, 5])
ax.set_xlim([120, 270])
plt.show()

### threshold max-grad version

In [None]:
# def find_maxgrad(T_recon):
#     """function to find depth of thermocline"""

#     ## find index of max gradient
#     idx = T_recon.differentiate("z_t").fillna(1e20).argmin("z_t")

#     ## get thermocline depth
#     h = T_recon.z_t.isel(z_t=idx)

#     ## fill in NaN values
#     h = h.where(~np.isnan(T_recon.isel(z_t=0)), other=np.nan)

#     return h

### Plot results

#### Load $Z_{20}$ (validation data)

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()
z20_scores = forced["z20"] + anom["z20"]
z20 = xr.merge([z20_scores, forced["z20_comp"]])

## chunk it
z20 = z20.chunk({"member": 1, "time": 180})

## get value on equator
z20_eq = src.utils.reconstruct_wrapper(
    z20, fn=lambda x: x.sel(latitude=slice(-1.5, 1.5)).mean("latitude")
)

## load into memory
z20_eq.load();

In [None]:
sel = lambda x: x.isel(time=slice(None, 360)).mean(["member", "time"])

hbar = sel(H)
hbar1 = H.isel(time=slice(-360, None)).mean(["member", "time"])
z20_bar = sel(z20_eq)

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(hbar.longitude, hbar)
ax.plot(z20_bar.longitude, z20_bar["z20"])
# plt.plot(hbar.longitude, hbar0)
# plt.plot(hbar.longitude, hbar1)
ax.set_ylim([200, 40])
plt.show()

### Spatial plot

In [None]:
z_ = z20.isel(time=slice(None, 360)).mean(["time", "member"]).compute()
z = src.utils.reconstruct_wrapper(z_)

In [None]:
coslat_weights = src.utils.get_coslat_weights(z_.z20_comp)

## evaluate function on spatial components
fn_eval = z_.z20_comp * 1 / coslat_weights

In [None]:
import cartopy.crs as ccrs

fig = plt.figure(figsize=(5, 2.5), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=1, ncols=1, format_func=format_func)

axs[0, 0].contourf(
    z.longitude,
    z.latitude,
    z["z20"],
    cmap="cmo.thermal",
    transform=ccrs.PlateCarree(),
)

plt.show()