In [1]:
import os
import torch
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DataLoader
from scipy.stats import entropy
from glob import glob
import numpy as np
import pyedflib
from helpers.utils import preprocess_data

DATASET_DIR = 'raw_data/eeg-motor-movement/'
dataset_config = {
    "meta_csv": os.path.join(DATASET_DIR, "eeg-motor-movement-metadata.csv"),
}

## EEG Dataset

In [2]:
def load_data(nr_of_subj=109, trial_type=1, chunk_data=True, chunks=8, base_folder=DATASET_DIR, sample_rate=160,
              samples=640, cpu_format=False, preprocessing=False, hp_freq=0.5, bp_low=2, bp_high=60, notch=False,
              hp_filter=False, bp_filter=False, artifact_removal=False, rms_feature=False):
    # Get file paths
    PATH = base_folder
    SUBS = glob(PATH + 'S[0-9]*')
    FNAMES = sorted([x[-4:] for x in SUBS])
    FNAMES = FNAMES[:nr_of_subj]

    # Remove the subjects with incorrectly annotated data that will be omitted from the final dataset
    subjects = ['S038', 'S088', 'S089', 'S092', 'S100', 'S104']
    try:
        for sub in subjects:
            FNAMES.remove(sub)
    except:
        pass

    # print("Using files:")
    # print(FNAMES)

    """
    @input - label (String)
            
    Helper method that converts trial labels into integer representations

    @output - data (Numpy array); target labels (Numpy array)
    """

    def convert_label_to_int(str):
        if str == 'T1':
            return 0
        if str == 'T2':
            return 1
        raise Exception("Invalid label %s" % str)

    """
    @input - data (array); number of chunks to divide the list into (int)
            
    Helper method that divides the input list into a given number of arrays

    @output - 2D array of divided input data
    """

    def divide_chunks(data, chunks):
        for i in range(0, len(data), chunks):
            yield data[i:i + chunks]


    executed_trials = '03,07,11'.split(',')
    imagined_trials = '04,08,12'.split(',')
    both_trials = executed_trials + imagined_trials
    samples_per_chunk = int(samples / chunks)


    # Determine the type of trials to be used
    # if trial_type == RunType.Executed:
    #     file_numbers = executed_trials
    # elif trial_type == RunType.Imagined:
    #     file_numbers = imagined_trials
    # elif trial_type == RunType.Combined:
    #     file_numbers = both_trials
    # else:
    #     raise Exception("Invalid trial type value %d" % trial_type)
    file_numbers = imagined_trials

    X = []
    y = []

    # Iterate over different subjects
    for subj in FNAMES:

        # Load the file names for given subject
        fnames = glob(os.path.join(PATH, subj, subj + 'R*.edf'))
        fnames = [name for name in fnames if name[-6:-4] in file_numbers]

        # Iterate over the trials for each subject
        for file_name in fnames:

            # Load the file
            # print("File name " + file_name)
            loaded_file = pyedflib.EdfReader(file_name)
            annotations = loaded_file.readAnnotations()
            times = annotations[0]
            durations = annotations[1]
            tasks = annotations[2]

            # Load the data signals into a buffer
            signals = loaded_file.signals_in_file
            # signal_labels = loaded_file.getSignalLabels()
            sigbufs = np.zeros((signals, loaded_file.getNSamples()[0]))
            for i in np.arange(signals):
                sigbufs[i, :] = loaded_file.readSignal(i)

            # initialize the result arrays with preferred shapes
            if chunk_data and not rms_feature:
                trial_data = np.zeros((15, 64, chunks, samples_per_chunk))
            elif chunk_data and rms_feature:
                trial_data = np.zeros((15, 64, chunks))
            else:
                trial_data = np.zeros((15, 64, samples))
            labels = []

            signal_start = 0
            k = 0

            # Iterate over tasks in the trial run
            for i in range(len(times)):
                # Collects only the 15 non-rest tasks in each run
                if k == 15:
                    break

                current_duration = durations[i]
                signal_end = signal_start + samples

                # Skipping tasks where the user was resting
                if tasks[i] == 'T0':
                    signal_start += int(sample_rate * current_duration)
                    continue

                # Iterate over each channel
                for j in range(len(sigbufs)):
                    channel_data = sigbufs[j][signal_start:signal_end]
                    if preprocessing:
                        channel_data = preprocess_data(channel_data, sample_rate=sample_rate, ac_freq=60,
                                                       hp_freq=hp_freq, bp_low=bp_low, bp_high=bp_high, notch=notch,
                                                       hp_filter=hp_filter, bp_filter=bp_filter,
                                                       artifact_removal=artifact_removal)
                    if chunk_data:
                        channel_data = list(divide_chunks(channel_data, samples_per_chunk))

                    if rms_feature:
                        channel_data = np.sqrt(np.mean(np.square(channel_data), axis=1))

                    # Add data for the current channel and task to the result
                    trial_data[k][j] = channel_data

                # add label(s) for the current task to the result
                if chunk_data:
                    # multiply the labels by the chunk size for chunked mode
                    labels.extend([convert_label_to_int(tasks[i])] * chunks)
                else:
                    labels.append(convert_label_to_int(tasks[i]))

                signal_start += int(sample_rate * current_duration)
                k += 1

            # Add labels and data for the current run into the final output numpy arrays
            y.extend(labels)
            if cpu_format:
                if chunk_data:
                    # (15, 64, 8, 80) => (15, 64, 80, 8) => (15, 8, 80, 64) => (120, 80, 64)
                    X.extend(trial_data.swapaxes(2, 3).swapaxes(1, 3).reshape((-1, samples_per_chunk, 64)))
                else:
                    # (15, 64, 640) => (15, 640, 64)
                    X.extend(trial_data.swapaxes(1, 2))
            else:
                if chunk_data and not rms_feature:
                    # (15, 64, 8, 80) => (15, 8, 64, 80) => (120, 64, 80)
                    X.extend(trial_data.swapaxes(1, 2).reshape((-1, 64, samples_per_chunk)))
                elif chunk_data and rms_feature:
                    # (15, 64, 8) => (15, 8, 64) => (120, 64)
                    X.extend(trial_data.swapaxes(1, 2).reshape((-1, 64)))
                else:
                    # (15, 64, 640)
                    X.extend(trial_data)

    # Shape the final output arrays to the correct format
    X = np.stack(X)
    y = np.array(y).reshape((-1, 1))

    return X, y

class SequenceDataset(TorchDataset):
    def __init__(self, X, y, rms_feature=False):
        self.X = X
        self.y = y
        self.rms_feature = rms_feature
        self.window_samples = 80 # 640/8
        self.window_samples = 4

    def get_windows_rms(self, eeg_data):
        # calculate number of windows for data length
        #num_samples = eeg_data.shape[1]


        # num_windows = (num_samples - self.window_samples) // self.step_samples + 1
        num_windows = 160
        # num_windows = 20
        rms_windows = []    
        for i in range(num_windows):
            # calculate start and end point for window
            start = i * self.window_samples
            end = start + self.window_samples
            
            # zero-pad window if it's longer than remaining data
            # print(start, end)
            
            segment = eeg_data[:, start:end]
            # make sure window has correct length
            # segment = segment[:, :self.window_samples]
            # Compute RMS for each channel in the window
            # print(segment.shape)
            rms_feature = np.sqrt(np.mean(np.square(segment.cpu().numpy()), axis=1))  
            rms_windows.append(rms_feature)
        windows_np = np.array(rms_windows) 
        
        # return windows tensor of consistent shape and true number of windows (without padding) 
        return torch.tensor(windows_np, dtype=torch.float32), num_windows

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x_data = torch.tensor(self.X[idx], dtype=torch.float32)
        # Make sure y is a single integer (no extra dim).
        y_data = torch.tensor(self.y[idx], dtype=torch.long).squeeze()

        # x_data = x_data.T
        selected_channels = [10, 17, 18, 11, 16, 9, 50, 51]
        if self.rms_feature:
            x_data, length = self.get_windows_rms(x_data)
            # if selected_channels:
            #     x_data = x_data[:, selected_channels]
            x_data = x_data.T    
        else:
            length = x_data.shape[1]
            # if selected_channels:
            #     x_data = x_data[selected_channels, :]

        # print("shape", x_data.shape)

        return x_data, length, y_data

def get_train_val_test_split(train_ratio=0.8, val_ratio=0.1, shuffle=True, random_seed=42, rms_feature=False):
    """
    Split arrays (X and y) into random train, val, test subsets.

    Args:
        X: numpy array or list-like of shape (N, ...)
        y: numpy array or list-like of shape (N,)
        train_ratio: float between (0, 1)
        val_ratio: float between (0, 1)
        shuffle: bool, whether to shuffle data before splitting
        random_seed: int or None, for reproducible output across multiple calls

    Returns:
        (X_train, y_train), (X_val, y_val), (X_test, y_test)
    """

    X, y = load_data(nr_of_subj=109, chunk_data=(not rms_feature), chunks=8, cpu_format=False,
                 preprocessing=(not rms_feature), hp_freq=0.5, bp_low=2, bp_high=60, notch=True,
                 hp_filter=False, bp_filter=True, artifact_removal=True)

    if random_seed is not None:
        np.random.seed(random_seed)
    
    # print("Start")
    N = len(X)
    
    indices = np.arange(N)

    if shuffle:
        np.random.shuffle(indices)

    # print("Get ends")
    train_end = int(N * train_ratio)
    val_end   = int(N * (train_ratio + val_ratio))
    # print("Indices")
    train_idx = indices[:train_end]
    val_idx   = indices[train_end:val_end]
    test_idx  = indices[val_end:]
    
    # print("Splits")
    X_train, y_train = X[train_idx], y[train_idx]
    X_val,   y_val   = X[val_idx],   y[val_idx]
    X_test,  y_test  = X[test_idx],  y[test_idx]
    # print("Datasets")
    train_dataset = SequenceDataset(X_train, y_train, rms_feature)
    val_dataset   = SequenceDataset(X_val,   y_val, rms_feature)
    test_dataset  = SequenceDataset(X_test,  y_test, rms_feature)
    full_dataset  = SequenceDataset(X, y, rms_feature)
    # print("Done")
    return full_dataset

In [9]:
ds = get_train_val_test_split(rms_feature=False)

## Channel selection

In [10]:
def compute_channel_entropy(data_loader):
    channel_entropies = []
    total_trials = 0
    
    # Iterate through the DataLoader
    for windows, original_length, label in data_loader:
        for trial in windows:  # Iterate over trials in the batch
            trial_entropies = []
            for channel_data in trial:  # Iterate over channels in the trial
                # Compute entropy for the channel
                hist, _ = np.histogram(channel_data.numpy(), bins=50, density=True)
                channel_entropy = entropy(hist)
                trial_entropies.append(channel_entropy)
            
            # Accumulate trial entropies
            if len(channel_entropies) == 0:
                channel_entropies = np.array(trial_entropies)
            else:
                channel_entropies += np.array(trial_entropies)
            total_trials += 1
    
    # Average the entropy values across all trials
    average_entropies = channel_entropies / total_trials
    return average_entropies

In [11]:
full_dl = DataLoader(dataset=ds,
                         batch_size=64)
average_entropies = compute_channel_entropy(full_dl)

In [12]:
print(average_entropies)

[3.38508116 3.39829385 3.40315092 3.40862522 3.40595538 3.40300463
 3.39557294 3.38514179 3.39563531 3.40290204 3.40627457 3.40520445
 3.40234082 3.39494077 3.38023233 3.39366868 3.3988124  3.40299252
 3.40333562 3.40110929 3.39640055 3.38378504 3.38695644 3.38784327
 3.38180831 3.38911086 3.40137265 3.39417319 3.38831157 3.38081684
 3.38735194 3.39765467 3.40260854 3.4050226  3.40385471 3.40173076
 3.39372801 3.38981061 3.38024442 3.39106833 3.37158099 3.38911584
 3.35392882 3.39477189 3.36488476 3.38785306 3.36088145 3.37502396
 3.38768698 3.39382957 3.39892653 3.39970062 3.39765516 3.39266575
 3.38526637 3.36339748 3.38065944 3.3927856  3.39213817 3.38229721
 3.3639379  3.36636424 3.37393801 3.31378357]


In [13]:
channel_indices = np.argsort(average_entropies)[::-1]  # Descending order
top_8_channels = channel_indices[:8]  # Select top 8
print("Top 8 Channels (Indices):", top_8_channels)

Top 8 Channels (Indices): [ 3 10  4 11 33 34 18  2]
