In [8]:
import os
import numpy as np
import pandas as pd
import wfdb
import scipy.io
import scipy.signal
import neurokit2 as nk
import biosppy.signals.ecg as ecg
import torch
import gzip
import pickle

# Define constants
TARGET_SAMPLING_RATE = 125  # Hz
MAX_LEN_PHYSIONET = 10 * TARGET_SAMPLING_RATE  # 10 seconds
MAX_LEN_MITBIH = 30 * TARGET_SAMPLING_RATE  # 30 seconds

# Load PhysioNet dataset
def load_physionet_data(path):
    """Load PhysioNet 2017 dataset from .mat files and reference.csv."""
    signals, labels = [], []
    ref_df = pd.read_csv(os.path.join(path, "REFERENCE.csv"), header=None)
    ref_dict = dict(zip(ref_df[0], ref_df[1]))
    label_mapping = {"N": 0, "A": 1, "O": 2, "~": 3}  # Modify as per dataset classes
    

    for file in os.listdir(path):
        if file.endswith(".mat"):
            record_name = file.replace(".mat", "")
            mat_data = scipy.io.loadmat(os.path.join(path, file))
            signal = mat_data["val"][0]  # Extract ECG lead
            label = ref_dict.get(record_name, None)
            if label:
                signals.append(signal)
                physionet_labels = label_mapping[label]
                labels.append(physionet_labels)

    return signals, labels



# Downsampling function
def downsample_signal(signal, original_fs, target_fs=125):
    """Downsample ECG signal from original_fs to target_fs."""
    num_samples = int(len(signal) * target_fs / original_fs)
    return scipy.signal.resample(signal, num_samples)

# Normalize function
def normalize_signal(signal):
    """Normalize ECG signal between 0 and 1."""
    return (signal - np.min(signal)) / (np.max(signal) - np.min(signal))

# R-peak detection
def detect_r_peaks(signal, sampling_rate=125):
    """Detect R-peaks using the Pan-Tompkins algorithm."""
    _, r_peaks = nk.ecg_peaks(signal, sampling_rate=sampling_rate)
    return np.array(r_peaks["ECG_R_Peaks"])
# Extract T-episodes
def extract_t_episodes(signal, r_peaks, fs):
    """T-episodes are intervals centered on R-peaks of length median(R-R interval)."""
    if len(r_peaks) < 2:
        return []
    rr_intervals = np.diff(r_peaks)
    median_rr = int(np.median(rr_intervals))
    episodes = []
    for r in r_peaks:
        start = max(0, r - median_rr // 2)
        end = min(len(signal), r + median_rr // 2)
        episodes.append((start, end))
    return episodes

# Zero-padding
def pad_signal(signal, max_len=MAX_LEN_PHYSIONET):
    """Pad signal to max_len with zeros."""
    if len(signal) < max_len:
        return np.pad(signal, (0, max_len - len(signal)), 'constant')
    else:
        return signal[:max_len]

# Full preprocessing pipeline
def preprocess_ecg_dataset(dataset_path):
    """Preprocess ECG dataset from PhysioNet or MIT-BIH."""
    signals, labels = load_physionet_data(dataset_path)
    original_fs = 300  # PhysioNet signals are sampled at 300 Hz
    max_len = MAX_LEN_PHYSIONET
    processed_signals,processed_labels = [],[]
    for signal,label in zip(signals,labels):
        # 1. Downsampling
        signal = downsample_signal(signal, original_fs, TARGET_SAMPLING_RATE)
        # 2. Normalization
        signal = normalize_signal(signal)
        # 3. R-peak detection
        r_peaks = detect_r_peaks(signal, TARGET_SAMPLING_RATE)
        #4. T-episode and beat extraction
        t_eps = extract_t_episodes(signal, r_peaks, TARGET_SAMPLING_RATE)
        for start, end in t_eps:
            beat = signal[start:end]
            processed_signals.append(pad_signal(beat, max_len))
            processed_labels.append(label)

    return np.array(processed_signals), np.array(processed_labels)

# Example Usage
physionet_data, physionet_labels = preprocess_ecg_dataset("data/training2017")

# Convert to PyTorch Tensors
X_physionet = torch.tensor(physionet_data, dtype=torch.float32)
y_physionet = torch.tensor(physionet_labels, dtype=torch.long)

#save the processed data
# Save data with gzip compression
with gzip.open("pretraining_data.pkl.gz", "wb") as f:
    pickle.dump((X_physionet, y_physionet), f)
print(f"PhysioNet Data Shape: {X_physionet.shape}")

KeyboardInterrupt: 

In [None]:
import os
import numpy as np
import torch
import wfdb
import gzip
import pickle
import neurokit2 as nk
from scipy.signal import resample

# Constants
TARGET_SAMPLING_RATE = 125
MAX_LEN_MITBIH = 30 * TARGET_SAMPLING_RATE  # 30 seconds

# Beat label mapping
label_mapping = {
    'N': 0, 'L': 0, 'R': 0,
    'V': 3, '/': 4,
    'A': 2, 'F': 3, 'f': 4,
    'j': 2, 'a': 2, 'E': 3,
    'J': 2, 'e': 2, 'Q': 4,
    'S': 2
}

def load_mitbih_record(record_path):
    record = wfdb.rdrecord(record_path)
    annotation = wfdb.rdann(record_path, 'atr')
    signal = record.p_signal[:, 0]  # use first ECG lead
    return signal, annotation.sample, annotation.symbol

def downsample_signal(signal, original_fs=360, target_fs=125):
    num_samples = int(len(signal) * target_fs / original_fs)
    return resample(signal, num_samples)

def normalize_signal(signal):
    return (signal - np.min(signal)) / (np.max(signal) - np.min(signal) + 1e-6)

def detect_r_peaks(signal, fs=125):
    """Use NeuroKit2 to detect R-peaks."""
    _, rpeaks = nk.ecg_peaks(signal, sampling_rate=fs)
    return rpeaks["ECG_R_Peaks"]

def extract_t_episodes(signal, r_peaks, fs):
    if len(r_peaks) < 2:
        return []
    rr_intervals = np.diff(r_peaks)
    median_rr = int(np.median(rr_intervals))
    episodes = []
    for r in r_peaks:
        start = max(0, r - median_rr // 2)
        end = min(len(signal), r + median_rr // 2)
        episodes.append((start, end))
    return episodes

def assign_labels_to_episodes(episodes, ann_samples, ann_symbols, label_map):
    labels = []
    for start, end in episodes:
        center = (start + end) // 2
        nearest_idx = np.argmin(np.abs(np.array(ann_samples) - center))
        label = ann_symbols[nearest_idx]
        if label in label_map:
            labels.append(label_map[label])
        else:
            labels.append(None)
    return labels

def pad_signal(signal, max_len):
    if len(signal) < max_len:
        return np.pad(signal, (0, max_len - len(signal)), 'constant')
    else:
        return signal[:max_len]

def preprocess_mitbih_dataset(dataset_dir):
    all_beats, all_labels = [], []
    for file in os.listdir(dataset_dir):
        if file.endswith('.dat'):
            record_name = file.replace('.dat', '')
            signal, ann_samples, ann_symbols = load_mitbih_record(os.path.join(dataset_dir, record_name))
            # Downsample signal
            signal = downsample_signal(signal)
            ann_samples = (np.array(ann_samples) * (TARGET_SAMPLING_RATE / 360)).astype(int)
            # Normalize
            signal = normalize_signal(signal)
            # R-peak detection using NeuroKit
            r_peaks = detect_r_peaks(signal, TARGET_SAMPLING_RATE)
            # Extract T-episodes
            t_episodes = extract_t_episodes(signal, r_peaks, TARGET_SAMPLING_RATE)
            # Match to labels
            labels = assign_labels_to_episodes(t_episodes, ann_samples, ann_symbols, label_mapping)
            for (start, end), label in zip(t_episodes, labels):
                if label is not None:
                    beat = signal[start:end]
                    padded = pad_signal(beat, MAX_LEN_MITBIH)
                    all_beats.append(padded)
                    all_labels.append(label)
    return np.array(all_beats), np.array(all_labels)

# Preprocess MIT-BIH
mitbih_signals, mitbih_labels = preprocess_mitbih_dataset("data/mit-bih-arrhythmia-database-1.0.0")

# Convert to PyTorch tensors
X_mitbih = torch.tensor(mitbih_signals, dtype=torch.float32)
y_mitbih = torch.tensor(mitbih_labels, dtype=torch.long)

# Save
with gzip.open("mitbih_beats.pkl.gz", "wb") as f:
    pickle.dump((X_mitbih, y_mitbih), f)

print(f"MIT-BIH Beats Shape: {X_mitbih.shape}")
print(f"Unique Labels: {np.unique(mitbih_labels, return_counts=True)}")

KeyboardInterrupt: 