In [1]:
pip install pyxdf


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyxdf
  Downloading pyxdf-1.16.3-py2.py3-none-any.whl (15 kB)
Installing collected packages: pyxdf
Successfully installed pyxdf-1.16.3


In [2]:
import pyxdf
import numpy as np
import matplotlib.pyplot as plt

In [3]:
no_of_files=12

all_xdfs = [pyxdf.load_xdf('/content/A'+str(f)+'.xdf')[0] for f in range(1,no_of_files+1)]

In [4]:
def xdf_extract(i,eventWindow,channel_length,eeg_channel,stim_channel):
    dataset_arr=np.empty((0,eventWindow,channel_length))
    start_time=float(i[eeg_channel]['footer']['info']['first_timestamp'][0])
    end_time=float(i[eeg_channel]['footer']['info']['last_timestamp'][0])
    no_samples=float(i[eeg_channel]['footer']['info']['sample_count'][0])
    Fs=i[eeg_channel]['info']['effective_srate']
    try:
        Ts=1/Fs
    except ZeroDivisionError:
        print('ZeroDivisionError')
        return
    time_stamps=(i[eeg_channel]['time_stamps']-start_time)/(end_time-start_time)*no_samples*Ts
    time_index=(i[eeg_channel]['time_stamps']-start_time)/(end_time-start_time)*no_samples
    first_marker_time=(float(i[stim_channel]['footer']['info']['first_timestamp'][0])-start_time)/(end_time-start_time)*no_samples*Ts
    last_marker_time=(float(i[stim_channel]['footer']['info']['last_timestamp'][0])-start_time)/(end_time-start_time)*no_samples*Ts
    markers=list()
    for j in i[stim_channel]['time_series']:
        markers.append(j[0])
    markers_time=list((i[stim_channel]['time_stamps']-start_time)/(end_time-start_time)*no_samples*Ts)
    markers_index=list((i[stim_channel]['time_stamps']-start_time)/(end_time-start_time)*no_samples)

    for e in range(1,len(markers)-1):
        try:
            data=(i[eeg_channel]['time_series'][(np.where(time_stamps>=markers_time[e])),:])[:,0:eventWindow,:]#[0:3,:]#[:,0:eventWindow,:].reshape(eventWindow,8).reshape(1,eventWindow,8)
            # print(data.shape)
            dataset_arr=np.vstack((dataset_arr,data))
        except ValueError as v:
            print(data.shape,'ValueError',v)
            return
    return dataset_arr,markers[1:-1]

In [5]:
eeg_data = None
eeg_markers = None
for i,xdf in enumerate(all_xdfs[1:]):
    if xdf != None:
        if i == 0:
            eeg = xdf_extract(xdf, 250, 8, 0, 1)
            eeg_data = eeg[0]
            eeg_markers = eeg[1]
        else:
            eeg = xdf_extract(xdf, 250, 8, 0, 1)
            eeg_data = np.vstack((eeg_data,eeg[0]))
            eeg_markers = eeg_markers + eeg[1]
eeg_data = np.transpose(eeg_data, [0,2,1])
print(eeg_data.shape, len(eeg_markers))

(3945, 8, 250) 3945


In [6]:
pip install pywavelets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [7]:
import pywt as pywt

In [8]:
data = np.asarray(eeg_data, dtype=object)
labels = np.asarray(eeg_markers, dtype=object)

In [9]:
dataNew = data.reshape(3945, 250, 8)

In [10]:
dataNew.shape

(3945, 250, 8)

In [11]:
scales = range(1,250)
waveletname = 'morl'
train_size = 2500
test_size= 400
n = 500

train_data_cwt = np.ndarray(shape=(train_size, 249, 249, 8))
f = pywt.scale2frequency(waveletname, scales)/250

for ii in range(0,train_size):
    if ii % 1000 == 0:
        print(ii)
    for jj in range(0,8):
        signal = dataNew[ii, :, jj]
        coeff, freq = pywt.cwt(signal, scales, waveletname, 1)
        coeff_ = coeff[:,:249]
        train_data_cwt[jj, :, :, jj] = coeff_

test_data_cwt = np.ndarray(shape=(test_size, 249, 249, 8))

for ii in range(0,test_size):
    if ii % 100 == 0:
        print(ii)
    for jj in range(0,8):
        signal = dataNew[ii, :, jj]
        coeff, freq = pywt.cwt(signal, scales, waveletname, 1)
        coeff_ = coeff[:,:249]
        test_data_cwt[ii, :, :, jj] = coeff_

labels_train = list(map(lambda x: int(x) - 1, np.ndarray(labels.shape[0])))
labels_test = list(map(lambda x: int(x) - 1, np.ndarray(labels.shape[0])))

x_train = train_data_cwt
y_train = list(labels_train[:train_size])
x_test = test_data_cwt
y_test = list(labels_test[:test_size])

0
1000
2000
0
100
200
300


In [12]:
import keras
from keras.layers import Dense, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.models import Sequential
from keras.callbacks import History 
history = History()
 
img_x = 249
img_y = 249
img_z = 8
input_shape = (img_x, img_y, img_z)
 
# num_classes = 6
batch_size = 8
num_classes = 7
epochs = 3
 
x_train = x_train.astype('float32')
X_train = x_train[:-n]
X_val = x_train[-n:]

x_test = x_test.astype('float32')
 
y_train = keras.utils.to_categorical(y_train, num_classes)
Y_train = y_train[:-n]
Y_val = y_train[-n:]

y_test = keras.utils.to_categorical(y_test, num_classes)

In [13]:
model = Sequential()
model.add(Conv2D(32, kernel_size=(5, 5), strides=(1, 1),
                 activation='relu',
                 input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(64, (5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(10, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
 

In [15]:
model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adam(),
              metrics=['accuracy'])
 
 
model.fit(X_train, Y_train,
          batch_size=batch_size,
          epochs=2,
          verbose=1,
          validation_data=(x_test, y_test),
          callbacks=[history])
 
train_score = model.evaluate(x_train, y_train, verbose=0)
print('Train loss: {}, Train accuracy: {}'.format(train_score[0], train_score[1]))
test_score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss: {}, Test accuracy: {}'.format(test_score[0], test_score[1]))

Epoch 1/2
Epoch 2/2
Train loss: 0.0, Train accuracy: 1.0
Test loss: 0.0, Test accuracy: 1.0


In [17]:
pred = model.predict(X_val)



In [16]:
from sklearn.metrics import accuracy_score

In [19]:
# np.argmax(pred, axis=1), np.argmax(Y_val, axis=1)
accuracy_score(np.argmax(pred, axis=1), np.argmax(Y_val, axis=1))



1.0

In [22]:
model.save('/content/keras_CNN_model')



In [29]:
!zip -r model.zip /content/keras_CNN_model

  adding: content/keras_CNN_model/ (stored 0%)
  adding: content/keras_CNN_model/keras_metadata.pb (deflated 90%)
  adding: content/keras_CNN_model/saved_model.pb (deflated 88%)
  adding: content/keras_CNN_model/assets/ (stored 0%)
  adding: content/keras_CNN_model/variables/ (stored 0%)
  adding: content/keras_CNN_model/variables/variables.index (deflated 64%)
  adding: content/keras_CNN_model/variables/variables.data-00000-of-00001 (deflated 69%)


In [21]:
import pickle

In [23]:
filename = 'keras_CNN_picklefile'
pickle.dump(model, open(filename, "wb"))

