In [48]:
import wfdb
import os
import numpy as np

def list_records(base_dir):
    return [f.split('.')[0] for f in os.listdir(base_dir) if f.endswith('.dat')]

def load_all_data(data_dir, noise_dir):
    data_records = list_records(data_dir)
    noise_records = list_records(noise_dir)

    all_noisy_signals = []
    all_clean_signals = []

    signal_length = None

    for record_name in data_records:
        noise_record_name = next((nr for nr in noise_records if nr.startswith(record_name[:3])), None)
        if not noise_record_name:
            continue

        record_path = os.path.join(data_dir, record_name)
        noise_record_path = os.path.join(noise_dir, noise_record_name)

        record = wfdb.rdrecord(record_path)
        noise_record = wfdb.rdrecord(noise_record_path)

        signal = record.p_signal
        noise_signal = noise_record.p_signal[:len(signal)]

        noisy_signal = signal + noise_signal

        if signal_length is None:
            signal_length = len(signal)
        elif signal_length != len(signal):
            raise ValueError("Signal lengths are not consistent across the dataset.")

        all_noisy_signals.append(noisy_signal)
        all_clean_signals.append(signal)

    return np.array(all_noisy_signals), np.array(all_clean_signals), signal_length

data_dir = 'M:\Dissertation\mit-bih-arrhythmia-database-1.0.0'
noise_dir = 'M:\Dissertation\mit-bih-noise-stress-test-database-1.0.0'

noisy_signals, clean_signals, signal_length = load_all_data(data_dir, noise_dir)

# Check the determined signal length
print(f"Signal length: {signal_length}")


Signal length: 650000


In [49]:
import tensorflow as tf
def build_generator(input_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)

    # Encoder
    e1 = tf.keras.layers.Conv1D(512, 1, padding='same', activation='linear')(inputs)
    e1 = tf.keras.layers.PReLU()(e1)
    e2 = tf.keras.layers.Conv1D(512, 16, padding='same', activation='linear')(e1)
    e2 = tf.keras.layers.PReLU()(e2)
    e3 = tf.keras.layers.Conv1D(512, 32, padding='same', activation='linear')(e2)
    e3 = tf.keras.layers.PReLU()(e3)
    e4 = tf.keras.layers.Conv1D(512, 64, padding='same', activation='linear')(e3)
    e4 = tf.keras.layers.PReLU()(e4)
    e5 = tf.keras.layers.Conv1D(256, 128, padding='same', activation='linear')(e4)
    e5 = tf.keras.layers.PReLU()(e5)
    e6 = tf.keras.layers.Conv1D(128, 256, padding='same', activation='linear')(e5)
    e6 = tf.keras.layers.PReLU()(e6)
    e7 = tf.keras.layers.Conv1D(64, 512, padding='same', activation='linear')(e6)
    e7 = tf.keras.layers.PReLU()(e7)
    e8 = tf.keras.layers.Conv1D(32, 1024, padding='same', activation='linear')(e7)
    e8 = tf.keras.layers.PReLU()(e8)

    # Decoder with skip connections
    d7 = tf.keras.layers.Conv1DTranspose(64, 512, padding='same', activation='linear')(e8)
    d7 = tf.keras.layers.PReLU()(d7)
    d7 = tf.keras.layers.Concatenate()([d7, e7])
    d6 = tf.keras.layers.Conv1DTranspose(128, 256, padding='same', activation='linear')(d7)
    d6 = tf.keras.layers.PReLU()(d6)
    d6 = tf.keras.layers.Concatenate()([d6, e6])
    d5 = tf.keras.layers.Conv1DTranspose(256, 128, padding='same', activation='linear')(d6)
    d5 = tf.keras.layers.PReLU()(d5)
    d5 = tf.keras.layers.Concatenate()([d5, e5])
    d4 = tf.keras.layers.Conv1DTranspose(512, 64, padding='same', activation='linear')(d5)
    d4 = tf.keras.layers.PReLU()(d4)
    d4 = tf.keras.layers.Concatenate()([d4, e4])
    d3 = tf.keras.layers.Conv1DTranspose(512, 32, padding='same', activation='linear')(d4)
    d3 = tf.keras.layers.PReLU()(d3)
    d3 = tf.keras.layers.Concatenate()([d3, e3])
    d2 = tf.keras.layers.Conv1DTranspose(512, 16, padding='same', activation='linear')(d3)
    d2 = tf.keras.layers.PReLU()(d2)
    d2 = tf.keras.layers.Concatenate()([d2, e2])
    d1 = tf.keras.layers.Conv1DTranspose(512, 1, padding='same', activation='linear')(d2)
    d1 = tf.keras.layers.PReLU()(d1)
    d1 = tf.keras.layers.Concatenate()([d1, e1])

    outputs = tf.keras.layers.Conv1DTranspose(1, 1, padding='same', activation='linear')(d1)

    model = tf.keras.Model(inputs, outputs, name='generator')
    return model

# Use the exact length determined from the dataset

input_shape = (signal_length, 1)
generator = build_generator(input_shape)
generator.summary()


ResourceExhaustedError: {{function_node __wrapped__Fill_device_/job:localhost/replica:0/task:0/device:CPU:0}} OOM when allocating tensor with shape[650000,512] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator mklcpu [Op:Fill] name: 

In [None]:

def build_discriminator(input_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)

    d = tf.keras.layers.Conv1D(64, 5, strides=2, padding='same')(inputs)
    d = tf.keras.layers.BatchNormalization()(d)
    d = tf.keras.layers.LeakyReLU(alpha=0.2)(d)
    d = tf.keras.layers.Conv1D(128, 5, strides=2, padding='same')(d)
    d = tf.keras.layers.BatchNormalization()(d)
    d = tf.keras.layers.LeakyReLU(alpha=0.2)(d)
    d = tf.keras.layers.Conv1D(256, 5, strides=2, padding='same')(d)
    d = tf.keras.layers.BatchNormalization()(d)
    d = tf.keras.layers.LeakyReLU(alpha=0.2)(d)
    d = tf.keras.layers.Conv1D(512, 5, strides=2, padding='same')(d)
    d = tf.keras.layers.BatchNormalization()(d)
    d = tf.keras.layers.LeakyReLU(alpha=0.2)(d)

    d = tf.keras.layers.Flatten()(d)
    d = tf.keras.layers.Dense(1, activation='sigmoid')(d)

    model = tf.keras.Model(inputs, d, name='discriminator')
    return model

# The discriminator input should match the combined shape of noisy and clean signals
discriminator_input_shape = (signal_length, 2)
discriminator = build_discriminator(discriminator_input_shape)
discriminator.summary()

In [None]:
noisy_signals = noisy_signals.reshape(-1, signal_length, 1)
clean_signals = clean_signals.reshape(-1, signal_length, 1)


In [None]:
def create_dataset(noisy_signals, clean_signals, batch_size=32):
    dataset = tf.data.Dataset.from_tensor_slices((noisy_signals, clean_signals))
    dataset = dataset.shuffle(buffer_size=1024).batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset

batch_size = 32
dataset = create_dataset(noisy_signals, clean_signals, batch_size)


In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

def generator_loss(fake_output, real_signals, generated_signals):
    gen_loss = tf.reduce_mean(tf.square(generated_signals - real_signals))
    return gen_loss

def discriminator_loss(real_output, fake_output):
    real_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(real_output), real_output, from_logits=True)
    fake_loss = tf.keras.losses.binary_crossentropy(tf.zeros_like(fake_output), fake_output, from_logits=True)
    disc_loss = real_loss + fake_loss
    return disc_loss

# Training step function with GradientTape
@tf.function
def train_step(noisy_signals, real_signals):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_signals = generator(noisy_signals, training=True)
        generated_signals = tf.cast(generated_signals, tf.float32)

        # Ensure noisy_signals is also float32
        noisy_signals = tf.cast(noisy_signals, tf.float32)

        real_concat = tf.concat([real_signals, noisy_signals], axis=2)
        fake_concat = tf.concat([generated_signals, noisy_signals], axis=2)

        real_output = discriminator(real_concat, training=True)
        fake_output = discriminator(fake_concat, training=True)

        gen_loss = generator_loss(fake_output, real_signals, generated_signals)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss


In [None]:
# Training function
def train(dataset, epochs):
    for epoch in range(epochs):
        epoch_gen_loss_avg = tf.keras.metrics.Mean()
        epoch_disc_loss_avg = tf.keras.metrics.Mean()

        for noisy_signals_batch, real_signals_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_signals_batch, real_signals_batch)
            epoch_gen_loss_avg.update_state(gen_loss)
            epoch_disc_loss_avg.update_state(disc_loss)

        print(f'Epoch {epoch + 1}/{epochs}, Generator Loss: {epoch_gen_loss_avg.result()}, Discriminator Loss: {epoch_disc_loss_avg.result()}')

# Number of epochs for training
epochs = 5
train(dataset, epochs)