In [None]:
import os
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from eofs.multivariate.standard import MultivariateEof
from eofs.xarray import Eof

from utils_mitgcm import open_mitgcm_ds_from_config
from utils_signal_processing import *

In [None]:
model = 'geneva_dummy_7ms'
mitgcm_config, ds = open_mitgcm_ds_from_config('..//config.json', model)

In [None]:
horizontal_resolution = ds.dxC.isel(XG=0, YC=0).values
ds['YC'] = np.arange(1, len(ds['YC'])+1) * horizontal_resolution - horizontal_resolution/2
ds['XC'] = np.arange(1, len(ds['XC'])+1) * horizontal_resolution - horizontal_resolution/2

In [None]:
depths = range(50)
times = range(15*24) #range(15*24, len(ds.time))
ds_select = ds.isel(Z=depths, time=times)

In [None]:
aligned_u = ds_select.UVEL.rename({'XG':'XC'})
aligned_u['XC'] = ds_select['XC']

aligned_v = ds_select.VVEL.rename({'YG':'YC'})
aligned_v['YC'] = ds_select['YC']

aligned_w = ds_select.WVEL.rename({'Zl':'Z'})
aligned_w['Z'] = ds_select['Z']

In [None]:
folder_path = os.path.dirname(mitgcm_config['datapath'])
output_folder = os.path.join(folder_path, "eof")
os.makedirs(output_folder, exist_ok=True)

# Multivariable EOF analysis

In [None]:
u_weighted = (aligned_u * np.sqrt(ds_select.drF * horizontal_resolution**2)) # u x sqrt(volume)
v_weighted = (aligned_v * np.sqrt(ds_select.drF * horizontal_resolution**2))

In [None]:
complex_var = u_weighted + 1j * v_weighted

In [None]:
solver_multi = MultivariateEof([u_weighted.values, v_weighted.values])

In [None]:
pcs = xr.DataArray(
    solver_multi.pcs(),
    coords={"time": ds_select.time, "mode": np.arange(1, len(ds_select.time)+1)},
    dims=["time", "mode"],
    name='PCs')

In [None]:
eofs = xr.DataArray(
    solver_multi.eofs(),
    coords={"var":["U","V"], "mode": np.arange(1, len(ds_select.time)+1), "Z": ds_select.Z.values, "YC": ds_select.YC.values, "XC": ds_select.XC.values},
    dims=["var", "mode", "Z", "YC", "XC"],
    name='eofs')

In [None]:
mode=1
z_index = 0

In [None]:
eofs.isel(var=0, mode=mode-1, Z=25).plot(cmap='RdBu_r')

In [None]:
fig,axs = plt.subplots(2,1, figsize=(12,15))

pcs.isel(mode=mode-1).plot(ax=axs[0])
axs[0].set_title(f'PC - Mode={mode}')

U_pattern = eofs.isel(var=0)
V_pattern = eofs.isel(var=1)

# Compute horizontal amplitude
amp = np.sqrt(U_pattern**2 + V_pattern**2)
vmax = np.abs(amp.isel(mode=mode-1, Z=z_index)).max()
im1 = amp.isel(mode=mode-1, Z=z_index).plot(ax=axs[1], cmap='RdBu_r', vmin=-vmax, vmax=vmax)
axs[1].set_title(f'Amp - Mode={mode}')

subsetting_factor=5
plt.imshow(amp[z_index,:,:])
plt.quiver(ds_select.XC[::subsetting_factor], ds_select.YC[::subsetting_factor],
           U_pattern[z_index,:,:][::subsetting_factor,::subsetting_factor],
           V_pattern[z_index,:,:][::subsetting_factor,::subsetting_factor],
           scale=1e-5,
           ax=axs[1])

#fig.savefig(os.path.join(output_folder, f'eof_mode{mode}_eddies.png'))
mode+=1

### Total KE

In [None]:
KE_all = (u_weighted**2+v_weighted**2) * 0.5 * 1000 / 1e6
KE_all_sum = KE_all.sum(dim=['XC','YC', 'Z'])

### KE per mode

In [None]:
def get_KE_from_mode(ds, solver_multi, mode):
    reconstructed_data = xr.DataArray(
        solver_multi.reconstructedField([mode]),
        coords={"var":["U", "V"], "time": ds.time, "Z": ds.Z, "YC": ds.YC,"XC": ds.XC},
        dims=["var", "time", "Z", "YC", "XC"],
        name='Reconstructed data')

    KE_mode = (reconstructed_data.isel(var=0)**2 + reconstructed_data.isel(var=1)**2) * 0.5 * 1000 / 1e6
    KE_mode_sum = KE_mode.sum(dim=['XC','YC', 'Z'])

    return KE_mode, KE_mode_sum


In [None]:
nb_modes = 20

In [None]:
# Collect all modes
KE_modes = []

for mode in range(nb_modes):
    _, KE_mode_sum = get_KE_from_mode(ds_select, solver_multi, mode)
    KE_modes.append(KE_mode_sum)

# Combine into one DataArray
KE_modes = xr.concat(KE_modes, dim='mode')

In [None]:
KE_modes['mode'] = range(nb_modes)
df = KE_modes.T.to_pandas()

fig, ax = plt.subplots(figsize=(12,8))
KE_all_sum.to_pandas().plot(color='red', ax=ax)

# Stacked area plot
df.plot.area(figsize=(10, 6), cmap='viridis', ax=ax)
plt.title("Stacked KE from Modes")
plt.ylabel("KE (MJ)")
plt.show()

### KE combinaison modes

In [None]:
nb_modes = 20

In [None]:
reconstructed_selected_modes_data = xr.DataArray(
    solver_multi.reconstructedField(nb_modes),
    coords={"var":["U", "V"], "time": ds_select.time, "Z": ds_select.Z, "YC": ds_select.YC,"XC": ds_select.XC},
    dims=["var", "time", "Z", "YC", "XC"],
    name='Reconstructed data')

In [None]:
var = 0 #0=u
mode = 1
z_index = 0
y_index = 80
x_index = 150

In [None]:
KE_selected_modes = (reconstructed_selected_modes_data.isel(var=0)**2 + reconstructed_selected_modes_data.isel(var=1)**2) * 0.5 * 1000 / 1e6
KE_selected_modes_sum = KE_selected_modes.sum(dim=['XC','YC', 'Z'])

In [None]:
KE_all_sum.plot()
KE_selected_modes_sum.plot()

In [None]:
KE_mode1.to_dataframe(name='ke_mode1_mj')['ke_mode1_mj'].reset_index().to_csv(os.path.join(output_folder, "ke_mode1_z50.csv"))

### Comparison with residual energy

In [None]:
reconstructed_residual_data = xr.DataArray(
    solver_multi.reconstructedField(np.arange(nb_modes+1,len(ds_select.time))),
    coords={"var":["U", "V"], "time": ds_select.time, "Z": ds_select.Z, "YC": ds_select.YC,"XC": ds_select.XC},
    dims=["var", "time", "Z", "YC", "XC"],
    name='Reconstructed data')

In [None]:
KE_resisual = (reconstructed_residual_data.isel(var=0)**2 + reconstructed_residual_data.isel(var=1)**2)* 0.5 * 1000 / 1e6
KE_resisual_sum = KE_resisual.sum(dim=['XC','YC', 'Z'])

In [None]:
KE_resisual_sum.to_dataframe(name='ke_res_mj')['ke_res_mj'].reset_index().to_csv(os.path.join(output_folder, "ke_res_z50.csv"))

### Spectral analysis

In [None]:
u_fft = xr_compute_meanfft(pcs.isel(mode=mode-1), M=1)

In [None]:
fig,ax = plot_freq_spectrum(u_fft, 'PC', depth=0, m_segm=1, y_lim_min=1e-3, x_lim_min=0.01e-4, fontsize=10)

# Rotary EOF with complex variable

In [None]:
def rotary_eof_xarray(ds, dx, dy, rho=None, n_modes=None):
    """
    Perform complex (rotary) EOF analysis on xarray Dataset with UVEL, VVEL.

    Parameters
    ----------
    ds : xarray.Dataset
        Must contain variables 'UVEL' and 'VVEL' with dims (time, z, y, x)
    dx, dy : float
        Horizontal grid spacing in meters (assumed uniform)
    rho : float or array-like, optional
        Density [kg/m³], either scalar or 1D array over z.
        Default: 1025 kg/m³ constant.
    n_modes : int, optional
        Number of EOF modes to retain. If None, keep all.

    Returns
    -------
    result : xarray.Dataset
        Contains complex EOFs (real=u pattern, imag=v pattern),
        principal components (PCs), modal KE, total KE, and explained variance.
    """

    # === 1. Align velocities ===
    U = ds.UVEL.rename({'XG':'XC'})
    U['XC'] = ds['XC']

    V = ds.VVEL.rename({'YG':'YC'})
    V['YC'] = ds['YC']

    dz = ds.drF
    nt, nz, ny, nx = U.shape

    # === 2. Prepare weights ===
    if rho is None:
        rho = np.full(nz, 1025.0)
    rho = np.asarray(rho).reshape(nz,)
    w_z = np.sqrt(rho * dz * dx * dy)  # sqrt(mass per layer)

    # complex weighted velocities
    Uw = (U * w_z).values.reshape(nt, nz*ny*nx)
    Vw = (V * w_z).values.reshape(nt, nz*ny*nx)
    Z = (Uw + 1j*Vw).T  # shape (nstate, nt)

    # === 3. Complex SVD ===
    Umat, S, VT = np.linalg.svd(Z, full_matrices=False)
    n_total = S.size
    if n_modes is None or n_modes > n_total:
        n_modes = n_total

    # truncate
    phi = Umat[:, :n_modes]
    PCs = (phi.conj().T @ Z).T  # (nt, n_modes)
    modal_ke_ts = 0.5 * np.abs(PCs) ** 2
    total_ke_ts = 0.5 * np.sum(np.abs(Z) ** 2, axis=0)

    # === 4. Map EOFs back to physical units ===
    w_vec = np.repeat(w_z.values, ny * nx)
    phi_phys = phi / w_vec[:, None]
    EOFs_phys = phi_phys.reshape(nz, ny, nx, n_modes).transpose(3, 0, 1, 2)

    # === 5. Compute explained variance fraction ===
    lam = 0.5 * S**2  # KE per mode
    frac = lam / lam.sum()

    # === 6. Package into xarray ===
    coords = {
        "mode": np.arange(1, n_modes + 1),
        "Z": ds.Z,
        "YC": ds.YC,
        "XC": ds.XC,
        "time": ds.time,
    }

    ds_out = xr.Dataset(
        {
            "EOF": (("mode", "z", "y", "x"), EOFs_phys),
            "PC": (("time", "mode"), PCs),
            "modal_KE": (("time", "mode"), modal_ke_ts),
            "KE_total": ("time", total_ke_ts),
            "variance_fraction": ("mode", frac[:n_modes]),
        },
        coords=coords,
        attrs={
            "description": "Complex (rotary) EOF decomposition of UVEL + iVVEL",
            "weighting": "sqrt(rho * dz * dx * dy)",
            "units": {
                "EOF": "m/s (complex: real=u, imag=v)",
                "PC": "sqrt(J/kg)",
                "modal_KE": "J/kg",
            },
        },
    )

    return ds_out


In [None]:
rotary = rotary_eof_xarray(ds_select, dx=200, dy=200, rho=None, n_modes=5)

In [None]:
# select first mode
mode = 1
EOF = rotary.EOF.isel(mode=mode)

# U and V components
U_pattern = EOF.real
V_pattern = EOF.imag

# Compute horizontal amplitude
amp = np.sqrt(U_pattern**2 + V_pattern**2)

# Plot quiver for a single depth slice
z_slice = 0  # e.g., 5th vertical level
plt.figure(figsize=(15,6))
subsetting_factor=5
plt.imshow(amp[z_slice,:,:])
plt.quiver(rotary.x[::subsetting_factor], rotary.y[::subsetting_factor],
           U_pattern[z_slice,:,:][::subsetting_factor,::subsetting_factor],
           V_pattern[z_slice,:,:][::subsetting_factor,::subsetting_factor],
           scale=1e-5)
plt.gca().invert_yaxis()
plt.title(f'Rotary EOF mode {mode+1}, depth level {z_slice}')
plt.xlabel('X')
plt.ylabel('Y')
plt.colorbar(label='Amplitude')
plt.show()

PC = rotary.PC.isel(mode=mode)
plt.figure(figsize=(10,4))
plt.plot(rotary.time, PC, label='Amplitude')
plt.ylabel('Amplitude (sqrt(KE))')
plt.xlabel('Time')
plt.title(f'PC amplitude, mode {mode+1}')
plt.grid()
plt.show()


In [None]:
mode=1
PC = rotary.PC.isel(mode=mode)

# amplitude
plt.figure(figsize=(10,4))
plt.plot(rotary.time, np.abs(PC), label='Amplitude')
plt.ylabel('Amplitude (sqrt(KE))')
plt.xlabel('Time')
plt.title(f'PC amplitude, mode {mode+1}')
plt.grid()
plt.show()

# phase (rotation)
plt.figure(figsize=(10,4))
plt.plot(rotary.time, np.angle(PC), label='Phase')
plt.ylabel('Phase (radians)')
plt.xlabel('Time')
plt.title(f'PC phase, mode {mode+1}')
plt.grid()
plt.show()


In [None]:
from numpy.fft import fft, fftfreq

pc = PC.values
nt = len(rotary.time)
dt = (rotary.time[1] - rotary.time[0]).values / np.timedelta64(1, 's')  # seconds as float

freq = fftfreq(nt, d=dt)
spectrum = np.abs(fft(pc))**2

cw_energy = spectrum[freq < 0].sum()
ccw_energy = spectrum[freq > 0].sum()

rotation = "CCW" if ccw_energy > cw_energy else "CW"
print(f"Mode {mode+1} is predominantly {rotation} rotating")


In [None]:
plt.figure(figsize=(10,4))
for m in range(rotary.dims['mode']):
    plt.plot(rotary.time, rotary.modal_KE.isel(mode=m), label=f'Mode {m+1}')
plt.ylabel('Modal KE')
plt.xlabel('Time')
plt.title('Rotary EOF Modal KE')
plt.legend()
plt.grid()
plt.show()

In [None]:
import matplotlib.animation as animation

fig, ax = plt.subplots(figsize=(15,6))

z_slice = 5
U = U_pattern[z_slice,:,:]
V = V_pattern[z_slice,:,:]

Q = ax.quiver(rotary.x, rotary.y, U, V, scale=1)

def update(frame):
    PC_val = PC[frame].values
    U_new = np.real(EOF.real[z_slice,:,:] * PC_val)
    V_new = np.real(EOF.imag[z_slice,:,:] * PC_val)
    Q.set_UVC(U_new, V_new)
    return Q,

ani = animation.FuncAnimation(fig, update, frames=len(rotary.time), blit=True)
plt.show()


# Single variable EOF analysis

In [None]:
solver_u = Eof(ds_select.UVEL.isel(Z=[0,1]))

In [None]:
pcs = solver_u.pcs()

In [None]:
pcs.isel(mode=0).plot()

In [None]:
u_fft = xr_compute_meanfft(pcs.isel(mode=0), M=1)

In [None]:
fig,ax = plot_freq_spectrum(u_fft, 'U', depth=0, m_segm=1, y_lim_min=1e-12, x_lim_min=0.01e-4, fontsize=10)

In [None]:
cutoff1_hr = 92
cutoff2_hr = 39
useiche = filter_signal_xarray(pcs.isel(mode=0), btype='bandpass', time_dim='time', dt=3600, period_cutoff_low=(cutoff1_hr*3600), period_cutoff_high=(cutoff2_hr*3600), order=5)

In [None]:
useiche.plot()

In [None]:
eofs = solver_u.eofs()

In [None]:
mode=0

In [None]:
fig,axs = plt.subplots(1,2, figsize=(20,7))
pcs.isel(mode=mode).plot(ax=axs[0])
eofs.isel(mode=mode, Z=0).plot(ax=axs[1])
mode += 1

In [None]:
reconstructed_data = solver_u.reconstructedField(1)

In [None]:
plt.figure(figsize=(20,10))
ds_select.UVEL.isel(Z=0, XC=150,YC=80).plot(color='r', label='U')
(eofs.isel(mode=0, Z=0, XC=150,YC=80) * useiche).plot(color='green', label='filtered eof mode 0')
reconstructed_data.isel(XC=150,YC=80,Z=0).plot(color='blue', label='eof mode 0')
plt.legend()
plt.show()