# Changes in thermal expansion coefficient

## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import xarray as xr
import tqdm
import pathlib
import cmocean
import os
import copy
import time
import gsw

# Import custom modules
import src.utils

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI
mpl.rcParams["figure.dpi"] = 100

## get filepaths
DATA_FP = pathlib.Path(os.environ["DATA_FP"])
SAVE_FP = pathlib.Path(os.environ["SAVE_FP"])

## funcs

In [None]:
def avg_upper_ocn(T, H=70, lon_range=[140, 280]):
    """average upper ocean above threshold"""
    T_ = T.sel(z_t=slice(None, H), longitude=slice(*lon_range))

    return T_.mean(["z_t", "longitude"])

## Load data

In [None]:
## load data
forced, _ = src.utils.load_consolidated()

## subset for temperature
T = forced[["T", "T_comp"]].sel(time=slice("1851", None))

## get windowed T
T = src.utils.get_windowed(T, stride=120)

## $\alpha$

#### Avg. temperature over near-surface layer

In [None]:
T_upper = src.utils.reconstruct_clim(T, fn=avg_upper_ocn)["T"]
T_upper_mean = T_upper.mean("month")

#### Compute thermal expansion

In [None]:
## specify reference values
SA = 25  # g/kg
P = 35  # dbar
H = 70  # m

## reference values for salinity, pressure
nyear = len(T_upper.year)
sa = SA * np.ones(nyear)
p = P * np.ones(nyear)

## compute alpha
alpha = xr.ones_like(T_upper_mean)
alpha.values = gsw.alpha(SA=sa, CT=T_upper_mean.values, p=p)

## get fractional change
delta_a = src.utils.frac_change(alpha, inv=False)

## get scaling factor for RO
alpha_scale = alpha.isel(year=0) / alpha

## save to file
save_fp = pathlib.Path(os.environ["SAVE_FP"], "cesm_alpha_scale.nc")
if save_fp.is_file():
    pass
else:
    alpha_scale.to_netcdf(save_fp)

### Plot

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(T_upper.year, delta_a)
ax.set_ylabel("Frac. change")
plt.show()

## $h_w$

### funcs

In [None]:
def lon_avg(x, lon_range):
    return x.sel(longitude=slice(*lon_range)).mean("longitude")


def get_H_metrics(x, lons_e, lons_w):
    """compute metrics for thermocline"""

    metrics = xr.merge(
        [
            lon_avg(x, lons_w).rename("Hw"),
            lon_avg(x, lons_e).rename("He"),
            lon_avg(x, (lons_w[0], lons_e[1])).rename("Hbar"),
        ]
    )
    metrics["dHdx"] = metrics["He"] - metrics["Hw"]

    return metrics


def frac_change(x, inv=False):
    """get fractional change"""

    if inv:
        x_ = 1 / x
    else:
        x_ = x

    return x_ / x_.isel(year=0) - 1

### Load and compute

In [None]:
## specify longitude ranges
LONS_W = (140, 210)
LONS_E = (210, 270)

## get maximum gradient of thermocline data
H_mg_forced, _ = src.utils.load_h_data(max_grad=True)

## remove first timestep for consistency with other data
## (don't do this for v2)
# H_mg_forced = H_mg_forced.sel(time=slice("1851", None))

## get windowed
H_mg_forced = src.utils.get_windowed(H_mg_forced, stride=120)

## get metrics
H_mg_forced = get_H_metrics(H_mg_forced, lons_w=LONS_W, lons_e=LONS_E)

## compute Hbar (use this for scaling)
Hbar = H_mg_forced["Hbar"].mean("time")
Hbar_scale = Hbar / Hbar.isel(year=0)

## load ssh data
ssh_w = src.utils.load_cesm_indices()["h_w"].sel(time=slice("1851", None))
ssh_w = src.utils.get_windowed(ssh_w, stride=120)

## save to file
save_fp = pathlib.Path(os.environ["SAVE_FP"], "cesm_Hbar_scale_v2.nc")
if save_fp.is_file():
    pass
else:
    Hbar_scale.to_netcdf(save_fp)

In [None]:
fig, ax = plt.subplots(figsize=(3, 2.5))

## plot fractional change in h
ax.plot(H_mg_forced.year, frac_change(H_mg_forced["Hw"].mean("time"), inv=True))
ax.plot(H_mg_forced.year, frac_change(Hbar, inv=True))

## plot change in sigma
ax.plot(ssh_w.year, frac_change(ssh_w.std(["time", "member"])))

plt.show()