In [None]:
import numpy as np
import xarray as xr 
import pandas as pd
import glob 

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from scipy.signal import detrend, windows
from scipy.stats.distributions import chi2
from scipy.signal import butter, filtfilt, detrend


from matplotlib.ticker import LogLocator, MaxNLocator, ScalarFormatter, FuncFormatter

import sys
sys.path.append('..//')
from utils_mitgcm import open_mitgcm_ds_from_config

from datetime import datetime

# Load data and select ZZ
- selected zz and sliced through time

In [None]:
model = 'geneva_dummy_extended'
mitgcm_config, ds = open_mitgcm_ds_from_config('..//config.json', model)

In [None]:
grid_resolution = 200
ds['YC'] = np.arange(1, len(ds['YC'])+1) * grid_resolution - grid_resolution/2
ds['XC'] = np.arange(1, len(ds['XC'])+1) * grid_resolution - grid_resolution/2
ds['YG'] = np.arange(0, len(ds['YG'])) * grid_resolution
ds['XG'] = np.arange(0, len(ds['XG'])) * grid_resolution

In [None]:
zz = 0

In [None]:
u = ds.UVEL.sel(Z=zz,method='nearest')
v = ds.VVEL.sel(Z=zz,method='nearest')
w = ds.WVEL.sel(Zl=zz,method='nearest')

In [None]:
u.load()
v.load()
w.load()

## Chunking data prior to FFT

u = u.chunk({'XG':1,'YC':1, 'time':-1})
v = v.chunk({'XC':1,'YG':1, 'time':-1})

# Compute freq. spectrum 

## Function for FFT (Fast Fourier Transform)

In [None]:
def compute_mean_fft(data, M):

    """Computes mean FFT by segmenting the time series."""
    N = len(data)
    p = N // M
    if p < 2:
        raise ValueError(f"Segment length (p={p}) is too small for FFT. Choose a smaller M.")
    
    data = data[:p * M]
    data_segments = data.reshape(M, p)
    window = np.hanning(p) * np.ones([M, 1])

    segments_demean = np.array([seg - np.nanmean(seg) for seg in data_segments])
    segments_demean[np.isnan(segments_demean)] = 0
    data_dtrend = detrend(segments_demean, axis=1, type='linear')
    fft_segments = np.fft.fft(data_dtrend * window, axis=1)
    amp_segments = abs(fft_segments[:, :p // 2] / p) ** 2
    amp_segments[:, 1:] *= 2
    amp_mean = amp_segments.mean(axis=0)

    nu = 2 * M
    err_up = nu / chi2.ppf(0.1 / 2, df=nu) * amp_mean
    err_low = nu / chi2.ppf(1 - 0.1 / 2, df=nu) * amp_mean

    return amp_mean, err_up, err_low


def xr_compute_meanfft(data, M):
    """Parallelized FFT computation for (X, Y) grid along T."""
    
    dt = (data['time'][1] - data['time'][0]).astype('float').values / 1e9
    N = len(data['time'].values)
    p = N // M
    freq = np.fft.fftfreq(p, dt)[:p // 2]
    len_freq = len(freq)


    amp_mean, err_up, err_low = xr.apply_ufunc(
        compute_mean_fft,
        data,
        M,
        input_core_dims=[["time"], []],
        output_core_dims=[["freq"], ["freq"], ["freq"]],
        output_sizes={"freq": len_freq},
        exclude_dims=set(("time",)),
        vectorize=True,
        dask="parallelized",
        output_dtypes=[np.float64, np.float64, np.float64],
    )


    fft_data = xr.Dataset(
        {
            "amp_mean": amp_mean,
            "err_up": err_up,
            "err_low": err_low,
        },
        coords={"freq": freq}
        )   
    
   

    return fft_data

In [None]:
def compute_fft_period(fft_freq):
    # Time period 
    with np.errstate(divide='ignore', invalid='ignore'):
        fft_period = np.where(fft_freq != 0, 1/ fft_freq, np.inf)
    
    return fft_period

## Compute freq. spectrum

In [None]:
v.mean(dim=['XC','YG']).plot()

In [None]:
u.mean(dim=['XG','YC']).plot()

In [None]:
u.isel()

In [None]:
m_seg = 1
u_fft = xr_compute_meanfft(u, M=m_seg)
v_fft = xr_compute_meanfft(v, M=m_seg)
w_fft = xr_compute_meanfft(w, M=m_seg)

## Select sub XY for mean freq. Spectrum

In [None]:
xx1 = 0
yy1 = 0

xcells = 336
ycells = 132

xx2 = xx1+(200*xcells)
yy2 = yy1+(200*ycells)


fig, ax = plt.subplots(1, figsize=(15, 8))
v.isel(time=24).plot()


plt.scatter(xx1,yy1, marker=".", color="k")
plt.scatter(xx1,yy2, marker = '.', color="k")

plt.scatter(xx2,yy1, marker=".", color="k")
plt.scatter(xx2,yy2, marker = '.', color="k")

plt.grid()


In [None]:

u_fft_mean = u_fft.sel(XG=slice(xx1,xx2), YC=slice(yy1,yy2)).mean(dim=['XG','YC'])
v_fft_mean = v_fft.sel(XC=slice(xx1,xx2), YG=slice(yy1,yy2)).mean(dim=['XC','YG'])
w_fft_mean = w_fft.sel(XC=slice(xx1,xx2), YC=slice(yy1,yy2)).mean(dim=['XC','YC'])


## Plot freq. spectrum

In [None]:
fig, ax = plt.subplots(1, figsize=(15, 6))

# Plot the mean amplitude spectrum
u_fft_mean.amp_mean.plot(ax=ax)

# Add uncertainty shading
ax.fill_between(u_fft_mean['freq'],
                u_fft_mean.err_low,
                u_fft_mean.err_up,
                color='gray', alpha=0.5, label='Uncertainty')

# Custom formatter for tick labels with 2 significant figures
def custom_formatter(x, pos):
    return f'{x:.1f}'


# Set both axes to log scale
ax.set_xscale('log')
ax.set_yscale('log')


# Grid and labels
ax.grid()
ax.set_xlabel('Frequency $(s^{-1})$', fontsize=16)
ax.set_ylabel('Power Spectrum', fontsize=18)

# Increase the number of ticks using LogLocator
ax.xaxis.set_major_locator(LogLocator(base=10, subs='auto', numticks=8))

# Create a secondary x-axis for period in hours
secax = ax.secondary_xaxis('top')
secax.set_xscale('log')
secax.set_xlabel('Period (hours)', fontsize=16)

# Compute FFT period ticks
fft_ticks = ax.get_xticks()
fft_period = compute_fft_period(fft_ticks)  # In seconds
fft_hr = fft_period / 3600


secax.set_xticks(fft_ticks)

# Set tick labels without decimals and in non-scientific notation form
secax.set_xticklabels([f'{int(x):d}' for x in fft_hr], fontsize=16)  # Increase tick label size

# Set primary x-axis tick labels with 2 significant figures
ax.xaxis.set_major_formatter(FuncFormatter(custom_formatter))

# Manually set the offset text for the power limits
ax.xaxis.get_offset_text().set_visible(False)  # Hide the default offset text

# Compute the offset manually
x_min, x_max = ax.get_xlim()
power_offset = int(np.floor(np.log10(x_max)))

# Set the custom offset text
ax.annotate(f'$\\times 10^{power_offset}$', xy=(1, 0), xycoords='axes fraction',
                fontsize=16, xytext=(-30, -30), textcoords='offset points',
                ha='center', va='center')

# Adjust tick labels by dividing by the power offset
def adjusted_formatter(x, pos):
    return f'{x / 10**power_offset:.1f}'

ax.xaxis.set_major_formatter(FuncFormatter(adjusted_formatter))


# Set axis limits
ax.set_ylim(1e-9, None)
ax.set_xlim(0.01e-4, None)


# Increase the size of the primary and secondary tick labels
plt.setp(ax.get_xticklabels(), fontsize=16)
plt.setp(ax.get_yticklabels(), fontsize=16)
plt.setp(secax.get_xticklabels(), fontsize=16)

plt.title('U - Depth:{}m - Segments for FFT:{}'.format(zz,m_seg), fontsize=18)


# Filtering in Spectral space  

## Defining cutoffs 
- Inertial period is at 16.4 hrs 

In [None]:
cutoff1_hr = 82
cutoff2_hr = 46

cutoff1 = 1/(cutoff1_hr * 3600)
cutoff2 = 1/(cutoff2_hr * 3600)

fig, ax = plt.subplots(1, figsize=(30, 10))

# Plot the mean amplitude spectrum
u_fft_mean.amp_mean.plot(ax=ax)

# Add uncertainty shading
ax.fill_between(u_fft_mean['freq'],
                u_fft_mean.err_low,
                u_fft_mean.err_up,
                color='gray', alpha=0.5, label='Uncertainty')

# Custom formatter for tick labels with 2 significant figures
def custom_formatter(x, pos):
    return f'{x:.1f}'


# Set both axes to log scale
ax.set_xscale('log')
ax.set_yscale('log')


# Grid and labels
ax.grid()
ax.set_xlabel('Frequency $(s^{-1})$', fontsize=16)
ax.set_ylabel('Power Spectrum', fontsize=18)

# Increase the number of ticks using LogLocator
ax.xaxis.set_major_locator(LogLocator(base=10, subs='auto', numticks=8))

# Create a secondary x-axis for period in hours
secax = ax.secondary_xaxis('top')
secax.set_xscale('log')
secax.set_xlabel('Period (hours)', fontsize=16)

# Compute FFT period ticks
fft_ticks = ax.get_xticks()
fft_period = compute_fft_period(fft_ticks)  # In seconds
fft_hr = fft_period / 3600


secax.set_xticks(fft_ticks)

# Set tick labels without decimals and in non-scientific notation form
secax.set_xticklabels([f'{int(x):d}' for x in fft_hr], fontsize=16)  # Increase tick label size

# Set primary x-axis tick labels with 2 significant figures
ax.xaxis.set_major_formatter(FuncFormatter(custom_formatter))

# Manually set the offset text for the power limits
ax.xaxis.get_offset_text().set_visible(False)  # Hide the default offset text

# Compute the offset manually
x_min, x_max = ax.get_xlim()
power_offset = int(np.floor(np.log10(x_max)))

# Set the custom offset text
ax.annotate(f'$\\times 10^{power_offset}$', xy=(1, 0), xycoords='axes fraction',
                fontsize=16, xytext=(-30, -30), textcoords='offset points',
                ha='center', va='center')

# Adjust tick labels by dividing by the power offset
def adjusted_formatter(x, pos):
    return f'{x / 10**power_offset:.1f}'

ax.xaxis.set_major_formatter(FuncFormatter(adjusted_formatter))


# Set axis limits
ax.set_ylim(1e-12, None)
ax.set_xlim(0.01e-4, None)

# Add in cutoffs 
ax.axvline(x = cutoff1, linestyle="--", color="k",label="cutoff1")
ax.axvline(x = cutoff2, linestyle="--", color="k",label="cutoff2")
ax.legend()



# Increase the size of the primary and secondary tick labels
plt.setp(ax.get_xticklabels(), fontsize=16)
plt.setp(ax.get_yticklabels(), fontsize=16)
plt.setp(secax.get_xticklabels(), fontsize=16)

plt.title('U - Depth:{}m - Segments for FFT:{}'.format(zz,m_seg), fontsize=18)


## Low-pass filtering (slow motions)

### Function - low pass filter 

In [None]:
def lowpass_filter_timeseries(timeseries, dt=1.0, period_cutoff=110.0, order=5):
    """
    Apply low-pass Butterworth filter to a single time series using filtfilt.
    
    Parameters:
    -----------
    timeseries : array-like
        1D time series data
    dt : float, optional
        Time step (sampling interval). Default is 1.0
    period_cutoff : float, optional
        Cutoff period in same units as dt. Default is 110.0
    order : int, optional
        Filter order. Default is 5
        
    Returns:
    --------
    filtered_ts : ndarray
        Low-pass filtered time series
    """
    # Remove NaNs with interpolation
    if np.any(np.isnan(timeseries)):
        valid_mask = ~np.isnan(timeseries)
        if np.sum(valid_mask) < len(timeseries) // 2:
            return np.full(len(timeseries), np.nan)
        timeseries = np.copy(timeseries)
        timeseries[~valid_mask] = np.interp(
            np.flatnonzero(~valid_mask),
            np.flatnonzero(valid_mask),
            timeseries[valid_mask]
        )

    timeseries = detrend(timeseries, type='linear')

    # Sampling frequency and cutoff
    fs = 1.0 / dt
    fc = 1.0 / period_cutoff  # Hz
    nyq = fs / 2.0
    normal_cutoff = fc / nyq
    
    if normal_cutoff >= 1.0:
        return timeseries

    # Design **low-pass** filter
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    
    # Apply zero-phase filter
    filtered_ts = filtfilt(b, a, timeseries)
    
    return filtered_ts


def lowpass_filter_xarray(data, time_dim='time', dt=1.0, period_cutoff=110.0, order=5):
    """
    Apply low-pass Butterworth filter over time dimension for each X,Y location in xarray.
    
    Parameters:
    -----------
    data : xarray.DataArray
        Input array with dimensions (T, Y, X) or any order containing T
    time_dim : str, optional
        Name of time dimension. Default is 'T'
    dt : float, optional
        Time step (sampling interval). Default is 1.0
    period_cutoff : float, optional
        Cutoff period in same units as dt. Default is 110.0
    order : int, optional
        Filter order. Default is 5
        
    Returns:
    --------
    filtered_result : xarray.DataArray
        Low-pass filtered time series with same dimensions as input
    """
    # Apply filter function using apply_ufunc
    filtered_result = xr.apply_ufunc(
        lowpass_filter_timeseries,
        data,
        dt,
        period_cutoff,
        order,
        input_core_dims=[[time_dim], [], [], []],
        output_core_dims=[[time_dim]],
        output_dtypes=[np.float64],
        dask='parallelized' if hasattr(data.data, 'chunks') else 'forbidden',
        vectorize=True,
        kwargs={}
    )
    
    # Preserve coordinates and attributes
    filtered_result = filtered_result.assign_coords(data.coords)
    
    # Add attributes
    filtered_result.attrs = data.attrs.copy()
    filtered_result.attrs['long_name'] = f'Low-pass filtered {data.attrs.get("long_name", "data")}'
    filtered_result.attrs['filter_type'] = 'Butterworth low-pass (filtfilt)'
    filtered_result.attrs['cutoff_period'] = period_cutoff
    filtered_result.attrs['filter_order'] = order
    filtered_result.attrs['units'] = data.attrs.get('units', 'units')
    
    return filtered_result


### Filtering to low freq. motions

In [None]:
ulow = lowpass_filter_xarray(u, time_dim='time', dt=3600, period_cutoff=(cutoff1_hr*3600), order=5)
vlow = lowpass_filter_xarray(v, time_dim='time', dt=3600, period_cutoff=(cutoff1_hr*3600), order=5)
wlow = lowpass_filter_xarray(w, time_dim='time', dt=3600, period_cutoff=(cutoff1_hr*3600), order=5)


### Plotting low freq. motions 

In [None]:
i_time_to_plot = 24

vmax_low = 7.5e-2
ulow.isel(time=i_time_to_plot).plot(vmax=vmax_low)
plt.grid()

## Bandpass filtering (seiches)

### Function - bandpass filter 

In [None]:
def bandpass_filter_timeseries(timeseries, dt=1.0, period_cutoff_low=500.0, period_cutoff_high=110.0, order=5):
    """
    Apply band-pass Butterworth filter to a single time series using filtfilt.

    Parameters:
    -----------
    timeseries : array-like
        1D time series data
    dt : float, optional
        Time step (sampling interval). Default is 1.0
    period_cutoff_low : float
        Lower cutoff period (longer period) in same units as dt.
    period_cutoff_high : float
        Upper cutoff period (shorter period) in same units as dt.
    order : int, optional
        Filter order. Default is 5

    Returns:
    --------
    filtered_ts : ndarray
        Band-pass filtered time series
    """
    from scipy.signal import butter, filtfilt, detrend
    import numpy as np

    # Remove NaNs with interpolation
    if np.any(np.isnan(timeseries)):
        valid_mask = ~np.isnan(timeseries)
        if np.sum(valid_mask) < len(timeseries) // 2:
            return np.full(len(timeseries), np.nan)
        timeseries = np.copy(timeseries)
        timeseries[~valid_mask] = np.interp(
            np.flatnonzero(~valid_mask),
            np.flatnonzero(valid_mask),
            timeseries[valid_mask]
        )

    timeseries = detrend(timeseries, type='linear')

    # Sampling frequency
    fs = 1.0 / dt
    nyq = fs / 2.0

    # Convert cutoff periods to frequencies
    fc_low = 1.0 / period_cutoff_low
    fc_high = 1.0 / period_cutoff_high

    # Normalised cutoff frequencies
    normal_cutoff = [fc_low / nyq, fc_high / nyq]

    if normal_cutoff[1] >= 1.0 or normal_cutoff[0] <= 0.0:
        return timeseries

    # Design bandpass filter
    b, a = butter(order, normal_cutoff, btype='band', analog=False)

    # Apply zero-phase filter
    filtered_ts = filtfilt(b, a, timeseries)

    return filtered_ts



def bandpass_filter_xarray(data, time_dim='time', dt=1.0, period_cutoff_low=500.0, period_cutoff_high=110.0, order=5):
    """
    Apply band-pass Butterworth filter over time dimension for each X,Y location in xarray.

    Parameters:
    -----------
    data : xarray.DataArray
        Input array with dimensions (T, Y, X) or any order containing T
    time_dim : str, optional
        Name of time dimension. Default is 'T'
    dt : float, optional
        Time step (sampling interval). Default is 1.0
    period_cutoff_low : float
        Lower cutoff period (longer period edge of the band).
    period_cutoff_high : float
        Upper cutoff period (shorter period edge of the band).
    order : int, optional
        Filter order. Default is 5

    Returns:
    --------
    filtered_result : xarray.DataArray
        Band-pass filtered time series with same dimensions as input
    """

    # Apply filter function using apply_ufunc
    filtered_result = xr.apply_ufunc(
        bandpass_filter_timeseries,
        data,
        dt,
        period_cutoff_low,
        period_cutoff_high,
        order,
        input_core_dims=[[time_dim], [], [], [], []],
        output_core_dims=[[time_dim]],
        output_dtypes=[np.float64],
        dask='parallelized' if hasattr(data.data, 'chunks') else 'forbidden',
        vectorize=True,
        kwargs={}
    )

    # Preserve coordinates and attributes
    filtered_result = filtered_result.assign_coords(data.coords)

    # Add attributes
    filtered_result.attrs = data.attrs.copy()
    filtered_result.attrs['long_name'] = f'Band-pass filtered {data.attrs.get("long_name", "data")}'
    filtered_result.attrs['filter_type'] = 'Butterworth band-pass (filtfilt)'
    filtered_result.attrs['cutoff_period_low'] = period_cutoff_low
    filtered_result.attrs['cutoff_period_high'] = period_cutoff_high
    filtered_result.attrs['filter_order'] = order
    filtered_result.attrs['units'] = data.attrs.get('units', 'units')

    return filtered_result



### Filtering to seiches 

In [None]:
useiche = bandpass_filter_xarray(u, time_dim='time', dt=3600, period_cutoff_low=(cutoff1_hr*3600), period_cutoff_high=(cutoff2_hr*3600), order=5)
vseiche = bandpass_filter_xarray(v, time_dim='time', dt=3600, period_cutoff_low=(cutoff1_hr*3600), period_cutoff_high=(cutoff2_hr*3600), order=5)
wseiche = bandpass_filter_xarray(w, time_dim='time', dt=3600, period_cutoff_low=(cutoff1_hr*3600), period_cutoff_high=(cutoff2_hr*3600), order=5)


### Plotting seiches 

In [None]:
i_time_to_plot += 1
print(i_time_to_plot)
vmax_seiche = 7.5e-2
vseiche.isel(time=i_time_to_plot).plot(vmax=vmax_seiche)
plt.grid()

## High-pass filtering (high frq. waves)

### Function - high pass filter

In [None]:
def highpass_filter_timeseries(timeseries, dt=1.0, period_cutoff=110.0, order=5):
    """
    Apply high-pass Butterworth filter to a single time series using filtfilt.
    
    Parameters:
    -----------
    timeseries : array-like
        1D time series data
    dt : float, optional
        Time step (sampling interval). Default is 1.0
    period_cutoff : float, optional
        Cutoff period in same units as dt. Default is 110.0
    order : int, optional
        Filter order. Default is 5
        
    Returns:
    --------
    filtered_ts : ndarray
        High-pass filtered time series
    """
    # Remove NaNs with interpolation
    if np.any(np.isnan(timeseries)):
        valid_mask = ~np.isnan(timeseries)
        if np.sum(valid_mask) < len(timeseries) // 2:
            return np.full(len(timeseries), np.nan)
        timeseries = np.copy(timeseries)
        timeseries[~valid_mask] = np.interp(
            np.flatnonzero(~valid_mask),
            np.flatnonzero(valid_mask),
            timeseries[valid_mask]
        )

    timeseries = detrend(timeseries, type='linear')

    # Sampling frequency and cutoff
    fs = 1.0 / dt
    fc = 1.0 / period_cutoff  # Hz
    nyq = fs / 2.0
    normal_cutoff = fc / nyq
    
    if normal_cutoff >= 1.0:
        return timeseries
    
    # Design filter
    b, a = butter(order, normal_cutoff, btype='high', analog=False)
    
    # Apply zero-phase filter
    filtered_ts = filtfilt(b, a, timeseries)
    
    return filtered_ts

def highpass_filter_xarray(data, time_dim='T', dt=1.0, period_cutoff=110.0, order=5):
    """
    Apply high-pass Butterworth filter over time dimension for each X,Y location in xarray.
    
    Parameters:
    -----------
    data : xarray.DataArray
        Input array with dimensions (T, Y, X) or any order containing T
    time_dim : str, optional
        Name of time dimension. Default is 'T'
    dt : float, optional
        Time step (sampling interval). Default is 1.0
    period_cutoff : float, optional
        Cutoff period in same units as dt. Default is 110.0
    order : int, optional
        Filter order. Default is 5
        
    Returns:
    --------
    filtered_result : xarray.DataArray
        High-pass filtered time series with same dimensions as input
    """
    # Apply filter function using apply_ufunc
    filtered_result = xr.apply_ufunc(
        highpass_filter_timeseries,
        data,
        dt,
        period_cutoff,
        order,
        input_core_dims=[[time_dim], [], [], []], # Time dimension is core for first input
        output_core_dims=[[time_dim]], # Output has same time dimension
        output_dtypes=[np.float64],
        dask='parallelized' if hasattr(data.data, 'chunks') else 'forbidden',
        vectorize=True, # Vectorize over remaining dimensions
        kwargs={}
    )
    
    # Preserve original coordinates and attributes
    filtered_result = filtered_result.assign_coords(data.coords)
    
    # Add attributes
    filtered_result.attrs = data.attrs.copy()
    filtered_result.attrs['long_name'] = f'High-pass filtered {data.attrs.get("long_name", "data")}'
    filtered_result.attrs['filter_type'] = 'Butterworth high-pass (filtfilt)'
    filtered_result.attrs['cutoff_period'] = period_cutoff
    filtered_result.attrs['filter_order'] = order
    filtered_result.attrs['units'] = data.attrs.get('units', 'units')
    
    return filtered_result




### Filtering to high freq. waves 

In [None]:
uhigh = highpass_filter_xarray(u, time_dim='time', dt=3600, period_cutoff=(cutoff2_hr*3600), order=5)
vhigh = highpass_filter_xarray(v, time_dim='time', dt=3600, period_cutoff=(cutoff2_hr*3600), order=5)
whigh = highpass_filter_xarray(w, time_dim='time', dt=3600, period_cutoff=(cutoff2_hr*3600), order=5)


### Plot high freq. internal waves

In [None]:
vmax_high = 7.5e-3
vhigh.isel(time=i_time_to_plot).plot(vmax=vmax_high)
plt.grid()

## Comparing time series

In [None]:
xx = 40000; yy = 20000

fig, ax = plt.subplots(2,1, figsize=(18, 8))

# upper plot 
ulow.sel(XG=xx, YC=yy, method="nearest").plot(ax=ax[0], label="low freq.")
useiche.sel(XG=xx, YC=yy, method="nearest").plot(ax=ax[0], label="seiche")
uhigh.sel(XG=xx, YC=yy, method="nearest").plot(ax=ax[0], label="high freq.")

# bottom plot
u.sel(XG=xx, YC=yy, method="nearest").plot(ax=ax[1], label="Original")
(ulow + useiche + uhigh).sel(XG=xx, YC=yy, method="nearest").plot(ax=ax[1], label="low freq + seiche + high freq")

for ax in (ax[0], ax[1]):
    ax.grid()
    ax.set_xlabel('')
    ax.legend(loc='upper right')
    ax.set_ylabel('U (m/s)')



In [None]:
xx = 40000; yy = 15000

fig, ax = plt.subplots(2,1, figsize=(18, 8))

# upper plot
vlow.sel(XC=xx, YG=yy, method="nearest").plot(ax=ax[0], label="low freq.")
vseiche.sel(XC=xx, YG=yy, method="nearest").plot(ax=ax[0], label="seiche")
vhigh.sel(XC=xx, YG=yy, method="nearest").plot(ax=ax[0], label="high freq.")

# bottom plot
v.sel(XC=xx, YG=yy, method="nearest").plot(ax=ax[1], label="Original")
(vlow + vseiche + vhigh).sel(XC=xx, YG=yy, method="nearest").plot(ax=ax[1], label="low freq + seiche + high freq")

for ax in (ax[0], ax[1]):
    ax.grid()
    ax.set_xlabel('')
    ax.legend(loc='upper right')
    ax.set_ylabel('V (m/s)')



In [None]:
xx = 40000; yy = 15000

fig, ax = plt.subplots(2,1, figsize=(18, 8))

# upper plot
wlow.sel(XC=xx, YC=yy, method="nearest").plot(ax=ax[0], label="low freq.")
wseiche.sel(XC=xx, YC=yy, method="nearest").plot(ax=ax[0], label="seiche")
whigh.sel(XC=xx, YC=yy, method="nearest").plot(ax=ax[0], label="high freq.")

# bottom plot
w.sel(XC=xx, YC=yy, method="nearest").plot(ax=ax[1], label="Original")
(wlow + wseiche + whigh).sel(XC=xx, YC=yy, method="nearest").plot(ax=ax[1], label="low freq + seiche + high freq")

for ax in (ax[0], ax[1]):
    ax.grid()
    ax.set_xlabel('')
    ax.legend(loc='upper right')
    ax.set_ylabel('W (m/s)')



In [None]:
useiche.to_netcdf(r"/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/analysis/u_surface_filtered_46-82h.nc")
vseiche.to_netcdf(r"/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/analysis/v_surface_filtered_46-82h.nc")

In [None]:
wseiche.to_netcdf(r"/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/analysis/w_surface_filtered_46-82h.nc")

# Energy analysis

useiche = xr.open_dataset(r"/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/analysis/u_surface_filtered_46-82h.nc")
vseiche = xr.open_dataset(r"/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/analysis/v_surface_filtered_46-82h.nc")

In [None]:
aligned_useiche = useiche.rename({'XG':'XC'})
aligned_useiche['XC'] = wseiche['XC']

aligned_vseiche = vseiche.rename({'YG':'YC'})
aligned_vseiche['YC'] = wseiche['YC']

aligned_wseiche = wseiche.rename({'Zl':'Z'})
aligned_wseiche['Z'] = useiche['Z']

In [None]:
def compute_ke_snapshot(uvel, vvel, wvel, dx, dy, dz, rho=1000.0):
    ke = 0.5 * rho * (uvel ** 2 + vvel ** 2 + wvel ** 2) * dx * dy * dz  # This gives J per cell

    return ke / 1e6  # Convert to MJ

In [None]:
ke_mj_seiche = compute_ke_snapshot(aligned_useiche, aligned_vseiche, aligned_wseiche, 200, 200, ds.drF.isel(Z=zz).values)

In [None]:
ke_mj_seiche.isel(time=-1).plot()

In [None]:
ke_mj_seiche.sum(dim=['XC','YC']).plot()

In [None]:
df_ke = ke_mj_seiche.sum(dim=['XC','YC']).to_dataframe(name='ke_mj_seiche')['ke_mj_seiche'].reset_index()

In [None]:
df_ke.to_csv(r"/home/leroquan@eawag.wroot.emp-eaw.ch/work_space/dummy_extended/analysis/ke_surface_seiche.csv")