# MIT-BIH R-Peak Detection - Multiple Methods

**Author:** Kumar Abhishek  
**Date:** October 2025

This notebook implements **4 different methods** for R-peak detection:
1. Pan-Tompkins Algorithm
2. Wavelet Transform Detection
3. 1D CNN (Deep Learning)
4. LSTM (Deep Learning)

## Setup: Install Required Packages

In [None]:
# Uncomment and run if packages are not installed
# !pip install -r requirements.txt

import numpy as np
import matplotlib.pyplot as plt
import wfdb
import scipy.signal as signal
import pywt
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

print("TensorFlow version:", tf.__version__)
print("Setup complete!")

## 1. Data Loading Functions

In [None]:
def load_mitbih_record(record_name='100', sampfrom=0, sampto=None):
    """Load MIT-BIH record with annotations"""
    record = wfdb.rdrecord(record_name, pn_dir='mitdb', sampfrom=sampfrom, sampto=sampto)
    annotation = wfdb.rdann(record_name, 'atr', pn_dir='mitdb', sampfrom=sampfrom, sampto=sampto)
    ecg_signal = record.p_signal[:, 0]
    fs = record.fs
    r_peaks = annotation.sample
    return ecg_signal, r_peaks, fs

def load_multiple_records(record_list, duration=10000):
    """Load multiple records for training"""
    all_signals = []
    all_peaks = []
    for record in record_list:
        try:
            sig, peaks, fs = load_mitbih_record(record, sampto=duration)
            all_signals.append(sig)
            all_peaks.append(peaks)
            print(f"Loaded record {record}: {len(sig)} samples, {len(peaks)} peaks")
        except Exception as e:
            print(f"Error loading {record}: {e}")
    return all_signals, all_peaks, fs

## 2. Pan-Tompkins Algorithm

In [None]:
class PanTompkinsDetector:
    def __init__(self, fs=360):
        self.fs = fs
    
    def bandpass_filter(self, signal_data):
        from scipy.signal import filtfilt
        b_low = np.array([1, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 1])
        a_low = np.array([1, -2, 1])
        b_high = np.array([-1] + [0]*15 + [32, -32] + [0]*14 + [1])
        a_high = np.array([1, -1])
        filtered = filtfilt(b_low, a_low, signal_data)
        filtered = filtfilt(b_high, a_high, filtered)
        return filtered
    
    def detect(self, ecg_signal):
        filtered = self.bandpass_filter(ecg_signal)
        derivative = np.diff(filtered)
        squared = derivative ** 2
        window_size = int(0.150 * self.fs)
        window = np.ones(window_size) / window_size
        integrated = np.convolve(squared, window, mode='same')
        
        peaks = []
        threshold = np.max(integrated) * 0.5
        refractory = int(0.2 * self.fs)
        
        for i in range(1, len(integrated) - 1):
            if integrated[i] > threshold and integrated[i] > integrated[i-1] and integrated[i] > integrated[i+1]:
                if not peaks or (i - peaks[-1]) > refractory:
                    peaks.append(i)
        return np.array(peaks)

## 3. Wavelet-Based Detection

In [None]:
class WaveletDetector:
    def __init__(self, fs=360, wavelet='db4', level=4):
        self.fs = fs
        self.wavelet = wavelet
        self.level = level
    
    def detect(self, ecg_signal):
        coeffs = pywt.swt(ecg_signal, self.wavelet, level=self.level)
        detail_coeffs = coeffs[2][1]
        squared = detail_coeffs ** 2
        window_size = int(0.120 * self.fs)
        smoothed = np.convolve(squared, np.ones(window_size)/window_size, mode='same')
        threshold = np.mean(smoothed) + 0.5 * np.std(smoothed)
        
        peaks = []
        refractory = int(0.2 * self.fs)
        for i in range(1, len(smoothed) - 1):
            if smoothed[i] > threshold and smoothed[i] > smoothed[i-1] and smoothed[i] > smoothed[i+1]:
                if not peaks or (i - peaks[-1]) > refractory:
                    search_window = 20
                    start = max(0, i - search_window)
                    end = min(len(ecg_signal), i + search_window)
                    local_max = start + np.argmax(ecg_signal[start:end])
                    peaks.append(local_max)
        return np.array(peaks)

## 4. Deep Learning Models

In [None]:
def create_cnn_model(input_shape=(300, 1)):
    model = models.Sequential([
        layers.Conv1D(32, 5, activation='relu', input_shape=input_shape),
        layers.BatchNormalization(),
        layers.MaxPooling1D(2),
        layers.Dropout(0.2),
        layers.Conv1D(64, 5, activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling1D(2),
        layers.Dropout(0.2),
        layers.Conv1D(128, 3, activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling1D(2),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(64, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy',
                  metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
    return model

def create_lstm_model(input_shape=(300, 1)):
    model = models.Sequential([
        layers.LSTM(64, return_sequences=True, input_shape=input_shape),
        layers.Dropout(0.3),
        layers.LSTM(32, return_sequences=True),
        layers.Dropout(0.3),
        layers.LSTM(16),
        layers.Dropout(0.3),
        layers.Dense(32, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy',
                  metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
    return model

## 5. Data Preparation

In [None]:
def create_training_data(signals, all_peaks, window_size=150, fs=360):
    X, y = [], []
    for signal_data, peaks in zip(signals, all_peaks):
        signal_norm = (signal_data - np.mean(signal_data)) / np.std(signal_data)
        for peak in peaks:
            start, end = peak - window_size, peak + window_size
            if start >= 0 and end < len(signal_norm):
                X.append(signal_norm[start:end])
                y.append(1)
        num_negative = len(peaks)
        for _ in range(num_negative):
            while True:
                idx = np.random.randint(window_size, len(signal_norm) - window_size)
                if np.min(np.abs(peaks - idx)) > int(0.3 * fs):
                    break
            X.append(signal_norm[idx-window_size:idx+window_size])
            y.append(0)
    return np.array(X).reshape(-1, window_size*2, 1), np.array(y)

## 6. Evaluation Metrics

In [None]:
def calculate_metrics(detected_peaks, true_peaks, tolerance=50):
    TP, FP, matched_true = 0, 0, set()
    for detected in detected_peaks:
        distances = np.abs(true_peaks - detected)
        min_idx = np.argmin(distances)
        if distances[min_idx] <= tolerance and min_idx not in matched_true:
            TP += 1
            matched_true.add(min_idx)
        else:
            FP += 1
    FN = len(true_peaks) - TP
    sens = TP / (TP + FN) if (TP + FN) > 0 else 0
    prec = TP / (TP + FP) if (TP + FP) > 0 else 0
    f1 = 2 * (prec * sens) / (prec + sens) if (prec + sens) > 0 else 0
    return {'TP': TP, 'FP': FP, 'FN': FN, 'Sensitivity': sens*100, 
            'Precision': prec*100, 'F1-Score': f1*100}

def print_metrics(metrics, name):
    print(f"\n{'='*50}\n{name}\n{'='*50}")
    for k, v in metrics.items():
        if isinstance(v, float):
            print(f"{k}: {v:.2f}%")
        else:
            print(f"{k}: {v}")

## 7. Visualization Functions

In [None]:
def plot_detection_results(ecg, detected, true, name, fs=360, dur=10):
    samples = int(dur * fs)
    time = np.arange(samples) / fs
    plt.figure(figsize=(15, 6))
    plt.plot(time, ecg[:samples], 'k-', linewidth=0.8, label='ECG')
    true_in = true[true < samples]
    det_in = detected[detected < samples]
    plt.scatter(true_in/fs, ecg[true_in], c='green', s=100, marker='o', label='True', zorder=5)
    plt.scatter(det_in/fs, ecg[det_in], c='red', s=50, marker='x', label='Detected', zorder=4)
    plt.xlabel('Time (s)'); plt.ylabel('Amplitude')
    plt.title(f'{name}', fontweight='bold')
    plt.legend(); plt.grid(alpha=0.3); plt.tight_layout(); plt.show()

def plot_comparison(ecg, methods, true, fs=360, dur=5):
    samples = int(dur * fs)
    time = np.arange(samples) / fs
    fig, axes = plt.subplots(len(methods)+1, 1, figsize=(15, 3*(len(methods)+1)))
    axes[0].plot(time, ecg[:samples], 'k-', linewidth=0.8)
    true_in = true[true < samples]
    axes[0].scatter(true_in/fs, ecg[true_in], c='green', s=100, label='True')
    axes[0].set_title('Ground Truth', fontweight='bold')
    axes[0].legend(); axes[0].grid(alpha=0.3)
    for i, (name, det) in enumerate(methods.items(), 1):
        axes[i].plot(time, ecg[:samples], 'k-', alpha=0.5, linewidth=0.8)
        det_in = det[det < samples]
        axes[i].scatter(det_in/fs, ecg[det_in], c='red', s=50, marker='x', label='Detected')
        axes[i].scatter(true_in/fs, ecg[true_in], c='green', s=30, alpha=0.3)
        axes[i].set_title(name, fontweight='bold')
        axes[i].legend(); axes[i].grid(alpha=0.3)
    axes[-1].set_xlabel('Time (s)')
    plt.tight_layout(); plt.show()

## 8. Main Execution Pipeline

In [None]:
print("="*70)
print("MIT-BIH R-PEAK DETECTION - STARTING...")
print("="*70)

# Load test data
print("\n[1/6] Loading test data...")
ecg_signal, true_peaks, fs = load_mitbih_record('100', sampto=650000)
print(f"Loaded: {len(ecg_signal)} samples, {len(true_peaks)} peaks, {fs}Hz")

## 9. Test Classical Methods

In [None]:
# Pan-Tompkins
print("\n[2/6] Testing Pan-Tompkins...")
pt_detector = PanTompkinsDetector(fs=fs)
pt_peaks = pt_detector.detect(ecg_signal)
pt_metrics = calculate_metrics(pt_peaks, true_peaks)
print_metrics(pt_metrics, "Pan-Tompkins")

# Wavelet
print("\n[3/6] Testing Wavelet...")
wt_detector = WaveletDetector(fs=fs)
wt_peaks = wt_detector.detect(ecg_signal)
wt_metrics = calculate_metrics(wt_peaks, true_peaks)
print_metrics(wt_metrics, "Wavelet Transform")

## 10. Prepare Deep Learning Data

In [None]:
print("\n[4/6] Preparing deep learning data...")
train_records = ['100', '101', '103', '105', '106', '108', '109', '111', '112', '113']
train_signals, train_peaks, _ = load_multiple_records(train_records, duration=50000)
X, y = create_training_data(train_signals, train_peaks, window_size=150, fs=fs)
print(f"Dataset: {X.shape[0]} samples ({np.sum(y)} positive, {len(y)-np.sum(y)} negative)")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
print(f"Train: {len(X_train)} | Test: {len(X_test)}")

## 11. Train CNN Model

In [None]:
print("\n[5/6] Training CNN...")
cnn_model = create_cnn_model(input_shape=(X.shape[1], 1))
history_cnn = cnn_model.fit(X_train, y_train, validation_data=(X_test, y_test),
                             epochs=20, batch_size=64, verbose=1)
cnn_results = cnn_model.evaluate(X_test, y_test, verbose=0)
print(f"\nCNN Results: Acc={cnn_results[1]*100:.2f}%, Prec={cnn_results[2]*100:.2f}%, Rec={cnn_results[3]*100:.2f}%")

## 12. Train LSTM Model

In [None]:
print("\n[6/6] Training LSTM...")
lstm_model = create_lstm_model(input_shape=(X.shape[1], 1))
history_lstm = lstm_model.fit(X_train, y_train, validation_data=(X_test, y_test),
                               epochs=20, batch_size=64, verbose=1)
lstm_results = lstm_model.evaluate(X_test, y_test, verbose=0)
print(f"\nLSTM Results: Acc={lstm_results[1]*100:.2f}%, Prec={lstm_results[2]*100:.2f}%, Rec={lstm_results[3]*100:.2f}%")

## 13. Visualize Results

In [None]:
# Plot individual results
plot_detection_results(ecg_signal, pt_peaks, true_peaks, "Pan-Tompkins", fs, 10)
plot_detection_results(ecg_signal, wt_peaks, true_peaks, "Wavelet Transform", fs, 10)

# Comparison plot
methods_results = {'Pan-Tompkins': pt_peaks, 'Wavelet': wt_peaks}
plot_comparison(ecg_signal, methods_results, true_peaks, fs, 5)

## 14. Training History Plots

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes[0,0].plot(history_cnn.history['loss'], label='Train')
axes[0,0].plot(history_cnn.history['val_loss'], label='Val')
axes[0,0].set_title('CNN - Loss', fontweight='bold')
axes[0,0].legend(); axes[0,0].grid(alpha=0.3)

axes[0,1].plot(history_cnn.history['accuracy'], label='Train')
axes[0,1].plot(history_cnn.history['val_accuracy'], label='Val')
axes[0,1].set_title('CNN - Accuracy', fontweight='bold')
axes[0,1].legend(); axes[0,1].grid(alpha=0.3)

axes[1,0].plot(history_lstm.history['loss'], label='Train')
axes[1,0].plot(history_lstm.history['val_loss'], label='Val')
axes[1,0].set_title('LSTM - Loss', fontweight='bold')
axes[1,0].legend(); axes[1,0].grid(alpha=0.3)

axes[1,1].plot(history_lstm.history['accuracy'], label='Train')
axes[1,1].plot(history_lstm.history['val_accuracy'], label='Val')
axes[1,1].set_title('LSTM - Accuracy', fontweight='bold')
axes[1,1].legend(); axes[1,1].grid(alpha=0.3)

plt.tight_layout(); plt.show()

## 15. Save Models

In [None]:
cnn_model.save('rpeak_cnn_model.h5')
lstm_model.save('rpeak_lstm_model.h5')
print("Models saved successfully!")
print("\n" + "="*70)
print("ALL PROCESSES COMPLETED!")
print("="*70)

## 16. Summary Report

In [None]:
print("\nFINAL SUMMARY:")
print("="*70)
print("\nClassical Methods:")
print(f"  Pan-Tompkins: F1={pt_metrics['F1-Score']:.2f}%, Sens={pt_metrics['Sensitivity']:.2f}%")
print(f"  Wavelet:      F1={wt_metrics['F1-Score']:.2f}%, Sens={wt_metrics['Sensitivity']:.2f}%")
print("\nDeep Learning Models:")
print(f"  CNN:  Acc={cnn_results[1]*100:.2f}%, Prec={cnn_results[2]*100:.2f}%, Rec={cnn_results[3]*100:.2f}%")
print(f"  LSTM: Acc={lstm_results[1]*100:.2f}%, Prec={lstm_results[2]*100:.2f}%, Rec={lstm_results[3]*100:.2f}%")
print("="*70)