In [1]:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pyedflib
import mne

import random 
import re
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import EarlyStopping


DB_PATH = './db/v1.5.1/'

In [2]:
seiz_types_path = DB_PATH + '_DOCS/seizures_types_v02.xlsx'
seiz_types = pd.read_excel(seiz_types_path)

seiz_types = seiz_types.set_index('Class Code')
display(seiz_types)

Unnamed: 0_level_0,Class No.,Event Name,Signs,Locality,Description
Class Code,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
,0,No Event,,,An unclassified event
SPSW,1,Spike/Sharp and Wave,Electrographic,All,"Spike and wave/complexes , sharp and wave/comp..."
GPED,2,Generalized Periodic Epileptiform Discharges,Electrographic,Generalized,Diffused periodic discharges
PLED,3,Periodic Lateralized Epileptiform Discharges,Electrographic,Hemispheric/Focal,Focal periodic discharges
EYBL,4,Eye blink,Clinical & Electrographic,Focal,"A specific type of sharp, high amplitude eye m..."
ARTF,5,Artifacts (All),Clinical & Electrographic,All,"Any non-brain activity electrical signal, such..."
BCKG,6,Background,Electrographic,,Baseline/non-interesting events
SEIZ,7,Seizure,Clinical &| Electrographic,All,Common seizure class which can include all typ...
FNSZ,8,Focal Non-Specific Seizure,Electrographic,Hemispheric/Focal,Focal seizures which cannot be specified with ...
GNSZ,9,Generalized Non-Specific Seizure,Electrographic,Generalized,Generalized seizures which cannot be further c...


In [3]:
seiz_info_path = DB_PATH +'_DOCS/seizures_v34r.xlsx'
train_info = pd.read_excel(seiz_info_path, 'train')

train_seiz_type = train_info.iloc[1:12,26:30]
train_seiz_type.columns = ['Class Code', 'Events', 'Freq.', 'Cum.']
train_seiz_type = train_seiz_type.set_index('Class Code')
train_seiz_type.join(seiz_types) 

Unnamed: 0_level_0,Events,Freq.,Cum.,Class No.,Event Name,Signs,Locality,Description
Class Code,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
FNSZ,1536,0.648101,0.648101,8.0,Focal Non-Specific Seizure,Electrographic,Hemispheric/Focal,Focal seizures which cannot be specified with ...
GNSZ,408,0.172152,0.820253,9.0,Generalized Non-Specific Seizure,Electrographic,Generalized,Generalized seizures which cannot be further c...
SPSZ,49,0.0206751,0.840928,10.0,Simple Partial Seizure,Clinical & Electrographic,All,Partial seizures during consciousness; Type sp...
CPSZ,277,0.116878,0.957806,11.0,Complex Partial Seizure,Clinical & Electrographic,All,Partial Seizures during unconsciousness; Type ...
ABSZ,50,0.021097,0.978903,12.0,Absence Seizure,Clinical & Electrographic,Generalized,Absence Discharges observed on EEG; patient lo...
TNSZ,18,0.00759494,0.986498,13.0,Tonic Seizure,Clinical & Electrographic,All,Stiffening of body during seizure (EEG effects...
CNSZ,0,0.0,0.986498,14.0,Clonic Seizure,Clinical & Electrographic,All,Jerking/shivering of body during seizure
TCSZ,30,0.0126582,0.999156,15.0,Tonic Clonic Seizure,Clinical & Electrographic,All,At first stiffening and then jerking of body (...
ATSZ,0,0.0,0.999156,16.0,Atonic Seizure,Clinical & Electrographic,,Sudden loss of muscle tone
MYSZ,2,0.000843882,1.0,17.0,Myoclonic Seizure,Clinical & Electrographic,,Myoclonous jerks of limbs


In [4]:
# ----------------
# Descriptive Keys
# ----------------
train_type_key = train_info.iloc[24:43,16:21]
train_type_key.columns = ['EEG Type', 'EEG SubType', 'Rooms', 'REMOVE', 'Description']
train_type_key = train_type_key.drop(['Rooms','REMOVE'], axis = 1)
train_type_key['EEG Type'] = train_type_key['EEG Type'].ffill()
train_type_key = train_type_key.set_index('EEG Type')

# ------------
# Type Summary
# ------------
train_type_summary = train_info.iloc[1:7,16:20]
train_type_summary.columns = ['EEG Type', 'Sessions', 'Freq.', 'Cum.']
train_type_summary = train_type_summary.set_index('EEG Type')

desc = train_type_key[train_type_key.isnull().any(axis=1)].iloc[:-1]
train_type_summary = train_type_summary.join(desc)
train_type_summary = train_type_summary.drop('EEG SubType', axis=1)

train_type_summary[['Description','Sessions', 'Freq.', 'Cum.']]

Unnamed: 0_level_0,Description,Sessions,Freq.,Cum.
EEG Type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
EMU,,162,0.136709,0.136709
ICU,Intensive Care Unit,438,0.36962,0.506329
Inpatient,Inpatient But Not ICU,350,0.295359,0.801688
Outpatient,Routine EEGs,193,0.162869,0.964557
Unknown,EEG Report Is Not Informative,42,0.035443,1.0
Total:,,1185,1.0,


In [5]:
# just want the info per file here
file_info = train_info.iloc[1:6101,1:15]
# cleans some of the names
file_info_cols = ['File No.', 'Patient', 'Session', 'File', 
                       'EEG Type', 'EEG SubType', 'LTM or Routine', 
                       'Normal/Abnormal', 'No. Seizures File', 
                       'No. Seizures/Session', 'Filename', 'Seizure Start', 
                       'Seizure Stop', 'Seizure Type']
file_info.columns = file_info_cols

# we forward fill as there are gaps in the excel file to represent the info 
# is the same as above (apart from in the filename, seizure start, seizure stop 
# and seizure type columns)
for col_name in file_info.columns[:-4]:
  file_info[col_name] = file_info[col_name].ffill()

# patient ID is an integer rather than float
file_info['Patient'] = file_info['Patient'].astype(int)

file_info.head()

Unnamed: 0,File No.,Patient,Session,File,EEG Type,EEG SubType,LTM or Routine,Normal/Abnormal,No. Seizures File,No. Seizures/Session,Filename,Seizure Start,Seizure Stop,Seizure Type
1,1.0,77,s003,t000,Outpatient,Outpatient,Routine,Abnormal,0.0,12.0,./train/01_tcp_ar/000/00000077/s003_2010_01_21...,,,
2,2.0,254,s005,t000,Outpatient,Outpatient,Routine,Abnormal,0.0,0.0,./train/01_tcp_ar/002/00000254/s005_2010_11_15...,,,
3,3.0,254,s006,t001,Outpatient,Outpatient,Routine,Abnormal,0.0,0.0,./train/01_tcp_ar/002/00000254/s006_2011_07_01...,,,
4,4.0,254,s007,t000,Inpatient,General,Routine,Abnormal,0.0,0.0,./train/01_tcp_ar/002/00000254/s007_2013_03_25...,,,
5,5.0,272,s007,t000,ICU,BURN,LTM,Abnormal,0.0,0.0,./train/01_tcp_ar/002/00000272/s007_2003_07_03...,,,


In [6]:
# our example events file picked from the events filename
SEIZURE_EVENTS_FILE = file_info[file_info['No. Seizures File']>0]['Filename'].iloc[20]

# we use the above to get the file directory this file is in
example_file_dir = DB_PATH + 'edf/' + '/'.join(SEIZURE_EVENTS_FILE.split('/')[1:-1])

print(example_file_dir)
print(SEIZURE_EVENTS_FILE)

./db/v1.5.1/edf/train/01_tcp_ar/008/00000883/s002_2010_09_01
./train/01_tcp_ar/008/00000883/s002_2010_09_01/00000883_s002_t000.tse


In [7]:
'''
Load data from one file
'''
def data_load(data_file, selected_channels=[]):

    try:
        # use the reader to get an EdfReader file
        f = pyedflib.EdfReader(data_file)

        # get the names of the signals
        channel_names = f.getSignalLabels()
        # get the sampling frequencies of each signal
        channel_freq = f.getSampleFrequencies()
        
        # get a list of the EEG channels
        if len(selected_channels) == 0:
            selected_channels = channel_names

        # make an empty file of 0's
        sigbufs = np.zeros((f.getNSamples()[0],len(selected_channels)))
        # for each of the channels in the selected channels
        for i, channel in enumerate(selected_channels):
            try:
              # add the channel data into the array
              sigbufs[:, i] = f.readSignal(channel_names.index(channel))
            
            except:
              ValueError
              # This happens if the sampling rate of that channel is 
              # different to the others.
              # For simplicity, in this case we just make it na.
              sigbufs[:, i] = np.nan


        # turn to a pandas df and save a little space
        df = pd.DataFrame(sigbufs, columns = selected_channels)#.astype('float32')

        # get equally increasing numbers upto the length of the data depending
        # on the length of the data divided by the sampling frequency
        index_increase = np.linspace(0,
                                      len(df)/channel_freq[0],
                                      len(df), endpoint=False)

        # round these to the lowest nearest decimal to get the seconds
        #seconds = np.floor(index_increase).astype('uint16')

        seconds = index_increase
        
        # make a column the timestamp
        df['Time'] = seconds

        # make the time stamp the index
        df = df.set_index('Time')

        # name the columns as channel
        df.columns.name = 'Channel'

        return df, channel_freq[0]

    except:
        OSError
        return pd.DataFrame(), None

seiz_edf_name = SEIZURE_EVENTS_FILE.split('/')[-1][:-3]+'edf'
seiz_edf_file = example_file_dir +'/'+ seiz_edf_name
seiz_data, seiz_freq = data_load(seiz_edf_file)

print(seiz_edf_file)
display(seiz_data.shape)

./db/v1.5.1/edf/train/01_tcp_ar/008/00000883/s002_2010_09_01/00000883_s002_t000.edf


(400250, 36)

In [8]:
# What type of seizure in TUH (If `None` will do all)?
TUH_code = 'GNSZ'

def sel_file_list(set_name, seiz_type):
    # load the training information
    seiz_info_path = DB_PATH +'_DOCS/seizures_v34r.xlsx'
    train_info = pd.read_excel(seiz_info_path, 'train')
    # just want the info per file here
    file_info = train_info.iloc[1:6101,1:15]
    # cleans some of the names
    file_info_cols = ['File No.', 'Patient', 'Session', 'File', 
                          'EEG Type', 'EEG SubType', 'LTM or Routine', 
                          'Normal/Abnormal', 'No. Seizures File', 
                          'No. Seizures/Session', 'Filename', 'Seizure Start', 
                          'Seizure Stop', 'Seizure Type']
    file_info.columns = file_info_cols

    # we forward fill as there are gaps in the excel file to represent the info 
    # is the same as above (apart from in the filename, seizure start, seizure stop 
    # and seizure type columns)
    for col_name in file_info.columns[:-4]:
        file_info[col_name] = file_info[col_name].ffill()

    # patient ID is an integer rather than float
    file_info['Patient'] = file_info['Patient'].astype(int)

    if seiz_type:
        # Get a list of patient event files that have a specifc type of seizure
        return list(file_info[file_info['Seizure Type']==seiz_type]['Filename'])
    else:
        return list(file_info['Filename'])

int_code = seiz_types.to_dict()['Class No.']
# change to lower case
int_code = { k.lower() : v for k,v in int_code.items() if not isinstance(k, float)}

# get a list of files
tuh_file_list = sel_file_list('train', TUH_code)+sel_file_list('dev_test', TUH_code)


In [9]:
print(len(tuh_file_list))

print(tuh_file_list[0])

print(int_code)


#print(tuh_file_list)

816
./train/01_tcp_ar/004/00000492/s003_2003_07_18/00000492_s003_t000.tse
{'spsw': 1, 'gped': 2, 'pled': 3, 'eybl': 4, 'artf': 5, 'bckg': 6, 'seiz': 7, 'fnsz': 8, 'gnsz': 9, 'spsz': 10, 'cpsz': 11, 'absz': 12, 'tnsz': 13, 'cnsz': 14, 'tcsz': 15, 'atsz': 16, 'mysz': 17, 'nesz': 18, 'intr': 19, 'slow': 20, 'eyem': 21, 'chew': 22, 'shiv': 23, 'musc': 24, 'elpp': 25, 'elst': 26, 'calb': 27, 'hphs': 28, 'trip': 29}


In [10]:
# get a list of the montages
montage = []
for file in tuh_file_list:
    montage.append(file.split('/')[3])
    
# count how many times the montages appear in the data
montage_counts = pd.Series(montage).value_counts()

#print(montage_counts)

# remove all files apart from those in the most common montage
#regex = re.compile(montage_counts.index[0])
#tuh_file_list = [i for i in tuh_file_list if regex.search(i)]
# remove duplicates
tuh_file_list = list(set(tuh_file_list))

print(len(tuh_file_list))


186


In [11]:
# --------------------
# GET SIMILAR CHANNELS
# --------------------
# this is to make sure all the data have the same channels
all_channels = []
for events_path in tqdm(tuh_file_list, desc = 'Finding Channels'):
    file_ID = events_path.split('/')[-1][:-4]
    # we use the above to get the file directory this file is in
    pat_file_dir = 'edf/' + '/'.join(events_path.split('/')[1:-1])
    
    file_path = DB_PATH+pat_file_dir+'/'+file_ID+'.edf'
    #print(file_path)
    
    with pyedflib.EdfReader(file_path) as f:
        # get the names of the signals
        all_channels.extend(f.getSignalLabels())

# turn the list into a pandas series
all_channels = pd.Series(all_channels)

# count how many times the channels appear in each participant
channel_counts = all_channels.value_counts()
  
# threshold the channels to only those found in all raw data
channel_keeps = list(channel_counts[channel_counts >= channel_counts[0]].index)
regex = re.compile('30|PHOTIC|EKG|PG')
channel_keeps = [i for i in channel_keeps if not regex.search(i)]

print(channel_keeps)

Finding Channels: 100%|██████████████████████████████████████████████████████████████| 186/186 [00:03<00:00, 52.91it/s]

['EEG F3-REF', 'EEG O2-REF', 'EEG FZ-REF', 'EEG P3-REF', 'EEG CZ-REF', 'EEG F8-REF', 'EEG T4-REF', 'EEG C3-REF', 'EEG T3-REF', 'EEG P4-REF', 'EEG C4-REF', 'EEG FP2-REF', 'EEG T6-REF', 'EEG T5-REF', 'EEG PZ-REF', 'EEG FP1-REF', 'EEG F4-REF', 'EEG O1-REF', 'EEG F7-REF']





In [12]:
ATT_START = 0
ATT_END = 0

import chardet
def create_events(file_name, df, code = None):

    data_y = pd.Series(index=df.index, dtype='float64')
    data_y.name = 'Events'
    
    events_tse = pd.read_csv(file_name,
                             skiprows=1,
                             sep = ' ',
                             header=None,
                             names =['Start', 'End', 'Code', 'Certainty'])
    
    data_y = data_y.fillna('bckg')
    
    
    for pos, row in events_tse.iterrows():
        # if you want to manually set the code
        if code != None:
          if row['Code'] == code:
              data_y[row['Start']:row['End']] = code
        # let it be the code it is in the event file
        else:
          data_y[row['Start']:row['End']] = row['Code']

    return data_y

In [27]:
WINDOW_LENGTH = 10 # in seconds 
WINDOW_LENGTH_MS = WINDOW_LENGTH * 250

In [42]:

preictal_train = []
y_pre_train = []
preictal_test = []
y_pre_test = []
    
# Read files into df and create events series
for events_path in tqdm(tuh_file_list, desc='Reading files'):
    file_ID = events_path.split('/')[-1][:-4]
    # we use the above to get the file directory this file is in
    pat_file_dir = 'edf/' + '/'.join(events_path.split('/')[1:-1])
    
    file_path = DB_PATH+pat_file_dir+'/'+file_ID+'.edf'
    #print(file_path)
    pat_ID = events_path.split('/')[-3]

    # load data
    raw_data, freq = data_load(file_path, channel_keeps)
    
    if raw_data.empty:
        print('Skipped: '+file_ID)
    else:
        raw_events = create_events(DB_PATH+pat_file_dir+'/'+file_ID+'.tse', raw_data)
        # change to integer representation
        #raw_events = raw_events.replace(int_code)
        
        # Replace labels in pd series, 0 for no seizure and 1 for seizure
        raw_events = raw_events.replace('bckg',0)
        raw_events = raw_events.replace(TUH_code.lower(),1)
        
        # Extacts the timestamps where a seizure is present
        seizure_index = raw_events[raw_events == 1].index

        #np.set_printoptions(threshold=sys.maxsize)
        np.set_printoptions(threshold=20)

        #print(seizure_index.values)
        
        for index, val in raw_events.iteritems():
            if(val == 1):
                #print(index, val)
                seiz_start = index # Exact moment when the seizure starts
                break
       
        PREICTAL_DURATION = 10 # In minutes
        PREICTAL_DURATION_MS = PREICTAL_DURATION * 250
        
        #print('PREICTAL_DURATION_MS: ' + str(PREICTAL_DURATION_MS) + ', seiz_start:' + str(seiz_start))
        preictal_start = seiz_start - PREICTAL_DURATION_MS
        preictal_end = seiz_start

        if preictal_start < 0: # Prevent negative index error
            preictal_start = 0

        preictal = []

        count = 0
        for index, row in raw_data.iterrows():
            if(index >= preictal_start and index < preictal_end):
                #print(index)
                #print(row.values)
                preictal.append(row.values)
            elif index > preictal_end:
                break

        preictal_np = np.array(preictal)  

        if len(preictal_np) > WINDOW_LENGTH_MS: # There are enough  preictal recordings to take samples from
            # Preictal train
            for i in range(15):
                end_index = random.randint(WINDOW_LENGTH_MS, len(preictal_np))
                start_index = end_index - WINDOW_LENGTH_MS
                
                array_sum = np.sum(preictal_np[start_index:end_index]) 
                array_has_nan = np.isnan(array_sum)
                
                if not array_has_nan:
                    preictal_train.append(preictal_np[start_index:end_index])
                    y_pre_train.append([1])

            # Preictal test
            for i in range(1):
                end_index = random.randint(WINDOW_LENGTH_MS, len(preictal_np))
                start_index = end_index - WINDOW_LENGTH_MS
                
                array_sum = np.sum(preictal_np[start_index:end_index]) 
                array_has_nan = np.isnan(array_sum)
                
                if not array_has_nan:
                    preictal_test.append(preictal_np[start_index:end_index])
                    y_pre_test.append([1])

            


Reading files: 100%|█████████████████████████████████████████████████████████████████| 186/186 [12:26<00:00,  4.01s/it]


In [14]:
# Find sessions without seizures for interictal readings

def sel_inter_file_list(set_name):
    # load the training information
    seiz_info_path = DB_PATH +'_DOCS/seizures_v34r.xlsx'
    train_info = pd.read_excel(seiz_info_path, 'train')
    # just want the info per file here
    file_info = train_info.iloc[1:6101,1:15]
    # cleans some of the names
    file_info_cols = ['File No.', 'Patient', 'Session', 'File', 
                          'EEG Type', 'EEG SubType', 'LTM or Routine', 
                          'Normal/Abnormal', 'No. Seizures File', 
                          'No. Seizures/Session', 'Filename', 'Seizure Start', 
                          'Seizure Stop', 'Seizure Type']
    file_info.columns = file_info_cols

    # we forward fill as there are gaps in the excel file to represent the info 
    # is the same as above (apart from in the filename, seizure start, seizure stop 
    # and seizure type columns)
    for col_name in file_info.columns[:-4]:
        file_info[col_name] = file_info[col_name].ffill()

    # patient ID is an integer rather than float
    file_info['Patient'] = file_info['Patient'].astype(int)

    # Get a list of patient event files that have no seizures
    return list(file_info[file_info['No. Seizures/Session']==0]['Filename'])

# get a list of files
inter_file_list = sel_inter_file_list('train')+sel_inter_file_list('dev_test')

In [15]:
# get a list of the montages
inter_montage = []
for file in inter_file_list:
    inter_montage.append(file.split('/')[3])
    
# count how many times the montages appear in the data
inter_montage_counts = pd.Series(inter_montage).value_counts()

#print(montage_counts)

# remove all files apart from those in the most common montage
#regex = re.compile(montage_counts.index[0])
#tuh_file_list = [i for i in tuh_file_list if regex.search(i)]
# remove duplicates
inter_file_list = list(set(inter_file_list))

print(len(inter_file_list))

3018


In [16]:
np.random.shuffle(inter_file_list)

inter_file_list = inter_file_list[:len(inter_file_list)-2832]

print(len(inter_file_list))

186


In [43]:
interictal_train = []
y_inter_train = []
interictal_test = []
y_inter_test = []

# Read files into df and create events series
for events_path in tqdm(inter_file_list, desc='Reading files'):
    file_ID = events_path.split('/')[-1][:-4]
    # we use the above to get the file directory this file is in
    pat_file_dir = 'edf/' + '/'.join(events_path.split('/')[1:-1])
    
    file_path = DB_PATH+pat_file_dir+'/'+file_ID+'.edf'
    #print(file_path)
    pat_ID = events_path.split('/')[-3]

    # load data
    raw_data, freq = data_load(file_path, channel_keeps)
    
    if raw_data.empty:
        print('Skipped: '+file_ID)
    else:
        raw_events = create_events(DB_PATH+pat_file_dir+'/'+file_ID+'.tse', raw_data)
        # change to integer representation
        #raw_events = raw_events.replace(int_code)
        
        # Replace labels in pd series, 0 for no seizure and 1 for seizure
        raw_events = raw_events.replace('bckg',0)
        raw_events = raw_events.replace(TUH_code.lower(),1)
        
        # Extacts the timestamps where a seizure is present
        seizure_index = raw_events[raw_events == 1].index

        #np.set_printoptions(threshold=sys.maxsize)
        np.set_printoptions(threshold=20)

        #print(seizure_index.values)
        
        interictal = []

        for index, row in raw_data.iterrows():
            interictal.append(row.values)

        interictal_np = np.array(interictal) 
        
        if len(interictal_np) > WINDOW_LENGTH_MS: # There are enough  interictal recordings to take samples from
            # Preictal train
            for i in range(15):
                end_index = random.randint(WINDOW_LENGTH_MS, len(interictal_np))
                start_index = end_index - WINDOW_LENGTH_MS
                
                array_sum = np.sum(interictal_np[start_index:end_index]) 
                array_has_nan = np.isnan(array_sum)
                
                if not array_has_nan:
                    interictal_train.append(interictal_np[start_index:end_index])
                    y_inter_train.append([0])

            # Preictal test
            for i in range(1):
                end_index = random.randint(WINDOW_LENGTH_MS, len(interictal_np))
                start_index = end_index - WINDOW_LENGTH_MS
                
                array_sum = np.sum(interictal_np[start_index:end_index]) 
                array_has_nan = np.isnan(array_sum)
                
                if not array_has_nan:
                    interictal_test.append(interictal_np[start_index:end_index])
                    y_inter_test.append([0])
                    

Reading files: 100%|█████████████████████████████████████████████████████████████████| 186/186 [22:26<00:00,  7.24s/it]


In [44]:
X_train = interictal_train + preictal_train
X_test = interictal_test + preictal_test
y_train = y_inter_train + y_pre_train
y_test = y_inter_test + y_pre_test

print(len(X_train))
print(len(X_test))


4350
290


In [45]:
X_train = np.array(X_train)
y_train = np.array(y_train)
X_test = np.array(X_test)
y_test = np.array(y_test)

print(X_train.shape)
print(y_train.shape)

(4350, 500, 19)
(4350, 1)


In [46]:
# Save the built arrays 
import pickle

output = open('X_train.pkl', 'wb')
pickle.dump(X_train, output)
output.close()

output = open('y_train.pkl', 'wb')
pickle.dump(y_train, output)
output.close()

output = open('X_test.pkl', 'wb')
pickle.dump(X_test, output)
output.close()

output = open('y_test.pkl', 'wb')
pickle.dump(y_test, output)
output.close()

In [21]:
# Load the arrays
import pickle

with open('X_train.pkl', 'rb') as pickle_file:
    X_train = pickle.load(pickle_file)
    pickle_file.close()
    
with open('y_train.pkl', 'rb') as pickle_file:
    y_train = pickle.load(pickle_file)
    pickle_file.close()
    
with open('X_test.pkl', 'rb') as pickle_file:
    X_test = pickle.load(pickle_file)
    pickle_file.close()
    
with open('y_test.pkl', 'rb') as pickle_file:
    y_test = pickle.load(pickle_file)
    pickle_file.close()
    

In [36]:
# Shuffle the data to prevent similar samples getting memorized together

train_indices = tf.range(start=0, limit=tf.shape(X_train)[0], dtype=tf.int32)
test_indices = tf.range(start=0, limit=tf.shape(X_test)[0], dtype=tf.int32)

train_idx = tf.random.shuffle(train_indices)
X_train = tf.gather(X_train, train_idx)
y_train = tf.gather(y_train, train_idx)

test_idx = tf.random.shuffle(test_indices)
X_test = tf.gather(X_test, test_idx)
y_test = tf.gather(y_test, test_idx)

In [37]:

model = tf.keras.models.Sequential([
    # 100 unit GRU layer
    tf.keras.layers.GRU(100, activation='tanh', input_shape=(WINDOW_LENGTH_MS, len(channel_keeps)), return_sequences=True),
    # Add a dropout rate of 0.5
    tf.keras.layers.Dropout(0.5),
    # 100 neuron GRU layer
    tf.keras.layers.GRU(100, activation='tanh'),
    # Add a dropout rate of 0.5
    tf.keras.layers.Dropout(0.5),
    # Only 1 output neuron. It will contain a value from 0-1 where 0 for 'interictal' and 1 for 'preictal'
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
gru_4 (GRU)                  (None, 2500, 100)         36300     
_________________________________________________________________
dropout_4 (Dropout)          (None, 2500, 100)         0         
_________________________________________________________________
gru_5 (GRU)                  (None, 100)               60600     
_________________________________________________________________
dropout_5 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 101       
Total params: 97,001
Trainable params: 97,001
Non-trainable params: 0
_________________________________________________________________


In [38]:

best_model_filepath = "Best_Model.ckpt"

# Save the model under certain conditions
callback_checkpoint = ModelCheckpoint(filepath=best_model_filepath,
                                      monitor='val_loss',
                                      verbose=0,
                                      save_weights_only=True,
                                      save_best_only=True)


callback_early_stopping = EarlyStopping(monitor='val_loss', 
                                            min_delta=0, 
                                            patience=5, 
                                            verbose=1,
                                            mode='auto', 
                                            baseline=None, 
                                            restore_best_weights=False)


In [39]:
from tensorflow.keras.optimizers import Adam

model.compile(optimizer=Adam(0.001),
              loss='binary_crossentropy',
              metrics=['accuracy', 'AUC', 'Recall', 'Precision'])

In [40]:
EPOCHS = 20
BATCH_SIZE = 32

history = model.fit(X_train, 
                    y_train,
                    batch_size=BATCH_SIZE, 
                    epochs=EPOCHS,
                    validation_split=0.06,
                    callbacks=[callback_checkpoint],
                    verbose=1)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
 21/125 [====>.........................] - ETA: 12:23 - loss: 0.1418 - accuracy: 0.9384 - auc: 0.9917 - recall: 0.9502 - precision: 0.9109

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

def plot_progress(history_dict):
  for key in list(history_dict.keys())[:5]:
    plt.clf() # Clears the figure
    training_values = history_dict[key] 
    val_values = history_dict['val_'+key]

    epochs = range(1, len(training_values) + 1)

    plt.plot(epochs, training_values, 'bo', label='Training '+key)

    plt.plot(epochs, val_values, 'b', label='Validation '+key)

    if key != 'loss':
      plt.ylim([0.,1.1])

    plt.title('Training and Validation '+key)
    plt.xlabel('Epochs')
    plt.ylabel(key)
    plt.legend()
    plt.show()
    
plot_progress(history.history)

In [None]:

metrics = model.evaluate(X_test, y_test)

pd.DataFrame(metrics, index = model.metrics_names)
