In [4]:
import os
import numpy as np
import pandas as pd
import wfdb
import scipy.io
import scipy.signal
import neurokit2 as nk
import biosppy.signals.ecg as ecg
import torch

# Define constants
TARGET_SAMPLING_RATE = 125  # Hz
MAX_LEN_PHYSIONET = 10 * TARGET_SAMPLING_RATE  # 10 seconds
MAX_LEN_MITBIH = 30 * TARGET_SAMPLING_RATE  # 30 seconds

# Load PhysioNet dataset
def load_physionet_data(path):
    """Load PhysioNet 2017 dataset from .mat files and reference.csv."""
    signals, labels = [], []
    ref_df = pd.read_csv(os.path.join(path, "REFERENCE.csv"), header=None)
    ref_dict = dict(zip(ref_df[0], ref_df[1]))
    label_mapping = {"N": 0, "A": 1, "O": 2, "~": 3}  # Modify as per dataset classes
    

    for file in os.listdir(path):
        if file.endswith(".mat"):
            record_name = file.replace(".mat", "")
            mat_data = scipy.io.loadmat(os.path.join(path, file))
            signal = mat_data["val"][0]  # Extract ECG lead
            label = ref_dict.get(record_name, None)
            if label:
                signals.append(signal)
                physionet_labels = label_mapping[label]
                labels.append(physionet_labels)

    return signals, labels



# Downsampling function
def downsample_signal(signal, original_fs, target_fs=125):
    """Downsample ECG signal from original_fs to target_fs."""
    num_samples = int(len(signal) * target_fs / original_fs)
    return scipy.signal.resample(signal, num_samples)

# Normalize function
def normalize_signal(signal):
    """Normalize ECG signal between 0 and 1."""
    return (signal - np.min(signal)) / (np.max(signal) - np.min(signal))

# R-peak detection
def detect_r_peaks(signal, sampling_rate=125):
    """Detect R-peaks using the Pan-Tompkins algorithm."""
    _, r_peaks = nk.ecg_peaks(signal, sampling_rate=sampling_rate)
    return np.array(r_peaks["ECG_R_Peaks"])

# Beat segmentation
def extract_beats(signal, r_peaks, window_size=0.5, fs=125):
    """Extract ECG beats centered around R-peaks."""
    beat_length = int(window_size * fs)
    beats = []
    for peak in r_peaks:
        start = max(0, peak - beat_length)
        end = min(len(signal), peak + beat_length)
        beat = signal[start:end]
        beats.append(beat)
    return beats

# Zero-padding
def pad_signal(signal, max_len=MAX_LEN_PHYSIONET):
    """Pad signal to max_len with zeros."""
    if len(signal) < max_len:
        return np.pad(signal, (0, max_len - len(signal)), 'constant')
    else:
        return signal[:max_len]

# Full preprocessing pipeline
def preprocess_ecg_dataset(dataset_path):
    """Preprocess ECG dataset from PhysioNet or MIT-BIH."""
    signals, labels = load_physionet_data(dataset_path)
    original_fs = 300  # PhysioNet signals are sampled at 300 Hz
    max_len = MAX_LEN_PHYSIONET
    processed_signals,processed_labels = [],[]
    for signal,label in zip(signals,labels):
        # 1. Downsampling
        signal = downsample_signal(signal, original_fs, TARGET_SAMPLING_RATE)
        # 2. Normalization
        signal = normalize_signal(signal)
        # 3. R-peak detection
        r_peaks = detect_r_peaks(signal, TARGET_SAMPLING_RATE)
        # 4. Beat extraction
        beats = extract_beats(signal, r_peaks)
        # 5. Zero-padding each beat
        padded_beats = [pad_signal(beat, max_len) for beat in beats]
        processed_signals.extend(padded_beats)  # Collect all beats
        processed_labels.extend([label] * len(padded_beats))  # Assign the same label to all extracted beats

    return np.array(processed_signals), np.array(processed_labels)

# Example Usage
physionet_data, physionet_labels = preprocess_ecg_dataset("data/training2017")

# Convert to PyTorch Tensors
X_physionet = torch.tensor(physionet_data, dtype=torch.float32)
y_physionet = torch.tensor(physionet_labels, dtype=torch.long)


print(f"PhysioNet Data Shape: {X_physionet.shape}")

PhysioNet Data Shape: torch.Size([340517, 1250])


In [None]:
# Load MIT-BIH dataset
def load_mitbih_data(path):
    """Load MIT-BIH dataset from WFDB format."""
    signals, labels = [], []
    for file in os.listdir(path):
        if file.endswith(".dat"):
            record_name = file.replace(".dat", "")
            record = wfdb.rdsamp(os.path.join(path, record_name))
            annotation = wfdb.rdann(os.path.join(path, record_name), 'atr')
            signals.append(record[0][:, 0])  # Extract first ECG lead
            labels.append(annotation.sample)
    return signals, labels