In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

In [None]:
!pip install mne
!pip install tensorflow_addons

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import RandomOverSampler 
import mne
from mne import io
import tensorflow as tf
import keras
from keras.layers import *
from keras.models import *
from tensorflow.keras import regularizers
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.models import load_model
from keras.callbacks import CSVLogger
import tensorflow_addons as tfa
import tensorflow_addons
from keras import backend as K

In [None]:
# ORIGINAL FILTERING
# 10/47 in previous
# NEW = 20/47

def filter_sequence(sequence, patient_data):
  # ISEK recommendations for surface EMG: high pass with 5 Hz cut off, low pass with 500 Hz cutoff.
  # https://www1.udel.edu/biology/rosewc/kaap686/notes/EMG%20analysis.pdf
  # 30 Hz was optimal high pass cutoff frequency 
  # High-pass filtering to remove electrocardiographic interference from torso EMG recordings PAPER


  lower_sampling = ['train025','train127', 'train163', 'train198', 'train203']

  lower_freq = 250

  higher_sampling = ['train177', 'train226', 'train178', 'train307', 'train256', 'train275',
                      'train291', 'train276', 'train353', 'train349', 'train357']
  higher_freq = 256

  truncated_name = patient_data[:8]

  notch_freqs = np.arange(50, 101, 50)

  if truncated_name in lower_sampling:
    # filtered_sequence = mne.filter.notch_filter(x = sequence, Fs = lower_freq, freqs = notch_freqs, method = 'iir', iir_params = None, verbose = 0)
    filtered_sequence =  mne.filter.filter_data(data = sequence, sfreq = lower_freq, l_freq = 10, h_freq = 47, method='iir', iir_params = None, verbose = 0)
    # dict(order=4, ftype='butter', output='sos')
    
    # print("250 Hz sampling rate")
    
  
  elif truncated_name in higher_sampling:
    
    #filtered_sequence = mne.filter.notch_filter(x = sequence, Fs = higher_freq, freqs = notch_freqs, method = 'iir', iir_params = None, verbose = 0)
    filtered_sequence =  mne.filter.filter_data(data = sequence, sfreq = higher_freq, l_freq = 10, h_freq = 47, method='iir', iir_params = None, verbose = 0)
    
    # print("256 Hz sampling rate")
    
  else: 
    pass

  return filtered_sequence


In [None]:
# calculate FAR and specificity - FAR = how many seizures will be decected that arent there in a time series,
# Speficiifty - ability to identify all the seizure events

def calculate_FAR_Specificity(original_seq, predicted_seq):

  # to calculate FAR - get both counts
  org_len = len(original_seq) # they are the same but anyway... for clarity
  org_count, org_start, org_end = find_seizures(original_seq, org_len)

  pred_len = len(predicted_seq) # they are the same but anyway... for clarity
  pred_count, pred_start, pred_end = find_seizures(predicted_seq, pred_len)

  # seizure counter:
  total_seizure_count_predicted = 0
  total_seizure_count_real = len(org_count)

  # using these counts - need to to first find if seizures in org are in predicted at least partially
  for i in range(0, len(org_count)): # iterate through all seizures found in original

    for j in range(0, len(pred_count)): # iterate through predicted seizure log
      # 4 cases - either wholly inside the actual seizure, encompass all the seizure and beyond, partially find on one side, partially on the other
      if (pred_start[j] <= org_start[i] and pred_end[j] >= org_end[i]) or (pred_start[j] <= org_start[i] and (org_start[i] <= pred_end[j] <= org_end[i])) or ((org_start[i] <= pred_start[j] <= org_end[i]) and pred_end[j] >= org_end[i]) or (org_start[i] <= pred_start[j] and pred_end[j] <= org_end[i]):
        total_seizure_count_predicted += 1
      
      else:
        continue 

  
  # print("predicted:", total_seizure_count_predicted)
  # print("real:", total_seizure_count_real)
  print("Total seizures predicted:", total_seizure_count_predicted)
  print("Total seizures real:", total_seizure_count_real)

  if total_seizure_count_real == 0:
    sensitivity = 1
  
  else:
    sensitivity = total_seizure_count_predicted/total_seizure_count_real
  
  # just extract number of false alarams
  pred_count_number = len(pred_count)

  far_count = pred_count_number - total_seizure_count_predicted

  # To calculate FAR - take the length of the recording in hours and calculate amount of false seizures/hour
  length_segments = len(original_seq)
  # each segment = 2 seconds, non overlapping! 
  # 1 hour = 60 mins = 3600 seconds
  seconds = length_segments * 8 # --> = 8 for 2000 length segments
  # seconds = length_segments*4 # --> = 4 for 1000 length segments
  # seconds = length_segments*2 # --> = 2 for 500 length segments
  minutes = seconds/60
  hours = minutes/60
  print("Recording length: {} hours".format(hours))

  far = far_count/hours
  
  print("The prediction sensitivity: ", sensitivity)
  print("The prediction False Alarm Rate FA/h: ", far )

  return far, sensitivity

In [None]:
def eliminate_isolates(sequence, threshold):
  # threshold sets the prediction cutoff - how many consequite seizures are required to escape elimination (seizures usually typical length etc.)
  seq_length = len(sequence)
  seq_count, seq_start, seq_end = find_seizures(sequence, seq_length)
  final_seq = sequence

  for k in range(0, len(seq_count)):
    if seq_count[k] < threshold:
  # alter the sequence inside the below threshold segment
      for l in range(seq_start[k], seq_end[k]+1):
        final_seq[l] = 0  
        
  return final_seq

In [None]:
def predictions_post_processing(real, predictions, seizure_count, start_seizure, stop_seizure, max_distance_filter, threshold):
  # max distance filter allows for tuning of the concatenation scope
  # function concatenates closeby seizure events, and discards isolated positive predictions as false
  # 10 distance scope = approx 20 seconds
  
  processed_predictions = predictions
  

  # concatenation portion of the function
  for i in range(0, len(seizure_count)-1):
    distance_to_next_seizure = start_seizure[i+1] - stop_seizure[i]
    if distance_to_next_seizure < max_distance_filter:
     
    # Fill the 0 gaps in between with 1s
     for j in range(stop_seizure[i], start_seizure[i+1], 1):
       processed_predictions[j] = 1
    
    else: 
      continue


  
  # pruning most probably incorrect predictions - isolated, not long enough etc
  # different function
  # threshold controls how majny required to escape cutoff
  final_prediction = eliminate_isolates(processed_predictions, threshold)
  padding = 25
  # final_prediction = pad_seizure_sites(final_prediction, padding)

  return final_prediction

In [None]:
def pad_seizure_sites(sequence, pad):

  seq_length = len(sequence)
  seq_count, seq_start, seq_end = find_seizures(sequence, seq_length)

  if len(seq_start)!=len(seq_end):
    seq_end.append(seq_length)
  
  else:
    pass

  for i in range(0, len(seq_count)):

    if seq_start[i] > pad:
      for k in range(seq_start[i]-pad, seq_start[i]):
        sequence[k] = 1

    else:
      pass

    if seq_end[i] < seq_length - pad:    
                                
      for j in range(seq_end[i], seq_end[i]+pad):
        sequence[j] = 1

    else:
      pass


  final_sequence = sequence

  return final_sequence

In [None]:
def find_seizures(binary_predictions, max_len):
  # take in the list
  # take max_seizure length - arbitrary - need to optimize 
  # 
  # initialize count and index start monitors
  counts = []
  index_at_seizure_start = []
  index_at_seizure_end = []
  length = 0

  while length < max_len:
    
    # initialiez count, switch and iterator j
    count = 1
    switch = True
    j = 1
  
    if binary_predictions[length] == 1:

      # count consecutive ones: 
      while switch == True:
        if length+j < max_len:

          if binary_predictions[length+j] == 1:
            count += 1
            j += 1
          elif binary_predictions[length+j] == 0:
            ending_j = length + j - 1
            index_at_seizure_end.append(ending_j)
            break
        
        else:
          break

      
      counts.append(count)
      index_at_seizure_start.append(length)
      # print(length)
      # print(j)

      length = length + j
      # print(length)
      # print(j)

    else:
      length +=1
  

  # If final seizure till the end:
  if len(index_at_seizure_start) == len(index_at_seizure_end):
    pass
  
  elif len(index_at_seizure_start) == len(index_at_seizure_end) + 1:
    index_at_seizure_end.append(max_len)

  else:
    print("ANOTHER INDEXING ISSUE WITH FINDING SEIZURES")
  
  return counts, index_at_seizure_start, index_at_seizure_end

In [None]:
def vectorized_stride(array, clearing_time_index, max_time, sub_window_size,
                         stride_size):
    start = clearing_time_index + 1 - sub_window_size + 1

    sub_windows = (
            start +
            np.expand_dims(np.arange(sub_window_size), 0) +
            # Create a rightmost vector as [0, V, 2V, ...].
            np.expand_dims(np.arange(max_time + 1, step=stride_size), 0).T
    )

    return array[sub_windows]

In [None]:
def segmentor_test(subject):

  # ros_test = RandomOverSampler(sampling_strategy=1.0, random_state=None)
  # test array placeholder
  testX = np.empty(shape=[0])
  # test label array placeholder
  testY = np.empty(shape=[0])

  for i in range(0, len(subject)):
    
    print("Current subject:", subject[i])
    subj_array = np.load('/content/drive/MyDrive/Colab Notebooks/patient_arrays/{}'.format(subject[i]))
    # Get data to single arrays
    x_subj = subj_array[:, 0]
    y_subj = subj_array[:, 1]
    # x_subj = filter_sequence(x_subj, subject[i])

    testX = np.concatenate((testX, x_subj), axis = 0)
    testY = np.concatenate((testY, y_subj), axis = 0)

  test_data_std = (testX - testX.mean()) / (testX.std())


  # extract length of recording
  length = np.size(test_data_std)
  recording_length = int(length - 1)

  # slice up training data
  x_test_sliced = vectorized_stride(test_data_std, 1, max_time=recording_length, sub_window_size=1000,
                                       stride_size=1000)
  # slice up the label data
  y_test_sliced = vectorized_stride(testY, 1, max_time=recording_length, sub_window_size=1000,
                                       stride_size=1000)
  
  # touch up the label array (tuple to single value)
  rows, columns = np.shape(y_test_sliced)

  y_data = np.zeros(rows)
  # iterate through rows:
  for i in range(0, rows):
      # iterate through row
      if any(y_test_sliced[i]) == 1:
        # print("Seizure")
        y_data[i] = 1
      else:
        continue

  return x_test_sliced, y_data.astype(int), y_test_sliced

In [None]:
def segmentor(sequence, name):
    # Get data to single arrays
  train_data_X = sequence[:, 0]
  label_data = sequence[:, 1]

  train_data_X = filter_sequence(train_data_X, name)

  # Standardize the data
  train_data_std = (train_data_X - train_data_X.mean()) / (train_data_X.std())
  
  if len(train_data_std) == len(label_data):
    pass
    # print("Lengths MATCH")
    
  elif len(train_data_std) != len(label_data):
    pass
    # print("ERROR IN PREPROCESSING RECORDING")


    # extract length of recording
  length = np.size(train_data_std)
  recording_length = int(length - 1)

    # slice up training data
  x_train_sliced = vectorized_stride(train_data_std, 1, max_time=recording_length, sub_window_size=1000,
                                      stride_size=250)
    # slice up the label data
  y_train_sliced = vectorized_stride(label_data, 1, max_time=recording_length, sub_window_size=1000,
                                      stride_size=250)

    # touch up the label array (tuple to single value)
  rows, columns = np.shape(y_train_sliced)

  y_data = np.zeros(rows)
    # iterate through rows:
  for i in range(0, rows):
        # iterate through row
    if any(y_train_sliced[i]) == 1:
            # print("Seizure")
      y_data[i] = 1
    else:
      continue


  return x_train_sliced, y_data.astype(int), y_train_sliced

In [None]:
def train_set_prep(subject_list):
  ros_x = 0 # determine the upsampling ratio wrt. to majority
  rus_x = 0 # determine donwsampling ratio wrt. to minority

  ros = RandomOverSampler(sampling_strategy=ros_x, random_state=None)
  rus = RandomUnderSampler(sampling_strategy=rus_x, random_state=None) 
  
  # train array placeholder
  X = np.empty(shape=[0, 1000])
  # label array placeholder
  Y = np.empty(shape=[0])

  # iterate through subjects for training:
  for i in range(0, len(subject_list)):

      subj_array = np.load('/content/drive/MyDrive/Colab Notebooks/patient_arrays/{}'.format(subject_list[i]))
      print("Current subject:", subject_list[i])
      x_subject, y_subject, _ = segmentor(subj_array, subject_list[i])

      x_over, y_over = ros.fit_resample(x_subject, y_subject)
      x_under, y_under = rus.fit_resample(x_over,y_over)
      #x_under, y_under = ros.fit_resample(x_subject, y_subject)

      X = np.concatenate((X, x_under), axis = 0)
      Y = np.concatenate((Y, y_under), axis = 0)

  # Check sample balance
  print("label counts seizure:", np.count_nonzero(Y == 1))      
  print("label counts non-seizure:", np.count_nonzero(Y == 0))      

  rows_train, columns_train = np.shape(X)
  X = X.reshape(rows_train, columns_train, 1)   
 
  return X, Y


In [None]:
all_subjects = ['train025_r2.npy','train163_r1.npy','train177_r4.npy','train178_r4.npy','train198_r2.npy','train203_r8.npy',
'train226_r6.npy','train256_r12.npy','train256_r14.npy','train275_r29.npy','train275_r31.npy','train276_r5.npy','train276_r8.npy',
'train276_r39.npy','train291_r15.npy','train291_r16.npy','train291_r21.npy','train291_r23.npy','train291_r25.npy','train291_r26.npy',
'train307_r1.npy','train349_r3.npy','train353_r1.npy','train353_r6.npy','train357_r45.npy','train357_r58.npy']

a = ['train025_r2.npy']
b = ['train163_r1.npy']
c = ['train177_r4.npy']
d = ['train178_r4.npy']
e = ['train198_r2.npy']
f = ['train203_r8.npy']
g = ['train226_r6.npy']
h = ['train256_r12.npy','train256_r14.npy']
i = ['train275_r29.npy','train275_r31.npy']
j = ['train276_r5.npy','train276_r8.npy','train276_r39.npy']
k = ['train291_r15.npy','train291_r16.npy','train291_r21.npy','train291_r23.npy','train291_r25.npy','train291_r26.npy']
l = ['train307_r1.npy']
m = ['train349_r3.npy']
n = ['train353_r1.npy','train353_r6.npy']
o = ['train357_r45.npy','train357_r58.npy']

all_targets = [a,b,c,d,e,f,g,h,i,j,k,l,m,n,o]

In [None]:
for i in range(0, len(all_targets)):

  all_subjects = ['train025_r2.npy','train163_r1.npy','train177_r4.npy','train178_r4.npy','train198_r2.npy','train203_r8.npy',
  'train226_r6.npy','train256_r12.npy','train256_r14.npy','train275_r29.npy','train275_r31.npy','train276_r5.npy','train276_r8.npy',
  'train276_r39.npy','train291_r15.npy','train291_r16.npy','train291_r21.npy','train291_r23.npy','train291_r25.npy','train291_r26.npy',
  'train307_r1.npy','train349_r3.npy','train353_r1.npy','train353_r6.npy','train357_r45.npy','train357_r58.npy']


  current_target = all_targets[i]
  training_subjects = all_subjects
  for j in range(0, len(current_target)):
    training_subjects.remove(current_target[j])
  
  
  print('training subjects: ', training_subjects)
  

training subjects:  ['train163_r1.npy', 'train177_r4.npy', 'train178_r4.npy', 'train198_r2.npy', 'train203_r8.npy', 'train226_r6.npy', 'train256_r12.npy', 'train256_r14.npy', 'train275_r29.npy', 'train275_r31.npy', 'train276_r5.npy', 'train276_r8.npy', 'train276_r39.npy', 'train291_r15.npy', 'train291_r16.npy', 'train291_r21.npy', 'train291_r23.npy', 'train291_r25.npy', 'train291_r26.npy', 'train307_r1.npy', 'train349_r3.npy', 'train353_r1.npy', 'train353_r6.npy', 'train357_r45.npy', 'train357_r58.npy']
training subjects:  ['train025_r2.npy', 'train177_r4.npy', 'train178_r4.npy', 'train198_r2.npy', 'train203_r8.npy', 'train226_r6.npy', 'train256_r12.npy', 'train256_r14.npy', 'train275_r29.npy', 'train275_r31.npy', 'train276_r5.npy', 'train276_r8.npy', 'train276_r39.npy', 'train291_r15.npy', 'train291_r16.npy', 'train291_r21.npy', 'train291_r23.npy', 'train291_r25.npy', 'train291_r26.npy', 'train307_r1.npy', 'train349_r3.npy', 'train353_r1.npy', 'train353_r6.npy', 'train357_r45.npy', 't

Code


Architecture assembly

In [None]:
fl = tfa.losses.SigmoidFocalCrossEntropy()

In [None]:
def focal_loss_custom(alpha, gamma):
   def binary_focal_loss(y_true, y_pred):
      fl = tfa.losses.SigmoidFocalCrossEntropy(alpha=alpha, gamma=gamma)
      y_true_K = K.ones_like(y_true)
      focal_loss = fl(y_true, y_pred)
      return focal_loss
   return binary_focal_loss

In [None]:
# The Hybrid 1D CNN-LSTM Model
model = keras.models.Sequential()

model.add(Conv1D(filters=32, kernel_size=40, strides = 1, input_shape=(1000, 1)))

model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(MaxPooling1D(pool_size=5, strides = 5))

model.add(Conv1D(filters=32, kernel_size=20, strides = 1))
model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(MaxPooling1D(pool_size=5, strides = 5))

model.add(Conv1D(filters=64, kernel_size=5, strides = 1,  activation='relu'))
model.add(BatchNormalization())
model.add(Activation('relu'))

model.add(MaxPooling1D(pool_size=5, strides = 1))

model.add(LSTM(64, return_sequences = True))
model.add(LSTM(64))

model.add(Dense(64, activation='relu', kernel_regularizer= regularizers.l2(0.001)))
model.add(Dense(1, activation = 'sigmoid')) # sigmoid here adjust
model.compile(optimizer='adam', loss=focal_loss_custom(alpha=0.2, gamma=2.0), metrics="binary_accuracy")

In [None]:
model.summary()

In [None]:
all_subjects = ['train025_r2.npy','train163_r1.npy','train177_r4.npy','train178_r4.npy','train198_r2.npy','train203_r8.npy',
'train226_r6.npy','train256_r12.npy','train256_r14.npy','train275_r29.npy','train275_r31.npy','train276_r5.npy','train276_r8.npy',
'train276_r39.npy','train291_r15.npy','train291_r16.npy','train291_r21.npy','train291_r23.npy','train291_r25.npy','train291_r26.npy',
'train307_r1.npy','train349_r3.npy','train353_r1.npy','train353_r6.npy','train357_r45.npy','train357_r58.npy']

a = ['train025_r2.npy'] # [X]
b = ['train163_r1.npy']
c = ['train177_r4.npy']
d = ['train178_r4.npy']
e = ['train198_r2.npy']
f = ['train203_r8.npy']
g = ['train226_r6.npy']
h = ['train256_r12.npy','train256_r14.npy']
i = ['train275_r29.npy','train275_r31.npy']
j = ['train276_r5.npy','train276_r8.npy','train276_r39.npy']
k = ['train291_r15.npy','train291_r16.npy','train291_r21.npy','train291_r23.npy','train291_r25.npy','train291_r26.npy']
l = ['train307_r1.npy']
m = ['train349_r3.npy']
n = ['train353_r1.npy','train353_r6.npy']
o = ['train357_r45.npy','train357_r58.npy']

all_targets = [a,b,c,d,e,f,g,h,i,j,k,l,m,n,o]

In [None]:
for i in range(0, len(all_targets)):

  all_subjects = ['train025_r2.npy','train163_r1.npy','train177_r4.npy','train178_r4.npy','train198_r2.npy','train203_r8.npy',
  'train226_r6.npy','train256_r12.npy','train256_r14.npy','train275_r29.npy','train275_r31.npy','train276_r5.npy','train276_r8.npy',
  'train276_r39.npy','train291_r15.npy','train291_r16.npy','train291_r21.npy','train291_r23.npy','train291_r25.npy','train291_r26.npy',
  'train307_r1.npy','train349_r3.npy','train353_r1.npy','train353_r6.npy','train357_r45.npy','train357_r58.npy']


  current_target = all_targets[i]
  training_subjects = all_subjects
  for j in range(0, len(current_target)):
    training_subjects.remove(current_target[j])
  
  
  print('training subjects: ', training_subjects)
  

Leave-one-patient-out cross-validation scheme. In each iteration, a patient's recordings are removed and the rest used to train the model. 

In [None]:
for i in range(0, len(all_targets)):

  all_subjects = ['train025_r2.npy','train163_r1.npy','train177_r4.npy','train178_r4.npy','train198_r2.npy','train203_r8.npy',
  'train226_r6.npy','train256_r12.npy','train256_r14.npy','train275_r29.npy','train275_r31.npy','train276_r5.npy','train276_r8.npy',
  'train276_r39.npy','train291_r15.npy','train291_r16.npy','train291_r21.npy','train291_r23.npy','train291_r25.npy','train291_r26.npy',
  'train307_r1.npy','train349_r3.npy','train353_r1.npy','train353_r6.npy','train357_r45.npy','train357_r58.npy']

  current_target = all_targets[i]
  training_subjects = all_subjects
  for j in range(0, len(current_target)):
    training_subjects.remove(current_target[j])
  
  print("Current target {}: {} ".format(i, current_target))
  print("Training subjects {}: {}".format(i, training_subjects))

  test_X = 0
  test_Y = 0
  test_X, test_Y, _, = segmentor_test(current_target)
  rows_test, columns_test = np.shape(test_X)
  test_X = test_X.reshape(rows_test, columns_test, 1)
  
  train_X = 0
  train_Y = 0
  train_X, train_Y = train_set_prep(training_subjects)

    # saving the history for model:

    # Patient string truncation
  n = 7
  m = 5
  model_name = current_target[0]
  model_name = model_name[m:m+3]
  # 227 repeat 3 SEEMS PROMISISNG!!!!
  # 227 repeat 6_larger -- detected but high false rate. 
  model_name = 'patient'+model_name+'_227_REPEAT_9_larger'
    
  print("------------------------------------------------{}------------------------------------------------".format(model_name))


  patient_ID = current_target[0]
  patient_ID = patient_ID[m:-n]
    
  patient_ID = 'Patient '+patient_ID
  patient_ID = patient_ID[:-1]
  log_path_name = model_name+'_training.log'
  csv_logger = CSVLogger('/content/drive/MyDrive/Colab Notebooks/patient_MODELS_226/GPU_Traininglogs/{}'.format(log_path_name), separator=',', append=False)

  optimal_checkpoint_model = model_name+'optimal_checkpoint'
  es = EarlyStopping(monitor='val_binary_accuracy', mode='max', verbose=1, patience=50)
  #es = EarlyStopping(monitor='val_binary_accuracy', mode='max', verbose=1, patience=30)
  '''
  mc = ModelCheckpoint('/content/drive/MyDrive/Colab Notebooks/patient_MODELS_GPU_6_altImbalanceStrategy/GPU_Checkpoints_6/{}'.format(optimal_checkpoint_model),
                      monitor='val_loss', mode='min', verbose=1, save_best_only=True)
  
  '''    

  mc = ModelCheckpoint('/content/drive/MyDrive/Colab Notebooks/patient_MODELS_226/GPU_Checkpoints/{}'.format(optimal_checkpoint_model),
                      monitor='val_binary_accuracy', mode='max', verbose=1, save_best_only=True)
  
  mc2 = ModelCheckpoint('/content/drive/MyDrive/Colab Notebooks/patient_MODELS_226/GPU_Checkpoints_loss/{}'.format(optimal_checkpoint_model),
                      monitor='val_loss', mode='min', verbose=1, save_best_only=True)
  
        


  history = model.fit(train_X, train_Y, epochs = 50, batch_size = 2048, validation_data=(test_X, test_Y), callbacks=[csv_logger, es, mc, mc2], verbose = 1)
  print("MODEL NAME: ", model_name)
  model.save('/content/drive/MyDrive/Colab Notebooks/patient_MODELS_226/{}'.format(model_name))

  test_y_pred = model.predict(test_X)
  plt.plot(test_y_pred)
  predictions = np.where(np.array(test_y_pred) >= 0.8,1, 0).tolist()
  plt.plot(predictions)
  flat_predictions =  [item for sublist in predictions for item in sublist]
  flat_predictions = [ int(x) for x in flat_predictions ]
  print("INITIAL FAR, SPEC,... ")
  far, spec = calculate_FAR_Specificity(test_Y, flat_predictions)


  print("Post_processed FAR, SPEC, ... ")
  list_len = len(flat_predictions)
  counts, start, end = find_seizures(flat_predictions, list_len)
  count_real, start_real, end_real = find_seizures(test_Y, list_len)   
  final_pred= predictions_post_processing(test_Y, flat_predictions, counts, start, end, 5, 2)
  count_final, start_final, end_final = find_seizures(final_pred, list_len)
  far, spec = calculate_FAR_Specificity(test_Y, final_pred)

  print("------------------------------------------END CYCLE--------------------------------------------------------")


In [None]:
# Cleans the seizure sites per observations in recordings - annotations done via EEG, do not exactly match the onset of the seizure int the EMG
# FUNCTION NOT USED IN FINAL MODEL TRAINING
def clean_seizure_sites(sequenceX, labelsY, patient_rec):

  length_rec = len(sequenceX)
  length_labels = len(labelsY)
  print(length_rec)

  if length_rec == length_labels:
    pass
    # print('MATCHING SEQ LENGTH CHECK GOOD')
  else:
    pass
    # print('SEQUENC LEN NOT MATCHING')


  if patient_rec == 'train025_r2.npy':
    TRIM = 11000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)
        
        break


  elif patient_rec == 'train163_r1.npy':
    # patient 163_r1
    # TRIM 7500
    TRIM = 7500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] == 0 and labelsY[i] == 1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break
    
  elif patient_rec == 'train177_r4.npy':
    # patient 177_r4
    # TRIM 20000
    TRIM = 20000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break
  
  elif patient_rec == 'train178_r4.npy':

    # patient 178_r4 
    # TRIM 8000
    TRIM = 8000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break
  
  elif patient_rec == 'train198_r2.npy':
    # patient 198_r2
    # TRIM 7500
    TRIM = 7500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train203_r8.npy':
    # patient 203_r8
    # TRIM 5000 from both 1 and 2
    TRIM = 5000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train226_r6.npy':
    # patient 226_r6
    # TRIM 1000
    TRIM = 1000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train256_r12.npy':
    # patient 256_r12
    # TRIM 7500
    TRIM = 7500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break

  elif patient_rec == 'train256_r14.npy':
    # patient 256_r14
    # TRIM 7500
    TRIM = 7500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train275_r29.npy':
    # patient 275_r29 --> INCLUDE
    # TRIM = 5000
    TRIM = 5000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train275_r31.npy':
    # patient 275_r31 --> INCLUDE
    # TRIM 12500
    TRIM = 12500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break



  elif patient_rec == 'train276_r5.npy':
    # patietn 276_r5
    # TRIM 25000
    TRIM = 25000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train276_r8.npy':
    # patient 276_r8
    # TRIM 10000
    TRIM = 10000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break

  elif patient_rec == 'train276_r39.npy':

    # patient 276_r39
    # TRIM 3000
    TRIM = 3000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break

  elif patient_rec == 'train291_r15.npy':
    # patient 291_r15
    # TRIM 14000
    TRIM = 14000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break
  
  elif patient_rec == 'train291_r16.npy':
    # patient 291_r16
    # TRIM 12000
    TRIM = 12000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break

  elif patient_rec == 'train291_r21.npy':
    # patient 291_r21
    # TRIM 12500
    TRIM = 12500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train291_r23.npy':

    # patient 291_r23
    # TRIM 12500
    TRIM = 12500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train291_r25.npy':
     # patient 291_r25
    # TRIM 12500
    TRIM = 12500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break

  elif patient_rec == 'train291_r26.npy':
    # patient 291_r26
    # TRIM 12500
    TRIM = 12500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train307_r1.npy':
    # patient 307_r1
    # TRIM 6000 first
    # TRIM 2500 second
    TRIM = 6000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break

  elif patient_rec == 'train349_r3.npy':
    # patient 349_r3
    # TRIM 7500
    TRIM = 7500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train353_r1.npy':
    # patient 353_r1
    # TRIM 12500
    TRIM = 12500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train353_r6.npy':

    # patient 353_r6
    # TRIM 6000
    TRIM = 6000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break


  elif patient_rec == 'train357_r45.npy':
    # patient 357_r45
    # TRIM 2500
    TRIM = 2500
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break

  elif patient_rec == 'train357_r58.npy':

    # patient 357_r58
    # TRIM 5000
    TRIM = 5000
    for i in range(0, length_labels):
      if labelsY[i] == 0:
        pass
      elif labelsY[i-1] ==0 and labelsY[i] ==1:
        print('Removing the first {} sample points from the seizure site'.format(TRIM))
        start_delete_row = i
        end_delete_row = i+TRIM
        deletion_range = np.arange(start_delete_row, end_delete_row, 1)
        labelsY = np.delete(labelsY, deletion_range, axis = 0)
        sequenceX = np.delete(sequenceX, deletion_range, axis = 0)

        break
  
  return sequenceX, labelsY