In [None]:
import os
import numpy as np
import pandas as pd
import wfdb
import scipy.io
import scipy.signal
import neurokit2 as nk
import biosppy.signals.ecg as ecg
import torch
import gzip
import pickle

# Define constants
TARGET_SAMPLING_RATE = 125  # Hz
MAX_LEN_PHYSIONET = 10 * TARGET_SAMPLING_RATE  # 10 seconds
MAX_LEN_MITBIH = 30 * TARGET_SAMPLING_RATE  # 30 seconds

# Load PhysioNet dataset
def load_physionet_data(path):
    """Load PhysioNet 2017 dataset from .mat files and reference.csv."""
    signals, labels = [], []
    ref_df = pd.read_csv(os.path.join(path, "REFERENCE.csv"), header=None)
    ref_dict = dict(zip(ref_df[0], ref_df[1]))
    label_mapping = {"N": 0, "A": 1, "O": 2, "~": 3}  # Modify as per dataset classes
    

    for file in os.listdir(path):
        if file.endswith(".mat"):
            record_name = file.replace(".mat", "")
            mat_data = scipy.io.loadmat(os.path.join(path, file))
            signal = mat_data["val"][0]  # Extract ECG lead
            label = ref_dict.get(record_name, None)
            if label:
                signals.append(signal)
                physionet_labels = label_mapping[label]
                labels.append(physionet_labels)

    return signals, labels



# Downsampling function
def downsample_signal(signal, original_fs, target_fs=125):
    """Downsample ECG signal from original_fs to target_fs."""
    num_samples = int(len(signal) * target_fs / original_fs)
    return scipy.signal.resample(signal, num_samples)

# Normalize function
def normalize_signal(signal):
    """Normalize ECG signal between 0 and 1."""
    return (signal - np.min(signal)) / (np.max(signal) - np.min(signal))

# R-peak detection
def detect_r_peaks(signal, sampling_rate=125):
    """Detect R-peaks using the Pan-Tompkins algorithm."""
    _, r_peaks = nk.ecg_peaks(signal, sampling_rate=sampling_rate)
    return np.array(r_peaks["ECG_R_Peaks"])

# Beat segmentation
def extract_beats(signal, r_peaks, window_size=0.5, fs=125):
    """Extract ECG beats centered around R-peaks."""
    beat_length = int(window_size * fs)
    beats = []
    for peak in r_peaks:
        start = max(0, peak - beat_length)
        end = min(len(signal), peak + beat_length)
        beat = signal[start:end]
        beats.append(beat)
    return beats

# Zero-padding
def pad_signal(signal, max_len=MAX_LEN_PHYSIONET):
    """Pad signal to max_len with zeros."""
    if len(signal) < max_len:
        return np.pad(signal, (0, max_len - len(signal)), 'constant')
    else:
        return signal[:max_len]

# Full preprocessing pipeline
def preprocess_ecg_dataset(dataset_path):
    """Preprocess ECG dataset from PhysioNet or MIT-BIH."""
    signals, labels = load_physionet_data(dataset_path)
    original_fs = 300  # PhysioNet signals are sampled at 300 Hz
    max_len = MAX_LEN_PHYSIONET
    processed_signals,processed_labels = [],[]
    for signal,label in zip(signals,labels):
        # 1. Downsampling
        signal = downsample_signal(signal, original_fs, TARGET_SAMPLING_RATE)
        # 2. Normalization
        signal = normalize_signal(signal)
        # 3. R-peak detection
        r_peaks = detect_r_peaks(signal, TARGET_SAMPLING_RATE)
        # 4. Beat extraction
        beats = extract_beats(signal, r_peaks)
        # 5. Zero-padding each beat
        padded_beats = [pad_signal(beat, max_len) for beat in beats]
        processed_signals.extend(padded_beats)  # Collect all beats
        processed_labels.extend([label] * len(padded_beats))  # Assign the same label to all extracted beats

    return np.array(processed_signals), np.array(processed_labels)

# Example Usage
physionet_data, physionet_labels = preprocess_ecg_dataset("data/training2017")

# Convert to PyTorch Tensors
X_physionet = torch.tensor(physionet_data, dtype=torch.float32)
y_physionet = torch.tensor(physionet_labels, dtype=torch.long)

#save the processed data
# Save data with gzip compression
with gzip.open("pretraining_data.pkl.gz", "wb") as f:
    pickle.dump((X_physionet, y_physionet), f)
print(f"PhysioNet Data Shape: {X_physionet.shape}")


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/mpallasmichael/Desktop/ecgDiagnosis/.venv/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/mpallasmichael/Desktop/ecgDiagnosis/.venv/lib/python3.9/s

PhysioNet Data Shape: torch.Size([340517, 1250])


In [2]:
import os
import numpy as np
import torch
import wfdb
import gzip
import pickle
from scipy.signal import resample

# Constants
TARGET_SAMPLING_RATE = 125  # Hz
MAX_LEN_MITBIH = 2 * TARGET_SAMPLING_RATE  # e.g., 2-second window around each R-peak

# Beat label mapping (can expand if needed)
label_mapping = {
    'N': 0,  # Normal
    'L': 0, 'R': 0, 'e': 0, 'j': 0,
    'A': 1, 'a': 1, 'J': 1, 'S': 1,
    'V': 2, 'E': 2, '!': 2,
    'F': 3,
    '/': 4, 'f': 4, 'Q': 4, '?': 4  # Unknown/other
}

def load_mitbih_record(record_path):
    """Load signal and annotations from MIT-BIH record."""
    record = wfdb.rdrecord(record_path)
    annotation = wfdb.rdann(record_path, 'atr')
    signal = record.p_signal[:, 0]  # Use first ECG lead
    return signal, annotation.sample, annotation.symbol

def downsample_signal(signal, original_fs=360, target_fs=125):
    num_samples = int(len(signal) * target_fs / original_fs)
    return resample(signal, num_samples)

def normalize_signal(signal):
    return (signal - np.min(signal)) / (np.max(signal) - np.min(signal) + 1e-6)

def pad_signal(signal, max_len):
    if len(signal) < max_len:
        return np.pad(signal, (0, max_len - len(signal)), 'constant')
    else:
        return signal[:max_len]

def extract_beats_and_labels(signal, r_peaks, ann_symbols, fs=125, window_size=1.0):
    """Extract beats centered around R-peaks and assign arrhythmia labels."""
    half_window = int((window_size / 2) * fs)
    beats, labels = [], []
    for i, peak in enumerate(r_peaks):
        if ann_symbols[i] in label_mapping:
            start = max(0, peak - half_window)
            end = min(len(signal), peak + half_window)
            beat = signal[start:end]
            padded = pad_signal(beat, max_len=MAX_LEN_MITBIH)
            beats.append(padded)
            labels.append(label_mapping[ann_symbols[i]])
    return beats, labels

def preprocess_mitbih_dataset(dataset_dir):
    """Load all records, extract beats and labels."""
    all_beats, all_labels = [], []
    for file in os.listdir(dataset_dir):
        if file.endswith('.dat'):
            record_name = file.replace('.dat', '')
            signal, r_peaks, ann_symbols = load_mitbih_record(os.path.join(dataset_dir, record_name))
            # Downsample
            signal = downsample_signal(signal, original_fs=360, target_fs=TARGET_SAMPLING_RATE)
            r_peaks = (np.array(r_peaks) * (TARGET_SAMPLING_RATE / 360)).astype(int)
            # Normalize
            signal = normalize_signal(signal)
            # Extract
            beats, labels = extract_beats_and_labels(signal, r_peaks, ann_symbols, fs=TARGET_SAMPLING_RATE)
            all_beats.extend(beats)
            all_labels.extend(labels)
    return np.array(all_beats), np.array(all_labels)
# Preprocess MIT-BIH
mitbih_signals, mitbih_labels = preprocess_mitbih_dataset("data/mit-bih-arrhythmia-database-1.0.0")

# Convert to PyTorch tensors
X_mitbih = torch.tensor(mitbih_signals, dtype=torch.float32)
y_mitbih = torch.tensor(mitbih_labels, dtype=torch.long)

# Save
with gzip.open("mitbih_beats.pkl.gz", "wb") as f:
    pickle.dump((X_mitbih, y_mitbih), f)

print(f"MIT-BIH Beats Shape: {X_mitbih.shape}")
print(f"Unique Labels: {np.unique(mitbih_labels, return_counts=True)}")

MIT-BIH Beats Shape: torch.Size([109966, 250])
Unique Labels: (array([0, 1, 2, 3, 4]), array([90631,  2781,  7708,   803,  8043]))
