In [1]:
import os
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split

def find_data_files(data_directory):
        header_files = list()
        data_files  = list()

        for dirpath, dirnames, filenames in os.walk(data_directory):
            # Sort files for processing in order
            for f in sorted(filenames):
                F = os.path.join(dirpath, f)  # Full path for the file

                if os.path.isfile(F) and F.lower().endswith('.hea') and not f.lower().startswith('.'):
                    header_files.append(F)

                if os.path.isfile(F) and F.lower().endswith('.mat') and not f.lower().startswith('.'):
                    data_files.append(F)

        if header_files:# and output_files:
            return header_files, data_files#, output_files
        else:
            raise IOError('No label or output files found.')

def load_data_crtnet(header_files, data_list):
    
    num_recordings = len(header_files)

    def load_matFile(matF):
        dict = sio.loadmat(matF)
        data = dict['val']
        return data

    data = list()

    for i in range(num_recordings):
        data.append(load_matFile(data_list[i]))

    # Load diagnoses.
    tmp_labels = list()
    for i in range(num_recordings):
        with open(header_files[i], 'r') as f:
            for l in f:
                if l.startswith('#Dx') or l.startswith('# Dx'):
                    dxs = set(arr.strip() for arr in l.split(': ')[1].split(','))
                    tmp_labels.append(dxs)

    # Identify classes.
    classes = set.union(*map(set, tmp_labels))

    all_classes = list()
    for i in range(num_recordings):
        dxs = tmp_labels[i]
        for dx in dxs:
            all_classes.append(dx)

    classes = list()
    for x in all_classes:
        if x not in classes:
            classes.append(x)

    classes = sorted (classes)

    index = list()
    
    # Use one-hot encoding for labels.
    labels = np.zeros((num_recordings, len(classes)))#, dtype=np.bool)
    for i in range(num_recordings):
        dxs = tmp_labels[i]
        flag = np.zeros((1,len(dxs)), dtype = bool)
        count = 0
        for dx in dxs:
            if dx in classes:
                j = classes.index(dx)
                labels[i, j] = 1
                flag [0 ,count] = True

            count += 1

        if np.any(flag) == False:
            index.append(i)
    
    return labels, data, index

def load_data_physionet(header_files, data_list):
    # The labels should have the following form:
    #
    # Dx: label_1, label_2, label_3
    #
    scored = list()
    with open('dx_mapping_scored.csv', 'r') as f:
        for l in f:
            dxs = (l.split(','))
            scored.append(dxs[1])
    scored = (sorted(scored[1:]))
    
    normal_class = '426783006'
    equivalent_classes_collection = [['713427006', '59118001'], ['284470004', '63593006'], ['427172004', '17338001']]

    num_recordings = len(header_files)

    def load_matFile(matF):
        dict = sio.loadmat(matF)
        data = dict['val']
        return data

    data = list()

    for i in range(num_recordings):
        data.append(load_matFile(data_list[i]))

    # Load diagnoses.
    tmp_labels = list()
    for i in range(num_recordings):
        with open(header_files[i], 'r') as f:
            for l in f:
                if l.startswith('#Dx') or l.startswith('# Dx'):
                    dxs = set(arr.strip() for arr in l.split(': ')[1].split(','))
                    tmp_labels.append(dxs)

    # Identify classes.
    classes = set.union(*map(set, tmp_labels))

    if normal_class not in classes:
        classes.add(normal_class)
        print('- The normal class {} is not one of the label classes, so it has been automatically added, but please check that you chose the correct normal class.'.format(normal_class))
    classes = sorted(classes)

    classes2 = list()
    for i in range(num_recordings):
        dxs = tmp_labels[i]
        for dx in dxs:
            if dx in scored:
                classes2.append(dx)

    classes3 = list()
    for x in classes2:
        if x not in classes3:
            classes3.append(x)

    classes3 = sorted (classes3)

    index = list()
    # Use one-hot encoding for labels.
    labels = np.zeros((num_recordings, len(classes3)))#, dtype=np.bool)
    for i in range(num_recordings):
        dxs = tmp_labels[i]
        flag = np.zeros((1,len(dxs)), dtype = bool)
        count = 0
        for dx in dxs:
            if dx in classes3:
                j = classes3.index(dx)
                labels[i, j] = 1
                flag [0 ,count] = True

            count += 1

        if np.any(flag) == False:
            index.append(i)

    # For each set of equivalent class, use only one class as the representative class for the set and discard the other classes in the set.
    # The label for the representative class is positive if any of the labels in the set is positive.
    remove_classes = list()
    remove_indices = list()
    for equivalent_classes in equivalent_classes_collection:
        equivalent_classes = [x for x in equivalent_classes if x in classes3]
        if len(equivalent_classes)>1:
            other_classes = equivalent_classes[1:]
            equivalent_indices = [classes3.index(x) for x in equivalent_classes]
            representative_index = equivalent_indices[0]
            other_indices = equivalent_indices[1:]

            labels[:, representative_index] = np.any(labels[:, equivalent_indices], axis=1)
            remove_classes += other_classes
            remove_indices += other_indices

    for x in remove_classes:
        classes3.remove(x)
    labels = np.delete(labels, remove_indices, axis=1)

    # If the labels are negative for all classes, then change the label for the normal class to positive.
    normal_index = classes3.index(normal_class)
    for i in range(num_recordings):
        num_positive_classes = np.sum(labels[i, :])
        if num_positive_classes==0:
            labels[i, normal_index] = 1

    return labels, data, index

def load_data(input_directory, adjust_classes_for_physionet):
    
    lb, data = find_data_files(input_directory)
    
    if adjust_classes_for_physionet:    
        labels, data, ind = load_data_physionet(lb, data)
    else:
        labels, data, ind = load_data_crtnet(lb, data)
    
    labels = [labels[i] for i in range(len(labels)) if i not in ind]
    data = [data[i] for i in range(len(data)) if i not in ind]

    def standardiseLength(data, length):

        if len(data[0])<=length:
            ext= np.zeros([12,length])
            for i in range(0,12):
                ext[i][0:len(data[i])]=data[i]
            return ext.T  
        else:
            cut = np.zeros([12,length])
            for i in range(12):
                tocut = len(data[0])- length
                cut[i] = data[i][tocut:]
            return cut.T    

    standardised_data = list()    

    for i in range(len(data)):
        standardised_data.append(standardiseLength(data[i],5120))
        
    labels = np.stack(labels, axis =0)
    standardised_data = np.stack(standardised_data, axis =0)

    train_x, test_x, train_y, test_y = train_test_split(standardised_data, labels, test_size=0.1, random_state=42)
    train_x, dev_x, train_y, dev_y = train_test_split(train_x, train_y, test_size=0.1, random_state=42)
    
    return train_x,train_y,dev_x,dev_y, test_x, test_y
    

In [2]:
from tensorflow.keras import layers
import keras_nlp as nlp 
import tensorflow.keras as keras

def vgg_block(input, cnn_units):
    output = input
    output = layers.Conv1D(cnn_units, 3, padding='same', activation='relu')(output)
    output = layers.Conv1D(cnn_units, 3, padding='same', activation='relu')(output)
    output = layers.MaxPooling1D(2, padding='same')(output)
    return output

def crt_net(
        cnn_units=128,
        vgg_blocks=1,
        rnn_units=64,
        transformer_encoders=4,
        att_dim=64,
        att_heads=8,
        fnn_units=64,
        num_classes=6
    ):
    input = layers.Input(shape=(5120, 12))
    output = input

    for _ in range(vgg_blocks):
        output = vgg_block(output, cnn_units)

    output = layers.Bidirectional(layers.GRU(rnn_units, return_sequences=True), merge_mode='sum')(output)

    if transformer_encoders > 0:
        output = output + nlp.layers.SinePositionEncoding(max_wavelength=10000)(output)

        for _ in range(transformer_encoders):
            output = nlp.layers.TransformerEncoder(att_dim, att_heads)(output)

        output = layers.GlobalAveragePooling1D()(output)
        
    output = layers.Dropout(0.2)(output)
    output = layers.Dense(fnn_units, activation='relu')(output)
    output = layers.Dense(fnn_units, activation='relu')(output)

    output = layers.Dense(num_classes, activation='sigmoid')(output)
    return keras.Model(input, output)


Using TensorFlow backend


In [3]:
input_directory = 'training_data/cpsc_2018'
output_directory = 'output'

MAX_EPOCHS = 50
    
stopping = keras.callbacks.EarlyStopping(patience=5)

reduce_lr = keras.callbacks.ReduceLROnPlateau(
    factor=0.1,
    patience=3,
    min_lr= 0.001*0.001)

filepath = os.path.join(output_directory, "model.h5")

checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)
        
train_x, train_y, dev_x, dev_y, test_x, test_y = load_data(input_directory, adjust_classes_for_physionet=False)
        
model = crt_net(num_classes=train_y.shape[1])  
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
    
model.fit(
    train_x, train_y,
    batch_size=12,
    epochs=MAX_EPOCHS,
    validation_data=(dev_x, dev_y),
    callbacks= [checkpoint, reduce_lr,stopping])

Epoch 1/50


2024-05-03 17:23:36.848981: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


  1/357 [..............................] - ETA: 1:56:04 - loss: 0.6954 - accuracy: 0.3333

KeyboardInterrupt: 