In [None]:
import wfdb
import numpy as np
import pywt
from scipy.signal import find_peaks
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report
from tensorflow.keras.layers import Input, Conv1D, Conv1DTranspose, LeakyReLU, PReLU, Add, BatchNormalization, Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import RMSprop
import os


def load_mit_bih_data(record_path, load_annotation=True):
    record = wfdb.rdrecord(record_path)
    if load_annotation:
        annotation = wfdb.rdann(record_path, 'atr')
        return record.p_signal[:, 0], annotation
    else:
        return record.p_signal[:, 0], None


def add_noise(ecg_signal, noise_signal, snr):
    signal_power = np.mean(ecg_signal ** 2)
    noise_power = np.mean(noise_signal ** 2)
    scale_factor = np.sqrt(signal_power / (noise_power * 10 ** (snr / 10.0)))
    noisy_signal = ecg_signal + scale_factor * noise_signal
    return noisy_signal

# Pan-Tompkins Algorithm for QRS Detection
def pan_tompkins_qrs(ecg_signal, fs=360):
    # Ensure ecg_signal is a 1D array
    ecg_signal = ecg_signal.flatten()
    
    diff = np.diff(ecg_signal)
    squared = diff ** 2
    integrated = np.convolve(squared, np.ones(5), mode='same')
    peaks, _ = find_peaks(integrated, distance=int(fs*0.6))  # Added int() for clarity
    return peaks

# Segment Heartbeats based on QRS detection
def segment_heartbeats(ecg_signal, qrs_peaks, window_size=128):
    heartbeats = []
    for peak in qrs_peaks:
        start = max(peak - window_size // 2, 0)
        end = min(peak + window_size // 2, len(ecg_signal))
        heartbeat = ecg_signal[start:end]
        
        if len(heartbeat) < window_size:
            heartbeat = np.pad(heartbeat, (0, window_size - len(heartbeat)), 'constant')
        
        heartbeats.append(heartbeat)
    return np.array(heartbeats)

# Adjust the wavelet decomposition function
def wavelet_decomposition(signal, wavelet='db4', level=4, target_length=512):
    coeffs = pywt.wavedec(signal, wavelet, level=level)
    approx_coeffs = coeffs[0]
    
    current_length = len(approx_coeffs)
    
    if current_length > target_length:
        approx_coeffs = approx_coeffs[:target_length]
    elif current_length < target_length:
        approx_coeffs = np.pad(approx_coeffs, (0, target_length - current_length), 'constant')
    
    return approx_coeffs.reshape(target_length, 1)


# Build Generator Model
def build_generator(input_shape):
    input_layer = Input(shape=input_shape)

    x = Conv1D(64, kernel_size=15, strides=1, padding='same')(input_layer)
    x = PReLU()(x)
    x1 = Conv1D(128, kernel_size=15, strides=2, padding='same')(x)
    x1 = PReLU()(x1)
    x2 = Conv1D(256, kernel_size=15, strides=2, padding='same')(x1)
    x2 = PReLU()(x2)
    x3 = Conv1D(512, kernel_size=15, strides=2, padding='same')(x2)
    x3 = PReLU()(x3)
    x4 = Conv1D(512, kernel_size=15, strides=2, padding='same')(x3)
    x4 = PReLU()(x4)
    x5 = Conv1D(256, kernel_size=15, strides=2, padding='same')(x4)
    x5 = PReLU()(x5)
    x6 = Conv1D(128, kernel_size=15, strides=2, padding='same')(x5)
    x6 = PReLU()(x6)
    x7 = Conv1D(64, kernel_size=15, strides=2, padding='same')(x6)
    x7 = PReLU()(x7)
    
    x7 = Conv1DTranspose(64, kernel_size=15, strides=2, padding='same')(x7)
    x6_resized = Conv1D(64, kernel_size=1, strides=1, padding='same')(x6)
    x7 = Add()([x7, x6_resized])
    x7 = PReLU()(x7)

    x7 = Conv1DTranspose(128, kernel_size=15, strides=2, padding='same')(x7)
    x5_resized = Conv1D(128, kernel_size=1, strides=1, padding='same')(x5)
    x7 = Add()([x7, x5_resized])
    x7 = PReLU()(x7)

    x7 = Conv1DTranspose(256, kernel_size=15, strides=2, padding='same')(x7)
    x4_resized = Conv1D(256, kernel_size=1, strides=1, padding='same')(x4)
    x7 = Add()([x7, x4_resized])
    x7 = PReLU()(x7)

    x7 = Conv1DTranspose(512, kernel_size=15, strides=2, padding='same')(x7)
    x3_resized = Conv1D(512, kernel_size=1, strides=1, padding='same')(x3)
    x7 = Add()([x7, x3_resized])
    x7 = PReLU()(x7)

    x7 = Conv1DTranspose(256, kernel_size=15, strides=2, padding='same')(x7)
    x2_resized = Conv1D(256, kernel_size=1, strides=1, padding='same')(x2)
    x7 = Add()([x7, x2_resized])
    x7 = PReLU()(x7)

    x7 = Conv1DTranspose(128, kernel_size=15, strides=2, padding='same')(x7)
    x1_resized = Conv1D(128, kernel_size=1, strides=1, padding='same')(x1)
    x7 = Add()([x7, x1_resized])
    x7 = PReLU()(x7)

    x7 = Conv1DTranspose(64, kernel_size=15, strides=2, padding='same')(x7)
    x_resized = Conv1D(64, kernel_size=1, strides=1, padding='same')(x)
    x7 = Add()([x7, x_resized])
    x7 = PReLU()(x7)

    outputs = Conv1DTranspose(1, kernel_size=15, strides=1, padding='same', activation='tanh')(x7)

    return Model(input_layer, outputs)

# Build Discriminator Model
def build_discriminator(input_shape=(512, 1)):
    input_layer = Input(shape=input_shape)
    
    x = Conv1D(64, kernel_size=3, strides=2, padding='same')(input_layer)
    x = LeakyReLU(negative_slope=0.2)(x)
    x = BatchNormalization()(x)

    x = Conv1D(128, kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU(negative_slope=0.2)(x)
    x = BatchNormalization()(x)

    x = Conv1D(256, kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU(negative_slope=0.2)(x)
    x = BatchNormalization()(x)

    x = Conv1D(512, kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU(negative_slope=0.2)(x)
    x = BatchNormalization()(x)

    x = Flatten()(x)
    output_layer = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=input_layer, outputs=output_layer)
    return model

# Compile the GAN
def compile_gan(generator, discriminator):
    discriminator.compile(optimizer=RMSprop(learning_rate=0.0002), loss='binary_crossentropy', metrics=['accuracy'])
    discriminator.trainable = False

    gan_input = Input(shape=(generator.input_shape[1], generator.input_shape[2]))
    x = generator(gan_input)
    gan_output = discriminator(x)

    gan = Model(gan_input, gan_output)
    gan.compile(optimizer=RMSprop(learning_rate=0.0001), loss='binary_crossentropy')

    return gan

input_shape = (512, 1)

generator = build_generator(input_shape)
discriminator = build_discriminator(input_shape)
gan = compile_gan(generator, discriminator)

generator.summary()
discriminator.summary()

# Function to denoise the signals using the trained generator model
def denoise_signals(generator_model, signals):
    denoised_signals = generator_model.predict(signals)
    return denoised_signals

# Classification with SVM
def classify_heartbeats(features, labels):
    features_flattened = features.reshape(features.shape[0], features.shape[1])

    X_train, X_test, y_train, y_test = train_test_split(features_flattened, labels, test_size=0.2, random_state=42)
    svm_model = SVC(kernel='rbf', class_weight='balanced')
    svm_model.fit(X_train, y_train)
    y_pred = svm_model.predict(X_test)
    
    accuracy = accuracy_score(y_test, y_pred)
    class_report = classification_report(y_test, y_pred, target_names=['N', 'V', 'A', 'L'])
    
    return accuracy, y_pred, class_report


# Main Process
def main():
    mit_bih_records = ['103', '105', '111', '116', '122', '205', '213', '219', '223', '230']
    noise_types = ['em', 'ma', 'bw']

    all_heartbeats = []
    all_labels = []

    for record in mit_bih_records:
        ecg_signal, annotation = load_mit_bih_data(f'C:\\Users\\malik\\Desktop\\Disertation\\New folder\\mit-bih-arrhythmia-database-1.0.0/{record}')
        
        noises = {}
        for noise_type in noise_types:
            noise_signal, _ = load_mit_bih_data(f'C:\\Users\\malik\\Desktop\\Disertation\\New folder\\mit-bih-noise-stress-test-database-1.0.0\\{noise_type}', load_annotation=False)
            noises[noise_type] = noise_signal[:len(ecg_signal)].reshape(-1, 1)

        chunk_size = 512
        num_chunks = len(ecg_signal) // chunk_size

        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = (i + 1) * chunk_size
            ecg_chunk = ecg_signal[start_idx:end_idx]

            noisy_signals = {
                'EM': add_noise(ecg_chunk, noises['em'][start_idx:end_idx], snr=0),
                'MA': add_noise(ecg_chunk, noises['ma'][start_idx:end_idx], snr=0),
                'BW': add_noise(ecg_chunk, noises['bw'][start_idx:end_idx], snr=0),
                'EM+MA': add_noise(ecg_chunk, noises['em'][start_idx:end_idx] + noises['ma'][start_idx:end_idx], snr=0),
                'EM+BW': add_noise(ecg_chunk, noises['em'][start_idx:end_idx] + noises['bw'][start_idx:end_idx], snr=0),
                'MA+BW': add_noise(ecg_chunk, noises['ma'][start_idx:end_idx] + noises['bw'][start_idx:end_idx], snr=0),
                'EM+MA+BW': add_noise(ecg_chunk, noises['em'][start_idx:end_idx] + noises['ma'][start_idx:end_idx] + noises['bw'][start_idx:end_idx], snr=0)
            }

            for noise_type, noisy_signal in noisy_signals.items():
                noisy_signal = noisy_signal.flatten()  # Flatten the noisy signal
                qrs_peaks = pan_tompkins_qrs(noisy_signal)
                heartbeats = segment_heartbeats(noisy_signal, qrs_peaks)
                features = np.array([wavelet_decomposition(beat) for beat in heartbeats])

                labels = []
                for peak in qrs_peaks:
                    idx = np.searchsorted(annotation.sample, start_idx + peak)
                    if idx < len(annotation.symbol):
                        if annotation.symbol[idx] in ['N', 'V', 'A', 'L']:
                            labels.append(annotation.symbol[idx])
                        else:
                            labels.append('N')  # Default to 'N' if it's not one of the four classes
                    else:
                        labels.append('N')

                all_heartbeats.extend(features)
                all_labels.extend(labels)

    all_heartbeats = np.array(all_heartbeats)
    all_labels = np.array(all_labels)

    print("Shape of all_heartbeats:", all_heartbeats.shape)

    # Define the number of epochs and batch size for GAN training
    epochs = 1
    batch_size = 64

    for epoch in range(epochs):
        # Train Discriminator
        idx = np.random.randint(0, all_heartbeats.shape[0], batch_size)
        real_heartbeats = all_heartbeats[idx].reshape(batch_size, 512, 1)
        
        # Create noisy heartbeats to act as inputs for the generator
        noise_combination = noises['em'][:512] + noises['ma'][:512] + noises['bw'][:512]
        noisy_real_heartbeats = add_noise(real_heartbeats, noise_combination[:len(real_heartbeats)], snr=0)
        
        # Generate fake (denoised) heartbeats from the noisy inputs
        fake_heartbeats = generator.predict(noisy_real_heartbeats)

        # Labels for the discriminator
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        # Train the discriminator
        d_loss_real = discriminator.train_on_batch(real_heartbeats, valid)
        d_loss_fake = discriminator.train_on_batch(fake_heartbeats, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train Generator
        g_loss = gan.train_on_batch(noisy_real_heartbeats, valid)

        # Print the progress
        if epoch % 100 == 0:
            print(f"{epoch}/{epochs} [D loss: {d_loss[0]}, acc: {100 * d_loss[1]}] [G loss: {g_loss}]")

    # Save the trained generator model
    generator.save('cae_cgan_generator.h5')

    # Use the generator model to denoise the heartbeats
    denoised_heartbeats = denoise_signals(generator, all_heartbeats)

    # Classify the denoised heartbeats using an SVM model
    accuracy, predictions, class_report = classify_heartbeats(denoised_heartbeats, all_labels)

    # Print the classification accuracy and report
    print(f"Overall Classification Accuracy: {accuracy * 100:.2f}%")
    print("Classification Report:\n", class_report)

if __name__ == "__main__":
    main()
