In [20]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split

import os
import numpy as np
import scipy.io as sio

def read_all_labels(header_files):
    tmp_labels = []
    for header_file in header_files:
        with open(header_file, 'r') as f:
            for line in f:
                if line.startswith('#Dx') or line.startswith('# Dx'):
                    dxs = set(arr.strip() for arr in line.split(': ')[1].split(','))
                    tmp_labels.append(dxs)
    return tmp_labels

def create_one_hot_labels(all_labels, unique_classes, num_recordings):
    labels = np.zeros((num_recordings, len(unique_classes)))
    index = []
    for i, dxs in enumerate(all_labels):
        flag = np.zeros((1, len(dxs)), dtype=bool)
        for j, dx in enumerate(dxs):
            if dx in unique_classes:
                labels[i, unique_classes.index(dx)] = 1
                flag[0, j] = True
        if not np.any(flag):
            index.append(i)
    return labels, index

def load_mat_files(data_list):
    return [sio.loadmat(matF)['val'] for matF in data_list]

def get_unique_classes(all_labels, valid_classes=None):

    classes2 = list()
    for i in range(len(all_labels)):
        dxs = all_labels[i]
        for dx in dxs:
            if valid_classes is None or dx in valid_classes:
                classes2.append(dx)

    classes3 = list()
    for x in classes2:
        if x not in classes3:
            classes3.append(x)

    classes3 = sorted (classes3)
    return classes3

def find_data_files(data_directory):
    header_files = []
    data_files = []
    for dirpath, _, filenames in os.walk(data_directory):
        for f in sorted(filenames):
            file_path = os.path.join(dirpath, f)
            if os.path.isfile(file_path) and not f.lower().startswith('.'):
                if f.lower().endswith('.hea'):
                    header_files.append(file_path)
                elif f.lower().endswith('.mat'):
                    data_files.append(file_path)
    if header_files and data_files:
        return header_files, data_files
    else:
        raise IOError('No label or data files found.')

def filter(data, labels, index):
    labels = [labels[i] for i in range(len(labels)) if i not in index]
    data = [data[i] for i in range(len(data)) if i not in index]
    return labels, data

def consolidate_equivalent_classes(labels, unique_classes):
    equivalent_classes_collection = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]

    # For each set of equivalent class, use only one class as the representative class for the set and discard the other classes in the set.
    # The label for the representative class is positive if any of the labels in the set is positive.
    remove_classes = list()
    remove_indices = list()
    for equivalent_classes in equivalent_classes_collection:
        equivalent_classes = [x for x in equivalent_classes if x in unique_classes]
        if len(equivalent_classes)>1:
            other_classes = equivalent_classes[1:]
            equivalent_indices = [unique_classes.index(x) for x in equivalent_classes]
            representative_index = equivalent_indices[0]
            other_indices = equivalent_indices[1:]

            labels[:, representative_index] = np.any(labels[:, equivalent_indices], axis=1)
            remove_classes += other_classes
            remove_indices += other_indices

    for x in remove_classes:
        unique_classes.remove(x)
    labels = np.delete(labels, remove_indices, axis=1)

    return labels, unique_classes

def set_labels_to_normal_if_none_other(labels, unique_classes, normal_class):
    # If the labels are negative for all classes, then change the label for the normal class to positive.
    normal_index = unique_classes.index(normal_class)
    for i in range(len(labels)):
        num_positive_classes = np.sum(labels[i, :])
        if num_positive_classes==0:
            labels[i, normal_index] = 1

    return labels

def ensure_normal_class(unique_classes, normal_class):
    if normal_class not in unique_classes:
        unique_classes.add(normal_class)
        print('- The normal class {} is not one of the label classes, so it has been automatically added, but please check that you chose the correct normal class.'.format(normal_class))
    unique_classes = sorted(unique_classes)
    return unique_classes

def read_scored_classes():
    scored = list()
    with open('dx_mapping_scored.csv', 'r') as f:
        for l in f:
            dxs = (l.split(','))
            scored.append(dxs[1])
    return (sorted(scored[1:]))

def filter(labels, data, keep_index):
    labels = [labels[i] for i in range(len(labels)) if i not in keep_index]
    data = [data[i] for i in range(len(data)) if i not in keep_index]

    return labels, data

def load_data_files(header_file_list, data_file_list, adjust_classes_for_physionet):
        
    normal_class = '426783006'

    num_recordings = len(header_file_list)
    data = load_mat_files(data_file_list)

    all_labels = read_all_labels(header_file_list)

    scored = None
    if adjust_classes_for_physionet:
        scored = read_scored_classes()

    unique_classes = get_unique_classes(all_labels, scored)
    unique_classes = ensure_normal_class(unique_classes, normal_class)
    labels, keep_index = create_one_hot_labels(all_labels, unique_classes, num_recordings)

    if (adjust_classes_for_physionet):
        labels, unique_classes = consolidate_equivalent_classes(labels, unique_classes)
    
    labels = set_labels_to_normal_if_none_other(labels, unique_classes, normal_class)
    labels, data = filter(labels, data, keep_index)

    return labels, data

def standardiseLength(data, length):
    if len(data[0])<=length:
        ext= np.zeros([12,length])
        for i in range(0,12):
            ext[i][0:len(data[i])]=data[i]
        return ext.T  
    else:
        cut = np.zeros([12,length])
        for i in range(12):
            tocut = len(data[0])- length
            cut[i] = data[i][tocut:]
        return cut.T 

def load_data(input_directory, adjust_classes_for_physionet):
    
    header_file_list, data_file_list = find_data_files(input_directory)
    labels, data = load_data_files(header_file_list, data_file_list, adjust_classes_for_physionet)
    
    standardised_data = list()

    for i in range(len(data)):
        standardised_data.append(standardiseLength(data[i],5120))
        
    labels = np.stack(labels, axis =0)
    standardised_data = np.stack(standardised_data, axis =0)

    train_x, test_x, train_y, test_y = train_test_split(standardised_data, labels, test_size=0.1, random_state=42)
    train_x, dev_x, train_y, dev_y = train_test_split(train_x, train_y, test_size=0.1, random_state=42)
    
    return train_x,train_y,dev_x,dev_y, test_x, test_y
    

In [2]:
from tensorflow.keras import layers
import keras_nlp as nlp 
import tensorflow.keras as keras

def vgg_block(input, cnn_units):
    output = input
    output = layers.Conv1D(cnn_units, 3, padding='same', activation='relu')(output)
    output = layers.Conv1D(cnn_units, 3, padding='same', activation='relu')(output)
    output = layers.MaxPooling1D(2, padding='same')(output)
    return output

def crt_net(
        cnn_units=128,
        vgg_blocks=1,
        rnn_units=64,
        transformer_encoders=4,
        att_dim=64,
        att_heads=8,
        fnn_units=64,
        num_classes=6
    ):
    input = layers.Input(shape=(5120, 12))
    output = input

    for _ in range(vgg_blocks):
        output = vgg_block(output, cnn_units)

    output = layers.Bidirectional(layers.GRU(rnn_units, return_sequences=True), merge_mode='sum')(output)

    if transformer_encoders > 0:
        output = output + nlp.layers.SinePositionEncoding(max_wavelength=10000)(output)

        for _ in range(transformer_encoders):
            output = nlp.layers.TransformerEncoder(att_dim, att_heads)(output)

        output = layers.GlobalAveragePooling1D()(output)
        
    output = layers.Dropout(0.2)(output)
    output = layers.Dense(fnn_units, activation='relu')(output)
    output = layers.Dense(fnn_units, activation='relu')(output)

    output = layers.Dense(num_classes, activation='sigmoid')(output)
    return keras.Model(input, output)


Using TensorFlow backend


In [22]:
input_directory = 'training_data'
output_directory = 'output'

MAX_EPOCHS = 50
    
stopping = keras.callbacks.EarlyStopping(patience=5)

reduce_lr = keras.callbacks.ReduceLROnPlateau(
    factor=0.1,
    patience=3,
    min_lr= 0.001*0.001)

filepath = os.path.join(output_directory, "model.h5")

checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)
        
train_x, train_y, dev_x, dev_y, test_x, test_y = load_data(input_directory, adjust_classes_for_physionet=False)
        
model = crt_net(num_classes=train_y.shape[1])  
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
    
model.fit(
    train_x, train_y,
    batch_size=12,
    epochs=MAX_EPOCHS,
    validation_data=(dev_x, dev_y),
    callbacks= [checkpoint, reduce_lr,stopping])

Epoch 1/50


KeyboardInterrupt: 