In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import wfdb
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# Define the ECGDataset class
class ECGDataset(Dataset):
    def __init__(self, raw_signals, noisy_signals):
        self.raw_signals = raw_signals
        self.noisy_signals = noisy_signals

    def __len__(self):
        return len(self.raw_signals)

    def __getitem__(self, idx):
        raw_signal = self.raw_signals[idx]
        noisy_signal = self.noisy_signals[idx]
        return torch.tensor(raw_signal, dtype=torch.float32), torch.tensor(noisy_signal, dtype=torch.float32)

# Define the add_noise function
def add_noise(signal, noise, snr):
    signal_power = np.mean(signal ** 2)
    noise_power = np.mean(noise ** 2)
    factor = (signal_power / noise_power) / (10 ** (snr / 10))
    noisy_signal = signal + noise * np.sqrt(factor)
    return noisy_signal

# Define the load_mit_bih_data function
def load_mit_bih_data(records, noise_types, snr_levels, target_length=650000):
    raw_signals_dict = {noise_type: [] for noise_type in noise_types}
    noisy_signals_dict = {noise_type: {snr: [] for snr in snr_levels} for noise_type in noise_types}
    
    for record in records:
        raw_record = wfdb.rdrecord(f'M:/Dissertation/New folder/mit-bih-arrhythmia-database-1.0.0/{record}')
        raw_signal = raw_record.p_signal[:, 0]  # Use the first channel for simplicity
        
        for noise_type in noise_types:
            noise_record = wfdb.rdrecord(f'M:/Dissertation/New folder/mit-bih-noise-stress-test-database-1.0.0/{noise_type}')
            noise_signal = noise_record.p_signal[:, 0]
            
            min_length = min(len(raw_signal), len(noise_signal), target_length)
            raw_signal_cut = raw_signal[:min_length]
            noise_signal_cut = noise_signal[:min_length]
            
            if min_length < target_length:
                raw_signal_cut = np.pad(raw_signal_cut, (0, target_length - min_length), 'constant')
                noise_signal_cut = np.pad(noise_signal_cut, (0, target_length - min_length), 'constant')
            
            raw_signals_dict[noise_type].append(raw_signal_cut)
            
            for snr in snr_levels:
                noisy_signal = add_noise(raw_signal_cut, noise_signal_cut, snr)
                noisy_signals_dict[noise_type][snr].append(noisy_signal)
    
    return raw_signals_dict, noisy_signals_dict

# Select records and noise types for the experiment
records = ['103', '105', '111', '116', '122', '205', '213', '219', '223', '230']
noise_types = ['bw', 'em', 'ma']
combined_noise_types = ['em+bw', 'ma+bw', 'ma+em', 'ma+em+bw']
snr_levels = [0, 1, 2, 3, 4, 5]
target_length = 649984

raw_signals_dict, noisy_signals_dict = load_mit_bih_data(records, noise_types, snr_levels, target_length)

# For combined noise types, combine the corresponding noises
for combined_noise in combined_noise_types:
    components = combined_noise.split('+')
    combined_raw_signals = []
    combined_noisy_signals = {snr: [] for snr in snr_levels}
    
    for i in range(len(records)):
        combined_signal = np.zeros(target_length)
        for component in components:
            combined_signal += raw_signals_dict[component][i] / len(components)
        combined_raw_signals.append(combined_signal)
        
        for snr in snr_levels:
            combined_noise_signal = np.zeros(target_length)
            for component in components:
                combined_noise_signal += np.array(noisy_signals_dict[component][snr][i]) / len(components)
            combined_noisy_signals[snr].append(combined_noise_signal)
    
    raw_signals_dict[combined_noise] = combined_raw_signals
    noisy_signals_dict[combined_noise] = combined_noisy_signals

# Create datasets and dataloaders
datasets = {}
for noise_type in noise_types + combined_noise_types:
    for snr in snr_levels:
        raw_signals_train, raw_signals_test, noisy_signals_train, noisy_signals_test = train_test_split(
            raw_signals_dict[noise_type], noisy_signals_dict[noise_type][snr], test_size=0.2, random_state=42)
        
        train_dataset = ECGDataset(raw_signals_train, noisy_signals_train)
        test_dataset = ECGDataset(raw_signals_test, noisy_signals_test)
        
        datasets[(noise_type, snr, 'train')] = train_dataset
        datasets[(noise_type, snr, 'test')] = test_dataset

dataloaders = {key: DataLoader(dataset, batch_size=256, shuffle=True) for key, dataset in datasets.items()}

# Define the Generator class with and without input variables z
class GeneratorWithZ(nn.Module):
    def __init__(self):
        super(GeneratorWithZ, self).__init__()
        self.encoder = nn.ModuleList([
            nn.Conv1d(1, 16, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(16, 32, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(512, 1024, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(1024, 2048, kernel_size=4, stride=2, padding=1),
        ])
        self.z_layer = nn.Linear(10, 2048)  # Assuming z is a vector of size 10
        self.decoder = nn.ModuleList([
            nn.ConvTranspose1d(4096, 1024, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(1024, 512, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(16, 1, kernel_size=4, stride=2, padding=1),
        ])
        self.prelu = nn.PReLU()
        self.skip_connections = nn.ModuleList([
            nn.Conv1d(16, 16, kernel_size=1),
            nn.Conv1d(32, 32, kernel_size=1),
            nn.Conv1d(64, 64, kernel_size=1),
            nn.Conv1d(128, 128, kernel_size=1),
            nn.Conv1d(256, 256, kernel_size=1),
            nn.Conv1d(512, 512, kernel_size=1),
            nn.Conv1d(1024, 1024, kernel_size=1),
        ])

    def forward(self, x, z):
        encodings = []
        for layer in self.encoder:
            x = layer(x)
            x = self.prelu(x)
            encodings.append(x)
        
        z = self.prelu(self.z_layer(z)).unsqueeze(2).repeat(1, 1, x.size(2))
        x = torch.cat((x, z), dim=1)
        
        for i, layer in enumerate(self.decoder):
            x = layer(x)
            x = self.prelu(x)
            if i < len(self.skip_connections):
                x += self.skip_connections[-i-1](encodings[-i-1])
        return x

class GeneratorWithoutZ(nn.Module):
    def __init__(self):
        super(GeneratorWithoutZ, self).__init__()
        self.encoder = nn.ModuleList([
            nn.Conv1d(1, 16, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(16, 32, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(512, 1024, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(1024, 2048, kernel_size=4, stride=2, padding=1),
        ])
        self.decoder = nn.ModuleList([
            nn.ConvTranspose1d(2048, 1024, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(1024, 512, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(16, 1, kernel_size=4, stride=2, padding=1),
        ])
        self.prelu = nn.PReLU()
        self.skip_connections = nn.ModuleList([
            nn.Conv1d(16, 16, kernel_size=1),
            nn.Conv1d(32, 32, kernel_size=1),
            nn.Conv1d(64, 64, kernel_size=1),
            nn.Conv1d(128, 128, kernel_size=1),
            nn.Conv1d(256, 256, kernel_size=1),
            nn.Conv1d(512, 512, kernel_size=1),
            nn.Conv1d(1024, 1024, kernel_size=1),
        ])

    def forward(self, x):
        encodings = []
        for layer in self.encoder:
            x = layer(x)
            x = self.prelu(x)
            encodings.append(x)
        
        for i, layer in enumerate(self.decoder):
            x = layer(x)
            x = self.prelu(x)
            if i < len(self.skip_connections):
                x += self.skip_connections[-i-1](encodings[-i-1])
        return x

# Define the Discriminator class
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv1d(2, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(512, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x).view(x.size(0), -1)

# Define the calculate_snr function
def calculate_snr(original, denoised):
    noise = original - denoised
    snr = 10 * np.log10(np.sum(original ** 2) / np.sum(noise ** 2))
    return snr

def calculate_rmse(original, denoised):
    mse = np.mean((original - denoised) ** 2)
    rmse = np.sqrt(mse)
    return rmse

def train(generator, discriminator, dataloaders, num_epochs=100, lr=0.0001, use_z=False):
    criterion = nn.BCELoss()
    optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=lr)
    optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=lr)
    
    results = {'Noise_Type': [], 'SNR_Level': [], 'Epoch': [], 'Batch': [], 'D_loss': [], 'G_loss': [], 'SNR': [], 'RMSE': []}
    
    for (noise_type, snr, phase), dataloader in dataloaders.items():
        for epoch in range(num_epochs):
            for i, (raw_signals, noisy_signals) in enumerate(dataloader):
                batch_size = raw_signals.size(0)
                
                # Ensure the signals have the same length
                min_length = min(raw_signals.shape[-1], noisy_signals.shape[-1])
                raw_signals = raw_signals[:, :min_length]
                noisy_signals = noisy_signals[:, :min_length]

                # Denoise the noisy signals
                noisy_signals = noisy_signals.unsqueeze(1)  # Add channel dimension
                raw_signals = raw_signals.unsqueeze(1)
                
                # Create random z
                z = torch.randn(batch_size, 10) if use_z else None

                # Train Generator
                optimizer_G.zero_grad()
                if use_z:
                    gen_signals = generator(noisy_signals, z)
                else:
                    gen_signals = generator(noisy_signals)
                
                # Update valid and fake labels to match the discriminator output size
                disc_output_size = discriminator(torch.cat((gen_signals, noisy_signals), 1)).size()
                valid = torch.ones(disc_output_size).to(gen_signals.device)
                fake = torch.zeros(disc_output_size).to(gen_signals.device)
                
                g_loss = criterion(discriminator(torch.cat((gen_signals, noisy_signals), 1)), valid)
                g_loss.backward()
                optimizer_G.step()

                # Train Discriminator
                optimizer_D.zero_grad()
                real_loss = criterion(discriminator(torch.cat((raw_signals, noisy_signals), 1)), valid)
                fake_loss = criterion(discriminator(torch.cat((gen_signals.detach(), noisy_signals), 1)), fake)
                d_loss = (real_loss + fake_loss) / 2
                d_loss.backward()
                optimizer_D.step()
                
                # Calculate SNR and RMSE
                snr_value = calculate_snr(raw_signals.squeeze().cpu().numpy(), gen_signals.squeeze().cpu().detach().numpy())
                rmse_value = calculate_rmse(raw_signals.squeeze().cpu().numpy(), gen_signals.squeeze().cpu().detach().numpy())

                # Store results
                results['Noise_Type'].append(noise_type)
                results['SNR_Level'].append(snr)
                results['Epoch'].append(epoch + 1)
                results['Batch'].append(i + 1)
                results['D_loss'].append(d_loss.item())
                results['G_loss'].append(g_loss.item())
                results['SNR'].append(snr_value)
                results['RMSE'].append(rmse_value)
                
                print(f"[{noise_type} SNR {snr}] [Epoch {epoch + 1}/{num_epochs}] [Batch {i + 1}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}] [SNR: {snr_value}] [RMSE: {rmse_value}]")

    return results

# Initialize models
generator_with_z = GeneratorWithZ()
generator_without_z = GeneratorWithoutZ()
discriminator = Discriminator()

# Train the models and collect results
results_with_z = train(generator_with_z, discriminator, dataloaders, num_epochs=100, lr=0.0001, use_z=True)
results_without_z = train(generator_without_z, discriminator, dataloaders, num_epochs=100, lr=0.0001, use_z=False)

# Create DataFrames to display results
df_results_with_z = pd.DataFrame(results_with_z)
df_results_without_z = pd.DataFrame(results_without_z)

# Plotting the results
avg_snr_with_z = df_results_with_z.groupby('SNR_Level')['SNR'].mean()
avg_rmse_with_z = df_results_with_z.groupby('SNR_Level')['RMSE'].mean()
avg_snr_without_z = df_results_without_z.groupby('SNR_Level')['SNR'].mean()
avg_rmse_without_z = df_results_without_z.groupby('SNR_Level')['RMSE'].mean()

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.plot(snr_levels, avg_snr_with_z, label='With z')
plt.plot(snr_levels, avg_snr_without_z, label='Without z')
plt.xlabel('SNR (dB)')
plt.ylabel('Average SNR (dB)')
plt.title('SNR vs Average SNR')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(snr_levels, avg_rmse_with_z, label='With z')
plt.plot(snr_levels, avg_rmse_without_z, label='Without z')
plt.xlabel('SNR (dB)')
plt.ylabel('Average RMSE')
plt.title('SNR vs Average RMSE')
plt.legend()

plt.tight_layout()
plt.show()


RuntimeError: Given groups=1, weight of size [1024, 1024, 1], expected input[8, 2048, 2539] to have 1024 channels, but got 2048 channels instead