### Import libs

In [13]:
from glob import glob
import pandas as pd
import os
import numpy as np

import json
from scipy import signal
import re
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.preprocessing import RobustScaler
from sklearn.preprocessing import QuantileTransformer
import numpy.random as npr
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math

### Helper functions

In [2]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = signal.butter(order, [low, high], btype='band')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = signal.filtfilt(b, a, data, axis=0)
    return y

def marker_fix(markers,frame_rate):
    out=[]
    for i in range(len(markers)-1):
        chunk=[markers[i]]
        chunk += [int(markers[i]*(1-a/frame_rate)+markers[i+1]*a/frame_rate) for a in range(1,frame_rate)]
        out+=chunk
    return np.array(out)


### example folderof files and channels

In [3]:
test_path = "partialdata"

In [4]:
ch_without_F = ['P7', 'P4', 'Cz', 'Pz', 'P3', 'P8', 'O1', 'O2', 'T8', 'C4',
                             'C3', 'T7', 'Oz', 'PO4',
                              'CP6', 'CP2', 'CP1', 'CP5', 'PO3']

In [5]:
ch_without_FandT = ['P7', 'P4', 'Cz', 'Pz', 'P3', 'P8', 'O1', 'O2', 'C4',
                             'C3', 'Oz', 'PO4',
                              'CP6', 'CP2', 'CP1', 'CP5', 'PO3']

### Preprocess data

In [6]:
CHANNELS = ['P7', 'P4', 'Cz', 'Pz', 'P3', 'P8', 'O1', 'O2', 'T8', 'F8', 'C4',
                             'F4', 'Fp2', 'Fz', 'C3', 'F3', 'Fp1', 'T7', 'F7', 'Oz', 'PO4', 'FC6',
                             'FC2', 'AF4', 'CP6', 'CP2', 'CP1', 'CP5', 'FC1', 'FC5', 'AF3', 'PO3']
def load_data(folder_path, channels = CHANNELS, shift = 50, lf=10, hf=100):
    
    # load the filenames in eeg_files
    eeg_files = []
    log_files = []
    inf_files = []
    for root, subdirs, files in os.walk(folder_path):
        #print(files)
        for f in files:
            if f.endswith('.easy'):
                eeg_files.append(os.path.join(root,f))
            if f.endswith('.json'):
                log_files.append(os.path.join(root,f))
            if f.endswith('.info'):
                inf_files.append(os.path.join(root,f))
    print(f"{len(eeg_files)} experiments selected.")
    # process channel indeces:
    channel_idxs = [CHANNELS.index(ch) for ch in channels]

    # load csv files into dataframes
    outs = []
    for i,file in enumerate(tqdm(eeg_files)):
        # read the easy file with pandas #########################
        df = pd.read_csv(file, 
                        delimiter = '\t',
                        engine = 'c',
                        header = None,
                        index_col=None)
        
        #process markers #########################################
        raw_markers = df.loc[df[35]>0][35].values
        raw_marker_idx = df.loc[df[35]>0][35].index.values
        start_marker_loc = np.where(raw_markers==1)[0][-1]
        start_marker_idx = raw_marker_idx[start_marker_loc]
        end_marker_loc = np.where(raw_markers==1200)[0][-1]
        end_marker_idx = raw_marker_idx[end_marker_loc]
        raw_marker_idx = raw_marker_idx[start_marker_loc:end_marker_loc+1]

        #marker indices after selecting the data_range
        marker_idx = raw_marker_idx-start_marker_idx
        out = {'markers':marker_idx}
        # process EEG ############################################
        # select the channels and shift and take the data between the begining and the end marker
        raw_eeg = np.roll(np.array(df[channel_idxs]),
                          shift=shift,axis=0)[start_marker_idx:end_marker_idx+1]

        # read the sampling rate
        with open(inf_files[i], 'r') as inf_file:
            inf = inf_file.read()
            inf_lines = inf.split('\n')
            sampling_rate = float(re.findall("\d+\.\d+", inf_lines[18])[0])
        
        # filter eeg
        eeg = butter_bandpass_filter(raw_eeg,lf,hf,sampling_rate,order=5)
        #eeg = RobustScaler(quantile_range=(25, 75)).fit_transform(eeg)
        #
        #eeg = QuantileTransformer(n_quantiles=1000,output_distribution='uniform').fit_transform(eeg)
        out['eeg'] = eeg
        # process embeddings ####################################
        with open(log_files[i],'r') as log_file:
            log = json.load(log_file)
        out['embedding'] = np.array(log['e'])[:-20]
        out['frame_rate'] = log['f']
        out['video_length'] = log['vl']
        out['truncation'] = log['t']
        out['z_speed'] = log['z']
        out['switch_len'] = log['sl'] 
        outs.append(out)
    print('process completed!')
    return outs    

In [7]:
# example usage
data = load_data(test_path, channels=ch_without_FandT, shift=40, lf=5, hf=128)

  0%|          | 0/3 [00:00<?, ?it/s]

3 experiments selected.


100%|██████████| 3/3 [00:14<00:00,  4.87s/it]

process completed!





In [8]:
# Pickling
import pickle
with open('prp_WFT_s60lf20hf80.pkl', 'wb') as pickle_out:  
    pickle.dump(data, pickle_out) 

In [9]:
with open('prp_WFT_s60lf20hf80.pkl', 'rb') as pickle_in:  
    data = pickle.load(pickle_in) 

### Prepare data to be used to make a pytorch data loader

In [10]:
def prepare_data(data, eeg_len, frames_len, overlap=0, condition=None, split=[0.8,0.2]):
    if condition !=None:
        check = True
        prop,val = condition
    else:
        check = False

    eeg_chunks = []
    embedding_chunks = []
    for d in tqdm(data):
        if check:
            if d[prop]!=val:
                continue
        d_len = d['eeg'].shape[0]
        markers = d['markers']
        all_frame_idx = marker_fix(markers,20)
        eof = False
        pointer = 0
        while(not eof):
            start = pointer
            end = pointer + eeg_len
            if end > d_len:
                eof = True
                continue
            # append the EEG chunk    
            eeg_chunk = d['eeg'][start:end]
            eeg_chunks.append(eeg_chunk)

            # append the embedding chunk 
            range_idx = np.arange(start,end+1)
            frames_eeg_idx = np.intersect1d(all_frame_idx,range_idx)
            frames_embedding_idx = np.where(np.isin(all_frame_idx,frames_eeg_idx))[0]
            raw_embedding_chunk = d['embedding'][frames_embedding_idx]
            embedding_chunk = signal.resample(raw_embedding_chunk,frames_len)
            embedding_chunks.append(embedding_chunk)
            pointer = end-overlap
    # splitting #####################################################
    sizes = np.dot(split, len(eeg_chunks))
    sizes_ints = [int(sz) for sz in sizes]
    test_inds = npr.choice(len(eeg_chunks), sizes_ints[1], replace=False)
    train_inds = np.setdiff1d(np.arange(len(eeg_chunks)), test_inds)
    npr.shuffle(train_inds)
    # finalizing
    train_eeg = np.stack(eeg_chunks)[train_inds]
    test_eeg = np.stack(eeg_chunks)[test_inds]
    train_embedding = np.stack(embedding_chunks)[train_inds]
    test_embedding = np.stack(embedding_chunks)[test_inds]
    return {'train': (train_eeg,train_embedding), 'test': (test_eeg,test_embedding)}
            

### example usage

In [11]:
dataset = prepare_data(data,200,3,overlap=160)

100%|██████████| 3/3 [00:20<00:00,  6.69s/it]


In [12]:
# Pickling
with open('finWFT_300_3_o350_s60lf20hf80.pkl', 'wb') as pickle_out:  
    pickle.dump(dataset, pickle_out,protocol=4)