In [3]:
import os
import numpy as np
from sklearn.model_selection import train_test_split
import wfdb
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix

def read_records(record_files):
    records = []
    labels = []
    for record_file in record_files:
        record = wfdb.rdrecord(record_file)
        if record.file_name[0].endswith('.dat'):
            # TODO work out how to deal with MIT-BIH with its different hea/atr/dat files
            # and very low samples. Split into many files?
            ann = wfdb.rdann(record_file,'atr')
        else:
            for comment in record.comments:
                if comment.startswith('Dx') or comment.startswith(' Dx'):
                    dxs = set(arr.strip() for arr in comment.split(': ')[1].split(','))
                    labels.append(dxs)
                    
        records.append(wfdb.rdrecord(record_file))
    return records, labels


def create_one_hot_labels(all_labels, target_classes, num_recordings):
    discard_index = list()
    labels = np.zeros((num_recordings, len(target_classes)))#, dtype=np.bool)
    for i in range(num_recordings):
        dxs = all_labels[i]
        flag = np.zeros((1,len(dxs)), dtype = bool)
        count = 0
        for dx in dxs:
            if dx in target_classes:
                j = target_classes.index(dx)
                labels[i, j] = 1
                flag [0 ,count] = True

            count += 1

        # note any recordings that don't have any of the classes we are looking for
        if np.any(flag) == False:
            discard_index.append(i)

    return labels, discard_index

def get_unique_classes(all_labels, valid_classes=None):

    classes2 = list()
    for i in range(len(all_labels)):
        dxs = all_labels[i]
        for dx in dxs:
            if valid_classes is None or dx in valid_classes:
                classes2.append(dx)

    classes3 = list()
    for x in classes2:
        if x not in classes3:
            classes3.append(x)

    classes3 = sorted (classes3)
    return classes3

def find_records(directory):
    record_files = []
    for dirpath, _, filenames in os.walk(directory):
        for f in sorted(filenames):
            file_path = os.path.join(dirpath, f)
            if os.path.isfile(file_path) and not f.lower().startswith('.'):
                file, ext = os.path.splitext(file_path)
                if ext.lower() == '.hea':
                    record_files.append(file)
    if record_files:
        return record_files
    else:
        raise IOError('No record files found.')

def filter(data, labels, index):
    labels = [labels[i] for i in range(len(labels)) if i not in index]
    data = [data[i] for i in range(len(data)) if i not in index]
    return labels, data

def consolidate_equivalent_classes(one_hot_encoded_labels, unique_classes):
    equivalent_classes_collection = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]

    # For each set of equivalent class, use only one class as the representative class for the set and discard the other classes in the set.
    # The label for the representative class is positive if any of the labels in the set is positive.
    remove_classes = list()
    remove_indices = list()
    for equivalent_classes in equivalent_classes_collection:
        equivalent_classes = [x for x in equivalent_classes if x in unique_classes]
        if len(equivalent_classes)>1:
            other_classes = equivalent_classes[1:]
            equivalent_indices = [unique_classes.index(x) for x in equivalent_classes]
            representative_index = equivalent_indices[0]
            other_indices = equivalent_indices[1:]

            one_hot_encoded_labels[:, representative_index] = np.any(one_hot_encoded_labels[:, equivalent_indices], axis=1)
            remove_classes += other_classes
            remove_indices += other_indices

    for x in remove_classes:
        unique_classes.remove(x)
    one_hot_encoded_labels = np.delete(one_hot_encoded_labels, remove_indices, axis=1)

    return one_hot_encoded_labels, unique_classes

def set_labels_to_normal_if_none_other(labels, unique_classes, normal_class):
    # If the labels are negative for all classes, then change the label for the normal class to positive.
    normal_index = unique_classes.index(normal_class)
    for i in range(len(labels)):
        num_positive_classes = np.sum(labels[i, :])
        if num_positive_classes==0:
            labels[i, normal_index] = 1

    return labels

def ensure_normal_class(unique_classes, normal_class):
    if normal_class not in unique_classes:
        unique_classes.add(normal_class)
        print('- The normal class {} is not one of the label classes, so it has been automatically added, but please check that you chose the correct normal class.'.format(normal_class))
    unique_classes = sorted(unique_classes)
    return unique_classes

def read_scored_classes():
    scored = list()
    with open('dx_mapping_scored.csv', 'r') as f:
        for l in f:
            dxs = (l.split(','))
            scored.append(dxs[1])
    return (sorted(scored[1:]))

def filter_out(one_hot_encoded_labels, records, discard_index):
    one_hot_encoded_labels = [one_hot_encoded_labels[i] for i in range(len(one_hot_encoded_labels)) if i not in discard_index]
    records = [records[i] for i in range(len(records)) if i not in discard_index]

    return one_hot_encoded_labels, records

def load_records(record_file_list, adjust_classes_for_physionet, normal_class):
        
    if len(record_file_list) == 0:
        raise ValueError('No record files found.')

    num_recordings = len(record_file_list)

    records, all_labels = read_records(record_file_list)

    scored = None
    if adjust_classes_for_physionet:
        scored = read_scored_classes()

    unique_classes = get_unique_classes(all_labels, scored)

    if (normal_class is not None):
        unique_classes = ensure_normal_class(unique_classes, normal_class)
    
    one_hot_encoded_labels, discard_index = create_one_hot_labels(all_labels, unique_classes, num_recordings)

    if (adjust_classes_for_physionet):
        one_hot_encoded_labels, unique_classes = consolidate_equivalent_classes(one_hot_encoded_labels, unique_classes)
    
    if (normal_class is not None):
        one_hot_encoded_labels = set_labels_to_normal_if_none_other(one_hot_encoded_labels, unique_classes, normal_class)
    
    one_hot_encoded_labels, records = filter_out(one_hot_encoded_labels, records, discard_index)

    return one_hot_encoded_labels, records, unique_classes

def standardise_length(data, target_length):
    number_of_leads = data.shape[0]
    
    if len(data[0])<=target_length:
        ext= np.zeros([number_of_leads,target_length])
        for i in range(0,number_of_leads):
            ext[i][0:len(data[i])]=data[i]
        return ext.T  
    else:
        cut = np.zeros([number_of_leads,target_length])
        for i in range(number_of_leads):
            tocut = len(data[0])- target_length
            cut[i] = data[i][tocut:]
        return cut.T 

def resample(data, src_frq, trg_frq):

    if src_frq == trg_frq:
        return data

    N_src = data.shape[0]
    N_trg = int(N_src * trg_frq / src_frq)
    
    resampled = np.zeros((N_trg, data.shape[1]), dtype='float32')
    for i in range(data.shape[1]):
        resampled[:,i] = np.interp(np.linspace(0, N_src, N_trg), np.arange(N_src), data[:, i])
        
    return resampled

def standardise_data_samples(records):
    standardised_data = list()

    # find the most common fs
    fss = [record.fs for record in records]
    target_fs = max(set(fss), key=fss.count)

    # find the most common sig_len
    sig_lens = [record.sig_len for record in records]
    target_length = max(set(sig_lens), key=sig_lens.count)

    for i in range(len(records)):
        datum = records[i].p_signal.T
        datum = resample(datum,records[i].fs, target_fs)
        datum = standardise_length(datum, target_length)
        standardised_data.append(datum)

    return standardised_data

def load_data(input_directory, adjust_classes_for_physionet=False, normal_class=None):
    record_file_list = find_records(input_directory)
    
    one_hot_encoded_labels, records, classes = load_records(record_file_list, adjust_classes_for_physionet, normal_class=normal_class)
    samples = standardise_data_samples(records)
            
    one_hot_encoded_labels = np.stack(one_hot_encoded_labels, axis =0)
    samples = np.stack(samples, axis =0)

    return one_hot_encoded_labels, samples, classes

output_directory = 'data'
input_directory = 'training_data/cpsc_2018_subset'
    
one_hot_encoding_labels, samples, classes = load_data(input_directory)

# save the data to a file
np.save(os.path.join(output_directory, 'samples.npy'), samples)
np.save(os.path.join(output_directory, 'one_hot_encoding_labels.npy'), one_hot_encoding_labels)
np.save(os.path.join(output_directory, 'classes.npy'), classes)


In [8]:
from tensorflow.keras import layers
import keras_nlp as nlp 
import tensorflow.keras as keras


def train_and_evaluate_model(model, samples, one_hot_encoding_labels, callbacks=None, epochs=10, batch_size=64, classes=None):
    
    train_x, validation_x, train_y, validation_y = train_test_split(samples, one_hot_encoding_labels, test_size=0.2, random_state=42)
    history = model.fit(train_x, train_y, epochs=epochs, batch_size=batch_size, validation_data=(validation_x, validation_y), callbacks=callbacks)

    pd.DataFrame(history.history).plot(
        figsize=(8, 5), xlim=[0, epochs], ylim=[0, 1], grid=True, xlabel="Epoch",
        style=["r--", "r--.", "b-", "b-*"])
    plt.legend(loc="lower left")
    plt.show()

    y_pred = model.predict(validation_x)
    y_pred_classes = np.argmax(y_pred, axis=1)
    validation_y = np.argmax(validation_y, axis=1)

    # Confusion Matrix
    cm = confusion_matrix(validation_y, y_pred_classes)

    # Classification Report
    if classes is None:
        classes = ["Class " + str(i) for i in range(len(np.unique(validation_y)))]

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

def vgg_block(input, cnn_units):
    output = input
    output = layers.Conv1D(cnn_units, 3, padding='same', activation='relu')(output)
    output = layers.BatchNormalization()(output)
    output = layers.Dropout(0.1)(output)
    output = layers.Conv1D(cnn_units, 3, padding='same', activation='relu')(output)
    output = layers.BatchNormalization()(output)
    output = layers.Dropout(0.1)(output)
    output = layers.Conv1D(cnn_units, 24, padding='same', activation='relu', strides=2)(output)
    output = layers.BatchNormalization()(output)
    output = layers.MaxPooling1D(2, padding='same')(output)
    return output

def create_crt_baseline(number_of_leads=None, cnn_units=128, vgg_blocks=3, rnn_units=64,
                   transformer_encoders=4, att_dim=64, att_heads=8, fnn_units=64, num_classes=5):
    input = layers.Input(shape=(None, number_of_leads))
    output = input

    for _ in range(vgg_blocks):
        output = layers.BatchNormalization()(output)
        output = vgg_block(output, cnn_units)

    output = layers.BatchNormalization()(output)
    output = layers.Bidirectional(layers.GRU(rnn_units, return_sequences=True), merge_mode='sum')(output)
    output = layers.BatchNormalization()(output)

    if transformer_encoders > 0:
        output = output + nlp.layers.SinePositionEncoding(max_wavelength=10000)(output)

        for _ in range(transformer_encoders):
            output = layers.BatchNormalization()(output)
            output = nlp.layers.TransformerEncoder(att_dim, att_heads)(output)

        output = layers.GlobalAveragePooling1D()(output)

    output = layers.Dropout(0.2)(output)
    output = layers.Dense(fnn_units, activation='relu')(output)
    output = layers.BatchNormalization()(output)
    output = layers.Dropout(0.2)(output)
    output = layers.Dense(fnn_units, activation='relu')(output)
    output = layers.BatchNormalization()(output)
    output = layers.Dense(num_classes, activation='sigmoid')(output)
    model = keras.models.Model(input, output)
    return model


import tensorflow as tf

physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices):
    print(f'physical devices found: {physical_devices}')
    mem_growth = tf.config.experimental.get_memory_growth(physical_devices[0])
    print(f'memory growth of dev0: {mem_growth}')
    if not mem_growth:
        try:
            tf.config.experimental.set_memory_growth(physical_devices[0], True)
            print(f'memory growth of dev0: {tf.config.experimental.get_memory_growth(physical_devices[0])} (now enabled)')
        except:
            print(f'failed to modify device (likely already initialised)')
else:
    print('physical device not found')

physical device not found


In [None]:
import numpy as np
import os
from sklearn import preprocessing
from imblearn.over_sampling import RandomOverSampler
from tensorflow.keras import layers, models, callbacks, utils
import keras_nlp as nlp
from tensorflow.keras.layers import LayerNormalization, Dense, Dropout, Add, Conv1D, MaxPooling1D, BatchNormalization, ReLU, Input, MultiHeadAttention
import tensorflow as tf

# Define the VGG block
def vgg_block(input_tensor, filters):
    x = Conv1D(filters, 3, padding='same', activation='relu')(input_tensor)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv1D(filters, 3, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = MaxPooling1D(2, padding='same')(x)
    return x

# Define the Transformer encoder block
def transformer_encoder_block(inputs, att_dim, att_heads, dropout_rate, ff_dim):
    norm_input = LayerNormalization()(inputs)
    attention_output = MultiHeadAttention(num_heads=att_heads, key_dim=att_dim)(norm_input, norm_input)
    attention_output = Dropout(dropout_rate)(attention_output)
    attention_output = Add()([inputs, attention_output])
    attention_output = LayerNormalization()(attention_output)

    ff_output = Dense(ff_dim, activation='relu')(attention_output)
    ff_output = Dense(inputs.shape[-1])(ff_output)
    ff_output = Dropout(dropout_rate)(ff_output)
    ff_output = Add()([attention_output, ff_output])
    return LayerNormalization()(ff_output)

# Define the Bottleneck block
def bottleneck_block(x, in_channels, out_channels, kernel_size, stride, downsample, use_bn, use_do):
    identity = x

    x = Conv1D(filters=out_channels, kernel_size=kernel_size, strides=stride, padding='same')(x)
    if use_bn:
        x = BatchNormalization()(x)
    x = ReLU()(x)
    if use_do:
        x = Dropout(0.5)(x)

    x = Conv1D(filters=out_channels, kernel_size=kernel_size, strides=1, padding='same')(x)
    if use_bn:
        x = BatchNormalization()(x)
    x = ReLU()(x)
    if use_do:
        x = Dropout(0.5)(x)

    if downsample:
        identity = MaxPooling1D(pool_size=stride, padding='same')(identity)

    if out_channels != in_channels:
        identity = Conv1D(filters=out_channels, kernel_size=1, padding='same')(identity)

    x = layers.add([x, identity])
    return x

# Model creation function
def create_crtnet_bottleneck(number_of_leads=1,
                   cnn_units=128,
                   vgg_blocks=1,
                   rnn_units=64,
                   transformer_encoders=4,
                   att_dim=64,
                   att_heads=8,
                   ff_dim=64,
                   dropout_rate=0.1,
                   num_classes=5):
    input = Input(shape=(None, number_of_leads))
    x = input

    for _ in range(vgg_blocks):
        x = vgg_block(x, cnn_units)

    # Add bottleneck blocks
    x = bottleneck_block(x, in_channels=cnn_units, out_channels=cnn_units, kernel_size=3, stride=2, downsample=True, use_bn=True, use_do=True)

    x = layers.Bidirectional(layers.GRU(rnn_units, return_sequences=True), merge_mode='sum')(x)

    for _ in range(transformer_encoders):
        x = transformer_encoder_block(x, att_dim, att_heads, dropout_rate, ff_dim)

    x = layers.GlobalAveragePooling1D()(x)
    x = Dropout(0.2)(x)
    x = Dense(ff_dim, activation='relu')(x)
    x = Dense(ff_dim, activation='relu')(x)
    x = Dense(num_classes, activation='sigmoid')(x)

    model = models.Model(inputs=input, outputs=x)
    return model

In [10]:
import tensorflow.keras as keras
import os
import tensorflow as tf
import datetime
import numpy as np
    
stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=0.00001)
    
current_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

# load the data from the file
samples = np.load(os.path.join("data", 'samples.npy'))
one_hot_encoding_labels = np.load(os.path.join("data", 'one_hot_encoding_labels.npy'))
classes = np.load(os.path.join("data", 'classes.npy'))

model = create_crt_baseline(number_of_leads=samples.shape[2], num_classes=one_hot_encoding_labels.shape[1])

initial_learning_rate = 0.0001
optimizer = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

train_and_evaluate_model(model, samples=samples, one_hot_encoding_labels=one_hot_encoding_labels,callbacks=[reduce_lr, stopping], epochs=1, batch_size=64, classes=classes)



KeyboardInterrupt: 