In [None]:
import sys
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

sys.path.append('..//')
from utils_mitgcm import open_mitgcm_ds_from_config
import pylake

In [None]:
model = 'geneva_dummy_extended'
mitgcm_config, ds_to_plot = open_mitgcm_ds_from_config('..//config.json', model)

In [None]:
grid_resolution = 200
ds_to_plot['YC'] = np.arange(1, len(ds_to_plot['YC'])+1) * grid_resolution - grid_resolution/2
ds_to_plot['XC'] = np.arange(1, len(ds_to_plot['XC'])+1) * grid_resolution - grid_resolution/2
ds_to_plot['YG'] = np.arange(0, len(ds_to_plot['YG'])) * grid_resolution
ds_to_plot['XG'] = np.arange(0, len(ds_to_plot['XG'])) * grid_resolution

In [None]:
mask = ds_to_plot.THETA.isel(time=0).values != 0

In [None]:
plt.figure(figsize=(10,3))
plt.imshow(ds_to_plot.THETA.isel(time=-1,Z=0).where(mask[0], np.nan))
plt.gca().invert_yaxis()
plt.colorbar()

# Get temperature profile averaged over the entire lake

In [None]:
ds_to_plot['theta_nan'] = ds_to_plot['THETA'].where(mask, np.nan)

ds_to_plot['mean_temp_profile'] = ds_to_plot.theta_nan.mean(dim=['XC','YC']).compute()

ds_to_plot['mean_temp_profile'].plot()

# Get buoyancy frequency N

N_arr = []
for t in range(ds_to_plot.sizes['time']):
    snapshot_N = pylake.buoyancy_freq(ds_to_plot['mean_temp_profile'].isel(time=t).values, depth=ds_to_plot.Z.values, g=9.81)
    snapshot_N['time'] = [ds_to_plot.time.isel(time=t).values]
    N_arr.append(snapshot_N)

In [None]:
def get_N_profile(temp_profile, depths):
    N_profile = pylake.buoyancy_freq(temp_profile, depth=depths, g=9.81)

    return N_profile

In [None]:
xr_N = xr.apply_ufunc(get_N_profile,
                      ds_to_plot.theta_nan,
                      ds_to_plot.Z,
                      input_core_dims=[['Z'], ['Z']],
                      output_core_dims=[['avg_depth']],
                      vectorize=True,
                      dask='parallelized',
                      output_dtypes=[float],
                      output_sizes={'avg_depth': ds_to_plot.Z.size - 1}
                      ).compute()

In [None]:
xr_N.isel(time=0,XC=0,YC=0).plot()

In [None]:
N_mean = xr.concat(N_arr, dim='time')
N_mean = N_mean.rename({'avg_depth':'Z'})

# Get vertical displacement

In [None]:
import numpy as np
import xarray as xr

def compute_isotherm_displacement(temp: xr.DataArray, isotherms, ref_time=0):
    """
    Compute vertical displacement of isotherms.

    Parameters
    ----------
    temp : xr.DataArray
        Temperature field with dims (..., 'z'), typically ('time','z','y','x').
    isotherms : float or list of float
        Target isotherm value(s) in same units as `temp`.
    ref_time : int or np.datetime64
        Reference time index (default=0). Displacement is computed relative to this time.

    Returns
    -------
    dz : xr.DataArray
        Vertical displacement of isotherms, dims: (time, y, x, isotherm)
    """

    if not isinstance(isotherms, (list, tuple, np.ndarray)):
        isotherms = [isotherms]

    z_coord = temp['z']

    def z_from_temp(temp_prof, z_prof, T0):
        """Find depth where temp = T0, using linear interpolation."""
        return np.interp(T0, temp_prof[::-1], z_prof[::-1])  # reverse if z is positive down

    z_iso_all = []
    for iso in isotherms:
        z_iso = xr.apply_ufunc(
            z_from_temp,
            temp,
            z_coord,
            iso,
            input_core_dims=[['z'], ['z'], []],
            vectorize=True,
            dask="parallelized",
            output_dtypes=[float],
        )
        z_iso_all.append(z_iso.assign_coords(isotherm=iso))

    # Combine into one DataArray
    z_iso = xr.concat(z_iso_all, dim="isotherm")

    # Reference depth (at ref_time)
    z0 = z_iso.isel(time=ref_time)

    # Vertical displacement
    dz = z_iso - z0

    dz.name = "isotherm_displacement"
    dz.attrs["description"] = "Vertical displacement of isotherms (relative to ref_time)"
    dz.attrs["reference_time"] = str(temp['time'].values[ref_time])

    return z_iso, dz


In [None]:
def compute_iso_displacement_slice(theta_arr, ds, idx_z):
    ref_depth = ds['Z'].isel(Z=idx_z).values
    z_arr=ds['Z'].values

    # Loop over time
    d_vert_slice=[]
    for i in range(theta_arr.shape[0]):
        ref_temp = ds['mean_temp_profile'].isel(time=i,Z=idx_z).values
        diff = np.abs(theta_arr[i] - ref_temp)

        z_closest_idx = diff.argmin(axis=0)
        z_closest = z_arr[z_closest_idx]  # (time, YC, XC)

        # Displacement relative to this reference depth
        d_vert_slice.append(np.where(mask[idx_z], z_closest - ref_depth, np.nan))

    return d_vert_slice

In [None]:
theta_arr = ds_to_plot.THETA.values

In [None]:
d_vert = []
for idx_z in range(ds_to_plot.sizes['Z']):
    d_vert.append(compute_iso_displacement_slice(theta_arr, ds_to_plot, idx_z))

In [None]:
test = compute_iso_displacement_slice(theta_arr, ds_to_plot, 80)

In [None]:
ds_to_plot['mean_temp_profile'].isel(time=0,Z=80).values

In [None]:
diff = np.abs(theta_arr[0] - 7.26599979)
z_closest_idx = diff.argmin(axis=0)

In [None]:
z_closest_idx

In [None]:
v_disp =xr.full_like(ds_to_plot.THETA, fill_value=np.nan)

In [None]:
arr_swapped = np.swapaxes(d_vert, 0, 1)

In [None]:
v_disp[:] = arr_swapped

In [None]:
v_disp.to_netcdf(r"/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/analysis/vertical_displacement.nc")

In [None]:
v_disp.isel(time=-1, Z=80).THETA.plot()

# Compute potential energy

In [None]:
v_disp = xr.open_dataset(r"/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/analysis/vertical_displacement.nc")

In [None]:
v_disp_mean = v_disp.mean(dim=['XC','YC'])

In [None]:
N_interp = N_mean.interp(Z=v_disp_mean['Z'].values)

In [None]:
v_disp.THETA.isel(time=0, Z=70).plot()

In [None]:
(N_interp * v_disp_mean)