In [287]:
import scipy.signal
import numpy as np
from pymatreader import read_mat
import os
from matplotlib.pyplot import plot
import resampy


In [288]:
# Create a list with all the dataset files
dataset_files = []
dataset_folder = 'dataset' # dataset folder to load data
for file in os.listdir(dataset_folder):
    dataset_files.append(os.path.join(dataset_folder, file)) # append relative path

In [289]:
def preprocess(file):
    mat = read_mat(file)
    bci_exp = mat['bciexp'] # reference to the bci exp data which are the only relevant data
    labels = bci_exp['label'] # all channels

    sampling_rate = bci_exp['srate']
    downsampling_fact = 5 # downsampling factor, 250 / 5 = 50 Hz
    bandpass = np.array([1, 12.5]) # cutoff frequencies for the bandpass filter
    interval_len = .75 # seconds

    # calculate bandpass filter coefficients
    butter_b, butter_a = scipy.signal.butter(N=2, Wn=bandpass / (sampling_rate / 2), btype='bandpass')

    channels_of_interest = ['O9', 'CP1', 'CP2', 'O10', 'P7', 'P3', 'Pz',
            'P4', 'P8', 'PO7', 'PO3', 'Oz', 'PO4', 'PO8']

    # get the index of each channel in labels array
    channel_indexes = np.array([labels.index(ch) for ch in channels_of_interest])
    lmast_channel_idx = np.char.equal('LMAST', labels)

    # number of samples per analysis window
    num_samples_per_window = int(interval_len * sampling_rate / downsampling_fact) - 1

    stimuli = np.array(bci_exp['stim'], dtype=np.double)
    num_stimuli = np.sum(np.diff( np.sum(stimuli[:, :, 0], axis=0)) > 0)

    eeg_data = bci_exp['data']
    num_trials = eeg_data.shape[2]
    num_channels = len(channel_indexes)
    data = np.zeros(shape=(num_channels, num_samples_per_window, num_stimuli, num_trials))
    model = np.zeros(shape=(num_samples_per_window, num_samples_per_window, num_stimuli, num_trials))

    # for trial in range(num_trials):
    for trial in range(1):
        rdat = eeg_data[channel_indexes, :, trial] - (eeg_data[lmast_channel_idx, :, trial] / 2)
        filtfilt_signal = scipy.signal.filtfilt(butter_b, butter_a, rdat, padtype='odd',
                                                padlen=3 * (max(len(butter_b),len(butter_a)) - 1))
        resampled_signal = scipy.signal.resample_poly(filtfilt_signal, 1, downsampling_fact, axis=1)
        

In [290]:
preprocess(dataset_files[0])

-8.149351658437611
