In [1]:
import sys
import os.path
import statistics
from matplotlib import pyplot as plt
import pandas
from SciServer import CasJobs
import numpy as np
import mne

#BELOW is necessary since we are not currently running from project directory
#since we need to import libs from parent dir, need to add parent dir to path
project_path = '/home/idies/workspace/Storage/ncarey/persistent/PULSD/PsychoPy-pylsl-RSVP/'
if project_path not in sys.path:
    sys.path.append(project_path)

import importlib
EEGModels = importlib.import_module("arl-eegmodels.EEGModels")

In [2]:
# REMEMBER:     (2) 'image_data_format' = 'channels_first' in keras.json config
# If you get a negative dimension error or whatever, ensure the above and restart container


model = EEGModels.EEGNet(nb_classes = 2, Chans=16, Samples=128)
model.compile(loss = 'categorical_crossentropy', optimizer = 'adam')


In [3]:
#Data shape = (trials, kernels, channels, samples), which for the 
#        input layer, will be (trials, 1, channels, samples). 

#Lets get this data
session_ID=4
context="MyDB"

channel_names = ['F3', 'Fz', 'F4', 'T7', 'C3', 'Cz', 'C4', 'T8', 'Cp3', 'Cp4', 'P3', 'Pz', 'P4', 'PO7', 'PO8', 'Oz']
channel_types = ['eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg', 'eeg']
sfreq = 512
montage = 'standard_1005'
info = mne.create_info(channel_names, sfreq, channel_types, montage)
info['description'] = 'EEGNet test'

raw_query = "select * from session_eeg where session_ID = {0} order by timestamp".format(session_ID)
raw_df = CasJobs.executeQuery(sql=raw_query, context=context)

raw_data = []
for index in range(len(channel_names)):
    raw_data.append(raw_df[channel_names[index]].values)

custom_raw = mne.io.RawArray(raw_data, info)
print(custom_raw)


Creating RawArray with float64 data, n_channels=16, n_times=93856
    Range : 0 ... 93855 =      0.000 ...   183.311 secs
Ready.
<RawArray  |  None, n_channels x n_times : 16 x 93856 (183.3 sec), ~11.5 MB, data loaded>


In [4]:
#we do this query to get the data reading index at which the stims appear.  IE, instead of 
# saying stim X was presented at time Y (as it is in the raw data), we want to 
#say stim X appeared at data reading index Z
stim_index_query = '''
with stim_timestamps_index(index_value, timestamp) as (
select count(*), stim_timestamps.timestamp from session_eeg, stim_timestamps 
where session_eeg.session_ID = {0} and stim_timestamps.session_ID = {0} and session_eeg.timestamp < stim_timestamps.timestamp 
group by stim_timestamps.timestamp
)

select stim_timestamps_index.index_value, stim_timestamps.stim_ID from stim_timestamps_index, stim_timestamps 
where stim_timestamps.session_ID = {0} and stim_timestamps.timestamp = stim_timestamps_index.timestamp
  order by stim_timestamps_index.index_value
'''.format(session_ID)

stim_index_df = CasJobs.executeQuery(sql=stim_index_query, context=context)

In [5]:
stim_ind = stim_index_df['index_value'].values
stim_ID = stim_index_df['stim_ID'].values

events = []
for i in range(len(stim_ind)):
    events.append([stim_ind[i]+1, 0, stim_ID[i]])

In [6]:
event_id = dict(t_04=0, t_03=1, t_02=2, t_01=3, d_10=4, d_09=5, d_08=6, d_07=7, d_06=8, d_05=9, d_04=10, d_03=11, d_02=12, d_01=13)
epochs = mne.Epochs(raw=custom_raw, events=events, event_id=event_id, tmin=0, tmax=1)

400 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
0 projection items activated


In [7]:
#Now we load the epochs into their respective target and distractor arrays of epochs
# More importantly, we downsample to 128Hz, which is the input sampling rate EEGNet is setup for

t_epochs = epochs['t_01', 't_02', 't_03', 't_04']
t_epochs.load_data()
t_epochs_resampled = t_epochs.copy().resample(128, npad='auto')


d_epochs = epochs['d_01', 'd_02', 'd_03', 'd_04', 'd_05', 'd_06', 'd_07', 'd_08', 'd_09', 'd_10']
d_epochs.load_data()
d_epochs_resampled = d_epochs.copy().resample(128, npad='auto')



Loading data for 32 events and 513 original time points ...
0 bad epochs dropped
Loading data for 368 events and 513 original time points ...
0 bad epochs dropped


In [8]:
target_data = t_epochs_resampled.get_data()  #32 epochs of 16 channels x 128 readings
distract_data = d_epochs_resampled.get_data()


In [9]:
input_epochs = np.array(target_data[0], ndmin=4)
result = np.array([1,0], ndmin=2)

for i in range(1, len(target_data)):
    cur_epoch = np.array(target_data[i], ndmin=4)
    input_epochs = np.append(input_epochs, cur_epoch, axis=0)
    cur_result = np.array([1,0], ndmin=2)
    result = np.append(result, cur_result, axis=0)
#result = np.array([1,0])
#result = np.append(result, [0,1], axis=0)
result.shape
#input_epochs.shape


(32, 2)

In [10]:
fitted = model.fit(x=input_epochs, y=result, epochs=30)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
