In [5]:
from SciServer import CasJobs, Files, Authentication
import sys
import os.path
import statistics
from matplotlib import pyplot as plt
import pandas
import numpy as np
import mne
from sklearn.utils import shuffle, class_weight #pip install --user sklearn

#BELOW is necessary since we are not currently running from project directory
#since we need to import libs from parent dir, need to add parent dir to path
project_path = '/home/idies/workspace/Storage/ncarey/persistent/PULSD/PsychoPy-pylsl-RSVP/'
if project_path not in sys.path:
    sys.path.append(project_path)

import importlib
EEGModels = importlib.import_module("arl-eegmodels.EEGModels")

saved_weights_dir = "/home/idies/workspace/Storage/ncarey/persistent/PULSD/PsychoPy-pylsl-RSVP/saved_weights"
saved_weights_filename = "40epoch_NoClassWeight.h5"
saved_weights_file = os.path.join(saved_weights_dir, saved_weights_filename)

In [2]:
class MNEDataWrapper:
    
    def loadSession(self, session_ID):
        print("Querying session {0} from CasJobs".format(session_ID))
        
        info = mne.create_info(self.channel_names, self.recorded_sample_freq, self.channel_types, self.montage)
        info['description'] = '16 channel EEG sessionID {0}'.format(session_ID)
        
        raw_query = "select * from session_eeg where session_ID = {0} order by timestamp".format(session_ID)
        raw_df = CasJobs.executeQuery(sql=raw_query, context=self.casjobs_context)

        raw_data = []
        for index in range(len(self.channel_names)):
            raw_data.append(raw_df[self.channel_names[index]].values)

        custom_raw = mne.io.RawArray(raw_data, info)
        
        # we do this query to get the data reading index at which the stims appear.  IE, instead of 
        # saying stim X was presented at time Y (as it is in the raw data), we want to 
        # say stim X appeared at data reading index Z
        stim_index_query = '''
            with stim_timestamps_index(index_value, timestamp) as (
            select count(*), stim_timestamps.timestamp from session_eeg, stim_timestamps 
            where session_eeg.session_ID = {0} and stim_timestamps.session_ID = {0} and session_eeg.timestamp < stim_timestamps.timestamp 
            group by stim_timestamps.timestamp
            )

            select stim_timestamps_index.index_value, stim_timestamps.stim_ID from stim_timestamps_index, stim_timestamps 
            where stim_timestamps.session_ID = {0} and stim_timestamps.timestamp = stim_timestamps_index.timestamp
            order by stim_timestamps_index.index_value'''.format(session_ID)

        stim_index_df = CasJobs.executeQuery(sql=stim_index_query, context=self.casjobs_context)

        stim_ind = stim_index_df['index_value'].values
        stim_ID = stim_index_df['stim_ID'].values

        events = []
        for i in range(len(stim_ind)):
            events.append([stim_ind[i]+1, 0, stim_ID[i]])
        
        epochs = mne.Epochs(raw=custom_raw, events=events, event_id=self.event_id_dict, tmin=self.epoch_tmin, tmax=self.epoch_tmax)

        # Now we load the epochs into their respective target and distractor arrays of epochs
        # More importantly, we downsample to 128Hz, which is the input sampling rate EEGNet is setup for
        
        #Downsample to 128Hz
        
        epochs.load_data()
        epochs_resampled = epochs.copy().resample(self.resample_rate, npad='auto')

        target_epochs = epochs_resampled[self.target_epoch_names]
        distract_epochs = epochs_resampled[self.distract_epoch_names]
        
        
        self.sessions[session_ID] = [target_epochs, distract_epochs]
        
        #return target_epochs, distract_epochs 
    
    
    def __init__(self):
        
        # MNE-specific information
        self.channel_names = ['F3', 'Fz', 'F4', 'T7', 'C3', 'Cz', 'C4', 'T8', 'Cp3', 'Cp4', 'P3', 'Pz', 'P4', 'PO7', 'PO8', 'Oz']
        self.channel_types = ['eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg']
        self.recorded_sample_freq = 512
        self.montage = 'standard_1005'
        self.target_epoch_names = ['t_01', 't_02', 't_03', 't_04']
        self.distract_epoch_names = ['d_01', 'd_02', 'd_03', 'd_04', 'd_05', 'd_06', 'd_07', 'd_08', 'd_09', 'd_10']
        self.event_id_dict = dict(t_04=0, t_03=1, t_02=2, t_01=3, d_10=4, d_09=5, d_08=6, d_07=7, d_06=8, d_05=9, d_04=10, d_03=11, d_02=12, d_01=13)
        self.epoch_tmin = 0
        self.epoch_tmax = 1
        self.resample_rate = 128 #desired sample freq in Hz for EEGNet input 
        
        self.casjobs_context = "MyDB"
        
        self.sessions = {}
        

In [None]:
MNEDataWrap = MNEDataWrapper()
MNEDataWrap.loadSession(2)
MNEDataWrap.loadSession(3)


In [3]:
class EEGNetWrapper:
    
    def loadSession(self, session_ID):
        print("Querying session {0} from CasJobs".format(session_ID))
        
        info = mne.create_info(self.channel_names, self.recorded_sample_freq, self.channel_types, self.montage)
        info['description'] = '16 channel EEG sessionID {0}'.format(session_ID)
        
        raw_query = "select * from session_eeg where session_ID = {0} order by timestamp".format(session_ID)
        raw_df = CasJobs.executeQuery(sql=raw_query, context=self.casjobs_context)

        raw_data = []
        for index in range(len(self.channel_names)):
            raw_data.append(raw_df[self.channel_names[index]].values)

        custom_raw = mne.io.RawArray(raw_data, info)
        
        # we do this query to get the data reading index at which the stims appear.  IE, instead of 
        # saying stim X was presented at time Y (as it is in the raw data), we want to 
        # say stim X appeared at data reading index Z
        stim_index_query = '''
            with stim_timestamps_index(index_value, timestamp) as (
            select count(*), stim_timestamps.timestamp from session_eeg, stim_timestamps 
            where session_eeg.session_ID = {0} and stim_timestamps.session_ID = {0} and session_eeg.timestamp < stim_timestamps.timestamp 
            group by stim_timestamps.timestamp
            )

            select stim_timestamps_index.index_value, stim_timestamps.stim_ID from stim_timestamps_index, stim_timestamps 
            where stim_timestamps.session_ID = {0} and stim_timestamps.timestamp = stim_timestamps_index.timestamp
            order by stim_timestamps_index.index_value'''.format(session_ID)

        stim_index_df = CasJobs.executeQuery(sql=stim_index_query, context=self.casjobs_context)

        stim_ind = stim_index_df['index_value'].values
        stim_ID = stim_index_df['stim_ID'].values

        events = []
        for i in range(len(stim_ind)):
            events.append([stim_ind[i]+1, 0, stim_ID[i]])
        
        epochs = mne.Epochs(raw=custom_raw, events=events, event_id=self.event_id_dict, tmin=self.epoch_tmin, tmax=self.epoch_tmax)

        # Now we load the epochs into their respective target and distractor arrays of epochs
        # More importantly, we downsample to 128Hz, which is the input sampling rate EEGNet is setup for
        
        #Downsample to 128Hz
        
        epochs.load_data()
        epochs_resampled = epochs.copy().resample(self.resample_rate, npad='auto')

        target_epochs = epochs_resampled[self.target_epoch_names]
        distract_epochs = epochs_resampled[self.distract_epoch_names]
        
        return target_epochs, distract_epochs

            
    def loadTrainingData(self, session_ID):
        target_epochs, distract_epochs = self.loadSession(session_ID)
        
        target_data = target_epochs.get_data()  # len(target_epochs) epochs of 16 channels x 128 readings
        distract_data = distract_epochs.get_data()

        
        if len(self.training_data) == 0:
            self.training_data = np.array(target_data[0], ndmin=4)
            self.training_class = np.array([1,0], ndmin=2)
        else:
            self.training_data = np.append(self.training_data, np.array(target_data[0], ndmin=4), axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.training_class = np.append(self.training_class, cur_class, axis=0)

        for i in range(1, len(target_data)):
            cur_epoch = np.array(target_data[i], ndmin=4)
            self.training_data = np.append(self.training_data, cur_epoch, axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.training_class = np.append(self.training_class, cur_class, axis=0)
    
        for i in range(0, len(distract_data)):
            cur_epoch = np.array(distract_data[i], ndmin=4)
            self.training_data = np.append(self.training_data, cur_epoch, axis=0)
            cur_class = np.array([0,1], ndmin=2)
            self.training_class = np.append(self.training_class, cur_class, axis=0)
        
    def loadEvaluationData(self, session_ID):
        target_epochs, distract_epochs = self.loadSession(session_ID)
        
        target_data = target_epochs.get_data()  # len(target_epochs) epochs of 16 channels x 128 readings
        distract_data = distract_epochs.get_data()

        if len(self.eval_data) == 0:
            self.eval_data = np.array(target_data[0], ndmin=4)
            self.eval_class = np.array([1,0], ndmin=2)
        else:
            self.eval_data = np.append(self.eval_data, np.array(target_data[0], ndmin=4), axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.eval_class = np.append(self.eval_class, cur_class, axis=0)

        for i in range(1, len(target_data)):
            cur_epoch = np.array(target_data[i], ndmin=4)
            self.eval_data = np.append(self.eval_data, cur_epoch, axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.eval_class = np.append(self.eval_class, cur_class, axis=0)
    
        for i in range(0, len(distract_data)):
            cur_epoch = np.array(distract_data[i], ndmin=4)
            self.eval_data = np.append(self.eval_data, cur_epoch, axis=0)
            cur_class = np.array([0,1], ndmin=2)
            self.eval_class = np.append(self.eval_class, cur_class, axis=0)
            
    def load_saved_model(self, model_file_path):
        self.model.load_weights(filepath=model_file_path)
    
    def fit(self, training_iterations):
        self.fitted = self.model.fit(x=self.training_data, y=self.training_class, epochs=training_iterations) #validation_split=.2

    def predict(self, data_to_predict):
        print(self.model.predict(x=data_to_predict))
        
    def evaluate(self):
        pos_count = 0
        neg_count = 0
        true_pos_count = 0
        false_pos_count = 0
        true_neg_count = 0
        false_neg_count = 0
        
        for i in range(0, len(self.eval_data)):
            #self.predict(np.array(self.eval_data[i], ndmin=4))
            print("True Class: {0}".format(self.eval_class[i]))
            prediction = self.model.predict(x=np.array(self.eval_data[i], ndmin=4))
            print("Prediction: {0}".format(prediction))
            if self.eval_class[i][0] == 1: #positive
                pos_count = pos_count + 1
                if prediction[0][0] > prediction[0][1]: #True Positive
                    true_pos_count = true_pos_count + 1
                else:
                    false_neg_count = false_neg_count + 1
            
            else: #negative
                neg_count = neg_count + 1
                if prediction[0][0] < prediction[0][1]: #True Negative
                    true_neg_count = true_neg_count + 1
                else:
                    false_pos_count = false_pos_count + 1
            
        result = '''True Positives: {0}, True Negatives: {1}, False Positives: {2}, False Negatives: {3}'''.format(true_pos_count, true_neg_count, false_pos_count, false_neg_count)
        print(result)
        
            
    
    def __init__(self):
        
        self.model = EEGModels.EEGNet(nb_classes = 2, Chans=16, Samples=128)
        self.model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics=['accuracy'])
        
        # MNE-specific information
        self.channel_names = ['F3', 'Fz', 'F4', 'T7', 'C3', 'Cz', 'C4', 'T8', 'Cp3', 'Cp4', 'P3', 'Pz', 'P4', 'PO7', 'PO8', 'Oz']
        self.channel_types = ['eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg']
        self.recorded_sample_freq = 512
        self.montage = 'standard_1005'
        self.target_epoch_names = ['t_01', 't_02', 't_03', 't_04']
        self.distract_epoch_names = ['d_01', 'd_02', 'd_03', 'd_04', 'd_05', 'd_06', 'd_07', 'd_08', 'd_09', 'd_10']
        self.event_id_dict = dict(t_04=0, t_03=1, t_02=2, t_01=3, d_10=4, d_09=5, d_08=6, d_07=7, d_06=8, d_05=9, d_04=10, d_03=11, d_02=12, d_01=13)
        self.epoch_tmin = 0
        self.epoch_tmax = 1
        self.resample_rate = 128 #desired sample freq in Hz for EEGNet input 
        
        self.casjobs_context = "MyDB"
                
        self.training_data = []
        self.training_class = []
        
        self.eval_data = []
        self.eval_class = []
        
        

In [7]:
#This version loads an already trained model
EEGNetWrap = EEGNetWrapper()
EEGNetWrap.load_saved_model(saved_weights_file)
EEGNetWrap.loadEvaluationData(19)



Querying session 19 from CasJobs
Creating RawArray with float64 data, n_channels=16, n_times=92672
    Range : 0 ... 92671 =      0.000 ...   180.998 secs
Ready.
400 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
0 projection items activated
Loading data for 400 events and 513 original time points ...
0 bad epochs dropped


In [3]:
#This cell trains a model from scratch
EEGNetWrap = EEGNetWrapper()
for i in range(2, 19):  #Need to skip sessions 0, 1 as they are incomplete
    EEGNetWrap.loadTrainingData(i)  
    
EEGNetWrap.fit(training_iterations=40)
EEGNetWrap.loadEvaluationData(19)


Querying session 2 from CasJobs
Creating RawArray with float64 data, n_channels=16, n_times=81968
    Range : 0 ... 81967 =      0.000 ...   160.092 secs
Ready.
320 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
0 projection items activated
Loading data for 320 events and 513 original time points ...
0 bad epochs dropped
Querying session 3 from CasJobs
Creating RawArray with float64 data, n_channels=16, n_times=93728
    Range : 0 ... 93727 =      0.000 ...   183.061 secs
Ready.
400 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
0 projection items activated
Loading data for 400 events and 513 original time points ...
0 bad epochs dropped
Querying session 4 from CasJobs
Creating RawArray with float64 data, n_channels=16, n_times=93856
    Range : 0 ... 93855 =      0.000 ...   183.311 secs
Ready.
400 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
0 projection items activated

In [8]:
EEGNetWrap.evaluate()

True Class: [1 0]
Prediction: [[0.20218392 0.7978161 ]]
True Class: [1 0]
Prediction: [[0.65143794 0.34856206]]
True Class: [1 0]
Prediction: [[0.74664026 0.25335968]]
True Class: [1 0]
Prediction: [[0.32848957 0.67151046]]
True Class: [1 0]
Prediction: [[0.09994167 0.9000584 ]]
True Class: [1 0]
Prediction: [[0.9950696  0.00493042]]
True Class: [1 0]
Prediction: [[0.7340464  0.26595363]]
True Class: [1 0]
Prediction: [[0.9823909  0.01760914]]
True Class: [1 0]
Prediction: [[0.13196944 0.86803055]]
True Class: [1 0]
Prediction: [[0.9864336  0.01356633]]
True Class: [1 0]
Prediction: [[0.95843256 0.04156746]]
True Class: [1 0]
Prediction: [[0.89912915 0.10087088]]
True Class: [1 0]
Prediction: [[0.32657707 0.67342293]]
True Class: [1 0]
Prediction: [[0.9288347  0.07116531]]
True Class: [1 0]
Prediction: [[0.91529405 0.0847059 ]]
True Class: [1 0]
Prediction: [[0.667781   0.33221903]]
True Class: [1 0]
Prediction: [[0.9712711 0.0287289]]
True Class: [1 0]
Prediction: [[0.9573087  0.04269

Prediction: [[0.00108892 0.99891114]]
True Class: [0 1]
Prediction: [[0.03025448 0.9697455 ]]
True Class: [0 1]
Prediction: [[2.061697e-04 9.997938e-01]]
True Class: [0 1]
Prediction: [[0.00132777 0.99867225]]
True Class: [0 1]
Prediction: [[0.00202657 0.9979734 ]]
True Class: [0 1]
Prediction: [[0.0064565 0.9935435]]
True Class: [0 1]
Prediction: [[0.00261776 0.9973822 ]]
True Class: [0 1]
Prediction: [[0.09965518 0.9003448 ]]
True Class: [0 1]
Prediction: [[1.7001499e-04 9.9983001e-01]]
True Class: [0 1]
Prediction: [[8.3801727e-04 9.9916196e-01]]
True Class: [0 1]
Prediction: [[0.00271626 0.9972837 ]]
True Class: [0 1]
Prediction: [[0.00105753 0.99894243]]
True Class: [0 1]
Prediction: [[0.06155161 0.93844837]]
True Class: [0 1]
Prediction: [[5.425512e-04 9.994574e-01]]
True Class: [0 1]
Prediction: [[0.00143473 0.99856526]]
True Class: [0 1]
Prediction: [[0.01886921 0.9811307 ]]
True Class: [0 1]
Prediction: [[0.00168377 0.9983163 ]]
True Class: [0 1]
Prediction: [[0.00782532 0.992

Prediction: [[9.1452741e-05 9.9990857e-01]]
True Class: [0 1]
Prediction: [[6.8110065e-04 9.9931896e-01]]
True Class: [0 1]
Prediction: [[0.0045997  0.99540037]]
True Class: [0 1]
Prediction: [[0.00297414 0.9970259 ]]
True Class: [0 1]
Prediction: [[0.00697911 0.99302095]]
True Class: [0 1]
Prediction: [[0.00482283 0.99517715]]
True Class: [0 1]
Prediction: [[0.00185989 0.9981401 ]]
True Class: [0 1]
Prediction: [[0.001045 0.998955]]
True Class: [0 1]
Prediction: [[6.4864791e-05 9.9993515e-01]]
True Class: [0 1]
Prediction: [[0.00186732 0.99813265]]
True Class: [0 1]
Prediction: [[0.00624567 0.9937543 ]]
True Class: [0 1]
Prediction: [[3.8212354e-04 9.9961782e-01]]
True Class: [0 1]
Prediction: [[0.0786156 0.9213844]]
True Class: [0 1]
Prediction: [[0.01120272 0.98879725]]
True Class: [0 1]
Prediction: [[0.00267165 0.9973284 ]]
True Class: [0 1]
Prediction: [[4.6766203e-04 9.9953234e-01]]
True Class: [0 1]
Prediction: [[0.0228414  0.97715867]]
True Class: [0 1]
Prediction: [[0.002848 0

In [8]:
#Save the model!
EEGNetWrap.model.save_weights(filepath=saved_weights_file, overwrite=True, save_format='h5')

In [6]:
y_ints = [y.argmax() for y in EEGNetWrap.training_class]
class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(EEGNetWrap.training_class),
                                                 y_ints)
class_weights

array([5.93639576, 0.54598635])