In [1]:
from SciServer import CasJobs, Files, Authentication
import sys
import os.path
import statistics
from matplotlib import pyplot as plt
import pandas
import numpy as np
import mne
from sklearn.utils import shuffle #pip install --user sklearn


#BELOW is necessary since we are not currently running from project directory
#since we need to import libs from parent dir, need to add parent dir to path
project_path = '/home/idies/workspace/Storage/ncarey/persistent/PULSD/PsychoPy-pylsl-RSVP/'
if project_path not in sys.path:
    sys.path.append(project_path)

import importlib
EEGModels = importlib.import_module("arl-eegmodels.EEGModels")

In [25]:
class EEGNetWrapper:
    
    def loadSession(self, session_ID):
        print("Querying session {0} from CasJobs".format(session_ID))
        
        info = mne.create_info(self.channel_names, self.recorded_sample_freq, self.channel_types, self.montage)
        info['description'] = '16 channel EEG sessionID {0}'.format(session_ID)
        
        raw_query = "select * from session_eeg where session_ID = {0} order by timestamp".format(session_ID)
        raw_df = CasJobs.executeQuery(sql=raw_query, context=self.casjobs_context)

        raw_data = []
        for index in range(len(self.channel_names)):
            raw_data.append(raw_df[self.channel_names[index]].values)

        custom_raw = mne.io.RawArray(raw_data, info)
        
        # we do this query to get the data reading index at which the stims appear.  IE, instead of 
        # saying stim X was presented at time Y (as it is in the raw data), we want to 
        # say stim X appeared at data reading index Z
        stim_index_query = '''
            with stim_timestamps_index(index_value, timestamp) as (
            select count(*), stim_timestamps.timestamp from session_eeg, stim_timestamps 
            where session_eeg.session_ID = {0} and stim_timestamps.session_ID = {0} and session_eeg.timestamp < stim_timestamps.timestamp 
            group by stim_timestamps.timestamp
            )

            select stim_timestamps_index.index_value, stim_timestamps.stim_ID from stim_timestamps_index, stim_timestamps 
            where stim_timestamps.session_ID = {0} and stim_timestamps.timestamp = stim_timestamps_index.timestamp
            order by stim_timestamps_index.index_value'''.format(session_ID)

        stim_index_df = CasJobs.executeQuery(sql=stim_index_query, context=self.casjobs_context)

        stim_ind = stim_index_df['index_value'].values
        stim_ID = stim_index_df['stim_ID'].values

        events = []
        for i in range(len(stim_ind)):
            events.append([stim_ind[i]+1, 0, stim_ID[i]])
        
        epochs = mne.Epochs(raw=custom_raw, events=events, event_id=self.event_id_dict, tmin=self.epoch_tmin, tmax=self.epoch_tmax)

        # Now we load the epochs into their respective target and distractor arrays of epochs
        # More importantly, we downsample to 128Hz, which is the input sampling rate EEGNet is setup for
        
        #Downsample to 128Hz
        
        epochs.load_data()
        epochs_resampled = epochs.copy().resample(self.resample_rate, npad='auto')

        target_epochs = epochs_resampled[self.target_epoch_names]
        distract_epochs = epochs_resampled[self.distract_epoch_names]
        
        return target_epochs, distract_epochs

            
    def loadTrainingData(self, session_ID):
        target_epochs, distract_epochs = self.loadSession(session_ID)
        
        target_data = target_epochs.get_data()  # len(target_epochs) epochs of 16 channels x 128 readings
        distract_data = distract_epochs.get_data()

        
        if len(self.training_data) == 0:
            self.training_data = np.array(target_data[0], ndmin=4)
            self.training_class = np.array([1,0], ndmin=2)
        else:
            self.training_data = np.append(self.training_data, np.array(target_data[0], ndmin=4), axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.training_class = np.append(self.training_class, cur_class, axis=0)

        for i in range(1, len(target_data)):
            cur_epoch = np.array(target_data[i], ndmin=4)
            self.training_data = np.append(self.training_data, cur_epoch, axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.training_class = np.append(self.training_class, cur_class, axis=0)
    
        for i in range(0, len(distract_data)):
            cur_epoch = np.array(distract_data[i], ndmin=4)
            self.training_data = np.append(self.training_data, cur_epoch, axis=0)
            cur_class = np.array([0,1], ndmin=2)
            self.training_class = np.append(self.training_class, cur_class, axis=0)
        
    def loadEvaluationData(self, session_ID):
        target_epochs, distract_epochs = self.loadSession(session_ID)
        
        target_data = target_epochs.get_data()  # len(target_epochs) epochs of 16 channels x 128 readings
        distract_data = distract_epochs.get_data()

        if len(self.eval_data) == 0:
            self.eval_data = np.array(target_data[0], ndmin=4)
            self.eval_class = np.array([1,0], ndmin=2)
        else:
            self.eval_data = np.append(self.eval_data, np.array(target_data[0], ndmin=4), axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.eval_class = np.append(self.eval_class, cur_class, axis=0)

        for i in range(1, len(target_data)):
            cur_epoch = np.array(target_data[i], ndmin=4)
            self.eval_data = np.append(self.eval_data, cur_epoch, axis=0)
            cur_class = np.array([1,0], ndmin=2)
            self.eval_class = np.append(self.eval_class, cur_class, axis=0)
    
        for i in range(0, len(distract_data)):
            cur_epoch = np.array(distract_data[i], ndmin=4)
            self.eval_data = np.append(self.eval_data, cur_epoch, axis=0)
            cur_class = np.array([0,1], ndmin=2)
            self.eval_class = np.append(self.eval_class, cur_class, axis=0)
            
            
    def fit(self, training_iterations):
        self.fitted = self.model.fit(x=self.training_data, y=self.training_class, epochs=training_iterations) #validation_split=.2

    def predict(self, data_to_predict):
        print(self.model.predict(x=data_to_predict))
        
    def evaluate(self):
        pos_count = 0
        neg_count = 0
        true_pos_count = 0
        false_pos_count = 0
        true_neg_count = 0
        false_neg_count = 0
        
        for i in range(0, len(self.eval_data)):
            #self.predict(np.array(self.eval_data[i], ndmin=4))
            print("True Class: {0}".format(self.eval_class[i]))
            prediction = self.model.predict(x=np.array(self.eval_data[i], ndmin=4))
            print("Prediction: {0}".format(prediction))
            if self.eval_class[i][0] == 1: #positive
                pos_count = pos_count + 1
                if prediction[0][0] > prediction[0][1]: #True Positive
                    true_pos_count = true_pos_count + 1
                else:
                    false_neg_count = false_neg_count + 1
            
            else: #negative
                neg_count = neg_count + 1
                if prediction[0][0] < prediction[0][1]: #True Negative
                    true_neg_count = true_neg_count + 1
                else:
                    false_pos_count = false_pos_count + 1
            
        result = '''True Positives: {0}, True Negatives: {1}, False Positives: {2}, False Negatives: {3}'''.format(true_pos_count, true_neg_count, false_pos_count, false_neg_count)
        print(result)
        
            
    
    def __init__(self):
        
        self.model = EEGModels.EEGNet(nb_classes = 2, Chans=16, Samples=128)
        self.model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics=['accuracy'])
        
        # MNE-specific information
        self.channel_names = ['F3', 'Fz', 'F4', 'T7', 'C3', 'Cz', 'C4', 'T8', 'Cp3', 'Cp4', 'P3', 'Pz', 'P4', 'PO7', 'PO8', 'Oz']
        self.channel_types = ['eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg','eeg']
        self.recorded_sample_freq = 512
        self.montage = 'standard_1005'
        self.target_epoch_names = ['t_01', 't_02', 't_03', 't_04']
        self.distract_epoch_names = ['d_01', 'd_02', 'd_03', 'd_04', 'd_05', 'd_06', 'd_07', 'd_08', 'd_09', 'd_10']
        self.event_id_dict = dict(t_04=0, t_03=1, t_02=2, t_01=3, d_10=4, d_09=5, d_08=6, d_07=7, d_06=8, d_05=9, d_04=10, d_03=11, d_02=12, d_01=13)
        self.epoch_tmin = 0
        self.epoch_tmax = 1
        self.resample_rate = 128 #desired sample freq in Hz for EEGNet input 
        
        self.casjobs_context = "MyDB"
                
        self.training_data = []
        self.training_class = []
        
        self.eval_data = []
        self.eval_class = []
        
        

In [26]:
EEGNetWrap = EEGNetWrapper()
for i in range(2, 19):  #Need to skip sessions 0, 1 as they are incomplete
    EEGNetWrap.loadTrainingData(i)  
    
EEGNetWrap.fit(training_iterations=30)
EEGNetWrap.loadEvaluationData(19)


Querying session 2 from CasJobs
Creating RawArray with float64 data, n_channels=16, n_times=81968
    Range : 0 ... 81967 =      0.000 ...   160.092 secs
Ready.
320 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
0 projection items activated
Loading data for 320 events and 513 original time points ...
0 bad epochs dropped
Querying session 3 from CasJobs
Creating RawArray with float64 data, n_channels=16, n_times=93728
    Range : 0 ... 93727 =      0.000 ...   183.061 secs
Ready.
400 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
0 projection items activated
Loading data for 400 events and 513 original time points ...
0 bad epochs dropped
Querying session 4 from CasJobs
Creating RawArray with float64 data, n_channels=16, n_times=93856
    Range : 0 ... 93855 =      0.000 ...   183.311 secs
Ready.
400 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
0 projection items activated

Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
Querying session 19 from CasJobs
Creating RawArray with float64 data, n_channels=16, n_times=92672
    Range : 0 ... 92671 =      0.000 ...   180.998 secs
Ready.
400 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
0 projection items activated
Loading data for 400 events and 513 original time points ...
0 bad epochs dropped


In [27]:
EEGNetWrap.evaluate()

True Class: [1 0]
Prediction: [[0.15294406 0.8470559 ]]
True Class: [1 0]
Prediction: [[0.5175245 0.4824755]]
True Class: [1 0]
Prediction: [[0.6166121 0.3833879]]
True Class: [1 0]
Prediction: [[0.28656048 0.71343946]]
True Class: [1 0]
Prediction: [[0.28004044 0.7199596 ]]
True Class: [1 0]
Prediction: [[0.9513215  0.04867853]]
True Class: [1 0]
Prediction: [[0.45058697 0.5494131 ]]
True Class: [1 0]
Prediction: [[0.92401445 0.07598552]]
True Class: [1 0]
Prediction: [[0.12993388 0.87006605]]
True Class: [1 0]
Prediction: [[0.9594573  0.04054275]]
True Class: [1 0]
Prediction: [[0.93278795 0.06721205]]
True Class: [1 0]
Prediction: [[0.67130375 0.32869628]]
True Class: [1 0]
Prediction: [[0.7053491  0.29465094]]
True Class: [1 0]
Prediction: [[0.86414844 0.13585153]]
True Class: [1 0]
Prediction: [[0.7823204  0.21767959]]
True Class: [1 0]
Prediction: [[0.4125423 0.5874577]]
True Class: [1 0]
Prediction: [[0.8773711  0.12262891]]
True Class: [1 0]
Prediction: [[0.78081906 0.21918094]

Prediction: [[6.2196283e-04 9.9937797e-01]]
True Class: [0 1]
Prediction: [[0.00733258 0.9926675 ]]
True Class: [0 1]
Prediction: [[1.8012139e-04 9.9981993e-01]]
True Class: [0 1]
Prediction: [[0.00141001 0.99858993]]
True Class: [0 1]
Prediction: [[8.9662225e-04 9.9910337e-01]]
True Class: [0 1]
Prediction: [[0.01936118 0.9806388 ]]
True Class: [0 1]
Prediction: [[0.03401517 0.96598476]]
True Class: [0 1]
Prediction: [[0.0108811  0.98911893]]
True Class: [0 1]
Prediction: [[0.00211077 0.9978892 ]]
True Class: [0 1]
Prediction: [[0.02877778 0.9712223 ]]
True Class: [0 1]
Prediction: [[0.00754818 0.99245185]]
True Class: [0 1]
Prediction: [[0.00148975 0.9985102 ]]
True Class: [0 1]
Prediction: [[0.01780681 0.98219323]]
True Class: [0 1]
Prediction: [[0.00284243 0.9971576 ]]
True Class: [0 1]
Prediction: [[0.01228696 0.987713  ]]
True Class: [0 1]
Prediction: [[0.00170033 0.99829966]]
True Class: [0 1]
Prediction: [[0.04167102 0.95832896]]
True Class: [0 1]
Prediction: [[0.00850801 0.991

Prediction: [[0.00413924 0.99586076]]
True Class: [0 1]
Prediction: [[0.00272848 0.9972715 ]]
True Class: [0 1]
Prediction: [[0.04098612 0.95901394]]
True Class: [0 1]
Prediction: [[0.00213579 0.9978642 ]]
True Class: [0 1]
Prediction: [[0.01140239 0.98859763]]
True Class: [0 1]
Prediction: [[6.439286e-04 9.993561e-01]]
True Class: [0 1]
Prediction: [[3.9264435e-04 9.9960738e-01]]
True Class: [0 1]
Prediction: [[0.02804354 0.97195643]]
True Class: [0 1]
Prediction: [[7.768476e-04 9.992231e-01]]
True Class: [0 1]
Prediction: [[0.08335675 0.9166432 ]]
True Class: [0 1]
Prediction: [[0.09159979 0.9084002 ]]
True Class: [0 1]
Prediction: [[0.08711988 0.9128801 ]]
True Class: [0 1]
Prediction: [[0.00323362 0.9967663 ]]
True Class: [0 1]
Prediction: [[0.01039461 0.9896054 ]]
True Class: [0 1]
Prediction: [[0.07163727 0.9283628 ]]
True Class: [0 1]
Prediction: [[0.02607418 0.9739258 ]]
True Class: [0 1]
Prediction: [[4.3344364e-04 9.9956650e-01]]
True Class: [0 1]
Prediction: [[0.00203722 0.9

In [29]:
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(EEGNetWrap.training_class),
                                                 EEGNetWrap.training_class)



TypeError: unhashable type: 'numpy.ndarray'

In [36]:
y_ints = [y.argmax() for y in EEGNetWrap.training_class]
class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(EEGNetWrap.training_class),
                                                 y_ints)
class_weights

array([5.93639576, 0.54598635])