# Implementation

## Setting up Tensorflow
Restart runtime, set it to TPU and then run this chunk. The default is tf2.x which has slightly different functions. I prefer working with tf1.x

In [None]:
%tensorflow_version 1.x
import tensorflow as tf
print(tf.__version__)

TensorFlow 1.x selected.
1.15.2


## Loading Dataset
Loading the EEG dataset from physionet databases [1]. Physionet allows us to load dataset directly from cloud. 
**Stop as soon as 24th folder has been loaded, try not to go below 20 GB**

In [None]:
# !wget -r -N -c -np https://physionet.org/files/chbmit/1.0.0/
!gsutil -m cp -r gs://chbmit-1.0.0.physionet.org DESTINATION

Checking some files content.

In [None]:
#Just to see if files got loaded.
file = '/content/DESTINATION/chbmit-1.0.0.physionet.org/chb03/chb03-summary.txt'
f = open(file, 'r')
file_contents = f.read()ON
print(file_contents)

In [None]:
# only run this for the first time.
!pip install pyedflib

In [None]:
# Just to check if things are syncing in.
from pyedflib import highlevel
import numpy as np
import matplotlib.pyplot as plt

# read an edf file
file2 = '/content/DESTINATION/chbmit-1.0.0.physionet.org/chb02/chb02_16+.edf'
signals, signal_headers, header = highlevel.read_edf(file2)

dft = np.fft.fft(signals, axis=1)

plt.subplot(1,2,1)
plt.plot(np.transpose(signals[0:3,:1000]))
plt.title('Raw signals')
plt.subplot(1,2,2)
plt.plot(np.transpose(dft[0:3,:1000]))
plt.title('Fourier transform')
plt.show()

## Preproccessing
Some basic data preprocessing includes obtaining signals in frequency domain usinf fft and shaping the data to arrange as labels and training+test data. Run all of this. Very important chunk.

### First we will read all the ".edf" and ".txt" files in the directory and stack them.

In [None]:
import os

path = '/content/DESTINATION'

edfFiles = []
txtFiles = []
# r=root, d=directories, f = files
for r, d, f in os.walk(path):
    for file in f:
        if file[-4:] == '.edf':
            edfFiles.append(os.path.join(r, file))
        elif file[-4:] == '.txt':
            txtFiles.append(os.path.join(r, file))

edfFiles = sorted(edfFiles)
txtFiles = sorted(txtFiles)

for f in edfFiles:
    print(f)

for f in txtFiles:
    print(f)

### Reading EDF & TXT files and stacking them in batches. 

In [None]:
from pyedflib import highlevel
import numpy as np
import matplotlib.pyplot as plt
import re
import matplotlib.pyplot as plt

def generateLabels(edfFileName):
  sub = edfFileName[54:59]
  filePath = '/content/DESTINATION/chbmit-1.0.0.physionet.org/' + sub + '/' + sub + '-summary.txt'
  f = open(filePath, 'r')
  file_contents = f.read()

  file_list = file_contents.split('\n')
  sub = edfFileName[54:-4]
  sub = 'File Name: ' + sub + '.edf'
  ind = file_list.index(sub)

  seizures = list(map(int, re.findall(r'\d+', file_list[ind+3]) ))[0]
  start = []
  end   = []
  for i in range(seizures):
    start.append(list(map(int, re.findall(r'\d+', file_list[ind+2*i+4])))[0])
    end.append(list(map(int, re.findall(r'\d+', file_list[ind+2*i+5])))[0])
    # print(start, end)

  if seizures == 0:
    labels = np.zeros((3600))
  else:
    labels = np.ones((3600))
    labels[end[-1]:] *= 0
    for i in range(len(start)):
      labels[start[i]:end[i]] *= 2
  
  return labels


Shuffling and partitioning list.

In [None]:
import random

totalData       = len(edfFiles)
# random.shuffle(edfFiles)
partition       = int(len(edfFiles) * 2/3)
edfFilesVal     = edfFiles[partition:]
edfFilesTrain   = edfFiles[:partition]
trainData       = len(edfFilesTrain)
valData         = len(edfFilesVal)

print(totalData, trainData, valData)

Frequency Domain

In [None]:
from keras.utils import to_categorical

def stackDFTTrain(nbatch = 2):
  count = 0

  stackedDFT = np.zeros((1, 23, 256, 3))
  stackedLabels = np.zeros((1))
  rejected   = []

  while True:
    for f in edfFilesTrain:
      # print(f[54:-4])
      if stackedDFT.shape[0] >= nbatch*3600//3 + 1:
        print(stackedLabels.shape)
        if stackedDFT[1:nbatch*3600//3 + 1,:,:,:].shape == (3600*nbatch//3, 23, 256, 3) and to_categorical(stackedLabels[1:nbatch*3600//3 + 1], num_classes=3).shape == (3600*nbatch//3, 3):
          yield (stackedDFT[1:nbatch*3600//3 + 1,:,:,:],
                 to_categorical(stackedLabels[1:nbatch*3600//3 + 1], num_classes=3))
        stackedDFT = stackedDFT[nbatch*3600//3:,:,:,:]
        stackedLabels = stackedLabels[nbatch*3600:]
        print('extra', stackedDFT.shape, stackedLabels.shape)

      signals, signal_headers, header = highlevel.read_edf(f)
      if signals.shape[-1] % 3600 != 0 or signals.shape[0] != 23:
        rejected.append(f[54:59])
        continue
      
      # if signals.shape != (23, 921600):
      #   rejected.append(f[54:59])
      #   continue

      count += 1
      print(f, signals.shape)
      s = int(signals.shape[1]/256)
      signals = np.reshape(signals, (23,256,3,s//3))
      signals = signals.transpose(3,0,1,2)
      stackedDFT = np.concatenate((stackedDFT, np.fft.fft(signals, axis=1)), axis=0)
      genLabels = generateLabels(f)
      stackedLabels = np.concatenate((stackedLabels, genLabels), axis=-1)
    
    

def stackDFTVal(nbatch = 1):
  count = 0

  stackedDFT = np.zeros((1, 23, 256, 3))
  stackedLabels = np.zeros((1))
  rejected   = []
  
  while True:

    for f in edfFilesVal:
      # print(f[54:-4])
      if stackedDFT.shape[0] >= nbatch*3600//3 + 1:
        # stackedDFT = np.reshape(stackedDFT, (1,)+stackedDFT.shape)
        # stackedDFT = np.reshape(stackedDFT, (1,23,256,3600
        # yield (np.reshape(stackedDFT[1:nbatch*3600//3 + 1,:,:,:], 
        #                   (1,)+stackedDFT[1:nbatch*3600//3 + 1,:,:,:].shape),
        #        to_categorical(stackedLabels[1:nbatch*3600//3 + 1], num_classes=3))
        if stackedDFT[1:nbatch*3600//3 + 1,:,:,:].shape == (3600*nbatch//3, 23, 256, 3) and to_categorical(stackedLabels[1:nbatch*3600//3 + 1], num_classes=3).shape == (3600*nbatch//3, 3):
          yield (stackedDFT[1:nbatch*3600//3 + 1,:,:,:],
                 to_categorical(stackedLabels[1:nbatch*3600//3 + 1], num_classes=3))
        stackedDFT = stackedDFT[nbatch*3600//3:,:,:,:]
        stackedLabels = stackedLabels[nbatch*3600:]

      signals, signal_headers, header = highlevel.read_edf(f)
      if signals.shape[-1] % 3600 != 0 or signals.shape[0] != 23:
        rejected.append(f[54:59])
        continue

      # if signals.shape != (23, 921600):
      #   rejected.append(f[54:59])
      #   continue
      count += 1
      print(f, signals.shape)
      s = int(signals.shape[1]/256)
      signals = np.reshape(signals, (23,256,3,s//3))
      signals = signals.transpose(3,0,1,2)
      stackedDFT = np.concatenate((stackedDFT, np.fft.fft(signals, axis=1)), axis=0)
      genLabels = generateLabels(f)
      stackedLabels = np.concatenate((stackedLabels, genLabels), axis=-1)


 

# for h in stackDFTVal():
#   print(len(h), h[0].shape, h[1].shape)
#   break

In [None]:
# for i in range(73):
#   for h in stackDFTTrain():
#     print(len(h), h[0].shape, h[1].shape)

Time domain

In [None]:
from keras.utils import to_categorical

def stackTimeTrain(nbatch = 2):
  count = 0

  stackedDFT = np.zeros((1, 23, 256, 3))
  stackedLabels = np.zeros((1))
  rejected   = []

  while True:
    for f in edfFilesTrain:
      # print(f[54:-4])
      if stackedDFT.shape[0] >= nbatch*3600//3 + 1:
        print(stackedLabels.shape)
        if stackedDFT[1:nbatch*3600//3 + 1,:,:,:].shape == (3600*nbatch//3, 23, 256, 3) and to_categorical(stackedLabels[1:nbatch*3600//3 + 1], num_classes=3).shape == (3600*nbatch//3, 3):
          yield (stackedDFT[1:nbatch*3600//3 + 1,:,:,:],
                 to_categorical(stackedLabels[1:nbatch*3600//3 + 1], num_classes=3))
        stackedDFT = stackedDFT[nbatch*3600//3:,:,:,:]
        stackedLabels = stackedLabels[nbatch*3600:]
        print('extra', stackedDFT.shape, stackedLabels.shape)

      signals, signal_headers, header = highlevel.read_edf(f)
      if signals.shape[-1] % 3600 != 0 or signals.shape[0] != 23:
        rejected.append(f[54:59])
        continue
      
      # if signals.shape != (23, 921600):
      #   rejected.append(f[54:59])
      #   continue

      count += 1
      print(f, signals.shape)
      s = int(signals.shape[1]/256)
      signals = np.reshape(signals, (23,256,3,s//3))
      signals = signals.transpose(3,0,1,2)
      stackedDFT = np.concatenate((stackedDFT, signals), axis=0)
      genLabels = generateLabels(f)
      stackedLabels = np.concatenate((stackedLabels, genLabels), axis=-1)
    
    

def stackTimeVal(nbatch = 1):
  count = 0

  stackedDFT = np.zeros((1, 23, 256, 3))
  stackedLabels = np.zeros((1))
  rejected   = []

  while True:
    for f in edfFilesTrain:
      # print(f[54:-4])
      if stackedDFT.shape[0] >= nbatch*3600//3 + 1:
        print(stackedLabels.shape)
        if stackedDFT[1:nbatch*3600//3 + 1,:,:,:].shape == (3600*nbatch//3, 23, 256, 3) and to_categorical(stackedLabels[1:nbatch*3600//3 + 1], num_classes=3).shape == (3600*nbatch//3, 3):
          yield (stackedDFT[1:nbatch*3600//3 + 1,:,:,:],
                 to_categorical(stackedLabels[1:nbatch*3600//3 + 1], num_classes=3))
        stackedDFT = stackedDFT[nbatch*3600//3:,:,:,:]
        stackedLabels = stackedLabels[nbatch*3600:]
        print('extra', stackedDFT.shape, stackedLabels.shape)

      signals, signal_headers, header = highlevel.read_edf(f)
      if signals.shape[-1] % 3600 != 0 or signals.shape[0] != 23:
        rejected.append(f[54:59])
        continue
      
      # if signals.shape != (23, 921600):
      #   rejected.append(f[54:59])
      #   continue

      count += 1
      print(f, signals.shape)
      s = int(signals.shape[1]/256)
      signals = np.reshape(signals, (23,256,3,s//3))
      signals = signals.transpose(3,0,1,2)
      stackedDFT = np.concatenate((stackedDFT, signals), axis=0)
      genLabels = generateLabels(f)
      stackedLabels = np.concatenate((stackedLabels, genLabels), axis=-1)

## Callback Class
In order to stop the training at a given threshold.

In [None]:
from keras.callbacks import Callback

# when accuracy reaches ACCURACY_THRESHOLD
ACCURACY_THRESHOLD = 0.95

class myCallback(Callback):
	def on_epoch_end(self, epoch, logs={}):
		if(logs.get('acc') > ACCURACY_THRESHOLD):
			print("\nReached %2.2f%% accuracy, so stopping training!!" %(ACCURACY_THRESHOLD*100))
			self.model.stop_training = True

# Instantiate a callback object
callbacks = myCallback()

## Building a CNN for Frequency domain


Here we just import some libraries and use them to buid an architecture. 

In [None]:
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten, Dropout, AveragePooling2D
from keras.utils import to_categorical

nBatch = 2

in1 = Input(shape=(23, 256, 3))
c1 = Conv2D(16, (5,5), activation='relu')(in1)
# m1 = MaxPooling2D()(c1)
m1 = AveragePooling2D()(c1)
# c2 = Conv2D(128, (5,5), activation='relu')(m1)
# # m2 = MaxPooling2D()(c2)
# d1 = Dropout(0.7)(c2)
# c3 = Conv2D(256, (5,5), activation='relu')(d1)
# # m3 = MaxPooling2D()(c3)
# d2 = Dropout(0.8)(c3)
# # c4 = Conv2D(64, (3,3), activation='relu')(d2)
# # m4 = MaxPooling2D()(c4)
# fl = Flatten()(d2)
# d1 = Dense(64, activation='relu')(fl)
# d2 = Dense(16, activation='relu')(d1)
# o = Dense(921600*nBatch, activation='softmax')(m1)
fl = Flatten()(m1)
# d1 = Dense(4, activation='relu')(fl)
o = Dense(3, activation='sigmoid')(fl)

model = Model(inputs=in1, outputs=o)
print(model.summary())

Now after we have constructed our model let's train it. 

In [None]:
model.compile(optimizer = 'sgd', loss = 'categorical_crossentropy', metrics=['acc'])
# stepsTrain = int(len(edfFiles)/nBatch * 11/15)
# stepsVal = int(len(edfFiles)/nBatch * 4/15)
testSteps = int(trainData/(16*8))
valSteps = int(valData/16)
history_cnn = model.fit_generator(generator = stackDFTTrain(), 
                                  steps_per_epoch = testSteps, 
                                  epochs = 7, 
                                  validation_data = stackDFTVal(), 
                                  validation_steps = valSteps)
                                  # callbacks=[callbacks])

### Results

In [None]:
import matplotlib.pyplot as pyplot

loss = history_cnn.history['loss']
val_loss = history_cnn.history['val_loss']
epochs = range(1, len(loss) + 1)
pyplot.grid()
pyplot.plot(epochs, loss, '*y-', label='Training loss')
pyplot.plot(epochs, val_loss, '*r-', label='Validation loss')
pyplot.title('Training and validation loss')
pyplot.xlabel('Epochs')
pyplot.ylabel('Loss')
pyplot.legend()
pyplot.show()

pyplot.grid()
acc = history_cnn.history['acc']
val_acc = history_cnn.history['val_acc']
epochs = range(1, len(loss) + 1)
pyplot.plot(epochs, acc, '*y-', label='Training Accuracy')
pyplot.plot(epochs, val_acc, '*r-', label='Validation Accuracy')
pyplot.title('Training and validation Accuracies')
pyplot.xlabel('Epochs')
pyplot.ylabel('Loss')
pyplot.legend()
pyplot.show()

In [None]:
history_cnn.history

## CNN for Time domain

Repeating the process for Time domain, with the same parameters and architecture.

In [None]:
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten, Dropout, AveragePooling2D
from keras.utils import to_categorical

nBatch = 2

in1 = Input(shape=(23, 256, 3))
c1 = Conv2D(16, (5,5), activation='relu')(in1)
m1 = MaxPooling2D()(c1)
# m1 = AveragePooling2D()(c1)
# c2 = Conv2D(128, (5,5), activation='relu')(m1)
# # m2 = MaxPooling2D()(c2)
# d1 = Dropout(0.7)(c2)
# c3 = Conv2D(256, (5,5), activation='relu')(d1)
# # m3 = MaxPooling2D()(c3)
# d2 = Dropout(0.8)(c3)
# # c4 = Conv2D(64, (3,3), activation='relu')(d2)
# # m4 = MaxPooling2D()(c4)
# fl = Flatten()(d2)
# d1 = Dense(64, activation='relu')(fl)
# d2 = Dense(16, activation='relu')(d1)
# o = Dense(921600*nBatch, activation='softmax')(m1)
fl = Flatten()(m1)
# d1 = Dense(4, activation='relu')(fl)
o = Dense(3, activation='sigmoid')(fl)

model2 = Model(inputs=in1, outputs=o)
print(model2.summary())

In [None]:
model2.compile(optimizer = 'sgd', loss = 'categorical_crossentropy', metrics=['acc'])
# stepsTrain = int(len(edfFiles)/nBatch * 11/15)
# stepsVal = int(len(edfFiles)/nBatch * 4/15)
testSteps = int(trainData/(16*8))
valSteps = int(valData/8)
history_cnn2 = model2.fit_generator(generator = stackTimeTrain(),
                                    steps_per_epoch = testSteps, 
                                    epochs = 7, 
                                    validation_data = stackTimeVal(), 
                                    validation_steps = valSteps)
                                    # callbacks=[callbacks])

### Results

In [None]:
import matplotlib.pyplot as pyplot

loss2 = history_cnn2.history['loss']
val_loss2 = history_cnn2.history['val_loss']
epochs2 = range(1, len(loss2) + 1)
pyplot.grid()
pyplot.plot(epochs2, loss2, '*y-', label='Training loss')
pyplot.plot(epochs2, val_loss2, '*r-', label='Validation loss')
pyplot.title('Training and validation loss')
pyplot.xlabel('Epochs')
pyplot.ylabel('Loss')
pyplot.legend()
pyplot.show()

pyplot.grid()
acc2 = history_cnn2.history['acc']
val_acc2 = history_cnn2.history['val_acc']
epochs2 = range(1, len(loss2) + 1)
pyplot.plot(epochs2, acc2, '*y-', label='Training Accuracy')
pyplot.plot(epochs2, val_acc2, '*r-', label='Validation Accuracy')
pyplot.title('Training and validation Accuracies')
pyplot.xlabel('Epochs')
pyplot.ylabel('Loss')
pyplot.legend()
pyplot.show()

# Bibliography
[1] CHB-MIT Scalp EEG Database, Retrieved from: https://physionet.org/content/chbmit/1.0.0/
