In [2]:
data_dir = 'C:\\Users\\malik\\Desktop\\Disertation\\mit-bih-arrhythmia-database-1.0.0'
noise_dir = 'C:\\Users\\malik\\Desktop\\Disertation\\mit-bih-noise-stress-test-database-1.0.0'

Signal length: 650000


In [1]:
# Define the extraction paths
arrhythmia_extract_path = 'C:\\Users\\malik\\Desktop\\Disertation\\mit-bih-arrhythmia-database-1.0.0'
noise_extract_path = 'C:\\Users\\malik\\Desktop\\Disertation\\mit-bih-noise-stress-test-database-1.0.0'

In [12]:
import os
import wfdb
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Paths to the extracted datasets
arrhythmia_extract_path = 'C:\\Users\\malik\\Desktop\\Disertation\\mit-bih-arrhythmia-database-1.0.0'
noise_extract_path = 'C:\\Users\\malik\\Desktop\\Disertation\\mit-bih-noise-stress-test-database-1.0.0'

print("Datasets extracted successfully!")

# Function to load ECG signals from the MIT-BIH Arrhythmia Database
def load_arrhythmia_signals(path, record_list):
    signals = []
    for record in record_list:
        record_path = os.path.join(path, record)
        signal, _ = wfdb.rdsamp(record_path)
        signals.append(signal[:, 0])  # Assuming single lead ECG (Lead II)
    return np.array(signals)

# Function to load noise signals from the MIT-BIH Noise Stress Test Database
def load_noise_signals(path, noise_types):
    noises = []
    for noise in noise_types:
        record_path = os.path.join(path, noise)
        signal, _ = wfdb.rdsamp(record_path)
        noises.append(signal[:, 0])  # Assuming single channel noise
    return noises

# Define the record list for the Arrhythmia Database
arrhythmia_records = ['103', '105', '111', '116', '122', '205', '213', '219', '223', '230']
# Define the noise types for the Noise Stress Test Database
noise_types = ['bw', 'em', 'ma']  # Baseline Wander, Electrode Motion, Muscle Artifact

# Load the signals
arrhythmia_signals = load_arrhythmia_signals(arrhythmia_extract_path, arrhythmia_records)
noise_signals = load_noise_signals(noise_extract_path, noise_types)

print("ECG and noise signals loaded successfully!")

# Prepare noisy ECG signals by adding noise to the clean ECG signals
def add_noise_to_signals(clean_signals, noise_signals, snr_dB):
    noisy_signals = []
    for clean_signal in clean_signals:
        noise = noise_signals[np.random.choice(len(noise_signals))]
        noise = noise[:len(clean_signal)]  # Ensure the noise length matches the signal length
        noise_power = np.mean(noise ** 2)
        signal_power = np.mean(clean_signal ** 2)
        scale_factor = np.sqrt(signal_power / (noise_power * 10**(snr_dB / 10)))
        noisy_signal = clean_signal + scale_factor * noise
        noisy_signals.append(noisy_signal)
    return np.array(noisy_signals)

# Example SNR values
snr_values = [0, 1, 2, 3, 4]

# Generate noisy signals for each SNR value
noisy_signals = []
for snr in snr_values:
    noisy_signals.extend(add_noise_to_signals(arrhythmia_signals, noise_signals, snr))

noisy_signals = np.array(noisy_signals)

print("Noisy ECG signals prepared successfully!")

# Custom Dataset
class ECGDataset(Dataset):
    def __init__(self, noisy_signals, clean_signals):
        self.noisy_signals = noisy_signals
        self.clean_signals = clean_signals

    def __len__(self):
        return len(self.noisy_signals)

    def __getitem__(self, idx):
        noisy_signal = self.noisy_signals[idx]
        clean_signal = self.clean_signals[idx % len(self.clean_signals)]
        return torch.tensor(noisy_signal, dtype=torch.float32), torch.tensor(clean_signal, dtype=torch.float32)

# Generator Model: Convolutional Auto-Encoder with Skip Connections
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder_conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=2, padding=2)    # 512 × 1 -> 512 × 16
        self.encoder_conv2 = nn.Conv1d(16, 32, kernel_size=5, stride=2, padding=2)   # 512 × 16 -> 512 × 32
        self.encoder_conv3 = nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=2)   # 512 × 32 -> 512 × 64
        self.encoder_conv4 = nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2)  # 512 × 64 -> 256 × 128
        self.encoder_conv5 = nn.Conv1d(128, 256, kernel_size=5, stride=2, padding=2) # 256 × 128 -> 128 × 256
        self.encoder_conv6 = nn.Conv1d(256, 512, kernel_size=5, stride=2, padding=2) # 128 × 256 -> 64 × 512
        self.encoder_conv7 = nn.Conv1d(512, 1024, kernel_size=5, stride=2, padding=2)# 64 × 512 -> 32 × 1024
        self.encoder_conv8 = nn.Conv1d(1024, 2048, kernel_size=5, stride=2, padding=2)# 32 × 1024 -> 16 × 2048
        
        self.decoder_conv1 = nn.ConvTranspose1d(2048, 1024, kernel_size=5, stride=2, padding=2, output_padding=1) # 16 × 2048 -> 32 × 1024
        self.decoder_conv2 = nn.ConvTranspose1d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1)  # 32 × 1024 -> 64 × 512
        self.decoder_conv3 = nn.ConvTranspose1d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1)   # 64 × 512 -> 128 × 256
        self.decoder_conv4 = nn.ConvTranspose1d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1)   # 128 × 256 -> 256 × 128
        self.decoder_conv5 = nn.ConvTranspose1d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1)    # 256 × 128 -> 512 × 64
        self.decoder_conv6 = nn.ConvTranspose1d(64, 32, kernel_size=5, stride=2, padding=2, output_padding=1)     # 512 × 64 -> 512 × 32
        self.decoder_conv7 = nn.ConvTranspose1d(32, 16, kernel_size=5, stride=2, padding=2, output_padding=1)     # 512 × 32 -> 512 × 16
        self.decoder_conv8 = nn.ConvTranspose1d(16, 1, kernel_size=5, stride=2, padding=2, output_padding=1)      # 512 × 16 -> 512 × 1
        
        self.prelu = nn.PReLU()

    def forward(self, x):
        # Encoder with skip connections
        e1 = self.prelu(self.encoder_conv1(x))
        e2 = self.prelu(self.encoder_conv2(e1))
        e3 = self.prelu(self.encoder_conv3(e2))
        e4 = self.prelu(self.encoder_conv4(e3))
        e5 = self.prelu(self.encoder_conv5(e4))
        e6 = self.prelu(self.encoder_conv6(e5))
        e7 = self.prelu(self.encoder_conv7(e6))
        e8 = self.prelu(self.encoder_conv8(e7))
        
        # Decoder with skip connections
        d1 = self.prelu(self.decoder_conv1(e8)) + e7
        d2 = self.prelu(self.decoder_conv2(d1)) + e6
        d3 = self.prelu(self.decoder_conv3(d2)) + e5
        d4 = self.prelu(self.decoder_conv4(d3)) + e4
        d5 = self.prelu(self.decoder_conv5(d4)) + e3
        d6 = self.prelu(self.decoder_conv6(d5)) + e2
        d7 = self.prelu(self.decoder_conv7(d6)) + e1
        d8 = self.prelu(self.decoder_conv8(d7))
        
        return d8

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv1d(2, 64, kernel_size=5, stride=2, padding=2)   # 512 -> 256
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2) # 256 -> 128
        self.conv3 = nn.Conv1d(128, 256, kernel_size=5, stride=2, padding=2)# 128 -> 64
        self.conv4 = nn.Conv1d(256, 512, kernel_size=5, stride=2, padding=2)# 64 -> 32
        self.conv5 = nn.Conv1d(512, 1024, kernel_size=5, stride=2, padding=2)# 32 -> 16
        self.conv6 = nn.Conv1d(1024, 2048, kernel_size=5, stride=2, padding=2)# 16 -> 8
        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(2048 * 8, 1)  # Adjust the linear layer input size to match flattened output size
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.leakyrelu(self.conv1(x))
        x = self.leakyrelu(self.conv2(x))
        x = self.leakyrelu(self.conv3(x))
        x = self.leakyrelu(self.conv4(x))
        x = self.leakyrelu(self.conv5(x))
        x = self.leakyrelu(self.conv6(x))
        x = self.flatten(x)
        x = self.linear(x)
        x = self.sigmoid(x)
        return x

# Loss Functions
def generator_loss(disc_output, gen_output, clean_signal, lambda_dist=0.7, lambda_max=0.2):
    adv_loss = nn.BCELoss()(disc_output, torch.ones_like(disc_output))
    dist_loss = torch.mean(torch.sqrt(torch.sum((gen_output - clean_signal) ** 2, dim=1)))
    max_loss = torch.mean(torch.max(torch.abs(gen_output - clean_signal), dim=1)[0])
    return adv_loss + lambda_dist * dist_loss + lambda_max * max_loss

def discriminator_loss(disc_real_output, disc_fake_output):
    real_loss = nn.BCELoss()(disc_real_output, torch.ones_like(disc_real_output))
    fake_loss = nn.BCELoss()(disc_fake_output, torch.zeros_like(disc_fake_output))
    return real_loss + fake_loss

# Training Function
def train(generator, discriminator, dataloader, num_epochs=100, lr=0.0001):
    optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
    optimizer_D = optim.RMSprop(discriminator.parameters(), lr=lr)

    for epoch in range(num_epochs):
        for noisy_signal, clean_signal in dataloader:
            noisy_signal, clean_signal = noisy_signal.unsqueeze(1), clean_signal.unsqueeze(1)
            
            # Train Discriminator
            optimizer_D.zero_grad()
            real_output = discriminator(torch.cat((clean_signal, noisy_signal), 1))
            fake_signal = generator(noisy_signal)
            fake_output = discriminator(torch.cat((fake_signal.detach(), noisy_signal), 1))
            d_loss = discriminator_loss(real_output, fake_output)
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            fake_signal = generator(noisy_signal)
            fake_output = discriminator(torch.cat((fake_signal, noisy_signal), 1))
            g_loss = generator_loss(fake_output, fake_signal, clean_signal)
            g_loss.backward()
            optimizer_G.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Generator Loss: {g_loss.item()}, Discriminator Loss: {d_loss.item()}")

# Main Execution
if __name__ == "__main__":
    dataset = ECGDataset(noisy_signals, arrhythmia_signals)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    generator = Generator()
    discriminator = Discriminator()

    train(generator, discriminator, dataloader)


Datasets extracted successfully!
ECG and noise signals loaded successfully!
Noisy ECG signals prepared successfully!


RuntimeError: [enforce fail at ..\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 2662400000 bytes.