In [1]:
from SciServer import CasJobs, Files, Authentication
import sys
import os.path
import statistics
from matplotlib import pyplot as plt
import pandas
import numpy as np
import mne
from sklearn.utils import shuffle, class_weight #pip install --user sklearn

#BELOW is necessary since we are not currently running from project directory
#since we need to import libs from parent dir, need to add parent dir to path
project_path = '/home/idies/workspace/Storage/ncarey/persistent/PULSD/PsychoPy-pylsl-RSVP/'
if project_path not in sys.path:
    sys.path.append(project_path)

import importlib
EEGModels = importlib.import_module("arl-eegmodels.EEGModels")

saved_weights_dir = "/home/idies/workspace/Storage/ncarey/persistent/PULSD/PsychoPy-pylsl-RSVP/saved_weights"
saved_weights_filename = "40epoch_NoClassWeight.h5"
saved_weights_file = os.path.join(saved_weights_dir, saved_weights_filename)

In [None]:
class MNEDataWrapper:
    
    def loadSession(self, session_ID):
        print("Querying session {0} from CasJobs".format(session_ID))
        
        info = mne.create_info(self.channel_names, self.recorded_sample_freq, self.channel_types, self.montage)
        info['description'] = '16 channel EEG sessionID {0}'.format(session_ID)
        
        raw_query = "select * from session_eeg where session_ID = {0} order by timestamp".format(session_ID)
        raw_df = CasJobs.executeQuery(sql=raw_query, context=self.casjobs_context)

        raw_data = []
        for index in range(len(self.channel_names)):
            raw_data.append(raw_df[self.channel_names[index]].values)

        custom_raw = mne.io.RawArray(raw_data, info)
        
        # we do this query to get the data reading index at which the stims appear.  IE, instead of 
        # saying stim X was presented at time Y (as it is in the raw data), we want to 
        # say stim X appeared at data reading index Z
        stim_index_query = '''
            with stim_timestamps_index(index_value, timestamp) as (
            select count(*), stim_timestamps.timestamp from session_eeg, stim_timestamps 
            where session_eeg.session_ID = {0} and stim_timestamps.session_ID = {0} and session_eeg.timestamp < stim_timestamps.timestamp 
            group by stim_timestamps.timestamp
            )

            select stim_timestamps_index.index_value, stim_timestamps.stim_ID from stim_timestamps_index, stim_timestamps 
            where stim_timestamps.session_ID = {0} and stim_timestamps.timestamp = stim_timestamps_index.timestamp
            order by stim_timestamps_index.index_value'''.format(session_ID)

        stim_index_df = CasJobs.executeQuery(sql=stim_index_query, context=self.casjobs_context)

        stim_ind = stim_index_df['index_value'].values
        stim_ID = stim_index_df['stim_ID'].values

        events = []
        for i in range(len(stim_ind)):
            events.append([stim_ind[i]+1, 0, stim_ID[i]])
        
        epochs = mne.Epochs(raw=custom_raw, events=events, event_id=self.event_id_dict, tmin=self.epoch_tmin, tmax=self.epoch_tmax)

        # Now we load the epochs into their respective target and distractor arrays of epochs
        # More importantly, we downsample to 128Hz, which is the input sampling rate EEGNet is setup for
        
        #Downsample to 128Hz
        
        epochs.load_data()
        epochs_resampled = epochs.copy().resample(self.resample_rate, npad='auto')

        target_epochs = epochs_resampled[self.target_epoch_names]
        distract_epochs = epochs_resampled[self.distract_epoch_names]
        
        
        self.sessions[session_ID] = [target_epochs, distract_epochs]
        
        #return target_epochs, distract_epochs 
    
    
    def __init__(self):
        
        # MNE-specific information
        self.channel_names = ['F3', 'Fz', 'F4', 'T7', 'C3', 'Cz', 'C4', 'T8', 'Cp3', 'Cp4', 'P3', 'Pz', 'P4', 'PO7', 'PO8', 'Oz']
        self.channel_types = ['eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg']
        self.recorded_sample_freq = 512
        self.montage = 'standard_1005'
        self.target_epoch_names = ['t_01', 't_02', 't_03', 't_04']
        self.distract_epoch_names = ['d_01', 'd_02', 'd_03', 'd_04', 'd_05', 'd_06', 'd_07', 'd_08', 'd_09', 'd_10']
        self.event_id_dict = dict(t_04=0, t_03=1, t_02=2, t_01=3, d_10=4, d_09=5, d_08=6, d_07=7, d_06=8, d_05=9, d_04=10, d_03=11, d_02=12, d_01=13)
        self.epoch_tmin = 0
        self.epoch_tmax = 1
        self.resample_rate = 128 #desired sample freq in Hz for EEGNet input 
        
        self.casjobs_context = "MyDB"
        
        self.sessions = {}
        

In [None]:
MNEDataWrap = MNEDataWrapper()
MNEDataWrap.loadSession(2)
MNEDataWrap.loadSession(3)


In [6]:
class EEGNetWrapper:
    
    def loadSession(self, session_ID):
        print("Querying session {0} from CasJobs".format(session_ID))
        
        info = mne.create_info(self.channel_names, self.recorded_sample_freq, self.channel_types, self.montage)
        info['description'] = '16 channel EEG sessionID {0}'.format(session_ID)
        
        raw_query = "select * from session_eeg where session_ID = {0} order by timestamp".format(session_ID)
        raw_df = CasJobs.executeQuery(sql=raw_query, context=self.casjobs_context)

        raw_data = []
        for index in range(len(self.channel_names)):
            raw_data.append(raw_df[self.channel_names[index]].values)

        custom_raw = mne.io.RawArray(raw_data, info)
        
        # we do this query to get the data reading index at which the stims appear.  IE, instead of 
        # saying stim X was presented at time Y (as it is in the raw data), we want to 
        # say stim X appeared at data reading index Z
        stim_index_query = '''
            with stim_timestamps_index(index_value, timestamp) as (
            select count(*), stim_timestamps.timestamp from session_eeg, stim_timestamps 
            where session_eeg.session_ID = {0} and stim_timestamps.session_ID = {0} and session_eeg.timestamp < stim_timestamps.timestamp 
            group by stim_timestamps.timestamp
            )

            select stim_timestamps_index.index_value, stim_timestamps.stim_ID from stim_timestamps_index, stim_timestamps 
            where stim_timestamps.session_ID = {0} and stim_timestamps.timestamp = stim_timestamps_index.timestamp
            order by stim_timestamps_index.index_value'''.format(session_ID)

        stim_index_df = CasJobs.executeQuery(sql=stim_index_query, context=self.casjobs_context)

        stim_ind = stim_index_df['index_value'].values
        stim_ID = stim_index_df['stim_ID'].values

        events = []
        for i in range(len(stim_ind)):
            events.append([stim_ind[i]+1, 0, stim_ID[i]])
        
        epochs = mne.Epochs(raw=custom_raw, events=events, event_id=self.event_id_dict, tmin=self.epoch_tmin, tmax=self.epoch_tmax)

        # Now we load the epochs into their respective target and distractor arrays of epochs
        # More importantly, we downsample to 128Hz, which is the input sampling rate EEGNet is setup for
        
        #Downsample to 128Hz
        
        epochs.load_data()
        epochs_resampled = epochs.copy().resample(self.resample_rate, npad='auto')

        target_epochs = epochs_resampled[self.target_epoch_names]
        distract_epochs = epochs_resampled[self.distract_epoch_names]
        
        return target_epochs, distract_epochs

            
    def loadTrainingData(self, session_ID):
        target_epochs, distract_epochs = self.loadSession(session_ID)
        
        target_data = target_epochs.get_data()  # len(target_epochs) epochs of 16 channels x 128 readings
        distract_data = distract_epochs.get_data()

        
        if len(self.training_data) == 0:
            self.training_data = np.array(target_data[0], ndmin=4)
            self.training_class = np.array([1,0], ndmin=2)
        else:
            self.training_data = np.append(self.training_data, np.array(target_data[0], ndmin=4), axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.training_class = np.append(self.training_class, cur_class, axis=0)

        for i in range(1, len(target_data)):
            cur_epoch = np.array(target_data[i], ndmin=4)
            self.training_data = np.append(self.training_data, cur_epoch, axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.training_class = np.append(self.training_class, cur_class, axis=0)
    
        for i in range(0, len(distract_data)):
            cur_epoch = np.array(distract_data[i], ndmin=4)
            self.training_data = np.append(self.training_data, cur_epoch, axis=0)
            cur_class = np.array([0,1], ndmin=2)
            self.training_class = np.append(self.training_class, cur_class, axis=0)
      
        y_ints = [y.argmax() for y in self.training_class]
        self.class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(self.training_class),
                                                 y_ints)
        
        print("Updated class weights to {0}".format(self.class_weights))
        
        
    def loadEvaluationData(self, session_ID):
        target_epochs, distract_epochs = self.loadSession(session_ID)
        
        target_data = target_epochs.get_data()  # len(target_epochs) epochs of 16 channels x 128 readings
        distract_data = distract_epochs.get_data()

        if len(self.eval_data) == 0:
            self.eval_data = np.array(target_data[0], ndmin=4)
            self.eval_class = np.array([1,0], ndmin=2)
        else:
            self.eval_data = np.append(self.eval_data, np.array(target_data[0], ndmin=4), axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.eval_class = np.append(self.eval_class, cur_class, axis=0)

        for i in range(1, len(target_data)):
            cur_epoch = np.array(target_data[i], ndmin=4)
            self.eval_data = np.append(self.eval_data, cur_epoch, axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.eval_class = np.append(self.eval_class, cur_class, axis=0)
    
        for i in range(0, len(distract_data)):
            cur_epoch = np.array(distract_data[i], ndmin=4)
            self.eval_data = np.append(self.eval_data, cur_epoch, axis=0)
            cur_class = np.array([0,1], ndmin=2)
            self.eval_class = np.append(self.eval_class, cur_class, axis=0)
           
    def loadTrainingDataFile(self, trainingDataPath, trainingClassPath):
        self.training_data = np.load(trainingDataPath)
        self.training_class = np.load(trainingClassPath)
        
        y_ints = [y.argmax() for y in self.training_class]
        self.class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(self.training_class),
                                                 y_ints)
        
        print("Updated class weights to {0}".format(self.class_weights))
    def loadEvaluationDataFile(self, evalDataPath, evalClassPath):
        self.eval_data = np.load(evalDataPath)
        self.eval_class = np.load(evalClassPath)
    
        
    def load_saved_model(self, model_file_path):
        self.model.load_weights(filepath=model_file_path)
    
    def fit(self, training_iterations):
        self.fitted = self.model.fit(x=self.training_data, y=self.training_class, epochs=training_iterations) #validation_split=.2

    def fit_with_class_weights(self, training_iterations):
        self.fitted = self.model.fit(x=self.training_data, y=self.training_class, epochs=training_iterations, class_weight=self.class_weights) #validation_split=.2

        
    def predict(self, data_to_predict):
        print(self.model.predict(x=data_to_predict))
        
    def evaluate(self):
        pos_count = 0
        neg_count = 0
        true_pos_count = 0
        false_pos_count = 0
        true_neg_count = 0
        false_neg_count = 0
        
        for i in range(0, len(self.eval_data)):
            #self.predict(np.array(self.eval_data[i], ndmin=4))
            print("True Class: {0}".format(self.eval_class[i]))
            prediction = self.model.predict(x=np.array(self.eval_data[i], ndmin=4))
            print("Prediction: {0}".format(prediction))
            if self.eval_class[i][0] == 1: #positive
                pos_count = pos_count + 1
                if prediction[0][0] > prediction[0][1]: #True Positive
                    true_pos_count = true_pos_count + 1
                else:
                    false_neg_count = false_neg_count + 1
            
            else: #negative
                neg_count = neg_count + 1
                if prediction[0][0] < prediction[0][1]: #True Negative
                    true_neg_count = true_neg_count + 1
                else:
                    false_pos_count = false_pos_count + 1
            
        result = '''True Positives: {0}, True Negatives: {1}, False Positives: {2}, False Negatives: {3}'''.format(true_pos_count, true_neg_count, false_pos_count, false_neg_count)
        print(result)
        
            
    
    def __init__(self):
        
        self.model = EEGModels.EEGNet(nb_classes = 2, Chans=16, Samples=128)
        self.model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics=['accuracy'])
        
        # MNE-specific information
        self.channel_names = ['F3', 'Fz', 'F4', 'T7', 'C3', 'Cz', 'C4', 'T8', 'Cp3', 'Cp4', 'P3', 'Pz', 'P4', 'PO7', 'PO8', 'Oz']
        self.channel_types = ['eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg']
        self.recorded_sample_freq = 512
        self.montage = 'standard_1005'
        self.target_epoch_names = ['t_01', 't_02', 't_03', 't_04']
        self.distract_epoch_names = ['d_01', 'd_02', 'd_03', 'd_04', 'd_05', 'd_06', 'd_07', 'd_08', 'd_09', 'd_10']
        self.event_id_dict = dict(t_04=0, t_03=1, t_02=2, t_01=3, d_10=4, d_09=5, d_08=6, d_07=7, d_06=8, d_05=9, d_04=10, d_03=11, d_02=12, d_01=13)
        self.epoch_tmin = 0
        self.epoch_tmax = 1
        self.resample_rate = 128 #desired sample freq in Hz for EEGNet input 
        
        self.casjobs_context = "MyDB"
                
        self.training_data = []
        self.training_class = []
        self.class_weights = []
        self.eval_data = []
        self.eval_class = []
        
        

In [None]:
#This version loads an already trained model
EEGNetWrap = EEGNetWrapper()
EEGNetWrap.load_saved_model(saved_weights_file)
EEGNetWrap.loadEvaluationData(19)



In [7]:
#This cell trains a model from scratch
EEGNetWrap = EEGNetWrapper()

EEGNetWrap.loadTrainingDataFile("training_data.npy", "training_class.npy")  
EEGNetWrap.loadEvaluationDataFile("eval_data.npy", "eval_class.npy")

EEGNetWrap.fit_with_class_weights(training_iterations=40)


Updated class weights to [5.93639576 0.54598635]
Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


In [None]:
#This cell trains a model from scratch
EEGNetWrap = EEGNetWrapper()
for i in range(2, 19):  #Need to skip sessions 0, 1 as they are incomplete
    EEGNetWrap.loadTrainingData(i)  
    
EEGNetWrap.fit_with_class_weights(training_iterations=40)
EEGNetWrap.loadEvaluationData(19)


In [8]:
EEGNetWrap.evaluate()

True Class: [1 0]
Prediction: [[0.12386142 0.87613857]]
True Class: [1 0]
Prediction: [[0.38450643 0.6154936 ]]
True Class: [1 0]
Prediction: [[0.59664494 0.40335503]]
True Class: [1 0]
Prediction: [[0.35003835 0.64996165]]
True Class: [1 0]
Prediction: [[0.18326011 0.8167399 ]]
True Class: [1 0]
Prediction: [[0.96499    0.03501003]]
True Class: [1 0]
Prediction: [[0.17032318 0.82967687]]
True Class: [1 0]
Prediction: [[0.9438042  0.05619577]]
True Class: [1 0]
Prediction: [[0.21848415 0.78151584]]
True Class: [1 0]
Prediction: [[0.9753294  0.02467068]]
True Class: [1 0]
Prediction: [[0.91655177 0.08344821]]
True Class: [1 0]
Prediction: [[0.6169643 0.3830357]]
True Class: [1 0]
Prediction: [[0.81130034 0.18869968]]
True Class: [1 0]
Prediction: [[0.9158159  0.08418412]]
True Class: [1 0]
Prediction: [[0.94048005 0.05951995]]
True Class: [1 0]
Prediction: [[0.7300297 0.2699703]]
True Class: [1 0]
Prediction: [[0.93731064 0.06268933]]
True Class: [1 0]
Prediction: [[0.8187545  0.1812455

Prediction: [[0.03315176 0.9668482 ]]
True Class: [0 1]
Prediction: [[2.6132583e-04 9.9973863e-01]]
True Class: [0 1]
Prediction: [[0.00345761 0.9965424 ]]
True Class: [0 1]
Prediction: [[0.00174878 0.99825114]]
True Class: [0 1]
Prediction: [[0.00137775 0.9986223 ]]
True Class: [0 1]
Prediction: [[0.00128849 0.99871147]]
True Class: [0 1]
Prediction: [[0.26940876 0.73059124]]
True Class: [0 1]
Prediction: [[3.493136e-05 9.999651e-01]]
True Class: [0 1]
Prediction: [[6.3869270e-04 9.9936134e-01]]
True Class: [0 1]
Prediction: [[1.04829385e-04 9.99895215e-01]]
True Class: [0 1]
Prediction: [[9.834621e-04 9.990165e-01]]
True Class: [0 1]
Prediction: [[0.06998211 0.9300179 ]]
True Class: [0 1]
Prediction: [[4.0287466e-04 9.9959713e-01]]
True Class: [0 1]
Prediction: [[8.087634e-04 9.991912e-01]]
True Class: [0 1]
Prediction: [[0.0108462 0.9891538]]
True Class: [0 1]
Prediction: [[5.1371695e-04 9.9948621e-01]]
True Class: [0 1]
Prediction: [[0.00294281 0.9970572 ]]
True Class: [0 1]
Predic

Prediction: [[0.01919338 0.98080665]]
True Class: [0 1]
Prediction: [[0.00549869 0.99450135]]
True Class: [0 1]
Prediction: [[8.0152776e-04 9.9919850e-01]]
True Class: [0 1]
Prediction: [[0.00597132 0.9940287 ]]
True Class: [0 1]
Prediction: [[7.702172e-04 9.992298e-01]]
True Class: [0 1]
Prediction: [[0.00819967 0.99180037]]
True Class: [0 1]
Prediction: [[4.370760e-04 9.995629e-01]]
True Class: [0 1]
Prediction: [[0.00546577 0.99453425]]
True Class: [0 1]
Prediction: [[0.00126722 0.99873275]]
True Class: [0 1]
Prediction: [[2.5861018e-05 9.9997413e-01]]
True Class: [0 1]
Prediction: [[5.8052281e-04 9.9941945e-01]]
True Class: [0 1]
Prediction: [[0.14148308 0.8585169 ]]
True Class: [0 1]
Prediction: [[9.1559370e-05 9.9990845e-01]]
True Class: [0 1]
Prediction: [[8.198067e-05 9.999180e-01]]
True Class: [0 1]
Prediction: [[0.00498596 0.995014  ]]
True Class: [0 1]
Prediction: [[0.00706824 0.9929317 ]]
True Class: [0 1]
Prediction: [[6.1955396e-04 9.9938047e-01]]
True Class: [0 1]
Predic

In [10]:
#Save the model!
to_save_file = os.path.join(saved_weights_dir, "40Epoch_ClassWeights.h5")
EEGNetWrap.model.save_weights(filepath=to_save_file, overwrite=True, save_format='h5')

In [None]:
#EEGNetWrapWeights = EEGNetWrapper()
#EEGNetWrapWeights.load_saved_model(saved_weights_file)