In [68]:
import os
import torch
from torch.utils.data import Dataset as TorchDataset
import pandas as pd
import numpy as np
from sklearn import preprocessing
import random
import mne

from scipy.signal import butter, lfilter
from scipy.signal import iirnotch, butter, filtfilt
from sklearn.preprocessing import MinMaxScaler

## Dataset

In [75]:
class EEGMotorMovementDataset(TorchDataset):
    """
    Basic EMG-EPN612 Dataset: loads data from files
    """
    #max-length 664 before
    def __init__(self, grouped_df, fs=160, window_size=0.25, overlap=0.0, max_samples=656, min_samples=476, window_mode="rms", scaling=True):
        """
        @param grouped_df: dataset df grouped by File Path
        @param fs: frequency sample rate in Hz
        @param window_size: length of window in ms
        @param overlap: overlap of windows in ms
        @param max_samples: max sample length in dataset (599 for EMG-EPN612)
        """  
        self.fs = fs
        self.window_mode = window_mode
        self.window_samples = int(window_size * fs)
        self.step_samples = int(self.window_samples * (1 - overlap))
        self.max_samples = max_samples
        self.min_samples = min_samples
        self.max_windows = (self.max_samples - self.window_samples) // self.step_samples + 1
        self.scaling = scaling

        print("max samples", self.max_samples)
        print("max windows", self.max_windows)
        labels = []
        locations = []
        start_indices = []
        end_indices = []

        all_data = []
        for folder_path, group in grouped_df:
            for label in group["Label"].unique():
                label_group = group[group["Label"] == label]
                # sampled = label_group.sample(num_reps, replace=False, random_state=42)
                # Extend the lists with the sampled data
                labels.extend(label_group["Label"].values)
                locations.extend(label_group["File_Path"].values)
                start_indices.extend(label_group["Start_Index"].values)
                end_indices.extend(label_group["End_Index"].values)
                if scaling:
                    for _, row in label_group.iterrows():
                        file_path = row["File_Path"]
                        start_idx = row["Start_Index"]
                        end_idx = row["End_Index"]
                        eeg_data = self.get_eeg_data(file_path, start_idx, end_idx, str(file_path) + '.csv', True)
                        # print("EEG DATA SHAPE")
                        # print(eeg_data.shape) 
                        all_data.append(eeg_data)
        if scaling:
            print("ALL DATA SHAPE")
            print(len(all_data)) 
            all_data = np.stack(all_data)  
            print(all_data.shape) #(num items, num channels, samples)
            reshaped_x = all_data.reshape(all_data.shape[0], all_data.shape[1] * all_data.shape[2])
            print(reshaped_x.shape)
            self.scaler = MinMaxScaler()
            self.scaler.fit(reshaped_x)  
            print(f"Scaler fitted: Min={self.scaler.data_min_}, Max={self.scaler.data_max_}")
            del all_data

        labels = np.array(labels)
        locations = np.array(locations)
        start_indices = np.array(start_indices)
        end_indices = np.array(end_indices)

        # Encode labels into integer values, print unique counts
        le = preprocessing.LabelEncoder()
        labels = torch.from_numpy(le.fit_transform(labels.reshape(-1)))
        self.print_unique_labels(labels)   

        # save all data into flat array for later use
        self.all_gestures = self.create_flat_array(labels, locations, start_indices, end_indices)
        print(self.all_gestures.shape)


    def print_unique_labels(self, labels):
        unique_labels, counts = torch.unique(labels, return_counts=True)
        print("Unique labels and counts:")
        for label, count in zip(unique_labels.tolist(), counts.tolist()):
            print(f"Label {label}: {count} occurrences")

    def create_flat_array(self, labels, locations, start_indices, end_indices):
        flat_array = []
        label_idx = 0
        for file_path in locations:
            flat_array.append([file_path, labels[label_idx], start_indices[label_idx], end_indices[label_idx]])
            label_idx += 1
        return np.array(flat_array, dtype=object)    

    def get_eeg_data(self, file_path, start_idx, end_idx, output_file, padding=False):
        # convert file path to h5 instead of json
        raw = mne.io.read_raw_edf(file_path, preload=False, verbose=False)
        raw_data_segment, times = raw[:, start_idx:end_idx]

        
        #eeg_data = self.normalize_and_group_channels(raw_data_segment, raw.info['ch_names'])
        eeg_data = np.array(raw_data_segment)
        # print(raw.info['ch_names'])
        #eeg_data = self.combine_channels(eeg_data, raw.info['ch_names'])
        # print(eeg_data.shape)
        # print(eeg_data.dtype)
        #eeg_data = eeg_data.T
        #print(eeg_data.shape)
        # header = ",".join([f"Channel{i+1}" for i in range(eeg_data.shape[0])])
        # np.savetxt(output_file, eeg_data, delimiter=',', header=header, comments='')
        def apply_notch_filter(data, freq, fs, quality_factor=30):
            notch_freq = freq / (fs / 2)  # Normalized frequency
            b, a = iirnotch(notch_freq, quality_factor)
            return lfilter(b, a, data, axis=1)

        # Define Butterworth band-pass filter
        def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
            nyquist = 0.5 * fs
            low = lowcut / nyquist
            high = highcut / nyquist
            b, a = butter(order, [low, high], btype='band')
            return lfilter(b, a, data, axis=1)

        eeg_data = apply_notch_filter(eeg_data, freq=60, fs=160)

        # Apply Butterworth band-pass filter
        eeg_data = butter_bandpass_filter(eeg_data, lowcut=2, highcut=60, fs=160)


        if padding:
            num_samples = eeg_data.shape[1]
            pad_size = self.max_samples - num_samples
            if pad_size > 0:
                padding = np.zeros((eeg_data.shape[0], pad_size))
                eeg_data = np.hstack((eeg_data, padding))
            eeg_data = eeg_data[:, :self.max_samples]
        return eeg_data 
    
    def get_high_pass_filtered_data(self, emg_data):
        num_samples = emg_data.shape[1]
        pad_size = self.max_samples - num_samples
        if pad_size > 0:
            padding = np.zeros((emg_data.shape[0], pad_size))
            emg_data = np.hstack((emg_data, padding))
        # emg_data = self.normalize_eeg(emg_data, local_min, local_max, -100.0, 100.0)
        # Ensure emg_data has contiguous memory layout
        emg_data = emg_data.copy()
        # for channel_idx in range(emg_data.shape[0]):
        #     emg_data[channel_idx] = self.highpass_filter(
        #         emg_data[channel_idx], 20, self.fs, 5
        #     )
        # return windows tensor of consistent shape and true number of windows (without padding) 
        # print(emg_data.shape)
        return torch.tensor(emg_data, dtype=torch.float32), num_samples

    def __getitem__(self, index):
        # Step 1: Fetch file path, gesture index, label and location
        file_path, label, start_idx, end_idx = self.all_gestures[index]
        # Step 2: Load EMG data from h5 file
        emg_data = self.get_eeg_data(file_path, start_idx, end_idx, str(index) + '.csv', True)
        # if self.scaling:
        #     eeg_data_scaled_flat = self.scaler.transform(emg_data.reshape(1, -1))
        #     # print(eeg_data_scaled_flat.shape)
        #     eeg_data_scaled = eeg_data_scaled_flat.reshape(emg_data.shape)  # Shape: (channels, samples)
        #     emg_data = eeg_data_scaled
        # Step 3: Create RMS windows and get original length


        # print("highpass")
        windows, original_length = self.get_high_pass_filtered_data(emg_data)


        return windows, original_length, label

    def __len__(self):
        return len(self.all_gestures)

## Load Dataset

In [76]:
from torch.utils.data import DataLoader

DATASET_DIR = 'raw_data/eeg-motor-movement/'
dataset_config = {
    "meta_csv": os.path.join(DATASET_DIR, "eeg-motor-movement-metadata.csv"),
}

meta_csv = dataset_config['meta_csv']
df = pd.read_csv(meta_csv)
df['Folder_Path'] = df['File_Path'].apply(lambda x: os.path.dirname(x))
grouped = df.groupby('Folder_Path')
group_keys = list(grouped.groups.keys())  # Get the group keys (folder paths)

ds_testing = EEGMotorMovementDataset(grouped)

test_dl = DataLoader(dataset=ds_testing,
                         batch_size=32)

max samples 656
max windows 16
ALL DATA SHAPE
6489
(6489, 64, 656)
(6489, 41984)
Scaler fitted: Min=[-9.36448330e-05 -3.11927451e-04 -3.30637775e-04 ... -1.58537296e-04
 -1.33784116e-04 -1.99436679e-04], Max=[0.00012621 0.00041959 0.00043925 ... 0.00019207 0.00021472 0.00024047]
Unique labels and counts:
Label 0: 2163 occurrences
Label 1: 2163 occurrences
Label 2: 2163 occurrences
(6489, 4)


## Channel selection

In [77]:
from scipy.stats import entropy

def compute_channel_entropy(data_loader):
    channel_entropies = []
    total_trials = 0
    
    # Iterate through the DataLoader
    for windows, original_length, label in data_loader:
        # Assuming `windows` has shape (batch_size, channels, samples)
        print(windows.shape)
        for trial in windows:  # Iterate over trials in the batch
            trial_entropies = []
            for channel_data in trial:  # Iterate over channels in the trial
                # Compute entropy for the channel
                hist, _ = np.histogram(channel_data.numpy(), bins=50, density=True)
                channel_entropy = entropy(hist)
                trial_entropies.append(channel_entropy)
            
            # Accumulate trial entropies
            if len(channel_entropies) == 0:
                channel_entropies = np.array(trial_entropies)
            else:
                channel_entropies += np.array(trial_entropies)
            total_trials += 1
    
    # Average the entropy values across all trials
    average_entropies = channel_entropies / total_trials
    return average_entropies

In [78]:
average_entropies = compute_channel_entropy(test_dl)

torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([32, 64, 656])
torch.Size([

In [79]:
channel_indices = np.argsort(average_entropies)[::-1]  # Descending order
top_8_channels = channel_indices[:8]  # Select top 8
print("Top 8 Channels (Indices):", top_8_channels)

Top 8 Channels (Indices): [10 17 18 11 16  9 50 51]


[17 14 19 50 12 15 20 52] --> scaled, rest included, fist runs
[51 52 18 20 50 58 17 53] --> unscaled, rest included, fist runs
[17 50 14 57 13  8 53 63] --> scaled, rest excluded, fist runs
[18 20 52 51 17 50 58 53] --> unscaled, rest excluded, fist runs
[53 63 11 20 16 54 50 13] --> scaled, rest excluded, imagined runs
[18 20 17 52 51 19 16 53] --> unscaled, rest excluded, imagined runs
[16 17 50 63 15 13 11 12] --> scaled, rest included, imagined runs
[41 63 13 40 51 54 61 14] --> scaled, rest excludedm real runs
[52 51 20 18 50 58 53 57] --> unscaled, rest excluded, real runs
[13 54 17 52 19 41  0 51] --> scaled, rest included, real runs


[ 7  9 26 23 62 17 21 43] --> scaled, rest included, imagined runs, filtered
[10 17 18 11 16  9 50 51] --> unscaled, rest included, imagined runs, filtered

In [80]:
average_entropies

array([3.36381565, 3.38061149, 3.39254282, 3.40340247, 3.393695  ,
       3.38290881, 3.36813049, 3.38165874, 3.39606927, 3.40766729,
       3.41598811, 3.41025329, 3.40360521, 3.38955349, 3.38283575,
       3.39965336, 3.40877822, 3.41181552, 3.41134504, 3.40272002,
       3.3979638 , 3.2116247 , 3.20088124, 3.20579235, 3.24310857,
       3.27684753, 3.3051548 , 3.28284909, 3.24451821, 3.29922404,
       3.32200044, 3.34409029, 3.35282105, 3.36277782, 3.34605726,
       3.34038397, 3.31804241, 3.28901222, 3.35298151, 3.35222523,
       3.36760301, 3.38404772, 3.32696348, 3.34909866, 3.35757745,
       3.38550132, 3.34955128, 3.37439484, 3.3891997 , 3.39881043,
       3.40595092, 3.40492887, 3.40171006, 3.39132499, 3.37721402,
       3.35582865, 3.37796192, 3.39540177, 3.39131519, 3.37176698,
       3.35024617, 3.35338831, 3.36084565, 3.28205678])