# MIT-BIH ECG Classification: Two-Stage Deep Learning Pipeline

**Google Colab Version**

This notebook implements a **two-stage hierarchical deep learning pipeline** for ECG beat classification:

## Pipeline Architecture:
```
Stage 1: Binary Classification (Normal vs Abnormal)
    ‚Üì (beats classified as Abnormal)
Stage 2: Multiclass Classification (S, V, F, Q)
```

## Key Features:
- **Dual-input CNN** (ECG waveform + RR intervals)
- **Patient-wise K-Fold** cross-validation (no beat leakage)
- **Stage 1**: Optimized for **high recall** of abnormal beats
- **Stage 2**: Optimized for **Macro F1** across abnormal classes
- **Class weighting** for imbalanced data

## AAMI Superclasses:
| Class | Type | Description |
|-------|------|-------------|
| N | Normal | Normal beats |
| S | Abnormal | Supraventricular ectopic |
| V | Abnormal | Ventricular ectopic |
| F | Abnormal | Fusion beats |
| Q | Abnormal | Unknown/Paced |

## 0) Google Colab Setup

In [None]:
# ============================================================
# GOOGLE COLAB SETUP
# ============================================================

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install required packages
!pip install -q wfdb

# Check GPU availability
import tensorflow as tf
print(f'TensorFlow version: {tf.__version__}')
print(f'GPU available: {tf.config.list_physical_devices("GPU")}')

# Set memory growth for GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f'Memory growth enabled for {len(gpus)} GPU(s)')
    except RuntimeError as e:
        print(e)

print('\n‚úÖ Colab setup complete!')

## 1) Imports & Configuration

In [None]:
# ============================================================
# IMPORTS
# ============================================================

import os
import json
import warnings
from pathlib import Path
from collections import Counter

# Set random seeds
import random
SEED = 42
random.seed(SEED)

import numpy as np
np.random.seed(SEED)
import pandas as pd

import wfdb

from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, confusion_matrix, roc_auc_score, roc_curve
)

import tensorflow as tf
tf.random.set_seed(SEED)

from tensorflow import keras
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.utils import to_categorical

import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

print(f'TensorFlow: {tf.__version__}')
print(f'NumPy: {np.__version__}')
print('\n‚úÖ All imports successful!')

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

# Paths - UPDATE FOR YOUR DRIVE LOCATION
DATASET_PATH = Path('/content/drive/MyDrive/ecg2.0')
OUTPUT_PATH = Path('/content/drive/MyDrive/ecg2.0/outputs_twostage')
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

# Beat extraction parameters
SAMPLES_BEFORE = 100
SAMPLES_AFTER = 150
BEAT_LENGTH = SAMPLES_BEFORE + SAMPLES_AFTER

# Training parameters
N_FOLDS = 5
BATCH_SIZE = 128
EPOCHS = 25
PATIENCE = 8

# AAMI Superclass Mapping
AAMI_MAP = {
    'N': 'N', 'L': 'N', 'R': 'N', 'e': 'N', 'j': 'N',
    'A': 'S', 'a': 'S', 'J': 'S', 'S': 'S',
    'V': 'V', 'E': 'V',
    'F': 'F',
    '/': 'Q', 'f': 'Q', '!': 'Q', 'Q': 'Q', 'P': 'Q'
}

AAMI_CLASSES = ['N', 'S', 'V', 'F', 'Q']
ABNORMAL_CLASSES = ['S', 'V', 'F', 'Q']  # Stage 2 classes
AAMI_NAMES = {
    'N': 'Normal', 'S': 'Supraventricular', 'V': 'Ventricular',
    'F': 'Fusion', 'Q': 'Unknown/Paced'
}

print(f'Dataset path: {DATASET_PATH}')
print(f'Output path: {OUTPUT_PATH}')
print(f'Beat length: {BEAT_LENGTH}')
print(f'K-Folds: {N_FOLDS}')
print(f'Epochs: {EPOCHS}')

## 2) Data Loading & Preprocessing

In [None]:
# ============================================================
# DATA LOADING
# ============================================================

def find_records(dataset_path):
    dataset_path = Path(dataset_path)
    hea_files = list(dataset_path.rglob('*.hea'))
    records = []
    for hea_file in hea_files:
        record_path = str(hea_file.with_suffix(''))
        if hea_file.with_suffix('.dat').exists() and hea_file.with_suffix('.atr').exists():
            records.append(record_path)
    return sorted(records)

def load_record(record_path):
    try:
        record = wfdb.rdrecord(record_path)
        annotation = wfdb.rdann(record_path, 'atr')
        return {
            'record_id': Path(record_path).stem,
            'signals': record.p_signal,
            'fs': record.fs,
            'ann_samples': annotation.sample,
            'ann_symbols': annotation.symbol
        }
    except Exception as e:
        print(f'Error: {e}')
        return None

print('Loading MIT-BIH records...')
record_paths = find_records(DATASET_PATH)
print(f'Found {len(record_paths)} records')

records_data = []
for i, rp in enumerate(record_paths):
    data = load_record(rp)
    if data:
        records_data.append(data)
    if (i + 1) % 20 == 0:
        print(f'  Loaded {i + 1}/{len(record_paths)}...')

print(f'\n‚úÖ Loaded {len(records_data)} records')

In [None]:
# ============================================================
# BEAT EXTRACTION
# ============================================================

def extract_beats(record_data, samples_before=100, samples_after=150, channel=0):
    signals = record_data['signals']
    ann_samples = record_data['ann_samples']
    ann_symbols = record_data['ann_symbols']
    record_id = record_data['record_id']
    fs = record_data['fs']
    signal_length = signals.shape[0]
    beat_length = samples_before + samples_after
    
    beats, aami_labels, record_ids = [], [], []
    rr_before_list, rr_after_list = [], []
    
    for i, (sample, symbol) in enumerate(zip(ann_samples, ann_symbols)):
        if symbol not in AAMI_MAP:
            continue
        start, end = sample - samples_before, sample + samples_after
        if start < 0 or end > signal_length:
            continue
        beat = signals[start:end, channel]
        if len(beat) != beat_length:
            continue
        
        rr_b = (ann_samples[i] - ann_samples[i-1]) / fs if i > 0 else 0.8
        rr_a = (ann_samples[i+1] - ann_samples[i]) / fs if i < len(ann_samples) - 1 else 0.8
        
        beats.append(beat)
        aami_labels.append(AAMI_MAP[symbol])
        record_ids.append(record_id)
        rr_before_list.append(rr_b)
        rr_after_list.append(rr_a)
    
    return beats, aami_labels, record_ids, rr_before_list, rr_after_list

print('Extracting beats...')
all_beats, all_labels, all_record_ids = [], [], []
all_rr_before, all_rr_after = [], []

for i, record in enumerate(records_data):
    beats, labels, rids, rr_b, rr_a = extract_beats(record, SAMPLES_BEFORE, SAMPLES_AFTER)
    all_beats.extend(beats)
    all_labels.extend(labels)
    all_record_ids.extend(rids)
    all_rr_before.extend(rr_b)
    all_rr_after.extend(rr_a)
    if (i + 1) % 20 == 0:
        print(f'  Processed {i + 1}/{len(records_data)}...')

X_beats = np.array(all_beats, dtype=np.float32)
y = np.array(all_labels)
record_ids = np.array(all_record_ids)
rr_before = np.array(all_rr_before, dtype=np.float32)
rr_after = np.array(all_rr_after, dtype=np.float32)

print(f'\n‚úÖ Extracted {len(X_beats):,} beats')
print(f'X_beats: {X_beats.shape}')

In [None]:
# ============================================================
# CLASS DISTRIBUTION & BINARY LABELS
# ============================================================

# Original distribution
print('AAMI Class Distribution:')
print('-' * 50)
counts = Counter(y)
for cls in AAMI_CLASSES:
    c = counts.get(cls, 0)
    print(f'  {cls} ({AAMI_NAMES[cls]:15s}): {c:6,} ({100*c/len(y):.2f}%)')

# Create binary labels (Normal=0, Abnormal=1)
y_binary = np.array(['Abnormal' if label != 'N' else 'Normal' for label in y])

print(f'\nBinary Distribution (Stage 1):')
print('-' * 50)
binary_counts = Counter(y_binary)
for cls in ['Normal', 'Abnormal']:
    c = binary_counts[cls]
    print(f'  {cls:10s}: {c:6,} ({100*c/len(y):.2f}%)')

# Abnormal-only distribution (Stage 2)
abnormal_mask = y != 'N'
y_abnormal = y[abnormal_mask]
print(f'\nAbnormal Classes Distribution (Stage 2):')
print('-' * 50)
abn_counts = Counter(y_abnormal)
for cls in ABNORMAL_CLASSES:
    c = abn_counts.get(cls, 0)
    print(f'  {cls} ({AAMI_NAMES[cls]:15s}): {c:6,} ({100*c/len(y_abnormal):.2f}%)')

In [None]:
# ============================================================
# PATIENT-WISE DATA SPLIT
# ============================================================

def holdout_test_set(X, rr_b, rr_a, y, y_bin, rids, test_size=0.15, seed=42):
    unique_ids = np.unique(rids)
    np.random.seed(seed)
    np.random.shuffle(unique_ids)
    
    n_test = int(len(unique_ids) * test_size)
    test_ids = set(unique_ids[:n_test])
    trainval_ids = set(unique_ids[n_test:])
    
    test_mask = np.array([rid in test_ids for rid in rids])
    trainval_mask = ~test_mask
    
    return (
        X[trainval_mask], rr_b[trainval_mask], rr_a[trainval_mask],
        y[trainval_mask], y_bin[trainval_mask], rids[trainval_mask],
        X[test_mask], rr_b[test_mask], rr_a[test_mask],
        y[test_mask], y_bin[test_mask], rids[test_mask],
        trainval_ids, test_ids
    )

(X_tv, rr_b_tv, rr_a_tv, y_tv, y_bin_tv, rids_tv,
 X_test, rr_b_test, rr_a_test, y_test, y_bin_test, rids_test,
 trainval_pids, test_pids) = holdout_test_set(
    X_beats, rr_before, rr_after, y, y_binary, record_ids
)

print('=' * 60)
print('PATIENT-WISE DATA SPLIT')
print('=' * 60)
print(f'Train+Val: {len(X_tv):,} beats from {len(trainval_pids)} patients')
print(f'Test:      {len(X_test):,} beats from {len(test_pids)} patients')

# Create patient groups for K-Fold
pid_to_group = {pid: i for i, pid in enumerate(trainval_pids)}
groups_tv = np.array([pid_to_group[rid] for rid in rids_tv])

In [None]:
# ============================================================
# DATA PREPROCESSING
# ============================================================

print('Preprocessing data...')

# Reshape ECG beats to (N, T, 1)
X_tv_ecg = X_tv.reshape(-1, BEAT_LENGTH, 1)
X_test_ecg = X_test.reshape(-1, BEAT_LENGTH, 1)

# Stack RR features (N, 2)
rr_tv = np.column_stack([rr_b_tv, rr_a_tv])
rr_test = np.column_stack([rr_b_test, rr_a_test])

# Normalize ECG (per-beat z-score)
def normalize_beats(X):
    X_norm = np.zeros_like(X)
    for i in range(len(X)):
        beat = X[i, :, 0]
        m, s = np.mean(beat), np.std(beat)
        X_norm[i, :, 0] = (beat - m) / s if s > 0 else beat - m
    return X_norm

X_tv_ecg = normalize_beats(X_tv_ecg)
X_test_ecg = normalize_beats(X_test_ecg)

# Standardize RR features
rr_scaler = StandardScaler()
rr_tv_scaled = rr_scaler.fit_transform(rr_tv)
rr_test_scaled = rr_scaler.transform(rr_test)

# Encode binary labels (Stage 1)
le_binary = LabelEncoder()
le_binary.fit(['Normal', 'Abnormal'])
y_tv_bin_enc = le_binary.transform(y_bin_tv)
y_test_bin_enc = le_binary.transform(y_bin_test)

# Encode multiclass labels (Stage 2)
le_multi = LabelEncoder()
le_multi.fit(ABNORMAL_CLASSES)
NUM_ABN_CLASSES = len(ABNORMAL_CLASSES)

print(f'\n‚úÖ Preprocessing complete!')
print(f'X_tv_ecg: {X_tv_ecg.shape}')
print(f'rr_tv_scaled: {rr_tv_scaled.shape}')
print(f'Binary classes: {le_binary.classes_}')
print(f'Abnormal classes: {le_multi.classes_}')

## 3) Stage 1: Normal vs Abnormal (Binary Classification)

In [None]:
# ============================================================
# STAGE 1 MODEL DEFINITION
# ============================================================

def build_stage1_model(ecg_length, rr_features):
    """
    Binary classifier: Normal vs Abnormal
    Optimized for HIGH RECALL of abnormal beats.
    """
    # ECG Branch
    ecg_input = Input(shape=(ecg_length, 1), name='ecg_input')
    x = layers.Conv1D(32, 5, padding='same')(ecg_input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling1D(2)(x)
    x = layers.Dropout(0.2)(x)
    
    x = layers.Conv1D(64, 5, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling1D(2)(x)
    x = layers.Dropout(0.2)(x)
    
    x = layers.Conv1D(128, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.GlobalAveragePooling1D()(x)
    
    # RR Branch
    rr_input = Input(shape=(rr_features,), name='rr_input')
    y = layers.Dense(32)(rr_input)
    y = layers.BatchNormalization()(y)
    y = layers.Activation('relu')(y)
    y = layers.Dropout(0.2)(y)
    y = layers.Dense(32)(y)
    y = layers.BatchNormalization()(y)
    y = layers.Activation('relu')(y)
    
    # Concatenate
    combined = layers.Concatenate()([x, y])
    z = layers.Dense(64)(combined)
    z = layers.BatchNormalization()(z)
    z = layers.Activation('relu')(z)
    z = layers.Dropout(0.3)(z)
    
    # Output (sigmoid for binary)
    output = layers.Dense(1, activation='sigmoid', name='output')(z)
    
    model = Model(inputs=[ecg_input, rr_input], outputs=output, name='Stage1_Binary')
    return model

def compile_stage1(model, lr=0.001):
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
    return model

# Build and show
model_s1 = build_stage1_model(BEAT_LENGTH, 2)
model_s1 = compile_stage1(model_s1)
model_s1.summary()

In [None]:
# ============================================================
# STAGE 1: K-FOLD CROSS-VALIDATION
# ============================================================

def clear_session():
    keras.backend.clear_session()
    tf.random.set_seed(SEED)
    np.random.seed(SEED)

sgkf = StratifiedGroupKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

s1_results = {'metrics': [], 'histories': []}

print('=' * 70)
print(f'STAGE 1: BINARY CLASSIFICATION ({N_FOLDS}-FOLD CV)')
print('=' * 70)

for fold, (train_idx, val_idx) in enumerate(sgkf.split(X_tv_ecg, y_tv_bin_enc, groups_tv)):
    print(f'\n--- FOLD {fold+1}/{N_FOLDS} ---')
    clear_session()
    
    # Split
    X_tr_ecg, X_vl_ecg = X_tv_ecg[train_idx], X_tv_ecg[val_idx]
    rr_tr, rr_vl = rr_tv_scaled[train_idx], rr_tv_scaled[val_idx]
    y_tr, y_vl = y_tv_bin_enc[train_idx], y_tv_bin_enc[val_idx]
    
    # Class weights (favor abnormal recall)
    cw = compute_class_weight('balanced', classes=np.unique(y_tr), y=y_tr)
    # Increase abnormal weight for higher recall
    cw_dict = {0: cw[0], 1: cw[1] * 1.5}  # Boost abnormal class
    
    # Build model
    model = build_stage1_model(BEAT_LENGTH, 2)
    model = compile_stage1(model)
    
    # Train
    history = model.fit(
        [X_tr_ecg, rr_tr], y_tr,
        validation_data=([X_vl_ecg, rr_vl], y_vl),
        epochs=EPOCHS, batch_size=BATCH_SIZE,
        class_weight=cw_dict,
        callbacks=[
            EarlyStopping(monitor='val_loss', patience=PATIENCE, restore_best_weights=True),
            ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-6)
        ],
        verbose=1
    )
    
    # Evaluate
    y_pred_prob = model.predict([X_vl_ecg, rr_vl], verbose=0).flatten()
    y_pred = (y_pred_prob >= 0.5).astype(int)
    
    metrics = {
        'accuracy': accuracy_score(y_vl, y_pred),
        'precision': precision_score(y_vl, y_pred, pos_label=1),
        'recall': recall_score(y_vl, y_pred, pos_label=1),  # Abnormal recall
        'f1': f1_score(y_vl, y_pred, pos_label=1),
        'roc_auc': roc_auc_score(y_vl, y_pred_prob)
    }
    s1_results['metrics'].append(metrics)
    s1_results['histories'].append(history.history)
    
    print(f'  Acc: {metrics["accuracy"]:.4f} | Recall(Abn): {metrics["recall"]:.4f} | F1: {metrics["f1"]:.4f} | AUC: {metrics["roc_auc"]:.4f}')

# Summary
print('\n' + '=' * 70)
print('STAGE 1 CV SUMMARY')
print('=' * 70)
for m in ['accuracy', 'recall', 'f1', 'roc_auc']:
    vals = [r[m] for r in s1_results['metrics']]
    print(f'  {m:12s}: {np.mean(vals):.4f} ¬± {np.std(vals):.4f}')

## 4) Stage 2: Abnormal Classification (Multiclass)

In [None]:
# ============================================================
# STAGE 2 MODEL DEFINITION
# ============================================================

def build_stage2_model(ecg_length, rr_features, num_classes):
    """
    Multiclass classifier for abnormal classes (S, V, F, Q)
    Optimized for MACRO F1.
    """
    # ECG Branch
    ecg_input = Input(shape=(ecg_length, 1), name='ecg_input')
    x = layers.Conv1D(32, 5, padding='same')(ecg_input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling1D(2)(x)
    x = layers.Dropout(0.2)(x)
    
    x = layers.Conv1D(64, 5, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling1D(2)(x)
    x = layers.Dropout(0.2)(x)
    
    x = layers.Conv1D(128, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling1D(2)(x)
    x = layers.Dropout(0.3)(x)
    
    x = layers.Conv1D(256, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.GlobalAveragePooling1D()(x)
    
    # RR Branch
    rr_input = Input(shape=(rr_features,), name='rr_input')
    y = layers.Dense(32)(rr_input)
    y = layers.BatchNormalization()(y)
    y = layers.Activation('relu')(y)
    y = layers.Dropout(0.2)(y)
    y = layers.Dense(64)(y)
    y = layers.BatchNormalization()(y)
    y = layers.Activation('relu')(y)
    
    # Concatenate
    combined = layers.Concatenate()([x, y])
    z = layers.Dense(128)(combined)
    z = layers.BatchNormalization()(z)
    z = layers.Activation('relu')(z)
    z = layers.Dropout(0.4)(z)
    z = layers.Dense(64)(z)
    z = layers.BatchNormalization()(z)
    z = layers.Activation('relu')(z)
    z = layers.Dropout(0.3)(z)
    
    # Output (softmax for multiclass)
    output = layers.Dense(num_classes, activation='softmax', name='output')(z)
    
    model = Model(inputs=[ecg_input, rr_input], outputs=output, name='Stage2_Multiclass')
    return model

def compile_stage2(model, lr=0.001):
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

model_s2 = build_stage2_model(BEAT_LENGTH, 2, NUM_ABN_CLASSES)
model_s2 = compile_stage2(model_s2)
model_s2.summary()

In [None]:
# ============================================================
# STAGE 2: K-FOLD CROSS-VALIDATION (Abnormal beats only)
# ============================================================

# Filter to abnormal beats only
abn_mask_tv = y_tv != 'N'
X_tv_abn_ecg = X_tv_ecg[abn_mask_tv]
rr_tv_abn = rr_tv_scaled[abn_mask_tv]
y_tv_abn = y_tv[abn_mask_tv]
groups_tv_abn = groups_tv[abn_mask_tv]

# Encode abnormal labels
y_tv_abn_enc = le_multi.transform(y_tv_abn)
y_tv_abn_onehot = to_categorical(y_tv_abn_enc, num_classes=NUM_ABN_CLASSES)

print(f'Stage 2 training data: {len(X_tv_abn_ecg):,} abnormal beats')
print(f'Classes: {le_multi.classes_}')

s2_results = {'metrics': [], 'histories': []}

print('\n' + '=' * 70)
print(f'STAGE 2: ABNORMAL CLASSIFICATION ({N_FOLDS}-FOLD CV)')
print('=' * 70)

for fold, (train_idx, val_idx) in enumerate(sgkf.split(X_tv_abn_ecg, y_tv_abn_enc, groups_tv_abn)):
    print(f'\n--- FOLD {fold+1}/{N_FOLDS} ---')
    clear_session()
    
    # Split
    X_tr, X_vl = X_tv_abn_ecg[train_idx], X_tv_abn_ecg[val_idx]
    rr_tr, rr_vl = rr_tv_abn[train_idx], rr_tv_abn[val_idx]
    y_tr, y_vl = y_tv_abn_onehot[train_idx], y_tv_abn_onehot[val_idx]
    y_vl_enc = y_tv_abn_enc[val_idx]
    y_tr_enc = y_tv_abn_enc[train_idx]
    
    # Class weights - handle missing classes
    unique_classes = np.unique(y_tr_enc)
    cw = compute_class_weight('balanced', classes=unique_classes, y=y_tr_enc)
    cw_dict = {c: w for c, w in zip(unique_classes, cw)}
    # Fill missing classes with weight 1.0
    for c in range(NUM_ABN_CLASSES):
        if c not in cw_dict:
            cw_dict[c] = 1.0
    
    # Build model
    model = build_stage2_model(BEAT_LENGTH, 2, NUM_ABN_CLASSES)
    model = compile_stage2(model)
    
    # Train
    history = model.fit(
        [X_tr, rr_tr], y_tr,
        validation_data=([X_vl, rr_vl], y_vl),
        epochs=EPOCHS, batch_size=BATCH_SIZE,
        class_weight=cw_dict,
        callbacks=[
            EarlyStopping(monitor='val_loss', patience=PATIENCE, restore_best_weights=True),
            ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-6)
        ],
        verbose=1
    )
    
    # Evaluate - use explicit labels to ensure consistent array sizes
    y_pred_prob = model.predict([X_vl, rr_vl], verbose=0)
    y_pred = np.argmax(y_pred_prob, axis=1)
    
    all_labels = list(range(NUM_ABN_CLASSES))  # [0, 1, 2, 3] for S, V, F, Q
    
    metrics = {
        'accuracy': accuracy_score(y_vl_enc, y_pred),
        'macro_f1': f1_score(y_vl_enc, y_pred, average='macro', labels=all_labels, zero_division=0),
        'weighted_f1': f1_score(y_vl_enc, y_pred, average='weighted', labels=all_labels, zero_division=0),
        'per_class_f1': f1_score(y_vl_enc, y_pred, average=None, labels=all_labels, zero_division=0)
    }
    s2_results['metrics'].append(metrics)
    s2_results['histories'].append(history.history)
    
    print(f'  Acc: {metrics["accuracy"]:.4f} | Macro F1: {metrics["macro_f1"]:.4f} | Weighted F1: {metrics["weighted_f1"]:.4f}')

# Summary
print('\n' + '=' * 70)
print('STAGE 2 CV SUMMARY')
print('=' * 70)
for m in ['accuracy', 'macro_f1', 'weighted_f1']:
    vals = [r[m] for r in s2_results['metrics']]
    print(f'  {m:12s}: {np.mean(vals):.4f} ¬± {np.std(vals):.4f}')

print('\nPer-class F1 (mean):')
pcf1 = np.stack([r['per_class_f1'] for r in s2_results['metrics']])
for i, cls in enumerate(le_multi.classes_):
    print(f'  {cls} ({AAMI_NAMES[cls]:15s}): {np.mean(pcf1[:, i]):.4f} ¬± {np.std(pcf1[:, i]):.4f}')

## 5) Final Model Training & End-to-End Evaluation

In [None]:
# ============================================================
# FINAL MODEL TRAINING
# ============================================================

print('=' * 70)
print('FINAL MODEL TRAINING')
print('=' * 70)

# --- STAGE 1: Train on full Train+Val ---
print('\n--- Training Final Stage 1 Model ---')
clear_session()

cw_s1 = compute_class_weight('balanced', classes=np.unique(y_tv_bin_enc), y=y_tv_bin_enc)
cw_s1_dict = {0: cw_s1[0], 1: cw_s1[1] * 1.5}

final_s1 = build_stage1_model(BEAT_LENGTH, 2)
final_s1 = compile_stage1(final_s1)

h1 = final_s1.fit(
    [X_tv_ecg, rr_tv_scaled], y_tv_bin_enc,
    epochs=EPOCHS, batch_size=BATCH_SIZE,
    class_weight=cw_s1_dict,
    validation_split=0.1,
    callbacks=[
        EarlyStopping(monitor='val_loss', patience=PATIENCE, restore_best_weights=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-6)
    ],
    verbose=1
)

# --- STAGE 2: Train on full abnormal subset ---
print('\n--- Training Final Stage 2 Model ---')

cw_s2 = compute_class_weight('balanced', classes=np.unique(y_tv_abn_enc), y=y_tv_abn_enc)
cw_s2_dict = dict(enumerate(cw_s2))

final_s2 = build_stage2_model(BEAT_LENGTH, 2, NUM_ABN_CLASSES)
final_s2 = compile_stage2(final_s2)

h2 = final_s2.fit(
    [X_tv_abn_ecg, rr_tv_abn], y_tv_abn_onehot,
    epochs=EPOCHS, batch_size=BATCH_SIZE,
    class_weight=cw_s2_dict,
    validation_split=0.1,
    callbacks=[
        EarlyStopping(monitor='val_loss', patience=PATIENCE, restore_best_weights=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=1e-6)
    ],
    verbose=1
)

print('\n‚úÖ Final models trained!')

In [None]:
# ============================================================
# END-TO-END EVALUATION ON HELD-OUT TEST SET
# ============================================================

print('=' * 70)
print('END-TO-END EVALUATION ON TEST SET')
print('=' * 70)

# Stage 1: Predict Normal vs Abnormal
print('\n--- Stage 1: Binary Classification ---')
y_test_s1_prob = final_s1.predict([X_test_ecg, rr_test_scaled], verbose=0).flatten()
y_test_s1_pred = (y_test_s1_prob >= 0.5).astype(int)

s1_acc = accuracy_score(y_test_bin_enc, y_test_s1_pred)
s1_recall = recall_score(y_test_bin_enc, y_test_s1_pred, pos_label=1)
s1_prec = precision_score(y_test_bin_enc, y_test_s1_pred, pos_label=1)
s1_f1 = f1_score(y_test_bin_enc, y_test_s1_pred, pos_label=1)
s1_auc = roc_auc_score(y_test_bin_enc, y_test_s1_prob)

print(f'  Accuracy:       {s1_acc:.4f}')
print(f'  Recall (Abn):   {s1_recall:.4f}  <- KEY METRIC')
print(f'  Precision (Abn):{s1_prec:.4f}')
print(f'  F1 (Abn):       {s1_f1:.4f}')
print(f'  ROC-AUC:        {s1_auc:.4f}')

# Stage 2: Classify abnormal beats
print('\n--- Stage 2: Abnormal Classification ---')

# Get beats predicted as abnormal by Stage 1
abn_pred_mask = y_test_s1_pred == 1
X_test_abn_ecg = X_test_ecg[abn_pred_mask]
rr_test_abn = rr_test_scaled[abn_pred_mask]

# Get true labels for these beats
y_test_true_for_abn = y_test[abn_pred_mask]

# Predict Stage 2
if len(X_test_abn_ecg) > 0:
    y_test_s2_prob = final_s2.predict([X_test_abn_ecg, rr_test_abn], verbose=0)
    y_test_s2_pred = np.argmax(y_test_s2_prob, axis=1)
    y_test_s2_labels = le_multi.inverse_transform(y_test_s2_pred)
else:
    y_test_s2_labels = np.array([])

# Build final predictions (combining both stages)
y_final_pred = np.array(['N'] * len(y_test))
y_final_pred[abn_pred_mask] = y_test_s2_labels

# Final evaluation
print('\n--- End-to-End Results (All 5 Classes) ---')
final_acc = accuracy_score(y_test, y_final_pred)
final_macro_f1 = f1_score(y_test, y_final_pred, average='macro', labels=AAMI_CLASSES, zero_division=0)
final_weighted_f1 = f1_score(y_test, y_final_pred, average='weighted', labels=AAMI_CLASSES, zero_division=0)
final_per_class_f1 = f1_score(y_test, y_final_pred, average=None, labels=AAMI_CLASSES, zero_division=0)

print(f'\n  Accuracy:     {final_acc:.4f}')
print(f'  Macro F1:     {final_macro_f1:.4f}  <- PRIMARY METRIC')
print(f'  Weighted F1:  {final_weighted_f1:.4f}')

print('\n  Per-Class F1:')
for i, cls in enumerate(AAMI_CLASSES):
    print(f'    {cls} ({AAMI_NAMES[cls]:15s}): {final_per_class_f1[i]:.4f}')

print('\n' + '=' * 70)
print('CLASSIFICATION REPORT')
print('=' * 70)
print(classification_report(y_test, y_final_pred, labels=AAMI_CLASSES, digits=4))

## 6) Visualization

In [None]:
# ============================================================
# TRAINING CURVES
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Stage 1 Loss
axes[0, 0].plot(h1.history['loss'], label='Train', linewidth=2)
axes[0, 0].plot(h1.history['val_loss'], label='Val', linewidth=2)
axes[0, 0].set_title('Stage 1: Loss', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Stage 1 Accuracy
axes[0, 1].plot(h1.history['accuracy'], label='Train', linewidth=2)
axes[0, 1].plot(h1.history['val_accuracy'], label='Val', linewidth=2)
axes[0, 1].set_title('Stage 1: Accuracy', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Stage 2 Loss
axes[1, 0].plot(h2.history['loss'], label='Train', linewidth=2)
axes[1, 0].plot(h2.history['val_loss'], label='Val', linewidth=2)
axes[1, 0].set_title('Stage 2: Loss', fontsize=12, fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Stage 2 Accuracy
axes[1, 1].plot(h2.history['accuracy'], label='Train', linewidth=2)
axes[1, 1].plot(h2.history['val_accuracy'], label='Val', linewidth=2)
axes[1, 1].set_title('Stage 2: Accuracy', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'twostage_training_curves.png', dpi=150)
plt.show()

In [None]:
# ============================================================
# CONFUSION MATRICES
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Stage 1 CM
cm1 = confusion_matrix(y_test_bin_enc, y_test_s1_pred)
sns.heatmap(cm1, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Normal', 'Abnormal'], yticklabels=['Normal', 'Abnormal'], ax=axes[0])
axes[0].set_title(f'Stage 1: Normal vs Abnormal\nRecall(Abn): {s1_recall:.3f}', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')

# Stage 2 CM (only beats predicted as abnormal by Stage 1)
# Filter to true abnormal beats that were correctly flagged
true_abn_mask = y_test_true_for_abn != 'N'
if np.sum(true_abn_mask) > 0:
    y_true_s2 = le_multi.transform(y_test_true_for_abn[true_abn_mask])
    y_pred_s2 = y_test_s2_pred[true_abn_mask]
    cm2 = confusion_matrix(y_true_s2, y_pred_s2, labels=range(NUM_ABN_CLASSES))
    sns.heatmap(cm2, annot=True, fmt='d', cmap='Oranges',
                xticklabels=le_multi.classes_, yticklabels=le_multi.classes_, ax=axes[1])
    axes[1].set_title('Stage 2: Abnormal Classes Only', fontsize=12, fontweight='bold')
    axes[1].set_xlabel('Predicted')
    axes[1].set_ylabel('True')
else:
    axes[1].text(0.5, 0.5, 'No abnormal beats', ha='center', va='center')
    axes[1].set_title('Stage 2')

# End-to-End CM
le_all = LabelEncoder()
le_all.fit(AAMI_CLASSES)
y_true_enc = le_all.transform(y_test)
y_pred_enc = le_all.transform(y_final_pred)
cm_all = confusion_matrix(y_true_enc, y_pred_enc, labels=range(5))
cm_all_norm = cm_all.astype('float') / cm_all.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_all_norm, annot=True, fmt='.2%', cmap='Greens',
            xticklabels=AAMI_CLASSES, yticklabels=AAMI_CLASSES, ax=axes[2])
axes[2].set_title(f'End-to-End (Normalized)\nMacro F1: {final_macro_f1:.3f}', fontsize=12, fontweight='bold')
axes[2].set_xlabel('Predicted')
axes[2].set_ylabel('True')

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'twostage_confusion_matrices.png', dpi=150)
plt.show()

In [None]:
# ============================================================
# ROC CURVE (Stage 1)
# ============================================================

fig, ax = plt.subplots(figsize=(8, 8))

fpr, tpr, _ = roc_curve(y_test_bin_enc, y_test_s1_prob)
ax.plot(fpr, tpr, linewidth=2, label=f'Stage 1 (AUC = {s1_auc:.3f})')
ax.plot([0, 1], [0, 1], 'k--', linewidth=1)
ax.set_xlabel('False Positive Rate', fontsize=12)
ax.set_ylabel('True Positive Rate', fontsize=12)
ax.set_title('Stage 1: ROC Curve (Normal vs Abnormal)', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'twostage_roc_curve.png', dpi=150)
plt.show()

In [None]:
# ============================================================
# PER-CLASS F1 BAR CHART
# ============================================================

fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(AAMI_CLASSES))
colors = ['#2ecc71' if cls == 'N' else '#e74c3c' for cls in AAMI_CLASSES]

bars = ax.bar(x, final_per_class_f1, color=colors, alpha=0.8)

for bar, f1 in zip(bars, final_per_class_f1):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{f1:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

ax.set_xticks(x)
ax.set_xticklabels([f'{c}\n({AAMI_NAMES[c]})' for c in AAMI_CLASSES])
ax.set_ylabel('F1 Score', fontsize=12)
ax.set_title(f'End-to-End Per-Class F1 Scores\nMacro F1: {final_macro_f1:.4f}', fontsize=14, fontweight='bold')
ax.set_ylim(0, 1.1)
ax.axhline(y=final_macro_f1, color='black', linestyle='--', linewidth=2, label=f'Macro F1 = {final_macro_f1:.3f}')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'twostage_per_class_f1.png', dpi=150)
plt.show()

## 7) Save Models & Summary

In [None]:
# ============================================================
# SAVE ARTIFACTS
# ============================================================

import joblib

print('Saving models and artifacts...')

final_s1.save(OUTPUT_PATH / 'stage1_model.keras')
final_s2.save(OUTPUT_PATH / 'stage2_model.keras')
joblib.dump(rr_scaler, OUTPUT_PATH / 'rr_scaler.joblib')
joblib.dump(le_binary, OUTPUT_PATH / 'le_binary.joblib')
joblib.dump(le_multi, OUTPUT_PATH / 'le_multi.joblib')

# Save metrics report
report = {
    'stage1_cv': {
        'accuracy': float(np.mean([r['accuracy'] for r in s1_results['metrics']])),
        'recall_abnormal': float(np.mean([r['recall'] for r in s1_results['metrics']])),
        'roc_auc': float(np.mean([r['roc_auc'] for r in s1_results['metrics']]))
    },
    'stage2_cv': {
        'accuracy': float(np.mean([r['accuracy'] for r in s2_results['metrics']])),
        'macro_f1': float(np.mean([r['macro_f1'] for r in s2_results['metrics']]))
    },
    'test_results': {
        'stage1_recall': float(s1_recall),
        'stage1_auc': float(s1_auc),
        'final_accuracy': float(final_acc),
        'final_macro_f1': float(final_macro_f1),
        'final_weighted_f1': float(final_weighted_f1),
        'per_class_f1': {c: float(f) for c, f in zip(AAMI_CLASSES, final_per_class_f1)}
    }
}

with open(OUTPUT_PATH / 'twostage_metrics.json', 'w') as f:
    json.dump(report, f, indent=2)

print('\n‚úÖ All artifacts saved!')
print(f'   Location: {OUTPUT_PATH}')

In [None]:
# ============================================================
# FINAL SUMMARY
# ============================================================

print('\n' + '=' * 70)
print('üéØ TWO-STAGE PIPELINE SUMMARY')
print('=' * 70)

print(f'''
ARCHITECTURE:
  Stage 1: Binary (Normal vs Abnormal) ‚Üí Sigmoid
  Stage 2: Multiclass (S, V, F, Q) ‚Üí Softmax
  Input: Dual-input CNN (ECG waveform + RR intervals)

CROSS-VALIDATION RESULTS ({N_FOLDS}-Fold):
  Stage 1 (Binary):
    Recall (Abnormal): {np.mean([r["recall"] for r in s1_results["metrics"]]):.4f} ¬± {np.std([r["recall"] for r in s1_results["metrics"]]):.4f}
    ROC-AUC:           {np.mean([r["roc_auc"] for r in s1_results["metrics"]]):.4f} ¬± {np.std([r["roc_auc"] for r in s1_results["metrics"]]):.4f}
  
  Stage 2 (Multiclass):
    Macro F1:          {np.mean([r["macro_f1"] for r in s2_results["metrics"]]):.4f} ¬± {np.std([r["macro_f1"] for r in s2_results["metrics"]]):.4f}

TEST SET RESULTS:
  Stage 1:
    Recall (Abnormal): {s1_recall:.4f}  ‚Üê Catches abnormal beats
    ROC-AUC:           {s1_auc:.4f}
  
  End-to-End (All 5 Classes):
    Accuracy:          {final_acc:.4f}
    Macro F1:          {final_macro_f1:.4f}  ‚Üê PRIMARY METRIC
    Weighted F1:       {final_weighted_f1:.4f}

PER-CLASS F1 SCORES:''')

for i, cls in enumerate(AAMI_CLASSES):
    marker = '‚Üê Normal' if cls == 'N' else ''
    print(f'    {cls} ({AAMI_NAMES[cls]:15s}): {final_per_class_f1[i]:.4f} {marker}')

print(f'''
OUTPUT FILES:
    {OUTPUT_PATH / 'stage1_model.keras'}
    {OUTPUT_PATH / 'stage2_model.keras'}
    {OUTPUT_PATH / 'twostage_metrics.json'}
    {OUTPUT_PATH / 'twostage_*.png'}

‚úÖ Two-stage pipeline complete!
''')