In [None]:
import numpy as np
import xarray as xr 
import pandas as pd
import glob
import os

import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, detrend

from matplotlib.ticker import LogLocator, FuncFormatter

import sys
sys.path.append('..//')
from utils_mitgcm import open_mitgcm_ds_from_config
from utils_signal_processing import *

# Load data and select ZZ

In [None]:
model = 'lucerne_2025'
mitgcm_config, ds = open_mitgcm_ds_from_config('..//config.json', model)

In [None]:
folder_path = os.path.dirname(mitgcm_config['datapath'])
output_folder = os.path.join(folder_path, "seiche_analysis")
os.makedirs(output_folder, exist_ok=True)

In [None]:
grid_resolution = 100
ds['YC'] = np.arange(1, len(ds['YC'])+1) * grid_resolution - grid_resolution/2
ds['XC'] = np.arange(1, len(ds['XC'])+1) * grid_resolution - grid_resolution/2
ds['YG'] = np.arange(0, len(ds['YG'])) * grid_resolution
ds['XG'] = np.arange(0, len(ds['XG'])) * grid_resolution

In [None]:
zz = 0

In [None]:
u = ds.UVEL.sel(Z=zz,method='nearest')#.isel(time=range(0,len(ds['time'])))
v = ds.VVEL.sel(Z=zz,method='nearest')#.isel(time=range(0,len(ds['time'])))
w = ds.WVEL.sel(Zl=zz,method='nearest')#.isel(time=range(0,len(ds['time'])))

In [None]:
ds.UVEL.isel(Z=0, time=-1).plot()

u = u.chunk({'time':-1})
v = v.chunk({'time':-1})
w = w.chunk({'time':-1})

u.load()
v.load()
w.load()

# Compute freq. spectrum

## Compute freq. spectrum

In [None]:
m_seg = 1
u_fft = xr_compute_meanfft(u.chunk({'time':-1}), M=m_seg)
v_fft = xr_compute_meanfft(v.chunk({'time':-1}), M=m_seg)
w_fft = xr_compute_meanfft(w.chunk({'time':-1}), M=m_seg)

## Select sub XY for mean freq. Spectrum

In [None]:
xx1 = 10000
yy1 = 10000
xcells = 2
ycells = 2
xx2 = xx1+(200*xcells)
yy2 = yy1+(200*ycells)

fig, ax = plt.subplots(1, figsize=(15, 8))
v.isel(time=24).plot()
plt.scatter(xx1,yy1, marker=".", color="k")
plt.scatter(xx1,yy2, marker = '.', color="k")
plt.scatter(xx2,yy1, marker=".", color="k")
plt.scatter(xx2,yy2, marker = '.', color="k")

plt.grid()

In [None]:
u.sel(XG=slice(xx1,xx2), YC=slice(yy1,yy2)).mean(dim=['XG','YC']).plot()

In [None]:
v.sel(XC=slice(xx1,xx2), YG=slice(yy1,yy2)).mean(dim=['XC','YG']).plot()

In [None]:
u_fft_mean = u_fft.sel(XG=slice(xx1,xx2), YC=slice(yy1,yy2)).mean(dim=['XG','YC'])
v_fft_mean = v_fft.sel(XC=slice(xx1,xx2), YG=slice(yy1,yy2)).mean(dim=['XC','YG'])
w_fft_mean = w_fft.sel(XC=slice(xx1,xx2), YC=slice(yy1,yy2)).mean(dim=['XC','YC'])

## Plot freq. spectrum

In [None]:
fig,ax = plot_freq_spectrum(u_fft_mean, 'U', depth=zz, m_segm=m_seg, y_lim_min=1e-9, x_lim_min=0.01e-4, fontsize=10)
#fig.savefig(os.path.join(output_folder, 'freq_spectrum.png'))

In [None]:
fig,ax = plot_freq_spectrum(v_fft_mean, 'V', depth=zz, m_segm=m_seg, y_lim_min=1e-9, x_lim_min=0.01e-4, fontsize=10)

# Filtering in Spectral space  

## Defining cutoffs 
- Inertial period is at 16.4 hrs

In [None]:
cutoff1_hr = 72
cutoff2_hr = 50

cutoff1 = 1/(cutoff1_hr * 3600)
cutoff2 = 1/(cutoff2_hr * 3600)

In [None]:
fig,ax = plot_freq_spectrum(u_fft_mean, 'U', depth=zz, m_segm=m_seg, y_lim_min=1e-11, x_lim_min=0.01e-4, fontsize=10)
ax.axvline(x = cutoff1, linestyle="--", color="k",label="cutoff1")
ax.axvline(x = cutoff2, linestyle="--", color="k",label="cutoff2")
ax.legend()
fig.savefig(os.path.join(output_folder, f'freq_spectrum_with_cutoffs_{cutoff2_hr}_{cutoff1_hr}h.png'))

In [None]:
path_cutoff_folder = os.path.join(output_folder, f'{cutoff2_hr}_{cutoff1_hr}h')
os.makedirs(path_cutoff_folder, exist_ok=True)

## Low-pass filtering (slow motions)

### Filtering to low freq. motions

In [None]:
ulow = filter_signal_xarray(u, btype='lowpass', time_dim='time', dt=3600, period_cutoff_high=(cutoff1_hr*3600), order=5)
vlow = filter_signal_xarray(v, btype='lowpass', time_dim='time', dt=3600, period_cutoff_high=(cutoff1_hr*3600), order=5)
wlow = filter_signal_xarray(w, btype='lowpass', time_dim='time', dt=3600, period_cutoff_high=(cutoff1_hr*3600), order=5)

### Plotting low freq. motions 

In [None]:
i_time_to_plot = 48

vmax_low = 7.5e-2
ulow.isel(time=i_time_to_plot).plot(vmax=vmax_low)
plt.grid()

## Bandpass filtering (seiches)

### Filtering to seiches

In [None]:
useiche = filter_signal_xarray(u, btype='bandpass', time_dim='time', dt=3600, period_cutoff_low=(cutoff1_hr*3600), period_cutoff_high=(cutoff2_hr*3600), order=5)
vseiche = filter_signal_xarray(v, btype='bandpass', time_dim='time', dt=3600, period_cutoff_low=(cutoff1_hr*3600), period_cutoff_high=(cutoff2_hr*3600), order=5)
wseiche = filter_signal_xarray(w, btype='bandpass', time_dim='time', dt=3600, period_cutoff_low=(cutoff1_hr*3600), period_cutoff_high=(cutoff2_hr*3600), order=5)

### Plotting seiches 

In [None]:
i_time_to_plot += 1
print(i_time_to_plot)
vmax_seiche = 7.5e-2
vseiche.isel(time=i_time_to_plot).plot(vmax=vmax_seiche)
plt.grid()

## High-pass filtering (high frq. waves)

### Filtering to high freq. waves 

In [None]:
uhigh = filter_signal_xarray(u, btype='highpass', time_dim='time', dt=3600, period_cutoff_low=(cutoff2_hr*3600), order=5)
vhigh = filter_signal_xarray(v, btype='highpass', time_dim='time', dt=3600, period_cutoff_low=(cutoff2_hr*3600), order=5)
whigh = filter_signal_xarray(w, btype='highpass', time_dim='time', dt=3600, period_cutoff_low=(cutoff2_hr*3600), order=5)


### Plot high freq. internal waves

In [None]:
vmax_high = 7.5e-3
vhigh.isel(time=i_time_to_plot).plot(vmax=vmax_high)
plt.grid()

## Comparing time series

In [None]:
def plot_comparison_timeseries(xx, yy, vel, low, seiche, high, name_var, x_name="XG", y_name="YC"):
    fig, ax = plt.subplots(2, 1, figsize=(18, 8))

    # upper plot
    low.sel({x_name: xx, y_name: yy}, method="nearest").plot(ax=ax[0], label="low freq.")
    seiche.sel({x_name: xx, y_name: yy}, method="nearest").plot(ax=ax[0], label="seiche")
    high.sel({x_name: xx, y_name: yy}, method="nearest").plot(ax=ax[0], label="high freq.")

    # bottom plot
    vel.sel({x_name: xx, y_name: yy}, method="nearest").plot(ax=ax[1], label="Original")
    (low + seiche + high).sel({x_name: xx, y_name: yy}, method="nearest").plot( ax=ax[1], label="low freq + seiche + high freq")

    for ax_i in (ax[0], ax[1]):
        ax_i.grid()
        ax_i.set_xlabel('')
        ax_i.legend(loc='upper right')
        ax_i.set_ylabel(f'{name_var} (m/s)')


In [None]:
xx = xx1; yy = yy1
plot_comparison_timeseries(xx, yy, u, ulow, useiche, uhigh, 'U', x_name="XG", y_name="YC")
plt.savefig(os.path.join(path_cutoff_folder, f'plot_filtered_timeseries_{cutoff2_hr}-{cutoff1_hr}h_U.png'))

In [None]:
plot_comparison_timeseries(xx, yy, v, vlow, vseiche, vhigh, 'V', x_name="XC", y_name="YG")
plt.savefig(os.path.join(path_cutoff_folder, f'plot_filtered_timeseries_{cutoff2_hr}-{cutoff1_hr}h_V.png'))

In [None]:
plot_comparison_timeseries(xx, yy, w, wlow, wseiche, whigh, 'W', x_name="XC", y_name="YC")
plt.savefig(os.path.join(path_cutoff_folder, f'plot_filtered_timeseries_{cutoff2_hr}-{cutoff1_hr}h_W.png'))

In [None]:
useiche.to_netcdf(os.path.join(path_cutoff_folder, rf"u_filtered_z{zz}_{cutoff2_hr}-{cutoff1_hr}h.nc"))
vseiche.to_netcdf(os.path.join(path_cutoff_folder, rf"v_filtered_z{zz}_{cutoff2_hr}-{cutoff1_hr}h.nc"))
wseiche.to_netcdf(os.path.join(path_cutoff_folder, rf"w_filtered_z{zz}_{cutoff2_hr}-{cutoff1_hr}h.nc"))

In [None]:
# --- Extract fields (same as you did) ---
u_tot = u.sel(XG=slice(xx1, xx2+200), YC=slice(yy1, yy2+200)).values
u_iw  = useiche.transpose('time', 'YC', 'XG') \
                .sel(XG=slice(xx1, xx2+200), YC=slice(yy1, yy2+200)).values

v_tot = v.sel(XC=slice(xx1, xx2+200), YG=slice(yy1, yy2+200)).values
v_iw  = vseiche.transpose('time', 'YG', 'XC') \
                .sel(XC=slice(xx1, xx2+200), YG=slice(yy1, yy2+200)).values

dt = 3600.0
dx = 200.0
dy = 200.0


# --- Nonlinear advection operator for scalar u ---
def advect_u(u, v, dx, dy):
    """
    Compute (u · ∇)u = u ∂x u + v ∂y u
    Arrays must have shape (time, y, x)
    Periodic boundaries assumed.
    """
    dudx = (np.roll(u, -1, axis=2) - np.roll(u,  1, axis=2)) / (2.0 * dx)
    dudy = (np.roll(u, -1, axis=1) - np.roll(u,  1, axis=1)) / (2.0 * dy)

    return u * dudx + v * dudy


# --- Time derivative (centered in time) ---
du_dt_tot = np.gradient(u_tot, dt, axis=0)
du_dt_iw  = np.gradient(u_iw,  dt, axis=0)


# --- Nonlinear acceleration ---
a_tot = du_dt_tot + advect_u(u_tot, v_tot, dx, dy)
a_iw  = du_dt_iw  + advect_u(u_iw,  v_iw,  dx, dy)

# --- Residual acceleration ---
a_res = a_tot - a_iw


# --- Integrate residual acceleration in time ---
u_res_accel = np.zeros_like(u_tot)

# Initial condition: exact residual velocity
u_res_accel[0] = u_tot[0] - u_iw[0]

# Trapezoidal (2nd-order consistent with np.gradient)
for t in range(1, u_tot.shape[0]):
    u_res_accel[t] = (
        u_res_accel[t-1]
        + 0.5 * (a_res[t] + a_res[t-1]) * dt
    )

print("Acceleration-based residual velocity shape:", u_res_accel.shape)


In [None]:
plt.plot(u_tot[:,1,1])
plt.plot(u_iw[:,1,1])
plt.plot(u_res_accel[:,1,1])