In [None]:
# Finding EWT for the 8 channels, and for each channel 5 IMFs

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
from scipy.fft import fft, fftfreq

def empirical_wavelet_transform(signal, fs=996, n_imfs=5):
    """
    Corrected implementation of Empirical Wavelet Transform
    Returns IMFs in proper order (high to low frequency)
    """
    # Compute Fourier spectrum
    n = len(signal)
    fft_vals = fft(signal)
    freqs = fftfreq(n, 1/fs)
    
    # Keep only positive frequencies for peak detection
    pos_freq = freqs[freqs > 0]
    pos_fft = np.abs(fft_vals[freqs > 0])
    
    # Find peaks in spectrum to determine boundaries
    peaks, _ = find_peaks(pos_fft, height=np.mean(pos_fft), distance=50)
    if len(peaks) < 2:
        peaks = np.array([0, len(pos_freq)//2, len(pos_freq)-1])
    
    # Sort peaks and get boundaries
    boundaries = np.sort(pos_freq[peaks])
    if len(boundaries) > n_imfs + 1:
        boundaries = boundaries[:n_imfs + 1]
    
    # Add 0 and Nyquist frequency if needed
    if boundaries[0] > 0:
        boundaries = np.insert(boundaries, 0, 0)
    if boundaries[-1] < fs/2:
        boundaries = np.append(boundaries, fs/2)
    
    # Create filter bank and extract IMFs - from high to low frequency
    imfs = []
    # Iterate in reverse order to get high frequency first
    for i in range(len(boundaries)-1, 0, -1):
        low_bound = boundaries[i-1]
        high_bound = boundaries[i]
        
        # Create bandpass filter for both positive and negative frequencies
        mask = ((np.abs(freqs) >= low_bound) & (np.abs(freqs) < high_bound))
        filtered_fft = fft_vals.copy()
        filtered_fft[~mask] = 0
        
        # Inverse FFT to get IMF
        component = np.real(np.fft.ifft(filtered_fft))
        imfs.append(component)
    
    return imfs

# --- Assuming preprocessed_emg is already loaded ---
# preprocessed_emg: list of 8 channels with the EMG signals
fs = 996  # Sampling rate
n_imfs = 5  # Number of IMFs to extract

# Time vector
time = np.arange(len(preprocessed_emg[0])) / fs

# Iterate over all channels and extract IMFs
for channel_idx in range(8):
    signal = preprocessed_emg[channel_idx]  # Get the signal for the current channel
    
    # Apply corrected EWT implementation
    imfs = empirical_wavelet_transform(signal, fs=fs, n_imfs=n_imfs)
    
    # Visualization of the original signal and IMFs
    plt.figure(fi   gsize=(15, 10))

    # Check that we don't exceed the maximum number of subplots
    num_subplots = min(n_imfs + 1, 6)

    plt.subplot(num_subplots, 1, 1)
    plt.plot(time, signal, 'b')
    plt.title(f'Original EMG (Channel {channel_idx+1})')
    plt.ylabel('Amplitude')
    plt.grid(True)

    for i, imf in enumerate(imfs[:num_subplots-1]):  # Limit to the available number of subplots
        plt.subplot(num_subplots, 1, i + 2)
        plt.plot(time, imf, 'r')
        plt.title(f'IMF {i+1}')
        plt.ylabel('Amplitude')
        plt.grid(True)

    plt.xlabel('Time (s)')
    plt.tight_layout()
    plt.show()

    # Optional: Spectral analysis of the IMFs to verify frequency ordering
    def plot_spectra(signal, imfs, fs):
        n = len(signal)
        freqs = fftfreq(n, 1/fs)[:n//2]
        n_imfs = len(imfs)  # <- Fix: Dynamically get number of IMFs
        
        plt.figure(figsize=(12, 8))

        # Original signal spectrum
        plt.subplot(n_imfs + 1, 1, 1)
        plt.plot(freqs, 2.0/n * np.abs(fft(signal)[:n//2]))
        plt.title('Original Signal Spectrum')
        plt.grid(True)

        # IMF spectra
        for i, imf in enumerate(imfs):
            plt.subplot(n_imfs + 1, 1, i + 2)
            plt.plot(freqs, 2.0/n * np.abs(fft(imf)[:n//2]))
            plt.title(f'IMF {i+1} Spectrum')
            plt.grid(True)

        plt.xlabel('Frequency (Hz)')
        plt.tight_layout()
        plt.show()


    # Uncomment to see frequency spectra of IMFs
    plot_spectra(signal, imfs, fs)
