# Compute $h$
Compute thermocline depth from subsurface temperature profiles

## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time
import pandas as pd

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## cluster

In [None]:
from dask.distributed import LocalCluster, Client

cluster = LocalCluster(n_workers=6)
client = Client(cluster)
client

## Load data

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()
T_scores = forced["T"] + anom["T"]
T = xr.merge([T_scores, forced["T_comp"]])

## chunk the data (for dask
T = T.chunk({"member": 1, "time": 180})

## Compute

### threshold-version

In [None]:
## set threshold for thermocline depth
THRESH = 0.04

## get save filepath
save_fp = pathlib.Path(SAVE_FP, "h_ests", f"h_int_{int(THRESH*1e3)}.nc")

if save_fp.is_file():
    print("File exists!")
    H = xr.open_dataarray(save_fp)

else:

    print("Computing")
    T_vals = src.utils.reconstruct_wrapper(T)
    H = src.utils.get_H_int(T_vals["T"], thresh=THRESH)

    ## save
    H.to_netcdf(save_fp)

### max-grad version

Need more memory for this...

In [None]:
np.isnan(T_vals["T"]).all(["z_t"])

In [None]:
## get save filepath
save_fp = pathlib.Path(SAVE_FP, "h_ests", f"h_max-grad.nc")

if save_fp.is_file():
    print("File exists!")
    H = xr.open_dataarray(save_fp)

else:

    print("Computing")
    T_vals = src.utils.reconstruct_wrapper(T).compute()
    H = src.utils.get_H(T_vals["T"])

    ## save
    H.to_netcdf(save_fp)

### Plot results

#### Load $Z_{20}$ (validation data)

In [None]:
## load spatial data
forced, anom = src.utils.load_consolidated()
z20_scores = forced["z20"] + anom["z20"]
z20 = xr.merge([z20_scores, forced["z20_comp"]])

## chunk it
z20 = z20.chunk({"member": 1, "time": 180})

## get value on equator
z20_eq = src.utils.reconstruct_wrapper(
    z20, fn=lambda x: x.sel(latitude=slice(-1.5, 1.5)).mean("latitude")
)

## load into memory
z20_eq.load();

In [None]:
sel = lambda x: x.isel(time=slice(None, 360)).mean(["member", "time"])

hbar = sel(H)
hbar1 = H.isel(time=slice(-360, None)).mean(["member", "time"])
z20_bar = sel(z20_eq)

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(hbar.longitude, hbar)
ax.plot(z20_bar.longitude, z20_bar["z20"])
# plt.plot(hbar.longitude, hbar0)
# plt.plot(hbar.longitude, hbar1)
ax.set_ylim([200, 40])
plt.show()

### Spatial plot

In [None]:
z_ = z20.isel(time=slice(None, 360)).mean(["time", "member"]).compute()
z = src.utils.reconstruct_wrapper(z_)

In [None]:
coslat_weights = src.utils.get_coslat_weights(z_.z20_comp)

## evaluate function on spatial components
fn_eval = z_.z20_comp * 1 / coslat_weights

In [None]:
import cartopy.crs as ccrs

fig = plt.figure(figsize=(5, 2.5), layout="constrained")
format_func = lambda ax,: src.utils.plot_setup_pac(ax, max_lat=20)
axs = src.utils.subplots_with_proj(fig, nrows=1, ncols=1, format_func=format_func)

axs[0, 0].contourf(
    z.longitude,
    z.latitude,
    z["z20"],
    cmap="cmo.thermal",
    transform=ccrs.PlateCarree(),
)

plt.show()