In [13]:
import os
import numpy as np
import torch
import wfdb
import gzip
import pickle
import neurokit2 as nk
from scipy.signal import resample

# Constants
TARGET_SAMPLING_RATE = 125
MAX_LEN_BEAT = 2 * TARGET_SAMPLING_RATE  # 2 seconds per beat

def parse_reason_from_header(header_file):
    """Extract 'Reason for admission' from header file."""
    reason = None
    with open(header_file, "r") as f:
        for line in f:
            if "Reason for admission:" in line:
                reason = line.split("Reason for admission:")[-1].strip()
                break
    if reason is None or reason == "":
        reason = "other"
    return reason

def map_reason(reason):
    """Map PTBDB reason into MIT-BIH 5-class style."""
    reason = reason.lower()

    if "healthy" in reason:
        return 0  # N
    elif "myocardial infarction" in reason:
        return 2  # V
    elif "cardiomyopathy" in reason or "myocarditis" in reason or "bundle branch" in reason or "hypertrophy" in reason:
        return 1  # S
    elif "valvular" in reason:
        return 3  # F
    else:
        return 4  # Unknown/Q

def load_ptbdb_record(record_path):
    """Load ECG signal (first lead)."""
    record = wfdb.rdrecord(record_path)
    signal = record.p_signal[:, 0]
    return signal, record.fs

def downsample_signal(signal, original_fs, target_fs=125):
    """Resample signal."""
    num_samples = int(len(signal) * target_fs / original_fs)
    return resample(signal, num_samples)

def normalize_signal(signal):
    """Min-max normalization."""
    return (signal - np.min(signal)) / (np.max(signal) - np.min(signal) + 1e-6)

def detect_r_peaks(signal, fs=125):
    """Detect R-peaks using NeuroKit2."""
    _, rpeaks = nk.ecg_peaks(signal, sampling_rate=fs)
    return rpeaks["ECG_R_Peaks"]

def extract_beats(signal, r_peaks, fs=125):
    """Extract beats centered on each R-peak (using median RR)."""
    if len(r_peaks) < 2:
        return []
    rr_intervals = np.diff(r_peaks)
    median_rr = int(np.median(rr_intervals))
    beats = []
    for r in r_peaks:
        start = max(0, r - median_rr // 2)
        end = min(len(signal), r + median_rr // 2)
        beats.append(signal[start:end])
    return beats

def pad_signal(signal, max_len):
    """Pad or truncate beats to fixed length."""
    if len(signal) < max_len:
        return np.pad(signal, (0, max_len - len(signal)), 'constant')
    else:
        return signal[:max_len]

def preprocess_ptbdb_beats(dataset_dir):
    all_beats, all_labels = [], []

    for patient_folder in sorted(os.listdir(dataset_dir)):
        patient_path = os.path.join(dataset_dir, patient_folder)

        # ✅ Only process folders starting with "patient"
        if not os.path.isdir(patient_path) or not patient_folder.startswith("patient"):
            continue

        # take first .dat file in folder
        dat_files = [f for f in os.listdir(patient_path) if f.endswith(".dat")]
        if not dat_files:
            continue
        first_dat = sorted(dat_files)[0]
        record_name = os.path.splitext(first_dat)[0]
        record_path = os.path.join(patient_path, record_name)

        # parse reason
        hea_file = record_path + ".hea"
        reason = parse_reason_from_header(hea_file)
        label = map_reason(reason)

        # load signal
        signal, fs = load_ptbdb_record(record_path)

        # downsample + normalize
        signal = downsample_signal(signal, fs, TARGET_SAMPLING_RATE)
        signal = normalize_signal(signal)

        # detect R-peaks
        r_peaks = detect_r_peaks(signal, TARGET_SAMPLING_RATE)

        # extract beats
        beats = extract_beats(signal, r_peaks, TARGET_SAMPLING_RATE)

        for beat in beats:
            padded = pad_signal(beat, MAX_LEN_BEAT)
            all_beats.append(padded)
            all_labels.append(label)

    return np.array(all_beats), np.array(all_labels)

# Run preprocessing
ptbdb_beats, ptbdb_labels = preprocess_ptbdb_beats("data/ptb-diagnostic-ecg-database-1.0.0")

# Convert to tensors
X_ptbdb = torch.tensor(ptbdb_beats, dtype=torch.float32)
y_ptbdb = torch.tensor(ptbdb_labels, dtype=torch.long)

# Save
with gzip.open("ptbdb_beats.pkl.gz", "wb") as f:
    pickle.dump((X_ptbdb, y_ptbdb), f)

print(f"PTBDB Beats Shape: {X_ptbdb.shape}")
print(f"Unique Labels: {np.unique(ptbdb_labels, return_counts=True)}")


PTBDB Beats Shape: torch.Size([38084, 250])
Unique Labels: (array([0, 1, 2, 3, 4]), array([ 7096,  5540, 21473,   504,  3471]))
