In [None]:
import sys
import os
from multiprocessing.pool import ExceptionWithTraceback

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from numpy.f2py.auxfuncs import throw_error

sys.path.append('..//')
from utils_mitgcm import open_mitgcm_ds_from_config
import pylake

In [None]:
model = 'geneva_dummy_extended'
mitgcm_config, ds_to_plot = open_mitgcm_ds_from_config('..//config.json', model)

In [None]:
output_folder = r'/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/seiche_analysis/potential_energy'

In [None]:
grid_resolution = 200
ds_to_plot['YC'] = np.arange(1, len(ds_to_plot['YC'])+1) * grid_resolution - grid_resolution/2
ds_to_plot['XC'] = np.arange(1, len(ds_to_plot['XC'])+1) * grid_resolution - grid_resolution/2
ds_to_plot['YG'] = np.arange(0, len(ds_to_plot['YG'])) * grid_resolution
ds_to_plot['XG'] = np.arange(0, len(ds_to_plot['XG'])) * grid_resolution

ds_to_plot['Z'] = -1 * ds_to_plot.Z

In [None]:
mask = ds_to_plot.THETA.isel(time=0).values != 0

In [None]:
plt.figure(figsize=(10,3))
plt.imshow(ds_to_plot.THETA.isel(time=-1,Z=0).where(mask[0], np.nan))
plt.gca().invert_yaxis()
plt.colorbar()

# Get temperature profile averaged over the entire lake

In [None]:
ds_to_plot['theta_nan'] = ds_to_plot['THETA'].where(mask, np.nan)

In [None]:
ds_to_plot['mean_temp_profile'] = ds_to_plot.theta_nan.mean(dim=['XC','YC']).compute()

In [None]:
ds_to_plot['mean_temp_profile'].plot()

# Get buoyancy frequency N

In [None]:
def buoyancy_freq(Temp, depth=None, g=9.81):
    '''
    Description: Calculate the buoyancy frequency (Brunt-Vaisala frequency) for a temperature profile. Copied from PyLake module and adapted to take xarray as input.
    https://github.com/eawag-surface-waters-research/PyLake/blob/main/pylake

    Parameters
    -----------
    Temp: array_like
        A numeric vector of water temperature in degrees C
    depth: array_like
        a numeric vector corresponding to the depth (in m) of the Temp measurements
    g: scalar, default: 9.81
        gravity acceleration (m/s2)

    Returns
    ----------
    n2: xarray with buoyancy frequency in units {sec^-2} and associated average lake depths (can be different from input depths).
    '''
    rho = pylake.dens0(s=0.2, t=Temp)
    numdepth = len(depth)
    rho_2 = rho.isel(Z=slice(0,numdepth-1))
    drho_dz = rho.diff('Z')/Temp.Z.diff('Z')
    rho_2["Z"] = drho_dz.Z
    n2 = g/rho_2*drho_dz
    n2["Z"] = [(a+b)/2 for a,b in zip(depth, depth[1:])]
    n2 = n2.rename({"Z":"avg_depth"})

    return n2

In [None]:
xr_N = buoyancy_freq(ds_to_plot.theta_nan, ds_to_plot.Z).compute()

In [None]:
fig,ax = plt.subplots(1,1)
xr_N.isel(time=24,XC=150,YC=80).plot()
ax.set_yscale('log')

In [None]:
xr_N = xr_N.interp({'avg_depth':ds_to_plot.Z})

In [None]:
xr_N.to_netcdf(r'/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/seiche_analysis/buoyancy_frequency.nc')

# Get vertical displacement

In [None]:
def inverse_interp_1d(T_prof, T_target, Z):
    """
    T_prof: 1D array of temps along Z
    T_target: 1D array of temps at which we want to find the depth
    Z: 1D array of depths
    Returns: 1D array of same shape as T_target with interpolated Z
    """
    out = np.full_like(T_target, np.nan, dtype=float)

    # Mask valid (non-nan) values
    mask = ~np.isnan(T_prof)
    T_valid = T_prof[mask]
    Z_valid = Z[mask]

    if T_valid.size < 2:
        return out

    # Use searchsorted to locate bracketing temps
    idxs = np.searchsorted(T_valid, T_target)

    for j, idx in enumerate(idxs):
        if 0 < idx < len(T_valid):
            T1, T2 = T_valid[idx - 1], T_valid[idx]
            Z1, Z2 = Z_valid[idx - 1], Z_valid[idx]
            Tt = T_target[j]
            if T2 != T1:
                out[j] = Z1 + (Z2 - Z1) * (Tt - T1) / (T2 - T1)

    return out

In [None]:
# Ensure Z sorted
ds_to_plot = ds_to_plot.sortby("Z", ascending=False)
z_profile = ds_to_plot.Z             # (Z)

expanded_mean,_ = xr.broadcast(ds_to_plot.mean_temp_profile, ds_to_plot.theta_nan)
mean_temperature_profile = expanded_mean.rename({'Z':'Z_temp'})  # (time, Z_temp, XC, YC)

temperature_profiles_whole_lake = ds_to_plot.theta_nan.chunk({'time': 1, 'XC':30, 'YC':30})  # (time, Z, XC, YC)

z_corresponding = xr.apply_ufunc(
    inverse_interp_1d,
    temperature_profiles_whole_lake,                # T_prof: (time, Z, XC, YC)
    mean_temperature_profile,      # T_target: (time, Z_temp, XC, YC)
    z_profile,                   # Z: (Z)
    input_core_dims=[["Z"], ["Z_temp"], ["Z"]],
    output_core_dims=[["Z_temp"]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[float],
)
z_corresponding = z_corresponding.rename({'Z_temp':'Z'}).compute()

In [None]:
mean_temperature_profile.isel(time=48,XC=150,YC=80).plot()
temperature_profiles_whole_lake.isel(time=48,XC=150,YC=80).plot()

In [None]:
vertical_displacement = z_profile - z_corresponding

In [None]:
vertical_displacement.to_netcdf(r'/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/seiche_analysis/vertical_displacement.nc')

vertical_displacement = xr.open_dataset(r'/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/seiche_analysis/vertical_displacement.nc')
vertical_displacement = vertical_displacement.__xarray_dataarray_variable__

# Compute potential energy

In [None]:
rho = 1000
surface_cell = 200 * 200
val = (vertical_displacement ** 2) * xr_N * ds_to_plot.drF

In [None]:
E_pot_profile = rho * surface_cell * val.sum(dim='Z') / 1e6

In [None]:
E_pot_profile.to_netcdf(r'/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/seiche_analysis/potential_energy.nc')

In [None]:
plt_obj = E_pot_profile.isel(time=48).plot(vmin=0, vmax=100)
plt_obj.colorbar.set_label('Potential Energy [MJ]')

In [None]:
for i in range(25,E_pot_profile.sizes['time']):
    plt.close('all')
    plt.figure(figsize=(15,7))
    plt_obj = E_pot_profile.isel(time=i).plot(vmin=0, vmax=100)
    plt_obj.colorbar.set_label('Potential Energy [MJ]')
    plt.title(E_pot_profile.time.isel(time=i).values)
    plt.savefig(rf'/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/seiche_analysis/e_pot_map/potential_energy_time{i}.png')

In [None]:
total_potential_energy = E_pot_profile.sum(['XC','YC'])

In [None]:
fig, ax = plt.subplots(1,1,figsize=(10,5))
total_potential_energy.plot()
plt.ylabel('Potential Energy [MJ]')
plt.xlabel('')
fig.savefig(os.path.join(output_folder, "potential_energy.png"))

In [None]:
df_pe = total_potential_energy.to_dataframe(name='pe_mj_seiche')['pe_mj_seiche'].reset_index()

In [None]:
df_pe.to_csv(os.path.join(output_folder, "potential_energy.csv"))