In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import wfdb
from wfdb import rdrecord, processing
import pywt

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

from scipy.signal import resample_poly, butter, filtfilt
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

sns.set_style('whitegrid')
np.set_printoptions(precision=3, suppress=True)

In [2]:
def lowpass_filter(signal, cutoff, fs, order=4):
    """
    Applies a low-pass Butterworth filter with given cutoff and order.
    """
    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    filtered = filtfilt(b, a, signal)
    return filtered

In [4]:
def downsample_ecg(ecg_signal, old_fs=250, new_fs=128):
    """
    Downsamples an ECG signal from old_fs to new_fs using polyphase resampling.
    Applies a low-pass filter to avoid aliasing.
    """
    # 1) Low-pass filter at half of the new fs (e.g., ~64 Hz)
    # ecg_filtered = lowpass_filter(ecg_signal, cutoff=new_fs/2, fs=old_fs, order=4)

    # 2) Resample
    # up = new_fs, down = old_fs
    ecg_down = resample_poly(ecg_signal, up=new_fs, down=old_fs)
    return ecg_down

In [5]:
def load_ecg_record(record_path, desired_fs=128, channel=0):
    """
    Loads an ECG record using wfdb.
    If the record's sampling frequency is 250 Hz, downsample to 128 Hz.
    If it's already 128 Hz, leave it as is.
    Returns:
      ecg: numpy array
      fs: final sampling frequency (should be 128 Hz if we unify everything)
    """
    record = rdrecord(record_path, channels=[channel])
    fs_original = record.fs
    ecg = record.p_signal.flatten()
    
    if fs_original == desired_fs:
        # Already at 128 Hz
        return ecg, fs_original
    
    elif fs_original == 250 and desired_fs == 128:
        # Downsample from 250 to 128
        ecg_ds = downsample_ecg(ecg, old_fs=fs_original, new_fs=desired_fs)
        return ecg_ds, desired_fs
    
    else:
        # If you encounter any other fs, handle similarly or raise an exception
        raise ValueError(f"Unexpected sampling frequency: {fs_original} Hz. "
                         f"Expected 128 or 250.")

In [6]:
import glob

In [7]:
nsr_files = glob.glob("./data/nsrdb/*.hea")
nsr_records = [f[:-4] for f in nsr_files]
nsr_records = [f.replace("\\", "/")[:-4] for f in nsr_files]
nsr_records

['./data/nsrdb/16265',
 './data/nsrdb/16272',
 './data/nsrdb/16273',
 './data/nsrdb/16420',
 './data/nsrdb/16483',
 './data/nsrdb/16539',
 './data/nsrdb/16773',
 './data/nsrdb/16786',
 './data/nsrdb/16795',
 './data/nsrdb/17052',
 './data/nsrdb/17453',
 './data/nsrdb/18177',
 './data/nsrdb/18184',
 './data/nsrdb/19088',
 './data/nsrdb/19090',
 './data/nsrdb/19093',
 './data/nsrdb/19140',
 './data/nsrdb/19830']

In [8]:
# scd database
sdd_files = glob.glob("./data/scddb/*.hea")
sdd_records = [f[:-4] for f in sdd_files]  # remove the ".hea" extension
sdd_records = [f.replace("\\", "/")[:-4] for f in sdd_files]
sdd_records

['./data/scddb/30',
 './data/scddb/31',
 './data/scddb/32',
 './data/scddb/33',
 './data/scddb/34',
 './data/scddb/35',
 './data/scddb/36',
 './data/scddb/37',
 './data/scddb/38',
 './data/scddb/39',
 './data/scddb/40',
 './data/scddb/41',
 './data/scddb/42',
 './data/scddb/43',
 './data/scddb/44',
 './data/scddb/45',
 './data/scddb/46',
 './data/scddb/47',
 './data/scddb/48',
 './data/scddb/49',
 './data/scddb/50',
 './data/scddb/51',
 './data/scddb/52']

In [9]:
def segment_ecg_signal(ecg, fs, segment_length_sec=300):
    """
    Splits the ECG into 5-minute segments (300 seconds).
    Returns a list of segments (each is a numpy array).
    """
    samples_per_segment = int(segment_length_sec * fs)
    segments = []
    start = 0
    
    # We'll create as many 5-min segments as possible
    # If you specifically want exactly 6 segments from 30 min,
    # ensure you have a 30-min signal. Otherwise, handle partial.
    while start + samples_per_segment <= len(ecg):
        end = start + samples_per_segment
        seg = ecg[start:end]
        segments.append(seg)
        start = end
    
    return segments

In [10]:
def wavelet_denoise(ecg_segment, wavelet='sym5', level=5):
    """
    Performs wavelet decomposition (sym5) and thresholding 
    on detail coefficients to denoise.
    """
    coeffs = pywt.wavedec(ecg_segment, wavelet, level=level)
    
    for i in range(1, len(coeffs)):  # skip approximation coeff (i=0)
        threshold = 0.04 * np.max(coeffs[i])  # simplistic threshold
        coeffs[i] = pywt.threshold(coeffs[i], threshold, mode='soft')
    
    denoised = pywt.waverec(coeffs, wavelet)
    denoised = denoised[:len(ecg_segment)]
    return denoised

In [11]:
def normalize_ecg_segment(ecg_segment):
    """
    Applies Normalized Absolute Deviation (NADev).
    """
    median_val = np.median(ecg_segment)
    mad = np.median(np.abs(ecg_segment - median_val))
    if mad == 0:
        mad = 1e-8
    ecg_norm = (ecg_segment - median_val) / mad
    return ecg_norm

In [12]:
def detect_r_peaks(ecg_segment, fs):
    """
    Uses wfdb's xqrs_detect to find R-peaks (indices).
    """
    qrs_inds = processing.xqrs_detect(ecg_segment, fs=fs, verbose=False)
    return qrs_inds

def compute_rr_intervals(r_peaks, fs):
    """
    Returns array of RR intervals in seconds.
    """
    if len(r_peaks) < 2:
        return np.array([])
    rr_int = np.diff(r_peaks) / fs
    return rr_int

In [13]:
def extract_hrv_features(rr_intervals):
    """
    Returns a numpy array of the 8 features:
      [MeanRR, RMSSD, pNN50, SDRR, CVRR, NN50, MinRR, MaxRR]
    """
    if len(rr_intervals) < 2:
        return np.zeros(8)
    
    mean_rr = np.mean(rr_intervals)
    sdrr = np.std(rr_intervals)
    
    diff_rr = np.diff(rr_intervals)
    rmssd = np.sqrt(np.mean(diff_rr**2)) if len(diff_rr) > 0 else 0
    
    # 50 ms threshold => 0.05 s
    threshold = 0.05
    nn50 = np.sum(np.abs(diff_rr) > threshold)
    pnn50 = (nn50 / len(diff_rr)) * 100 if len(diff_rr) > 0 else 0
    
    cvrr = sdrr / mean_rr if mean_rr != 0 else 0
    min_rr = np.min(rr_intervals)
    max_rr = np.max(rr_intervals)
    
    return np.array([mean_rr, rmssd, pnn50, sdrr, cvrr, nn50, min_rr, max_rr])

In [14]:
def build_dataset(record_paths, label):
    """
    record_paths: list of .hea file paths (without .hea extension),
                  e.g., './SCD_data/XXXX'
    label: 0 for NSR, 1 for SCD
    
    Returns:
      X -> array of shape (num_segments, 8)  [8 HRV features]
      y -> array of labels
    """
    X, y = [], []
    
    for rec_path in record_paths:
        # 1) Load ECG and unify sampling frequency at 128 Hz
        ecg, fs = load_ecg_record(rec_path, desired_fs=128)  # Will downsample if 250 Hz
        
        # 2) Segment (5-minute blocks)
        segments = segment_ecg_signal(ecg, fs, segment_length_sec=300)
        
        for seg in segments:
            # a) Denoise
            seg_denoised = wavelet_denoise(seg)
            # b) Normalize
            seg_norm = normalize_ecg_segment(seg_denoised)
            # c) R-peak detection
            r_peaks = detect_r_peaks(seg_norm, fs)
            # d) RR intervals
            rr_int = compute_rr_intervals(r_peaks, fs)
            # e) HRV features
            features = extract_hrv_features(rr_int)
            
            X.append(features)
            y.append(label)
    
    return np.array(X), np.array(y)

In [None]:
X_scd, y_scd = build_dataset(sdd_records, label=1)