### Import Data & Process

1. Import MIT-BIH Arrhythmia ECG records and annotations
2. Import MIT-BIH Atrial Fibrillation ECG records and annotations
3. Resample all data to the same frequency (360Hz)
4. Take each patient record and split it into 30 sec. samples
    - 10 hr. records from AFIB ECGs -> 1200 samples (27600 samples from 23 total ECGs)
    - .5 hr. records from Arrhythmia ECGs -> 60 samples (2880 samples from 48 total ECGs)
    
total ECG samples = 30480

using 30420, 30s ECG samples (60 are bad) 

total only AFIB samples = 11234 (1% of dataset, close to real world proportion = .5%)

total containing AFIB samples = 11296 = 37.13% of dataset

total normal samples = 17888 = 58.80% of dataset

other samples = 1236 = 4.06% of dataset

In [4]:
# arrhythmia ECG records

import wfdb

record_ids = [
    '100','101','102','103','104','105','106','107','108',
    '109','111','112','113','114','115','116','117','118',
    '119','121','122','123','124','200','201','202','203',
    '205','207','208','209','210','212','213','214','215',
    '217','219','220','221','222','223','228','230','231',
    '232','233','234']

# records = []
# for rec in record_ids:
#     record = wfdb.rdrecord(rec, pn_dir='mitdb')
#     records.append(record)

# to cut arrhythmia pieces
num_samps = 10800
arrhythmia_pieces = []
arrhythmia_ann_pieces = []
arrhythmia_afib_pieces = []
arrhythmia_normal_pieces = []
arrhythmia_other_pieces = []
arrhythmia_afib_and_other_pieces = []
arrhythmia_bad_pieces = 0

bad_ann = { 'qq\x00', 'U\x00', 'M\x00', 'MISSB\x00', 'P\x00', 'PE\x00', 'T\x00', 'TS\x00' }

for rec_id in record_ids:
    # read rec and ann
    rec = wfdb.rdrecord('mit-bih-arrhythmia-database-1.0.0/'+rec_id)
    ann = wfdb.rdann('mit-bih-arrhythmia-database-1.0.0/'+rec_id, extension='atr')
    start = 0
    end = num_samps
    start_ann = None
    while end <= 648000:
        # create the annotations array for the 30 second piece
        arrhythmia_ann_piece = []
        has_afib = (start_ann == '(AFIB')
        has_other = (start_ann not in bad_ann and start_ann != '(AFIB' and start_ann != '(N\x00')
        is_bad = (start_ann in bad_ann)
        if start_ann != None:
            arrhythmia_ann_piece.append([0, start_ann])
        for i in range(0, ann.__dict__['ann_len']):
            ann_symbol = ann.__dict__['aux_note'][i]
            if start <= ann.__dict__['sample'][i] and ann.__dict__['sample'][i] < end and ann_symbol != '':
                arrhythmia_ann_piece.append([
                    ann.__dict__['sample'][i] - start,
                    ann_symbol
                ])
                # find bad pieces
                if ann_symbol in bad_ann:
                    is_bad = True
                # find pieces with afib anns
                elif ann_symbol == '(AFIB':
                    has_afib = True
                # find pieces with other anns
                elif ann_symbol != '(AFIB' and ann_symbol != '(N\x00':
                    has_other = True
                start_ann = ann_symbol
        # identify afib pieces
        piece_index = len(arrhythmia_ann_pieces)
        # identify bad pieces
        if is_bad:
            arrhythmia_bad_pieces += 1
        else:
            # cut out the 30 second piece
            arrhythmia_pieces.append(rec.__dict__['p_signal'][start:end])
            arrhythmia_ann_pieces.append(arrhythmia_ann_piece)
            # identify afib pieces
            if has_afib:
                # identify afib and other pieces
                if has_other:
                    arrhythmia_afib_and_other_pieces.append(piece_index)
                # identify just afib pieces
                else:
                    arrhythmia_afib_pieces.append(piece_index)
            # identify no afib pieces
            else:
                # identify other arrhythmia pieces
                if has_other:
                    arrhythmia_other_pieces.append(piece_index)
                # identify normal pieces
                else:
                    arrhythmia_normal_pieces.append(piece_index)
        start = end
        end += num_samps
        
print('Loaded',len(arrhythmia_pieces),'arrhythmia pieces')
print('Loaded',len(arrhythmia_ann_pieces),'arrhythmia annotations')
print('Found',len(arrhythmia_afib_pieces),'afib pieces')
print('Found',len(arrhythmia_afib_and_other_pieces),'afib and other pieces')
print('Found',len(arrhythmia_other_pieces),'pieces with other arrhythmia')
print('Found',len(arrhythmia_normal_pieces),'normal pieces')
print('Found',arrhythmia_bad_pieces,'bad pieces')

# x_record_ids = [
#     'x_108','x_109','x_111','x_112','x_113',
#     'x_114','x_115','x_116','x_117','x_121',
#     'x_122','x_123','x_124','x_220','x_221',
#     'x_222','x_223','x_228','x_230','x_231',
#     'x_232','x_233','x_234']
# x_annotations = []
# x_records = []
# for rec in x_record_ids:
#     record = wfdb.rdrecord(rec, pn_dir='mitdb/x_mitdb')
#     x_records.append(record)
#     annotation = wfdb.rdann(rec, pn_dir='mitdb/x_mitdb', extension='atr')
#     x_annotations.append(annotation)
# print('Loaded', len(x_records), 'x record(s)')


Loaded 2820 arrhythmia pieces
Loaded 2820 arrhythmia annotations
Found 8 afib pieces
Found 25 afib and other pieces
Found 1048 pieces with other arrhythmia
Found 1739 normal pieces
Found 60 bad pieces


In [5]:
# atrial fibrillation ECG records

import wfdb.processing

# load MIT BIH AFib data & annotations
afib_record_ids = [
#     '00735', '03665', 
    '04015', '04043', '04048', '04126', 
    '04746', '04908', '04936', '05091', '05121', '05261', 
    '06426', '06453', '06995', '07162', '07859', '07879', 
    '07910', '08215', '08219', '08378', '08405', '08434', 
    '08455']

num_samps = 10800

afib_pieces = []
afib_ann_pieces = []
afib_afib_pieces = []
afib_normal_pieces = []
afib_other_pieces = []
afib_afib_and_other_pieces = []

for rec_id in afib_record_ids:
    # read rec and ann
    rec = wfdb.rdrecord('mit-bih-atrial-fibrillation-database-1.0.0/'+rec_id)
    ann = wfdb.rdann('mit-bih-atrial-fibrillation-database-1.0.0/'+rec_id, extension='atr')
    # resample rec and ann to 360 Hz
    afib_resampled, afib_annotation_resampled = wfdb.processing.resample_multichan(rec.__dict__['p_signal'], ann, 250, 360)
    ann_samples = [int(round(i*360.0/250.0)) for i in ann.__dict__['sample']]
    start = 0
    end = num_samps
    start_ann = None
    while end <= 12960000:
        # cut out the 30 second piece
        afib_pieces.append(afib_resampled[start:end])
        # create the annotations array for the 30 second piece
        ann_piece = []
        has_afib = (start_ann == '(AFIB')
        has_other = (start_ann != '(AFIB' and start_ann != '(N')
        if start_ann != None:
            ann_piece.append([0, start_ann])
        for i in range(0, ann.__dict__['ann_len']):
            ann_symbol = ann.__dict__['aux_note'][i]
            if start <= ann_samples[i] and ann_samples[i] < end:
                ann_piece.append([
                    ann_samples[i] - start,
                    ann_symbol])
                # find pieces with afib anns
                if ann_symbol == '(AFIB':
                    has_afib = True
                # find pieces with other anns
                elif ann_symbol != '(AFIB' and ann_symbol != '(N':
                    has_other = True
                start_ann = ann_symbol
        # identify afib pieces
        piece_index = len(afib_ann_pieces)
        # identify afib pieces
        if has_afib:
            # identify afib and other pieces
            if has_other:
                afib_afib_and_other_pieces.append(piece_index)
            # identify just afib pieces
            else:
                afib_afib_pieces.append(piece_index)
        # identify no afib pieces
        else:
            # identify other arrhythmia pieces
            if has_other:
                afib_other_pieces.append(piece_index)
            # identify normal pieces
            else:
                afib_normal_pieces.append(piece_index)
        afib_ann_pieces.append(ann_piece)
        start = end
        end += num_samps
        
print('Loaded',len(afib_pieces),'afib pieces')
print('Loaded',len(afib_ann_pieces),'afib annotations')
print('Found',len(afib_afib_pieces),'afib pieces')
print('Found',len(afib_afib_and_other_pieces),'afib and other pieces')
print('Found',len(afib_other_pieces),'pieces with other arrhythmia')
print('Found',len(afib_normal_pieces),'normal pieces')


Loaded 27600 afib pieces
Loaded 27600 afib annotations
Found 11226 afib pieces
Found 37 afib and other pieces
Found 188 pieces with other arrhythmia
Found 16149 normal pieces


In [6]:
# combine 30s pieces from arrhythmia and afib databases
all_pieces = arrhythmia_pieces + afib_pieces
all_ann_pieces = arrhythmia_ann_pieces + afib_ann_pieces
all_afib_pieces = arrhythmia_afib_pieces + [i + len(arrhythmia_pieces) for i in afib_afib_pieces]
all_afib_and_other_pieces = arrhythmia_afib_and_other_pieces + [i + len(arrhythmia_pieces) for i in afib_afib_and_other_pieces]
all_other_pieces = arrhythmia_other_pieces + [i + len(arrhythmia_pieces) for i in afib_other_pieces]
all_normal_pieces = arrhythmia_normal_pieces + [i + len(arrhythmia_pieces) for i in afib_normal_pieces]

# label all pieces: N = normal, A = afib, B = both afib and other, O = other
all_pieces_labels = ['N' for i in all_pieces]
for i in all_afib_pieces:
    all_pieces_labels[i] = 'A'
for i in all_afib_and_other_pieces:
    all_pieces_labels[i] = 'B'
for i in all_other_pieces:
    all_pieces_labels[i] = 'O'

# get target labels for output data (0 if not AFIB, 1 if AFIB)
all_pieces_targets = [1 if (l == 'A' or l == 'B') else 0 for l in all_pieces_labels]



In [7]:
# write all_pieces and all_pieces_targets to files

import pickle

all_pieces_file = open('all_pieces.pkl', mode='wb')
pickle.dump(all_pieces, all_pieces_file)
all_pieces_file.close()

all_pieces_targets_file = open('all_pieces_targets.pkl', mode='wb')
pickle.dump(all_pieces_targets, all_pieces_targets_file)
all_pieces_targets_file.close()

In [3]:
# read all_pieces and all_pieces_targets from files

import pickle

all_pieces_file = open('all_pieces.pkl', mode='rb')
all_pieces = pickle.load(all_pieces_file)
all_pieces_file.close()

all_pieces_targets_file = open('all_pieces_targets.pkl', mode='rb')
all_pieces_targets = pickle.load(all_pieces_targets_file)
all_pieces_targets_file.close()

In [4]:
from sklearn.model_selection import KFold

# 5-fold cross validation to get train and test sets
kfold = KFold(n_splits=5, shuffle=True)
train_sets = []
test_sets = []
for train_set, test_set in kfold.split(all_pieces, all_pieces_targets):
    train_sets.append(train_set)
    test_sets.append(test_set)

# train set 1
train_set_pieces_raw = [all_pieces[i] for i in train_sets[0]]
#input - train set 1 MLII signal only        
train_set_pieces = [[s[0] for s in piece] for piece in train_set_pieces_raw]
# output - target labels
train_set_targets = [all_pieces_targets[i] for i in train_sets[0]]

In [5]:
from keras.utils import np_utils

# one-hot encoding of train_set_targets
train_set_targets_one_hot = np_utils.to_categorical(train_set_targets, 2)
print('New train_set_targets shape: ', train_set_targets_one_hot.shape)

New train_set_targets shape:  (24336, 2)


In [None]:
from keras.models import Sequential
from keras.layers import Dense, Flatten, Dropout, Conv1D, MaxPooling1D, Reshape, GlobalAveragePooling1D
from keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint, EarlyStopping

# 1D CNN to identify AFIB or not AFIB
model_1 = Sequential()
model_1.add(Conv1D(filters=100, kernel_size=18, activation='relu', input_shape=(len(train_set_pieces), 10800)))
model_1.add(Conv1D(100, 5, activation='relu'))
model_1.add(MaxPooling1D(3))
model_1.add(Conv1D(160, 10, activation='relu'))
model_1.add(Conv1D(160, 10, activation='relu'))
model_1.add(GlobalAveragePooling1D())
model_1.add(Dropout(0.5))
model_1.add(Dense(activation='softmax', units=1))
print(model_1.summary())

callbacks_list = [
    ModelCheckpoint(
        filepath='best_model.{epoch:02d}-{val_loss:.2f}.h5',
        monitor='val_loss', save_best_only=True),
    EarlyStopping(monitor='acc', patience=1)
]
model_1.compile(loss='binary_crossentropy',
                optimizer='adam', metrics=['accuracy'])

history = model_1.fit(train_set_pieces,
                      train_set_targets_one_hot,
                      batch_size=32,
                      epochs=1,
                      callbacks=callbacks_list,
                      validation_split=0.2,
                      verbose=1)

In [None]:
history