In [None]:
import numpy as np
from scipy import signal

# Preprocess EMG: Bandpass filtering (20-450 Hz), and normalization (min-max)
def preprocess_emg(emg_data, sampling_rate):
    """
    emg_data shape: (8, N) where 8 = number of channels, N = time points
    """
    # Bandpass filter between 20-450 Hz
    lowcut = 20.0
    highcut = 450.0
    nyquist = 0.5 * sampling_rate
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = signal.butter(4, [low, high], btype='band')

    # Transpose to (N, 8) for filtering along axis=0 (time axis)
    emg_data_t = emg_data.T  # shape becomes (N, 8)
    filtered_emg = signal.filtfilt(b, a, emg_data_t, axis=0)

    # Normalize the data (min-max normalization)
    normalized_emg = (filtered_emg - np.min(filtered_emg, axis=0)) / (np.max(filtered_emg, axis=0) - np.min(filtered_emg, axis=0) + 1e-8)

    return normalized_emg.T  # Return shape back to (8, N)

# Test with simulated data: 8 channels, 992 time points
sample_emg = np.array(emg_data[0])
sampling_rate=len(sample_emg[0])
print(sampling_rate, sample_emg[:, :5])
preprocessed_emg = preprocess_emg(sample_emg, sampling_rate)

# Print results
print("Preprocessed shape:", preprocessed_emg.shape)
print(preprocessed_emg)  # Print first 5 timepoints per channel
