In [121]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split

import os
import numpy as np
import scipy.io as sio
import wfdb

def read_records(record_files):
    records = []
    labels = []
    for record_file in record_files:
        record = wfdb.rdrecord(record_file)
        if record.file_name[0].endswith('.dat'):
            # TODO work out how to deal with MIT-BIH with its different hea/atr/dat files
            # and very low samples. Split into many files?
            ann = wfdb.rdann(record_file,'atr')
        else:
            for comment in record.comments:
                if comment.startswith('Dx') or comment.startswith(' Dx'):
                    dxs = set(arr.strip() for arr in comment.split(': ')[1].split(','))
                    labels.append(dxs)
                else:
                    labels.append(set())
        records.append(wfdb.rdrecord(record_file))
    return records, labels


def create_one_hot_labels(all_labels, unique_classes, num_recordings):
    index = list()
    labels = np.zeros((num_recordings, len(unique_classes)))#, dtype=np.bool)
    for i in range(num_recordings):
        dxs = all_labels[i]
        flag = np.zeros((1,len(dxs)), dtype = bool)
        count = 0
        for dx in dxs:
            if dx in unique_classes:
                j = unique_classes.index(dx)
                labels[i, j] = 1
                flag [0 ,count] = True

            count += 1

        if np.any(flag) == False:
            index.append(i)

    return labels, index

def get_unique_classes(all_labels, valid_classes=None):

    classes2 = list()
    for i in range(len(all_labels)):
        dxs = all_labels[i]
        for dx in dxs:
            if valid_classes is None or dx in valid_classes:
                classes2.append(dx)

    classes3 = list()
    for x in classes2:
        if x not in classes3:
            classes3.append(x)

    classes3 = sorted (classes3)
    return classes3

def find_records(directory):
    record_files = []
    for dirpath, _, filenames in os.walk(directory):
        for f in sorted(filenames):
            file_path = os.path.join(dirpath, f)
            if os.path.isfile(file_path) and not f.lower().startswith('.'):
                file, ext = os.path.splitext(file_path)
                if ext.lower() == '.hea':
                    record_files.append(file)
    if record_files:
        return record_files
    else:
        raise IOError('No record files found.')

def filter(data, labels, index):
    labels = [labels[i] for i in range(len(labels)) if i not in index]
    data = [data[i] for i in range(len(data)) if i not in index]
    return labels, data

def consolidate_equivalent_classes(labels, unique_classes):
    equivalent_classes_collection = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]

    # For each set of equivalent class, use only one class as the representative class for the set and discard the other classes in the set.
    # The label for the representative class is positive if any of the labels in the set is positive.
    remove_classes = list()
    remove_indices = list()
    for equivalent_classes in equivalent_classes_collection:
        equivalent_classes = [x for x in equivalent_classes if x in unique_classes]
        if len(equivalent_classes)>1:
            other_classes = equivalent_classes[1:]
            equivalent_indices = [unique_classes.index(x) for x in equivalent_classes]
            representative_index = equivalent_indices[0]
            other_indices = equivalent_indices[1:]

            labels[:, representative_index] = np.any(labels[:, equivalent_indices], axis=1)
            remove_classes += other_classes
            remove_indices += other_indices

    for x in remove_classes:
        unique_classes.remove(x)
    labels = np.delete(labels, remove_indices, axis=1)

    return labels, unique_classes

def set_labels_to_normal_if_none_other(labels, unique_classes, normal_class):
    # If the labels are negative for all classes, then change the label for the normal class to positive.
    normal_index = unique_classes.index(normal_class)
    for i in range(len(labels)):
        num_positive_classes = np.sum(labels[i, :])
        if num_positive_classes==0:
            labels[i, normal_index] = 1

    return labels

def ensure_normal_class(unique_classes, normal_class):
    if normal_class not in unique_classes:
        unique_classes.add(normal_class)
        print('- The normal class {} is not one of the label classes, so it has been automatically added, but please check that you chose the correct normal class.'.format(normal_class))
    unique_classes = sorted(unique_classes)
    return unique_classes

def read_scored_classes():
    scored = list()
    with open('dx_mapping_scored.csv', 'r') as f:
        for l in f:
            dxs = (l.split(','))
            scored.append(dxs[1])
    return (sorted(scored[1:]))

def filter(labels, records, keep_index):
    labels = [labels[i] for i in range(len(labels)) if i not in keep_index]
    records = [records[i] for i in range(len(records)) if i not in keep_index]

    return labels, records

def load_records(header_file_list, adjust_classes_for_physionet, normal_class=None):
        
    if len(header_file_list) == 0:
        raise ValueError('No header files found.')

    num_recordings = len(header_file_list)

    records, all_labels = read_records(header_file_list)

    scored = None
    if adjust_classes_for_physionet:
        scored = read_scored_classes()

    unique_classes = get_unique_classes(all_labels, scored)

    if (normal_class is not None):
        unique_classes = ensure_normal_class(unique_classes, normal_class)
    
    labels, keep_index = create_one_hot_labels(all_labels, unique_classes, num_recordings)

    if (adjust_classes_for_physionet):
        labels, unique_classes = consolidate_equivalent_classes(labels, unique_classes)
    
    if (normal_class is not None):
        labels = set_labels_to_normal_if_none_other(labels, unique_classes, normal_class)
    
    labels, records = filter(labels, records, keep_index)

    return labels, records

def standardise_length(data, target_length):
    number_of_leads = data.shape[0]
    
    if len(data[0])<=target_length:
        ext= np.zeros([number_of_leads,target_length])
        for i in range(0,number_of_leads):
            ext[i][0:len(data[i])]=data[i]
        return ext.T  
    else:
        cut = np.zeros([number_of_leads,target_length])
        for i in range(number_of_leads):
            tocut = len(data[0])- target_length
            cut[i] = data[i][tocut:]
        return cut.T 

def resample(data, src_frq, trg_frq):

    if src_frq == trg_frq:
        return data

    N_src = data.shape[0]
    N_trg = int(N_src * trg_frq / src_frq)
    
    resampled = np.zeros((N_trg, data.shape[1]), dtype='float32')
    for i in range(data.shape[1]):
        resampled[:,i] = np.interp(np.linspace(0, N_src, N_trg), np.arange(N_src), data[:, i])
        
    return resampled

def standardise_data_samples(records):
    standardised_data = list()

    # find the most common fs
    fss = [record.fs for record in records]
    target_fs = max(set(fss), key=fss.count)

    # find the most common sig_len
    sig_lens = [record.sig_len for record in records]
    target_length = max(set(sig_lens), key=sig_lens.count)

    for i in range(len(records)):
        datum = records[i].p_signal.T
        datum = resample(datum,records[i].fs, target_fs)
        datum = standardise_length(datum, target_length)
        standardised_data.append(datum)

    return standardised_data

def load_data(input_directory, adjust_classes_for_physionet, normal_class):
    record_file_list = find_records(input_directory)
    
    labels, records = load_records(record_file_list, adjust_classes_for_physionet, normal_class=normal_class)
    samples = standardise_data_samples(records)
            
    labels = np.stack(labels, axis =0)
    samples = np.stack(samples, axis =0)

    train_x, test_x, train_y, test_y = train_test_split(samples, labels, test_size=0.1, random_state=42)
    train_x, val_x, train_y, val_y = train_test_split(train_x, train_y, test_size=0.1, random_state=42)
    
    return train_x,train_y,val_x,val_y, test_x, test_y
    

In [60]:
from tensorflow.keras import layers
import keras_nlp as nlp 
import tensorflow.keras as keras

def vgg_block(input, cnn_units):
    output = input
    output = layers.Conv1D(cnn_units, 3, padding='same', activation='relu')(output)
    output = layers.Conv1D(cnn_units, 3, padding='same', activation='relu')(output)
    output = layers.MaxPooling1D(2, padding='same')(output)
    return output

def crt_net(
        number_of_leads,
        cnn_units=128,
        vgg_blocks=1,
        rnn_units=64,
        transformer_encoders=4,
        att_dim=64,
        att_heads=8,
        fnn_units=64,
        num_classes=6
    ):
    input = layers.Input(shape=(None, number_of_leads))
    output = input

    for _ in range(vgg_blocks):
        output = vgg_block(output, cnn_units)

    output = layers.Bidirectional(layers.GRU(rnn_units, return_sequences=True), merge_mode='sum')(output)

    if transformer_encoders > 0:
        output = output + nlp.layers.SinePositionEncoding(max_wavelength=10000)(output)

        for _ in range(transformer_encoders):
            output = nlp.layers.TransformerEncoder(att_dim, att_heads)(output)

        output = layers.GlobalAveragePooling1D()(output)
        
    output = layers.Dropout(0.2)(output)
    output = layers.Dense(fnn_units, activation='relu')(output)
    output = layers.Dense(fnn_units, activation='relu')(output)

    output = layers.Dense(num_classes, activation='sigmoid')(output)
    return keras.Model(input, output)


In [124]:
import pandas as pd
import matplotlib.pyplot as plt

output_directory = 'output'

input_directory = 'training_data/cpsc_2018_subset'
adjust_classes_for_physionet = False
normal_class = None #'426783006'

MAX_EPOCHS = 2
    
stopping = keras.callbacks.EarlyStopping(patience=5)

reduce_lr = keras.callbacks.ReduceLROnPlateau(
    factor=0.1,
    patience=3,
    min_lr= 0.001*0.001)

filepath = os.path.join('output', "model.h5")

#checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)
        
train_x, train_y, val_x, val_y, test_x, test_y = load_data(input_directory, adjust_classes_for_physionet=False, normal_class=normal_class)
        
model = crt_net(number_of_leads=train_x.shape[2], num_classes=train_y.shape[1])  
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
    
history = model.fit(
    train_x, train_y,
    batch_size=32,
    epochs=MAX_EPOCHS,
    validation_data=(val_x, val_y),
    callbacks= [reduce_lr,stopping])

pd.DataFrame(history.history).plot(
    figsize=(8, 5), xlim=[0, 100], ylim=[0, 1], grid=True, xlabel="Epoch",
    style=["r--", "r--.", "b-", "b-*"])
plt.legend(loc="lower left")
plt.show()


Epoch 1/2

ValueError: Unable to create dataset (name already exists)