In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Feature Extraction in ECG Data

Feature extraction is a critical step in analyzing ECG data, especially for tasks like Atrial Fibrillation (AF) classification. The ECG signal is rich with information, but this information needs to be extracted into a format that can be used for machine learning models. 

## Key Features for AF Classification

### R-R Interval Features
R-R intervals, the periods between consecutive R-peaks in the ECG signal, are fundamental in assessing heart rhythm. Key features extracted from R-R intervals include:
- **RR_mean**: The average time between R-peaks, providing a basic measure of heart rate.
- **RR_std**: The standard deviation of R-R intervals, indicating the variability in heart rate, which is significant in AF detection.
- **Irregularity_index**: A measure of rhythm irregularity, calculated as the proportion of significant changes in successive R-R intervals (e.g., greater than 50ms). This index is particularly relevant for AF, where irregular heartbeats are a primary symptom.

### Frequency Domain Features
These features are derived from the power spectral density of the ECG signal and include:
- **LF (Low Frequency Power)**: Represents a mix of sympathetic and parasympathetic nervous system activity.
- **HF (High Frequency Power)**: More closely related to parasympathetic activity.
- **LF/HF Ratio**: Provides insights into the autonomic balance or stress levels, which can be altered in AF.

### Statistical Features
Simple statistical measures of the ECG signal can also provide valuable information:
- **Skewness**: Indicates the asymmetry of the ECG signal distribution. An abnormal skewness could suggest alterations in the ECG waveform.
- **Kurtosis**: Measures the 'tailedness' of the signal distribution. Extreme values might indicate anomalies in the ECG waveform.

### Heart Rate Variability (HRV) Features:

HRV measures the variation in time intervals between heartbeats and is critical for cardiac health analysis. 

- MeanNN/SDNN: The mean and standard deviation of NN intervals (normal-to-normal, which are intervals between adjacent QRS complexes).
- RMSSD: The root mean square of successive differences between normal heartbeats.
- pNN50: The percentage of successive NN intervals that differ by more than 50 ms.

In [18]:
import neurokit2 as nk
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm


In [19]:
from src.data.load_data import load_dataset, load_ecg, parse_header
from src.ecg.viz import plot_ecg
from src.config import TRAINING_DIR

dataset = load_dataset(TRAINING_DIR)

  0%|          | 0/8528 [00:00<?, ?it/s]

100%|██████████| 8528/8528 [00:00<00:00, 491589.35it/s]


In [40]:
from src.preprocessing.ecg_preprocessing import ECGPreprocessor

ecg_preprocessor = ECGPreprocessor(window_size=5000, overlap_size=1000)
processed_dataset = ecg_preprocessor.transform(dataset)                                                                                                                                 

Preprocessing ECG signals:   0%|          | 0/8528 [00:00<?, ?it/s]

Preprocessing ECG signals: 100%|██████████| 8528/8528 [00:13<00:00, 640.14it/s]


In [45]:
import numpy as np
import scipy.signal as sig
from scipy import stats

def extract_features(ecg_signal, fs):
    # ECG processing to find R-peaks and segment the signal
    _, info = nk.ecg_process(ecg_signal, sampling_rate=fs)
    
    # hrv features
    hrv_features = nk.hrv_time(info['ECG_R_Peaks'], sampling_rate=fs)
    hrv_features = hrv_features.to_dict('records')[0]
    
    # R-R Interval features
    rri = np.diff(info['ECG_R_Peaks']) / fs * 1000 # convert to ms
    rr_features = {
        'RR_mean': np.mean(rri),
        'RR_std': np.std(rri),
        'Irregularity_index': np.sum(np.abs(np.diff(rri)) > 50) / len(rri)
    }
    
    # Frequency Domain Features
    f, Pxx = sig.welch(ecg_signal, fs=fs)
    lf = np.trapz(Pxx[(f >= 0.04) & (f <= 0.15)])  # Low frequency power
    hf = np.trapz(Pxx[(f >= 0.15) & (f <= 0.4)])   # High frequency power
    freq_features = {
        'LF': lf,
        'HF': hf,
        'LF_HF_ratio': lf / hf if hf > 1e-10 else np.nan
    }

    # Statistical Features
    stat_features = {
        'Skewness': stats.skew(ecg_signal),
        'Kurtosis': stats.kurtosis(ecg_signal)
    }

    # Combine all features
    features = {
        **hrv_features,
        **rr_features,
        **freq_features,
        **stat_features
    }
    
    return features

In [46]:
def build_features(dataset):
    features = []

    for data_dict in tqdm(dataset, desc='Extracting features', total=len(dataset)):
        ecg_signal = data_dict['ecg_signal']
        patient_id = data_dict['patient_id']
        label = data_dict['label']
        
        header = parse_header(data_dict['hea_file'])
        sampling_rate = header['sample_rate']
        
        try:
            feature_dict = extract_features(ecg_signal, sampling_rate)
            feature_dict['label'] = label
            features.append(feature_dict)
        except ValueError as e:
            if str(e) == "NeuroKit error: the window cannot contain more data points than the time series. Decrease 'scale'.":
                continue
            if str(e) == "cannot convert float NaN to integer":
                continue
            else:
                plt.plot(ecg_signal)
                plt.title(f'Patient {patient_id} ECG {label}')
                plt.show()
                raise
        except ZeroDivisionError as e:
            continue
        except Exception as e:
            plt.plot(ecg_signal)
            plt.title(f'Patient {patient_id} ECG {label}')
            plt.show()
            raise
    
    return pd.DataFrame(features)

In [48]:
features_df = build_features(processed_dataset[:100])

Extracting features:   0%|          | 0/100 [00:00<?, ?it/s]

Extracting features: 100%|██████████| 100/100 [00:08<00:00, 12.18it/s]


In [53]:
features_df.isna().sum()

HRV_MeanNN            0
HRV_SDNN              0
HRV_RMSSD             0
HRV_SDSD              0
HRV_CVNN              0
HRV_CVSD              0
HRV_MedianNN          0
HRV_MadNN             0
HRV_MCVNN             0
HRV_IQRNN             0
HRV_SDRMSSD           0
HRV_Prc20NN           0
HRV_Prc80NN           0
HRV_pNN50             0
HRV_pNN20             0
HRV_MinNN             0
HRV_MaxNN             0
HRV_HTI               0
HRV_TINN              0
RR_mean               0
RR_std                0
Irregularity_index    0
Skewness              0
Kurtosis              0
label                 0
dtype: int64

In [63]:
# Transformer class to extract features from ECG signals
import numpy as np
import scipy.signal as sig
from scipy import stats
import neurokit2 as nk
from sklearn.base import BaseEstimator, TransformerMixin


class ECGFeatureExtractor(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        return self.build_features(X)
    
    def fit_transform(self, X, y=None):
        return self.build_features(X)
    
    @staticmethod
    def build_features(dataset):
        features = []

        for data_dict in tqdm(dataset, desc='Extracting features', total=len(dataset)):
            ecg_signal = data_dict['ecg_signal']
            patient_id = data_dict['patient_id']
            label = data_dict['label']
            
            header = parse_header(data_dict['hea_file'])
            sampling_rate = header['sample_rate']
            
            try:
                feature_dict = extract_features(ecg_signal, sampling_rate)
                feature_dict['label'] = label
                features.append(feature_dict)
            except ValueError as e:
                if str(e) == "NeuroKit error: the window cannot contain more data points than the time series. Decrease 'scale'.":
                    continue
                if str(e) == "cannot convert float NaN to integer":
                    continue
                else:
                    plt.plot(ecg_signal)
                    plt.title(f'Patient {patient_id} ECG {label}')
                    plt.show()
                    raise
            except ZeroDivisionError as e:
                continue
            except Exception as e:
                plt.plot(ecg_signal)
                plt.title(f'Patient {patient_id} ECG {label}')
                plt.show()
                raise
        
        feature_df = pd.DataFrame(features)
        features_df.drop(columns=['HRV_SDANN1', 'HRV_SDNNI1', 'HRV_SDANN2', 'HRV_SDNNI2', 'HRV_SDANN5', 'HRV_SDNNI5', 'LF', 'HF', 'LF_HF_ratio'], inplace=True)
        
        return feature_df
    
    @staticmethod
    def extract_features(ecg_signal, fs):
        # ECG processing to find R-peaks and segment the signal
        _, info = nk.ecg_process(ecg_signal, sampling_rate=fs)
        
        # hrv features
        hrv_features = nk.hrv_time(info['ECG_R_Peaks'], sampling_rate=fs)
        hrv_features = hrv_features.to_dict('records')[0]
        
        # R-R Interval features
        rri = np.diff(info['ECG_R_Peaks']) / fs * 1000 # convert to ms
        rr_features = {
            'RR_mean': np.mean(rri),
            'RR_std': np.std(rri),
            'Irregularity_index': np.sum(np.abs(np.diff(rri)) > 50) / len(rri)
        }
        
        # Frequency Domain Features
        f, Pxx = sig.welch(ecg_signal, fs=fs)
        lf = np.trapz(Pxx[(f >= 0.04) & (f <= 0.15)])  # Low frequency power
        hf = np.trapz(Pxx[(f >= 0.15) & (f <= 0.4)])   # High frequency power
        freq_features = {
            'LF': lf,
            'HF': hf,
            'LF_HF_ratio': lf / hf if hf > 1e-10 else np.nan
        }

        # Statistical Features
        stat_features = {
            'Skewness': stats.skew(ecg_signal),
            'Kurtosis': stats.kurtosis(ecg_signal)
        }

        # Combine all features
        features = {
            **hrv_features,
            **rr_features,
            **freq_features,
            **stat_features
        }
        
        return features

In [58]:
from sklearn.pipeline import Pipeline
from src.preprocessing.ecg_preprocessing import ECGPreprocessor

                                     
pipeline = Pipeline([
    ('preprocess', ECGPreprocessor(window_size=5000, overlap_size=1000)),
    ('extract_features', ECGFeatureExtractor())
])

features_df = pipeline.fit_transform(dataset)                                                                                

Preprocessing ECG signals:   0%|          | 0/8528 [00:00<?, ?it/s]

Preprocessing ECG signals: 100%|██████████| 8528/8528 [00:13<00:00, 628.79it/s]
Extracting features: 100%|██████████| 17613/17613 [23:44<00:00, 12.37it/s]
