# Abnormal ECG Beat Classification - ML Model Tester

**Standalone notebook for testing classical ML models on abnormal beats only**

## Purpose
This notebook is a **tester/sandbox** for comparing different ML approaches to classifying abnormal ECG beats (S, V, F, Q) before integrating into a full two-stage pipeline.

## Models Compared
1. **Random Forest** - Ensemble of decision trees
2. **XGBoost** - Gradient boosted trees
3. **AdaBoost** - Adaptive boosting with decision trees
4. **SVM** - Support Vector Machine (RBF kernel)
5. **Logistic Regression** - Multinomial classifier

## Key Features
- Engineered features (morphology + RR intervals)
- Patient-wise splitting (no data leakage)
- **SMOTE** for class balancing on training data
- Side-by-side model comparison

## AAMI Abnormal Classes
| Code | Name | Description |
|------|------|-------------|
| S | Supraventricular | Atrial/junctional ectopic beats |
| V | Ventricular | Ventricular ectopic beats |
| F | Fusion | Fusion of normal and ventricular |
| Q | Unknown | Paced/unclassifiable beats |

## 0) Google Colab Setup

In [None]:
# ============================================================
# GOOGLE COLAB SETUP
# ============================================================

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install required packages
!pip install -q wfdb xgboost imbalanced-learn

print('\n‚úÖ Colab setup complete!')

## 1) Imports & Configuration

In [None]:
# ============================================================
# IMPORTS
# ============================================================

import os
import json
import warnings
from pathlib import Path
from collections import Counter
from time import time

import numpy as np
np.random.seed(42)
import pandas as pd

import wfdb

from sklearn.model_selection import StratifiedGroupKFold, train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, confusion_matrix, roc_auc_score, roc_curve
)

# ML Models
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
import xgboost as xgb

# SMOTE for oversampling
from imblearn.over_sampling import SMOTE

import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

print('‚úÖ All imports successful!')

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

SEED = 42
np.random.seed(SEED)

# Paths - UPDATE FOR YOUR DRIVE LOCATION
DATASET_PATH = Path('/content/drive/MyDrive/ecg2.0')
OUTPUT_PATH = Path('/content/drive/MyDrive/ecg2.0/outputs_ml_tester')
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

# Beat extraction parameters
SAMPLES_BEFORE = 100
SAMPLES_AFTER = 150
BEAT_LENGTH = SAMPLES_BEFORE + SAMPLES_AFTER

# K-Fold parameters
N_FOLDS = 5

# AAMI Mapping
AAMI_MAP = {
    'N': 'N', 'L': 'N', 'R': 'N', 'e': 'N', 'j': 'N',
    'A': 'S', 'a': 'S', 'J': 'S', 'S': 'S',
    'V': 'V', 'E': 'V',
    'F': 'F',
    '/': 'Q', 'f': 'Q', '!': 'Q', 'Q': 'Q', 'P': 'Q'
}

# Abnormal classes only (no Normal)
ABNORMAL_CLASSES = ['S', 'V', 'F', 'Q']
AAMI_NAMES = {
    'S': 'Supraventricular', 'V': 'Ventricular',
    'F': 'Fusion', 'Q': 'Unknown/Paced'
}

print(f'Dataset path: {DATASET_PATH}')
print(f'Output path: {OUTPUT_PATH}')
print(f'Abnormal classes: {ABNORMAL_CLASSES}')

## 2) Data Loading & Beat Extraction

In [None]:
# ============================================================
# DATA LOADING
# ============================================================

def find_records(dataset_path):
    """Find all valid MIT-BIH records."""
    dataset_path = Path(dataset_path)
    hea_files = list(dataset_path.rglob('*.hea'))
    records = []
    for hea_file in hea_files:
        record_path = str(hea_file.with_suffix(''))
        if hea_file.with_suffix('.dat').exists() and hea_file.with_suffix('.atr').exists():
            records.append(record_path)
    return sorted(records)

def load_record(record_path):
    """Load a single MIT-BIH record."""
    try:
        record = wfdb.rdrecord(record_path)
        annotation = wfdb.rdann(record_path, 'atr')
        return {
            'record_id': Path(record_path).stem,
            'signals': record.p_signal,
            'fs': record.fs,
            'ann_samples': annotation.sample,
            'ann_symbols': annotation.symbol
        }
    except Exception as e:
        print(f'Error loading {record_path}: {e}')
        return None

print('Loading MIT-BIH records...')
record_paths = find_records(DATASET_PATH)
print(f'Found {len(record_paths)} records')

records_data = []
for i, rp in enumerate(record_paths):
    data = load_record(rp)
    if data:
        records_data.append(data)
    if (i + 1) % 20 == 0:
        print(f'  Loaded {i + 1}/{len(record_paths)}...')

print(f'\n‚úÖ Loaded {len(records_data)} records')

In [None]:
# ============================================================
# BEAT EXTRACTION (ABNORMAL ONLY)
# ============================================================

def extract_abnormal_beats(record_data, samples_before=100, samples_after=150, channel=0):
    """Extract only ABNORMAL beats from a record."""
    signals = record_data['signals']
    ann_samples = record_data['ann_samples']
    ann_symbols = record_data['ann_symbols']
    record_id = record_data['record_id']
    fs = record_data['fs']
    signal_length = signals.shape[0]
    beat_length = samples_before + samples_after
    
    beats, labels, record_ids = [], [], []
    rr_before_list, rr_after_list = [], []
    
    for i, (sample, symbol) in enumerate(zip(ann_samples, ann_symbols)):
        if symbol not in AAMI_MAP:
            continue
        aami_class = AAMI_MAP[symbol]
        
        # Skip Normal beats - we only want abnormal
        if aami_class == 'N':
            continue
        
        start, end = sample - samples_before, sample + samples_after
        if start < 0 or end > signal_length:
            continue
        
        beat = signals[start:end, channel]
        if len(beat) != beat_length:
            continue
        
        # RR intervals
        rr_b = (ann_samples[i] - ann_samples[i-1]) / fs if i > 0 else 0.8
        rr_a = (ann_samples[i+1] - ann_samples[i]) / fs if i < len(ann_samples) - 1 else 0.8
        
        beats.append(beat)
        labels.append(aami_class)
        record_ids.append(record_id)
        rr_before_list.append(rr_b)
        rr_after_list.append(rr_a)
    
    return beats, labels, record_ids, rr_before_list, rr_after_list

print('Extracting ABNORMAL beats only...')
all_beats, all_labels, all_record_ids = [], [], []
all_rr_before, all_rr_after = [], []

for i, record in enumerate(records_data):
    beats, labels, rids, rr_b, rr_a = extract_abnormal_beats(
        record, SAMPLES_BEFORE, SAMPLES_AFTER
    )
    all_beats.extend(beats)
    all_labels.extend(labels)
    all_record_ids.extend(rids)
    all_rr_before.extend(rr_b)
    all_rr_after.extend(rr_a)
    if (i + 1) % 20 == 0:
        print(f'  Processed {i + 1}/{len(records_data)}...')

X_abnormal = np.array(all_beats, dtype=np.float32)
y_abnormal = np.array(all_labels)
record_ids_abnormal = np.array(all_record_ids)
rr_before = np.array(all_rr_before, dtype=np.float32)
rr_after = np.array(all_rr_after, dtype=np.float32)

print(f'\n‚úÖ Extracted {len(X_abnormal):,} ABNORMAL beats')
print(f'X_abnormal shape: {X_abnormal.shape}')

In [None]:
# ============================================================
# CLASS DISTRIBUTION
# ============================================================

print('=' * 60)
print('ABNORMAL CLASS DISTRIBUTION')
print('=' * 60)

counts = Counter(y_abnormal)
total = len(y_abnormal)

df_dist = pd.DataFrame([
    {'Class': cls, 'Name': AAMI_NAMES[cls], 'Count': counts.get(cls, 0), 
     'Percentage': f"{100*counts.get(cls,0)/total:.2f}%"}
    for cls in ABNORMAL_CLASSES
])
print(df_dist.to_string(index=False))

# Visualize
fig, ax = plt.subplots(figsize=(8, 5))
colors = ['#e74c3c', '#3498db', '#9b59b6', '#f39c12']
bars = ax.bar(df_dist['Class'], df_dist['Count'], color=colors, alpha=0.8)
for bar, count in zip(bars, df_dist['Count']):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
            f'{count:,}', ha='center', va='bottom', fontsize=11, fontweight='bold')
ax.set_xlabel('Abnormal Class')
ax.set_ylabel('Count')
ax.set_title('Distribution of Abnormal ECG Beat Classes', fontweight='bold')
plt.tight_layout()
plt.show()

print(f'\nNote: Class imbalance detected. Will use class weighting.')

## 3) Feature Engineering

In [None]:
# ============================================================
# FEATURE ENGINEERING
# ============================================================

def extract_features(beat, rr_b, rr_a, fs=360):
    """
    Extract engineered features from a single beat.
    
    Features include:
    - Statistical: mean, std, min, max, range, skew, kurtosis
    - Morphological: peak amplitude, peak position, QRS energy
    - RR intervals: before, after, ratio, local average
    - Frequency domain: dominant frequency, spectral entropy
    """
    features = {}
    
    # --- Statistical Features ---
    features['mean'] = np.mean(beat)
    features['std'] = np.std(beat)
    features['min'] = np.min(beat)
    features['max'] = np.max(beat)
    features['range'] = np.max(beat) - np.min(beat)
    features['median'] = np.median(beat)
    
    # Skewness and kurtosis
    centered = beat - np.mean(beat)
    std = np.std(beat)
    if std > 0:
        features['skewness'] = np.mean((centered / std) ** 3)
        features['kurtosis'] = np.mean((centered / std) ** 4) - 3
    else:
        features['skewness'] = 0
        features['kurtosis'] = 0
    
    # --- Morphological Features ---
    # R-peak (assumed at center, samples_before=100)
    r_peak_idx = 100
    features['r_amplitude'] = beat[r_peak_idx]
    features['r_peak_pos'] = r_peak_idx
    
    # QRS region (roughly 40 samples around R-peak)
    qrs_start = max(0, r_peak_idx - 20)
    qrs_end = min(len(beat), r_peak_idx + 20)
    qrs_segment = beat[qrs_start:qrs_end]
    features['qrs_energy'] = np.sum(qrs_segment ** 2)
    features['qrs_duration'] = qrs_end - qrs_start
    features['qrs_max'] = np.max(qrs_segment)
    features['qrs_min'] = np.min(qrs_segment)
    features['qrs_range'] = features['qrs_max'] - features['qrs_min']
    
    # Pre-R and Post-R segments
    pre_r = beat[:r_peak_idx]
    post_r = beat[r_peak_idx:]
    features['pre_r_mean'] = np.mean(pre_r)
    features['post_r_mean'] = np.mean(post_r)
    features['pre_r_std'] = np.std(pre_r)
    features['post_r_std'] = np.std(post_r)
    
    # --- RR Interval Features ---
    features['rr_before'] = rr_b
    features['rr_after'] = rr_a
    features['rr_ratio'] = rr_b / rr_a if rr_a > 0 else 1.0
    features['rr_diff'] = rr_b - rr_a
    features['rr_avg'] = (rr_b + rr_a) / 2
    
    # --- Derivative Features ---
    diff1 = np.diff(beat)
    features['diff_max'] = np.max(diff1)
    features['diff_min'] = np.min(diff1)
    features['diff_std'] = np.std(diff1)
    
    # --- Energy Features ---
    features['total_energy'] = np.sum(beat ** 2)
    features['normalized_energy'] = np.sum(beat ** 2) / len(beat)
    
    # --- Zero Crossings ---
    zero_crossings = np.sum(np.abs(np.diff(np.sign(beat - np.mean(beat)))) > 0)
    features['zero_crossings'] = zero_crossings
    
    return features

# Extract features for all beats
print('Extracting features from all abnormal beats...')
feature_list = []
for i in range(len(X_abnormal)):
    feats = extract_features(X_abnormal[i], rr_before[i], rr_after[i])
    feature_list.append(feats)
    if (i + 1) % 2000 == 0:
        print(f'  Processed {i + 1:,}/{len(X_abnormal):,}...')

# Convert to DataFrame
df_features = pd.DataFrame(feature_list)
feature_cols = df_features.columns.tolist()

print(f'\n‚úÖ Extracted {len(feature_cols)} features per beat')
print(f'Feature matrix shape: {df_features.shape}')
print(f'\nFeatures: {feature_cols}')

## 4) Patient-Wise Data Split

In [None]:
# ============================================================
# PATIENT-WISE DATA SPLIT
# ============================================================

def patient_wise_split(X, y, record_ids, test_size=0.2, seed=42):
    """Split data ensuring no patient appears in both train and test."""
    unique_pids = np.unique(record_ids)
    np.random.seed(seed)
    np.random.shuffle(unique_pids)
    
    n_test = int(len(unique_pids) * test_size)
    test_pids = set(unique_pids[:n_test])
    train_pids = set(unique_pids[n_test:])
    
    test_mask = np.array([pid in test_pids for pid in record_ids])
    train_mask = ~test_mask
    
    return (X[train_mask], X[test_mask], y[train_mask], y[test_mask],
            record_ids[train_mask], record_ids[test_mask], train_pids, test_pids)

# Prepare feature matrix
X = df_features.values
y = y_abnormal

# Split
(X_train, X_test, y_train, y_test, 
 rids_train, rids_test, train_pids, test_pids) = patient_wise_split(
    X, y, record_ids_abnormal, test_size=0.2
)

print('=' * 60)
print('PATIENT-WISE DATA SPLIT')
print('=' * 60)
print(f'Train: {len(X_train):,} beats from {len(train_pids)} patients')
print(f'Test:  {len(X_test):,} beats from {len(test_pids)} patients')

# Class distribution in splits
print('\nTrain class distribution:')
for cls in ABNORMAL_CLASSES:
    c = np.sum(y_train == cls)
    print(f'  {cls}: {c:,} ({100*c/len(y_train):.1f}%)')

print('\nTest class distribution:')
for cls in ABNORMAL_CLASSES:
    c = np.sum(y_test == cls)
    print(f'  {cls}: {c:,} ({100*c/len(y_test):.1f}%)')

In [None]:
# ============================================================
# PREPROCESSING
# ============================================================

# Standardize features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Encode labels
le = LabelEncoder()
le.fit(ABNORMAL_CLASSES)
y_train_enc = le.transform(y_train)
y_test_enc = le.transform(y_test)

# Compute class weights (still useful for some models)
class_weights = compute_class_weight('balanced', classes=np.unique(y_train_enc), y=y_train_enc)
class_weight_dict = dict(enumerate(class_weights))

print('‚úÖ Preprocessing complete')
print(f'Classes: {le.classes_}')
print(f'Class weights: {class_weight_dict}')

# Create groups for K-Fold
pid_to_group = {pid: i for i, pid in enumerate(train_pids)}
groups_train = np.array([pid_to_group[rid] for rid in rids_train])

# ============================================================
# SMOTE FOR CLASS BALANCING (Training data only)
# ============================================================
print('\n' + '=' * 60)
print('APPLYING SMOTE FOR CLASS BALANCING')
print('=' * 60)

print('\nBefore SMOTE:')
for i, cls in enumerate(le.classes_):
    c = np.sum(y_train_enc == i)
    print(f'  {cls}: {c:,}')

# Apply SMOTE to training data
smote = SMOTE(random_state=SEED, k_neighbors=5)
X_train_smote, y_train_smote = smote.fit_resample(X_train_scaled, y_train_enc)

print('\nAfter SMOTE:')
for i, cls in enumerate(le.classes_):
    c = np.sum(y_train_smote == i)
    print(f'  {cls}: {c:,}')

print(f'\n‚úÖ SMOTE applied: {len(X_train_scaled):,} ‚Üí {len(X_train_smote):,} samples')
print('   (Test set remains unchanged)')

## 5) Model Training & Comparison

In [None]:
# ============================================================
# MODEL DEFINITIONS
# ============================================================

def get_models(class_weight_dict, seed=42):
    """
    Return dictionary of models to compare.
    Easy to extend with more models or hyperparameters.
    
    Note: With SMOTE-balanced data, class_weight='balanced' is optional
    but kept for robustness.
    """
    models = {
        'Random Forest': RandomForestClassifier(
            n_estimators=200,
            max_depth=20,
            min_samples_split=5,
            class_weight='balanced',
            random_state=seed,
            n_jobs=-1
        ),
        'XGBoost': xgb.XGBClassifier(
            n_estimators=200,
            max_depth=8,
            learning_rate=0.1,
            subsample=0.8,
            colsample_bytree=0.8,
            random_state=seed,
            use_label_encoder=False,
            eval_metric='mlogloss',
            n_jobs=-1
        ),
        'AdaBoost': AdaBoostClassifier(
            estimator=DecisionTreeClassifier(max_depth=6, random_state=seed),
            n_estimators=100,
            learning_rate=0.5,
            algorithm='SAMME',
            random_state=seed
        ),
        'SVM (RBF)': SVC(
            kernel='rbf',
            C=10,
            gamma='scale',
            class_weight='balanced',
            probability=True,  # For ROC-AUC
            random_state=seed
        ),
        'Logistic Regression': LogisticRegression(
            multi_class='multinomial',
            solver='lbfgs',
            max_iter=1000,
            class_weight='balanced',
            random_state=seed,
            n_jobs=-1
        )
    }
    return models

print('Model configurations:')
models = get_models(class_weight_dict)
for name, model in models.items():
    print(f'  ‚Ä¢ {name}')

In [None]:
# ============================================================
# K-FOLD CROSS-VALIDATION (with SMOTE per fold)
# ============================================================

def evaluate_model(y_true, y_pred, y_proba=None, class_labels=None):
    """Compute evaluation metrics."""
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'macro_f1': f1_score(y_true, y_pred, average='macro', labels=class_labels, zero_division=0),
        'weighted_f1': f1_score(y_true, y_pred, average='weighted', labels=class_labels, zero_division=0),
        'per_class_f1': f1_score(y_true, y_pred, average=None, labels=class_labels, zero_division=0)
    }
    
    # ROC-AUC if probabilities available
    if y_proba is not None and len(np.unique(y_true)) > 1:
        try:
            metrics['roc_auc'] = roc_auc_score(y_true, y_proba, multi_class='ovr', average='macro')
        except:
            metrics['roc_auc'] = np.nan
    else:
        metrics['roc_auc'] = np.nan
    
    return metrics

# K-Fold CV
sgkf = StratifiedGroupKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
all_labels = list(range(len(ABNORMAL_CLASSES)))

cv_results = {name: {'metrics': []} for name in get_models(class_weight_dict).keys()}

print('=' * 70)
print(f'{N_FOLDS}-FOLD CROSS-VALIDATION (with SMOTE per fold)')
print('=' * 70)

for fold, (train_idx, val_idx) in enumerate(sgkf.split(X_train_scaled, y_train_enc, groups_train)):
    print(f'\n--- FOLD {fold+1}/{N_FOLDS} ---')
    
    # Get fold data (original, not SMOTE'd)
    X_tr_fold, X_vl = X_train_scaled[train_idx], X_train_scaled[val_idx]
    y_tr_fold, y_vl = y_train_enc[train_idx], y_train_enc[val_idx]
    
    # Apply SMOTE only to training fold (validation stays original)
    smote_fold = SMOTE(random_state=SEED, k_neighbors=min(5, min(Counter(y_tr_fold).values()) - 1))
    try:
        X_tr_smote, y_tr_smote = smote_fold.fit_resample(X_tr_fold, y_tr_fold)
    except ValueError:
        # If SMOTE fails (not enough samples), use original
        X_tr_smote, y_tr_smote = X_tr_fold, y_tr_fold
    
    print(f'  Train: {len(X_tr_fold)} ‚Üí {len(X_tr_smote)} (SMOTE), Val: {len(X_vl)}')
    
    for name, model in get_models(class_weight_dict).items():
        t0 = time()
        
        # Train on SMOTE-balanced data
        model.fit(X_tr_smote, y_tr_smote)
        
        # Predict on original validation data
        y_pred = model.predict(X_vl)
        y_proba = model.predict_proba(X_vl) if hasattr(model, 'predict_proba') else None
        
        # Evaluate
        metrics = evaluate_model(y_vl, y_pred, y_proba, all_labels)
        metrics['train_time'] = time() - t0
        cv_results[name]['metrics'].append(metrics)
        
        print(f'  {name:20s}: Macro F1={metrics["macro_f1"]:.4f}, Acc={metrics["accuracy"]:.4f}')

print('\n' + '=' * 70)

In [None]:
# ============================================================
# CROSS-VALIDATION SUMMARY
# ============================================================

print('\n' + '=' * 70)
print('CROSS-VALIDATION RESULTS SUMMARY')
print('=' * 70)

cv_summary = []
for name, results in cv_results.items():
    metrics_list = results['metrics']
    
    row = {
        'Model': name,
        'Accuracy': f"{np.mean([m['accuracy'] for m in metrics_list]):.4f} ¬± {np.std([m['accuracy'] for m in metrics_list]):.4f}",
        'Macro F1': f"{np.mean([m['macro_f1'] for m in metrics_list]):.4f} ¬± {np.std([m['macro_f1'] for m in metrics_list]):.4f}",
        'Weighted F1': f"{np.mean([m['weighted_f1'] for m in metrics_list]):.4f} ¬± {np.std([m['weighted_f1'] for m in metrics_list]):.4f}",
        'ROC-AUC': f"{np.nanmean([m['roc_auc'] for m in metrics_list]):.4f}",
        'Avg Time (s)': f"{np.mean([m['train_time'] for m in metrics_list]):.2f}"
    }
    cv_summary.append(row)

df_cv = pd.DataFrame(cv_summary)
print(df_cv.to_string(index=False))

# Per-class F1 for each model
print('\n' + '-' * 70)
print('PER-CLASS F1 (Mean ¬± Std)')
print('-' * 70)

for name, results in cv_results.items():
    pcf1 = np.stack([m['per_class_f1'] for m in results['metrics']])
    print(f'\n{name}:')
    for i, cls in enumerate(le.classes_):
        print(f'  {cls} ({AAMI_NAMES[cls]:15s}): {np.mean(pcf1[:, i]):.4f} ¬± {np.std(pcf1[:, i]):.4f}')

## 6) Final Evaluation on Test Set

In [None]:
# ============================================================
# TRAIN FINAL MODELS ON SMOTE-BALANCED TRAINING SET
# ============================================================

print('Training final models on SMOTE-balanced training set...')
print(f'Training samples: {len(X_train_smote):,} (after SMOTE)')

final_models = {}

for name, model in get_models(class_weight_dict).items():
    print(f'  Training {name}...')
    model.fit(X_train_smote, y_train_smote)
    final_models[name] = model

print('\n‚úÖ All models trained on SMOTE-balanced data!')

In [None]:
# ============================================================
# EVALUATE ON HELD-OUT TEST SET
# ============================================================

print('=' * 70)
print('TEST SET EVALUATION')
print('=' * 70)

test_results = {}

for name, model in final_models.items():
    y_pred = model.predict(X_test_scaled)
    y_proba = model.predict_proba(X_test_scaled) if hasattr(model, 'predict_proba') else None
    
    metrics = evaluate_model(y_test_enc, y_pred, y_proba, all_labels)
    metrics['y_pred'] = y_pred
    metrics['y_proba'] = y_proba
    test_results[name] = metrics

# Summary table
test_summary = []
for name, metrics in test_results.items():
    row = {
        'Model': name,
        'Accuracy': f"{metrics['accuracy']:.4f}",
        'Macro F1': f"{metrics['macro_f1']:.4f}",
        'Weighted F1': f"{metrics['weighted_f1']:.4f}",
        'ROC-AUC': f"{metrics['roc_auc']:.4f}" if not np.isnan(metrics['roc_auc']) else 'N/A'
    }
    test_summary.append(row)

df_test = pd.DataFrame(test_summary)
print('\nüìä SIDE-BY-SIDE MODEL COMPARISON')
print(df_test.to_string(index=False))

# Find best model
best_model = max(test_results.items(), key=lambda x: x[1]['macro_f1'])
print(f'\nüèÜ Best Model (by Macro F1): {best_model[0]} ({best_model[1]["macro_f1"]:.4f})')

In [None]:
# ============================================================
# PER-CLASS F1 COMPARISON
# ============================================================

print('\n' + '=' * 70)
print('PER-CLASS F1 SCORES (Test Set)')
print('=' * 70)

# Create comparison table
pcf1_data = {'Class': [], 'Name': []}
for name in test_results.keys():
    pcf1_data[name] = []

for i, cls in enumerate(le.classes_):
    pcf1_data['Class'].append(cls)
    pcf1_data['Name'].append(AAMI_NAMES[cls])
    for name, metrics in test_results.items():
        pcf1_data[name].append(f"{metrics['per_class_f1'][i]:.4f}")

df_pcf1 = pd.DataFrame(pcf1_data)
print(df_pcf1.to_string(index=False))

## 7) Visualization

In [None]:
# ============================================================
# CONFUSION MATRICES
# ============================================================

n_models = len(test_results)
n_cols = 3
n_rows = (n_models + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
axes = axes.flatten()

for idx, (name, metrics) in enumerate(test_results.items()):
    cm = confusion_matrix(y_test_enc, metrics['y_pred'], labels=all_labels)
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues',
                xticklabels=le.classes_, yticklabels=le.classes_, ax=axes[idx])
    axes[idx].set_title(f'{name}\nMacro F1: {metrics["macro_f1"]:.4f}', fontsize=11, fontweight='bold')
    axes[idx].set_xlabel('Predicted')
    axes[idx].set_ylabel('True')

# Hide empty subplots
for idx in range(len(test_results), len(axes)):
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'ml_confusion_matrices.png', dpi=150)
plt.show()

In [None]:
# ============================================================
# MODEL COMPARISON BAR CHART
# ============================================================

fig, ax = plt.subplots(figsize=(12, 6))

model_names = list(test_results.keys())
x = np.arange(len(model_names))
width = 0.25

acc = [test_results[n]['accuracy'] for n in model_names]
macro = [test_results[n]['macro_f1'] for n in model_names]
weighted = [test_results[n]['weighted_f1'] for n in model_names]

bars1 = ax.bar(x - width, acc, width, label='Accuracy', color='#3498db', alpha=0.8)
bars2 = ax.bar(x, macro, width, label='Macro F1', color='#e74c3c', alpha=0.8)
bars3 = ax.bar(x + width, weighted, width, label='Weighted F1', color='#2ecc71', alpha=0.8)

ax.set_ylabel('Score')
ax.set_title('Model Comparison on Test Set', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(model_names, rotation=15, ha='right')
ax.legend()
ax.set_ylim(0, 1.1)
ax.grid(True, alpha=0.3, axis='y')

# Add value labels
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                f'{bar.get_height():.3f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'ml_model_comparison.png', dpi=150)
plt.show()

In [None]:
# ============================================================
# PER-CLASS F1 COMPARISON HEATMAP
# ============================================================

fig, ax = plt.subplots(figsize=(10, 6))

# Build matrix
pcf1_matrix = np.array([
    test_results[name]['per_class_f1'] for name in model_names
])

sns.heatmap(pcf1_matrix, annot=True, fmt='.3f', cmap='RdYlGn',
            xticklabels=[f"{c}\n({AAMI_NAMES[c]})" for c in le.classes_],
            yticklabels=model_names, ax=ax, vmin=0, vmax=1)
ax.set_title('Per-Class F1 Scores by Model', fontsize=14, fontweight='bold')
ax.set_xlabel('Abnormal Class')
ax.set_ylabel('Model')

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'ml_perclass_f1_heatmap.png', dpi=150)
plt.show()

## 8) Feature Importance (Tree-Based Models)

In [None]:
# ============================================================
# FEATURE IMPORTANCE (Tree-Based Models)
# ============================================================

tree_models = ['Random Forest', 'XGBoost', 'AdaBoost']
fig, axes = plt.subplots(1, 3, figsize=(18, 7))

for idx, name in enumerate(tree_models):
    model = final_models[name]
    importances = model.feature_importances_
    
    # Sort by importance
    indices = np.argsort(importances)[::-1][:15]  # Top 15
    top_features = [feature_cols[i] for i in indices]
    top_importances = importances[indices]
    
    # Plot
    colors = plt.cm.viridis(np.linspace(0.8, 0.2, len(top_features)))
    axes[idx].barh(range(len(top_features)), top_importances, color=colors)
    axes[idx].set_yticks(range(len(top_features)))
    axes[idx].set_yticklabels(top_features)
    axes[idx].invert_yaxis()
    axes[idx].set_xlabel('Importance')
    axes[idx].set_title(f'{name} - Top 15 Features', fontsize=12, fontweight='bold')
    axes[idx].grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'ml_feature_importance.png', dpi=150)
plt.show()

# Print top features for each tree model
for name in tree_models:
    print(f'\nTop 10 Features ({name}):')
    imp = final_models[name].feature_importances_
    for i in np.argsort(imp)[::-1][:10]:
        print(f'  {feature_cols[i]:20s}: {imp[i]:.4f}')

## 9) ROC Curves (One-vs-Rest)

In [None]:
# ============================================================
# ROC CURVES (One-vs-Rest)
# ============================================================

n_models = len(test_results)
n_cols = 3
n_rows = (n_models + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
axes = axes.flatten()

colors = ['#e74c3c', '#3498db', '#9b59b6', '#f39c12']

for idx, (name, metrics) in enumerate(test_results.items()):
    ax = axes[idx]
    y_proba = metrics['y_proba']
    
    if y_proba is not None:
        for i, cls in enumerate(le.classes_):
            y_true_binary = (y_test_enc == i).astype(int)
            if len(np.unique(y_true_binary)) < 2:
                continue
            
            fpr, tpr, _ = roc_curve(y_true_binary, y_proba[:, i])
            auc = roc_auc_score(y_true_binary, y_proba[:, i])
            ax.plot(fpr, tpr, color=colors[i], linewidth=2,
                    label=f'{cls} ({AAMI_NAMES[cls]}) AUC={auc:.3f}')
    
    ax.plot([0, 1], [0, 1], 'k--', linewidth=1)
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title(f'{name}', fontsize=12, fontweight='bold')
    ax.legend(loc='lower right', fontsize=9)
    ax.grid(True, alpha=0.3)

# Hide empty subplots
for idx in range(len(test_results), len(axes)):
    axes[idx].axis('off')

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'ml_roc_curves.png', dpi=150)
plt.show()

## 10) Save Results

In [None]:
# ============================================================
# SAVE RESULTS
# ============================================================

import joblib

print('Saving models and results...')

# Save models
for name, model in final_models.items():
    safe_name = name.replace(' ', '_').replace('(', '').replace(')', '').lower()
    joblib.dump(model, OUTPUT_PATH / f'model_{safe_name}.joblib')

# Save scaler and encoder
joblib.dump(scaler, OUTPUT_PATH / 'feature_scaler.joblib')
joblib.dump(le, OUTPUT_PATH / 'label_encoder.joblib')

# Save results as JSON
results_json = {
    'cv_summary': df_cv.to_dict('records'),
    'test_summary': df_test.to_dict('records'),
    'feature_cols': feature_cols,
    'classes': list(le.classes_),
    'best_model': best_model[0],
    'best_macro_f1': float(best_model[1]['macro_f1'])
}

with open(OUTPUT_PATH / 'ml_tester_results.json', 'w') as f:
    json.dump(results_json, f, indent=2)

print(f'\n‚úÖ All results saved to: {OUTPUT_PATH}')
print(f'\nSaved files:')
for f in OUTPUT_PATH.glob('*'):
    print(f'  ‚Ä¢ {f.name}')

In [None]:
# ============================================================
# FINAL SUMMARY
# ============================================================

print('\n' + '=' * 70)
print('üéØ ML TESTER SUMMARY (with SMOTE + AdaBoost)')
print('=' * 70)

print(f'''
DATASET:
  Total abnormal beats: {len(X_abnormal):,}
  Train (original): {len(X_train):,} | Train (SMOTE): {len(X_train_smote):,}
  Test: {len(X_test):,} (unchanged)
  Features: {len(feature_cols)}
  Classes: {ABNORMAL_CLASSES}

SMOTE BALANCING:
  Applied to training data only (per-fold in CV)
  Test/validation sets remain original distribution

CROSS-VALIDATION ({N_FOLDS}-Fold):
''')
print(df_cv.to_string(index=False))

print(f'''
TEST SET RESULTS:
''')
print(df_test.to_string(index=False))

print(f'''
üèÜ BEST MODEL: {best_model[0]}
   Macro F1: {best_model[1]['macro_f1']:.4f}
   Accuracy: {best_model[1]['accuracy']:.4f}

KEY FINDINGS:
  ‚Ä¢ SMOTE helps balance rare classes (F, Q)
  ‚Ä¢ Tree-based models (RF, XGBoost, AdaBoost) perform consistently
  ‚Ä¢ RR interval features remain highly predictive

NEXT STEPS:
  1. Integrate best model into Stage 2 of two-stage pipeline
  2. Try hyperparameter tuning for further improvement
  3. Consider ensemble of top models

‚úÖ ML Tester (SMOTE + AdaBoost) complete!
''')