# k-NN Classification of Abnormal ECG Beats (MIT-BIH)

**Objective:** Classify ONLY abnormal ECG beats (AAMI classes: S, V, F, Q) using k-Nearest Neighbors.  
**Normal beats (N) are EXCLUDED entirely.**

---

## ⚠️ Diagnostic Notebook Disclaimer

This notebook is for **diagnostic/research purposes only**:
- k-NN serves as a **baseline**, not a production model
- Results inform feature engineering and class handling strategy
- Ambiguous classes (F, S) should not be overfit
- MIT-BIH has known limitations (small, 1980s recordings)

---

## Overview

| Section | Description |
|---------|-------------|
| 1. Data Loading | Load features, filter abnormal only, verify N excluded |
| 2. Record-wise Split | Patient-level 80/20 split (no leakage) |
| 3. Preprocessing | StandardScaler + PCA (fit on train only) |
| 4. k-NN Modeling | Evaluate k ∈ {3, 5, 7, 9} with confidence analysis |
| 5. Results Analysis | Metrics table, class difficulty, strategy insights |
| 6. Visualization | Confusion matrix, t-SNE, S/F overlap analysis |
| 7. Summary | Findings, limitations, next steps |

---

## AAMI Abnormal Classes

| Code | Name | Description | Expected Difficulty |
|------|------|-------------|---------------------|
| **S** | Supraventricular | Atrial/junctional premature beats | Medium-Hard (RR-dependent) |
| **V** | Ventricular | Ventricular ectopic beats | Easy (distinct morphology) |
| **F** | Fusion | Fusion of normal and ventricular | Very Hard (rare, ambiguous) |
| **Q** | Unknown | Paced/unclassifiable beats | Medium (pacing artifacts) |

---

## Why k-NN as a Baseline?

- **Simple & interpretable** — easy to debug and understand
- **Distance-based** — naturally works with engineered features
- **Fast iteration** — no training time, quick experiments
- **Confidence via neighbors** — distance to neighbors indicates uncertainty

## Why Macro F1?

- **Accuracy is misleading** with imbalanced classes
- **Macro F1 treats all classes equally** — rare classes matter
- **Clinical relevance** — detecting ALL arrhythmia types is critical

## 0) Google Colab Setup

In [None]:
# ============================================================
# GOOGLE COLAB SETUP (skip if running locally)
# ============================================================

# Uncomment to mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive')

# Install packages if needed
# !pip install -q wfdb

print('✅ Setup complete!')

## 1) Imports & Configuration

In [None]:
# ============================================================
# IMPORTS
# ============================================================

import os
import warnings
from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import (
    accuracy_score, f1_score, classification_report,
    confusion_matrix, ConfusionMatrixDisplay
)

import matplotlib.pyplot as plt
import seaborn as sns

# Settings
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Reproducibility
SEED = 42
np.random.seed(SEED)

print('✅ All imports successful!')

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

# Paths - UPDATE FOR YOUR ENVIRONMENT
# The MIT-BIH data is in the 'mit-bih-arrhythmia-database-1.0.0' subdirectory

# For Colab:
BASE_PATH = Path('/content/drive/MyDrive/ecg2.0')
DATA_PATH = BASE_PATH / 'mit-bih-arrhythmia-database-1.0.0'

# For local (uncomment if running locally):
# BASE_PATH = Path('/Volumes/Crucial X6/medical_ai/ecg2.0')
# DATA_PATH = BASE_PATH / 'mit-bih-arrhythmia-database-1.0.0'

OUTPUT_PATH = BASE_PATH / 'outputs_knn'
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

# Verify data path exists
if DATA_PATH.exists():
    print(f'✅ Data path found: {DATA_PATH}')
    hea_files = list(DATA_PATH.glob('*.hea'))
    print(f'   Found {len(hea_files)} .hea files')
else:
    print(f'❌ Data path NOT found: {DATA_PATH}')
    print('   Please update DATA_PATH to point to your MIT-BIH data folder')

print(f'Output path: {OUTPUT_PATH}')

# AAMI class definitions
AAMI_MAP = {
    'N': 'N', 'L': 'N', 'R': 'N', 'e': 'N', 'j': 'N',  # Normal
    'A': 'S', 'a': 'S', 'J': 'S', 'S': 'S',             # Supraventricular
    'V': 'V', 'E': 'V',                                  # Ventricular
    'F': 'F',                                            # Fusion
    '/': 'Q', 'f': 'Q', '!': 'Q', 'Q': 'Q', 'P': 'Q'    # Unknown/Paced
}

# ABNORMAL CLASSES ONLY (excluding Normal)
ABNORMAL_CLASSES = ['S', 'V', 'F', 'Q']
CLASS_NAMES = {
    'S': 'Supraventricular',
    'V': 'Ventricular', 
    'F': 'Fusion',
    'Q': 'Unknown/Paced'
}

# k-NN parameters to evaluate
K_VALUES = [3, 5, 7, 9]

# PCA variance threshold
PCA_VARIANCE = 0.95

print(f'Abnormal classes: {ABNORMAL_CLASSES}')
print(f'k values to test: {K_VALUES}')

## 2) Data Loading & Filtering

Load the preprocessed beat-level feature dataset and filter to **abnormal beats only**.

In [None]:
# ============================================================
# DATA LOADING - Feature Extraction Function
# ============================================================

import wfdb
from scipy.stats import skew, kurtosis

# Beat extraction parameters
SAMPLES_BEFORE = 100
SAMPLES_AFTER = 150
BEAT_LENGTH = SAMPLES_BEFORE + SAMPLES_AFTER

def extract_beat_features(beat, rr_before, rr_after, local_rr_mean=None, fs=360):
    """
    Extract features from a single ECG beat.
    
    IMPORTANT: RR-heavy features are critical for detecting S (Supraventricular),
    as these beats are characterized by premature timing (short RR_before).
    
    Returns a dictionary of feature values.
    """
    features = {}
    
    # --- Statistical features ---
    features['mean'] = np.mean(beat)
    features['std'] = np.std(beat)
    features['min'] = np.min(beat)
    features['max'] = np.max(beat)
    features['ptp'] = np.ptp(beat)  # peak-to-peak
    features['median'] = np.median(beat)
    features['skewness'] = skew(beat)
    features['kurtosis'] = kurtosis(beat)
    
    # --- Energy features ---
    features['energy'] = np.sum(beat ** 2)
    features['rms'] = np.sqrt(np.mean(beat ** 2))
    
    # ============================================================
    # RR INTERVAL FEATURES (CRITICAL FOR S-CLASS DETECTION)
    # ============================================================
    # Supraventricular (S) beats are premature → short RR_before
    # These features help distinguish S from other classes
    
    features['rr_before'] = rr_before
    features['rr_after'] = rr_after
    features['rr_ratio'] = rr_before / rr_after if rr_after > 0 else 1.0
    features['rr_diff'] = rr_after - rr_before
    features['rr_mean'] = (rr_before + rr_after) / 2
    
    # Additional RR features for S detection
    features['rr_before_sq'] = rr_before ** 2  # Emphasize short intervals
    features['rr_ratio_inv'] = rr_after / rr_before if rr_before > 0 else 1.0
    features['rr_abs_diff'] = abs(rr_after - rr_before)
    features['rr_product'] = rr_before * rr_after
    
    # Prematurity index (key for S): how much shorter is this beat?
    if local_rr_mean and local_rr_mean > 0:
        features['prematurity_index'] = (local_rr_mean - rr_before) / local_rr_mean
        features['compensatory_pause'] = (rr_after - local_rr_mean) / local_rr_mean
    else:
        features['prematurity_index'] = 0.0
        features['compensatory_pause'] = 0.0
    
    # --- Morphological features ---
    r_peak_idx = SAMPLES_BEFORE
    features['r_amplitude'] = beat[r_peak_idx]
    
    # QRS approximation (central region)
    qrs_start = r_peak_idx - 20
    qrs_end = r_peak_idx + 20
    qrs_region = beat[max(0, qrs_start):min(len(beat), qrs_end)]
    features['qrs_energy'] = np.sum(qrs_region ** 2)
    features['qrs_max'] = np.max(qrs_region)
    features['qrs_min'] = np.min(qrs_region)
    features['qrs_width_approx'] = np.sum(np.abs(qrs_region) > 0.3 * np.max(np.abs(qrs_region)))
    
    # --- Derivative features ---
    deriv = np.diff(beat)
    features['deriv_max'] = np.max(deriv)
    features['deriv_min'] = np.min(deriv)
    features['deriv_std'] = np.std(deriv)
    features['deriv_abs_mean'] = np.mean(np.abs(deriv))
    
    # Zero crossings (complexity measure)
    features['zero_crossings'] = np.sum(np.diff(np.sign(beat)) != 0)
    
    # --- Segment features (for V vs F distinction) ---
    pre_qrs = beat[:qrs_start] if qrs_start > 0 else beat[:10]
    post_qrs = beat[qrs_end:] if qrs_end < len(beat) else beat[-10:]
    features['pre_qrs_mean'] = np.mean(pre_qrs)
    features['post_qrs_mean'] = np.mean(post_qrs)
    features['pre_post_ratio'] = features['pre_qrs_mean'] / (features['post_qrs_mean'] + 1e-6)
    
    return features

print('Feature extraction function defined.')
print(f'Beat length: {BEAT_LENGTH} samples ({SAMPLES_BEFORE} before, {SAMPLES_AFTER} after R-peak)')
print('\n📊 RR-heavy features added for S-class detection:')
print('   - prematurity_index: how early the beat occurs')
print('   - compensatory_pause: pause after premature beat')
print('   - rr_before_sq, rr_ratio_inv, rr_abs_diff, rr_product')

In [None]:
# ============================================================
# LOAD MIT-BIH DATA AND EXTRACT FEATURES
# ============================================================

# MIT-BIH record numbers
MIT_BIH_RECORDS = [
    100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
    111, 112, 113, 114, 115, 116, 117, 118, 119, 121,
    122, 123, 124, 200, 201, 202, 203, 205, 207, 208,
    209, 210, 212, 213, 214, 215, 217, 219, 220, 221,
    222, 223, 228, 230, 231, 232, 233, 234
]

all_features = []
all_labels = []
all_record_ids = []
skipped_normal = 0

print('Loading MIT-BIH records and extracting features...')
print(f'Data path: {DATA_PATH}')
print('=' * 60)

loaded_count = 0
for rec_num in MIT_BIH_RECORDS:
    rec_path = DATA_PATH / str(rec_num)
    
    try:
        # Load record
        record = wfdb.rdrecord(str(rec_path))
        annotation = wfdb.rdann(str(rec_path), 'atr')
        
        signal = record.p_signal[:, 0]  # Lead 0
        fs = record.fs
        
        r_peaks = annotation.sample
        symbols = annotation.symbol
        
        # Calculate local RR mean for the record (for prematurity index)
        rr_intervals = np.diff(r_peaks) / fs
        local_rr_mean = np.median(rr_intervals) if len(rr_intervals) > 0 else 0.8
        
        rec_abnormal_count = 0
        rec_normal_skipped = 0
        
        for i, (r_peak, symbol) in enumerate(zip(r_peaks, symbols)):
            if symbol not in AAMI_MAP:
                continue
            
            aami_label = AAMI_MAP[symbol]
            
            # EXPLICITLY SKIP Normal beats - ABNORMAL ONLY
            if aami_label == 'N':
                rec_normal_skipped += 1
                skipped_normal += 1
                continue
            
            # Extract beat window
            start = r_peak - SAMPLES_BEFORE
            end = r_peak + SAMPLES_AFTER
            
            if start < 0 or end > len(signal):
                continue
            
            beat = signal[start:end]
            
            # Calculate RR intervals
            rr_before = (r_peak - r_peaks[i-1]) / fs if i > 0 else local_rr_mean
            rr_after = (r_peaks[i+1] - r_peak) / fs if i < len(r_peaks)-1 else local_rr_mean
            
            # Extract features with local RR mean for prematurity calculation
            feat = extract_beat_features(beat, rr_before, rr_after, local_rr_mean, fs)
            
            all_features.append(feat)
            all_labels.append(aami_label)
            all_record_ids.append(rec_num)
            rec_abnormal_count += 1
        
        if rec_abnormal_count > 0:
            print(f'  Record {rec_num}: {rec_abnormal_count} abnormal (skipped {rec_normal_skipped} normal)')
            loaded_count += 1
            
    except Exception as e:
        print(f'  Warning: Could not load record {rec_num}: {e}')
        continue

# Create DataFrame
df = pd.DataFrame(all_features)
df['label'] = all_labels
df['record_id'] = all_record_ids

print('\n' + '=' * 60)
print(f'✅ Loaded {len(df):,} ABNORMAL beats from {loaded_count} records')
print(f'   Skipped {skipped_normal:,} Normal (N) beats')

In [None]:
# ============================================================
# VERIFICATION: CONFIRM N IS EXCLUDED
# ============================================================

print('=' * 60)
print('SANITY CHECK: Abnormal-Only Verification')
print('=' * 60)

# Assert N is fully excluded
n_count = (df['label'] == 'N').sum()
assert n_count == 0, f'ERROR: Found {n_count} Normal beats! N should be excluded.'
print(f'\n✅ VERIFIED: 0 Normal (N) beats in dataset')
print(f'   Total beats: {len(df):,}')
print(f'   All labels: {df["label"].unique()}')

# Class distribution
print('\n' + '=' * 60)
print('CLASS DISTRIBUTION (Abnormal Beats Only)')
print('=' * 60)

class_counts = df['label'].value_counts()
total = len(df)

print(f'\nTotal abnormal beats: {total:,}\n')
print(f'{"Class":<5} {"Name":<20} {"Count":>8} {"Percentage":>10}')
print('-' * 50)

for cls in ABNORMAL_CLASSES:
    if cls in class_counts.index:
        count = class_counts[cls]
        pct = 100 * count / total
        print(f'{cls:<5} {CLASS_NAMES[cls]:<20} {count:>8,} {pct:>9.2f}%')

print('-' * 50)

# Visualize
fig, ax = plt.subplots(figsize=(8, 5))
colors = ['#e74c3c', '#3498db', '#9b59b6', '#f39c12']
bars = ax.bar(ABNORMAL_CLASSES, [class_counts.get(c, 0) for c in ABNORMAL_CLASSES], color=colors)
ax.set_xlabel('AAMI Class')
ax.set_ylabel('Number of Beats')
ax.set_title('Abnormal Beat Distribution (N EXCLUDED)', fontweight='bold')

for bar, cls in zip(bars, ABNORMAL_CLASSES):
    height = bar.get_height()
    ax.annotate(f'{height:,}\n({100*height/total:.1f}%)',
                xy=(bar.get_x() + bar.get_width()/2, height),
                ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

print('\n⚠️ CLASS IMBALANCE NOTE:')
print('   V dominates (~70%), F is very rare (<1%)')
print('   This explains why high accuracy ≠ good Macro F1')

## 3) Record-Wise Train/Test Split

### ⚠️ Why Record-Wise Splitting is MANDATORY for MIT-BIH

**Beat-level splitting causes data leakage:**
- Beats from the **same patient** are highly correlated
- Same heart → same morphology, rhythm, baseline wander
- Model learns **patient-specific patterns**, not generalizable arrhythmia features
- Result: **Inflated metrics** that don't reflect real-world performance

**Record-wise splitting prevents leakage:**
- Each patient's beats go entirely to train OR test
- Model must generalize to **unseen patients**
- This is the **clinically relevant** evaluation scenario

In [None]:
# ============================================================
# RECORD-WISE TRAIN / TEST SPLIT (80% / 20%)
# ============================================================

# Get unique records
unique_records = df['record_id'].unique()
n_records = len(unique_records)

print(f'Total unique records with abnormal beats: {n_records}')

# Split records (NOT beats) - 80/20
train_records, test_records = train_test_split(
    unique_records,
    test_size=0.2,
    random_state=SEED,
    shuffle=True
)

print(f'\nSplit ratio: 80% train / 20% test')
print(f'Train records: {len(train_records)} ({100*len(train_records)/n_records:.1f}%)')
print(f'Test records:  {len(test_records)} ({100*len(test_records)/n_records:.1f}%)')

# ============================================================
# EXPLICIT VERIFICATION: Zero overlap
# ============================================================
train_set = set(train_records)
test_set = set(test_records)
overlap = train_set & test_set

print(f'\n🔍 LEAKAGE CHECK:')
print(f'   Train record IDs: {sorted(train_records)[:5]}... (showing first 5)')
print(f'   Test record IDs:  {sorted(test_records)[:5]}... (showing first 5)')
print(f'   Overlap: {len(overlap)} records')

assert len(overlap) == 0, f'ERROR: {len(overlap)} records appear in both splits!'
print(f'\n✅ VERIFIED: Zero overlap between train and test records')
print('   No patient data leakage.')

# Split data by records
train_mask = df['record_id'].isin(train_records)
test_mask = df['record_id'].isin(test_records)

df_train = df[train_mask].copy()
df_test = df[test_mask].copy()

print(f'\nTrain beats: {len(df_train):,}')
print(f'Test beats:  {len(df_test):,}')

In [None]:
# ============================================================
# VERIFY CLASS DISTRIBUTION IN SPLITS
# ============================================================

print('\n' + '=' * 60)
print('CLASS DISTRIBUTION BY SPLIT')
print('=' * 60)

print('\n--- TRAIN SET ---')
for cls in ABNORMAL_CLASSES:
    c = (df_train['label'] == cls).sum()
    pct = 100 * c / len(df_train)
    print(f'  {cls}: {c:>6,} ({pct:>5.2f}%)')

print('\n--- TEST SET ---')
for cls in ABNORMAL_CLASSES:
    c = (df_test['label'] == cls).sum()
    pct = 100 * c / len(df_test)
    print(f'  {cls}: {c:>6,} ({pct:>5.2f}%)')

print('\n✅ Record-wise split complete. No patient data leakage.')

## 4) Feature Preprocessing

### Why Scaling is REQUIRED for k-NN

k-NN uses **Euclidean distance** — features with larger scales dominate:
- `energy` might be ~10,000
- `rr_ratio` might be ~1.0
- Without scaling, `energy` completely dominates the distance calculation

**StandardScaler** normalizes all features to mean=0, std=1.

### Why PCA Can Help

- Removes correlated features (redundancy)
- Reduces dimensionality → faster k-NN
- Mitigates **curse of dimensionality** in high-D spaces

### ⚠️ Critical: Fit on Train ONLY

- **StandardScaler**: fit on X_train, transform both
- **PCA**: fit on scaled X_train, transform both
- Never fit on test data — that's **data leakage**!

In [None]:
# ============================================================
# PREPARE FEATURE MATRICES
# ============================================================

# Get feature columns (exclude label and record_id)
feature_cols = [col for col in df.columns if col not in ['label', 'record_id']]

print(f'Number of features: {len(feature_cols)}')
print(f'Features: {feature_cols[:10]}...')

# Extract X, y
X_train = df_train[feature_cols].values
y_train = df_train['label'].values

X_test = df_test[feature_cols].values
y_test = df_test['label'].values

print(f'\nX_train shape: {X_train.shape}')
print(f'X_test shape:  {X_test.shape}')

In [None]:
# ============================================================
# STANDARDSCALER (FIT ON TRAIN ONLY)
# ============================================================

scaler = StandardScaler()

# Fit ONLY on training data
X_train_scaled = scaler.fit_transform(X_train)

# Transform test data (do NOT fit)
X_test_scaled = scaler.transform(X_test)

print('✅ StandardScaler applied')
print('   - Fitted on training data only')
print('   - Test data transformed (not fitted)')
print(f'\n   Train mean after scaling: {X_train_scaled.mean():.6f}')
print(f'   Train std after scaling:  {X_train_scaled.std():.6f}')

In [None]:
# ============================================================
# PCA (FIT ON TRAIN ONLY, RETAIN ~95% VARIANCE)
# ============================================================

# Fit PCA on scaled training data ONLY
pca = PCA(n_components=PCA_VARIANCE, random_state=SEED)
X_train_pca = pca.fit_transform(X_train_scaled)

# Transform test data (do NOT fit)
X_test_pca = pca.transform(X_test_scaled)

n_components = pca.n_components_
explained_var = np.sum(pca.explained_variance_ratio_)

print('✅ PCA applied')
print('   - Fitted on training data only')
print(f'\n   Original features:    {X_train_scaled.shape[1]}')
print(f'   PCA components:       {n_components}')
print(f'   Variance explained:   {100*explained_var:.2f}%')

print(f'\n   X_train_pca shape: {X_train_pca.shape}')
print(f'   X_test_pca shape:  {X_test_pca.shape}')

# Visualize explained variance
fig, ax = plt.subplots(figsize=(10, 4))
cumsum = np.cumsum(pca.explained_variance_ratio_)
ax.bar(range(1, n_components+1), pca.explained_variance_ratio_, alpha=0.7, label='Individual')
ax.plot(range(1, n_components+1), cumsum, 'ro-', label='Cumulative')
ax.axhline(y=PCA_VARIANCE, color='g', linestyle='--', label=f'{100*PCA_VARIANCE}% threshold')
ax.set_xlabel('Principal Component')
ax.set_ylabel('Explained Variance Ratio')
ax.set_title('PCA Explained Variance', fontweight='bold')
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# ENCODE LABELS
# ============================================================

le = LabelEncoder()
le.fit(ABNORMAL_CLASSES)

y_train_enc = le.transform(y_train)
y_test_enc = le.transform(y_test)

print('✅ Labels encoded')
print(f'   Classes: {le.classes_}')
print(f'   Encoding: {dict(zip(le.classes_, range(len(le.classes_))))}')

## 5) k-NN Modeling

Train k-NN classifiers with different values of k.

**Parameters:**
- k ∈ {3, 5, 7, 9}
- Distance metric: Euclidean
- Weights: distance-based (closer neighbors have more influence)

In [None]:
# ============================================================
# k-NN TRAINING AND EVALUATION (with Confidence Analysis)
# ============================================================

results = []

print('=' * 70)
print('k-NN CLASSIFICATION RESULTS')
print('=' * 70)

for k in K_VALUES:
    print(f'\n--- k = {k} ---')
    
    # Create and train k-NN
    knn = KNeighborsClassifier(
        n_neighbors=k,
        metric='euclidean',
        weights='distance',  # Distance-weighted voting
        n_jobs=-1
    )
    knn.fit(X_train_pca, y_train_enc)
    
    # Predict on test set
    y_pred = knn.predict(X_test_pca)
    y_proba = knn.predict_proba(X_test_pca)
    
    # Calculate confidence (max probability)
    confidence = np.max(y_proba, axis=1)
    
    # Calculate metrics
    accuracy = accuracy_score(y_test_enc, y_pred)
    macro_f1 = f1_score(y_test_enc, y_pred, average='macro')
    weighted_f1 = f1_score(y_test_enc, y_pred, average='weighted')
    per_class_f1 = f1_score(y_test_enc, y_pred, average=None, labels=range(len(ABNORMAL_CLASSES)))
    
    # ============================================================
    # CONFIDENCE / REJECT ANALYSIS
    # ============================================================
    low_conf_mask = confidence < 0.5  # Flag predictions with <50% confidence
    low_conf_count = np.sum(low_conf_mask)
    low_conf_pct = 100 * low_conf_count / len(y_pred)
    
    # Accuracy on high-confidence predictions only
    if np.sum(~low_conf_mask) > 0:
        high_conf_acc = accuracy_score(y_test_enc[~low_conf_mask], y_pred[~low_conf_mask])
    else:
        high_conf_acc = 0.0
    
    # Store results
    result = {
        'k': k,
        'accuracy': accuracy,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'y_pred': y_pred,
        'y_proba': y_proba,
        'confidence': confidence,
        'low_conf_pct': low_conf_pct,
        'high_conf_acc': high_conf_acc,
        'knn_model': knn
    }
    for i, cls in enumerate(ABNORMAL_CLASSES):
        result[f'f1_{cls}'] = per_class_f1[i]
    
    results.append(result)
    
    # Print summary
    print(f'  Accuracy:   {accuracy:.4f}')
    print(f'  Macro F1:   {macro_f1:.4f} (PRIMARY)')
    print(f'  Per-class:  S={per_class_f1[0]:.3f}, V={per_class_f1[1]:.3f}, F={per_class_f1[2]:.3f}, Q={per_class_f1[3]:.3f}')
    print(f'  Low-conf:   {low_conf_pct:.1f}% predictions < 50% confidence')
    print(f'  High-conf accuracy: {high_conf_acc:.4f} (if rejecting low-conf)')

print('\n' + '=' * 70)

## 6) Results Analysis

Compare all k values and identify the best configuration.

In [None]:
# ============================================================
# RESULTS COMPARISON TABLE
# ============================================================

print('\n' + '=' * 100)
print('k-NN RESULTS COMPARISON')
print('=' * 100)

# Create comparison DataFrame
df_results = pd.DataFrame([{
    'k': r['k'],
    'Accuracy': f"{r['accuracy']:.4f}",
    'Macro F1': f"{r['macro_f1']:.4f}",
    'F1(S)': f"{r['f1_S']:.4f}",
    'F1(V)': f"{r['f1_V']:.4f}",
    'F1(F)': f"{r['f1_F']:.4f}",
    'F1(Q)': f"{r['f1_Q']:.4f}",
    'Low-Conf %': f"{r['low_conf_pct']:.1f}%"
} for r in results])

print('\n' + df_results.to_string(index=False))

# Find best k based on Macro F1
best_result = max(results, key=lambda x: x['macro_f1'])
best_k = best_result['k']

print(f'\n🏆 BEST k = {best_k} (Macro F1 = {best_result["macro_f1"]:.4f})')

# ============================================================
# ACCURACY vs MACRO F1 ANALYSIS
# ============================================================
print('\n' + '=' * 100)
print('WHY ACCURACY CAN BE HIGH BUT MACRO F1 LOW')
print('=' * 100)

print(f'''
Example from our results:
  Accuracy:  {best_result['accuracy']:.4f}
  Macro F1:  {best_result['macro_f1']:.4f}

The gap exists because:
  • V class dominates (~70% of data) → predicting V is often correct
  • But F class is tiny (~1%) → low F1(F) barely affects accuracy
  • Macro F1 treats F equally to V → penalizes poor F performance

CLINICAL IMPLICATION:
  A model with 95% accuracy but 0% F-recall would miss ALL fusion beats.
  Macro F1 catches this; accuracy does not.
''')

In [None]:
# ============================================================
# DETAILED CLASSIFICATION REPORT (BEST k)
# ============================================================

print(f'\n' + '=' * 70)
print(f'CLASSIFICATION REPORT (k = {best_k})')
print('=' * 70)

y_pred_best = best_result['y_pred']

print(classification_report(
    y_test_enc, y_pred_best,
    target_names=ABNORMAL_CLASSES,
    digits=4
))

In [None]:
# ============================================================
# PER-CLASS ANALYSIS & PRODUCT STRATEGY
# ============================================================

print('\n' + '=' * 70)
print('PER-CLASS PERFORMANCE ANALYSIS')
print('=' * 70)

print(f'''
📊 CLASS-BY-CLASS BREAKDOWN:

  S (Supraventricular): F1 = {best_result['f1_S']:.4f}
    • Atrial/junctional premature beats
    • DEPENDS HEAVILY ON RR FEATURES (prematurity index)
    • Often confused with Normal (morphology similar)
    • RR-heavy features should improve this
    
  V (Ventricular): F1 = {best_result['f1_V']:.4f}
    • Ventricular ectopic beats (PVCs)
    • MOST DISTINCTIVE morphology (wide QRS, different axis)
    • Expected to have highest F1
    • Easy wins here.

  F (Fusion): F1 = {best_result['f1_F']:.4f}
    • Fusion of normal and ventricular beats
    • RAREST and HARDEST class
    • Mixed morphology → inherently ambiguous
    • Even cardiologists disagree on F labels
    • Consider: confidence thresholding or merging with V
    
  Q (Unknown/Paced): F1 = {best_result['f1_Q']:.4f}
    • Paced or unclassifiable beats
    • Distinct pacing spikes help detection
    • Moderate difficulty
''')

print('\n📈 EXPECTED DIFFICULTY RANKING (easiest → hardest):')
print('    V  >  Q  >  S  >  F')
print('    └────────────────┘')
print('    Easy      Medium    Hard')

# Verify ranking matches results
ranking = sorted(ABNORMAL_CLASSES, key=lambda c: best_result[f'f1_{c}'], reverse=True)
print(f'\n📊 ACTUAL RANKING (from our results): {" > ".join(ranking)}')

print('\n' + '=' * 70)
print('PRODUCT STRATEGY INSIGHTS')
print('=' * 70)

print(f'''
🎯 WHAT TO KEEP SEPARATE vs MERGE:

  KEEP SEPARATE (distinct clinical meaning):
    • V (Ventricular) - critical, high-risk arrhythmia
    • S (Supraventricular) - different treatment path
    
  CONSIDER MERGING:
    • F (Fusion) → merge with V for product?
      - Clinically, F is a subtype of ventricular activity
      - Very rare, hard to learn, often mislabeled
      - Merging would boost V-class performance
    
  Q (Unknown/Paced) - PRODUCT DECISION:
    • Option A: Flag separately ("paced rhythm detected")
    • Option B: Exclude from main classification
    • Depends on target use case

🔧 CONFIDENCE/REJECT STRATEGY:

  For ambiguous classes (S, F):
    • Low-confidence predictions ({best_result['low_conf_pct']:.1f}% of test set)
    • Option: Flag for human review instead of hard prediction
    • Improves precision at cost of coverage
''')

## 7) Discussion

### Why Macro F1 is the Correct Metric Here

| Metric | What it measures | Problem with imbalanced data |
|--------|------------------|-----------------------------|
| **Accuracy** | Overall correct predictions | Dominated by majority class (V) |
| **Weighted F1** | F1 weighted by class support | Still favors majority class |
| **Macro F1** | Unweighted average of per-class F1 | Treats all classes equally |

For clinical use, **we care about detecting ALL arrhythmia types**, not just the common ones.

### Why Fusion (F) is Inherently Ambiguous

1. **Definition:** F beats occur when a normal beat and a ventricular beat happen simultaneously
2. **Morphology:** Looks like a mix of N and V — no single distinctive pattern
3. **Labeling:** Even expert cardiologists disagree on F annotations
4. **Rarity:** <1% of beats → insufficient training samples
5. **k-NN limitation:** Needs dense neighborhoods; F beats are sparse in feature space

**Recommendation:** Consider merging F into V for product, or flagging low-confidence F predictions for human review.

### Why Supraventricular (S) Depends on RR Features

S beats are **premature atrial contractions** — they arrive early:
- **Short RR_before** (premature timing)
- **Compensatory pause** (longer RR_after)
- Morphology often similar to Normal

Without RR features, S is easily confused with N. The `prematurity_index` feature specifically captures this.

### k-NN as a Baseline (Not a Final Model)

**Use k-NN for:**
- Quick benchmarking
- Feature importance (which features reduce distance to correct class?)
- Debugging data issues

**Don't use k-NN for:**
- Production deployment (slow inference on large datasets)
- Learning complex temporal patterns
- Handling severe class imbalance (no built-in weighting)

## 8) Visualization

In [None]:
# ============================================================
# CONFUSION MATRIX (BEST k)
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
cm = confusion_matrix(y_test_enc, y_pred_best, labels=range(len(ABNORMAL_CLASSES)))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=ABNORMAL_CLASSES, yticklabels=ABNORMAL_CLASSES, ax=axes[0])
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')
axes[0].set_title(f'Confusion Matrix (k={best_k}) - Counts', fontweight='bold')

# Normalized (percentages)
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=ABNORMAL_CLASSES, yticklabels=ABNORMAL_CLASSES, ax=axes[1])
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')
axes[1].set_title(f'Confusion Matrix (k={best_k}) - Normalized', fontweight='bold')

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'knn_confusion_matrix.png', dpi=150)
plt.show()

print('\n📊 Confusion matrix saved to outputs_knn/')

In [None]:
# ============================================================
# 2D VISUALIZATION: PCA and t-SNE
# ============================================================

from sklearn.manifold import TSNE

# Reduce to 2D for visualization
pca_2d = PCA(n_components=2, random_state=SEED)
X_test_2d_pca = pca_2d.fit_transform(X_test_scaled)

# t-SNE (better for visualizing clusters, but slower)
print('Computing t-SNE (this may take a moment)...')
tsne = TSNE(n_components=2, random_state=SEED, perplexity=30, n_iter=1000)
X_test_2d_tsne = tsne.fit_transform(X_test_scaled)
print('t-SNE complete.')

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

colors = {'S': '#e74c3c', 'V': '#3498db', 'F': '#9b59b6', 'Q': '#f39c12'}

# PCA projection
for cls in ABNORMAL_CLASSES:
    mask = y_test == cls
    axes[0].scatter(X_test_2d_pca[mask, 0], X_test_2d_pca[mask, 1],
                    c=colors[cls], label=f'{cls} ({CLASS_NAMES[cls]})',
                    alpha=0.6, s=30)
axes[0].set_xlabel('PC1')
axes[0].set_ylabel('PC2')
axes[0].set_title('Test Data - 2D PCA', fontweight='bold')
axes[0].legend()

# t-SNE projection
for cls in ABNORMAL_CLASSES:
    mask = y_test == cls
    axes[1].scatter(X_test_2d_tsne[mask, 0], X_test_2d_tsne[mask, 1],
                    c=colors[cls], label=f'{cls}',
                    alpha=0.6, s=30)
axes[1].set_xlabel('t-SNE 1')
axes[1].set_ylabel('t-SNE 2')
axes[1].set_title('Test Data - t-SNE (better cluster viz)', fontweight='bold')
axes[1].legend()

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'knn_2d_projections.png', dpi=150)
plt.show()

print('\n📊 OBSERVATIONS:')
print('   - V (blue) should form distinct clusters (different morphology)')
print('   - S (red) and F (purple) often overlap (similar features)')
print('   - Q (orange) may be scattered (heterogeneous class)')

In [None]:
# ============================================================
# S vs F OVERLAP ANALYSIS (Key Insight)
# ============================================================

# Highlight just S and F to see overlap
fig, ax = plt.subplots(figsize=(10, 8))

# Plot other classes faded
for cls in ['V', 'Q']:
    mask = y_test == cls
    ax.scatter(X_test_2d_tsne[mask, 0], X_test_2d_tsne[mask, 1],
               c='lightgray', alpha=0.2, s=20, label=f'{cls} (background)')

# Plot S and F prominently
s_mask = y_test == 'S'
f_mask = y_test == 'F'
ax.scatter(X_test_2d_tsne[s_mask, 0], X_test_2d_tsne[s_mask, 1],
           c='#e74c3c', alpha=0.8, s=50, label=f'S (Supraventricular) n={s_mask.sum()}')
ax.scatter(X_test_2d_tsne[f_mask, 0], X_test_2d_tsne[f_mask, 1],
           c='#9b59b6', alpha=0.8, s=50, marker='x', label=f'F (Fusion) n={f_mask.sum()}')

ax.set_xlabel('t-SNE 1')
ax.set_ylabel('t-SNE 2')
ax.set_title('S vs F Overlap Analysis (t-SNE)', fontweight='bold', fontsize=14)
ax.legend()

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'knn_sf_overlap.png', dpi=150)
plt.show()

print('\n🔍 S vs F OVERLAP INSIGHT:')
print('   If S and F points are heavily intermingled, this explains low F1 for both.')
print('   k-NN struggles when classes share the same feature space region.')
print('\n   SOLUTIONS:')
print('   1. Engineer features that separate S and F (e.g., QRS width for F)')
print('   2. Use confidence thresholding for ambiguous predictions')
print('   3. Consider merging F into V for product simplicity')

In [None]:
# ============================================================
# MACRO F1 vs k PLOT
# ============================================================

fig, ax = plt.subplots(figsize=(8, 5))

ks = [r['k'] for r in results]
macro_f1s = [r['macro_f1'] for r in results]
accuracies = [r['accuracy'] for r in results]

ax.plot(ks, macro_f1s, 'bo-', linewidth=2, markersize=10, label='Macro F1 (PRIMARY)')
ax.plot(ks, accuracies, 'g^--', linewidth=2, markersize=8, alpha=0.7, label='Accuracy')

# Highlight best k
best_idx = macro_f1s.index(max(macro_f1s))
ax.scatter([ks[best_idx]], [macro_f1s[best_idx]], color='red', s=200, zorder=5, label=f'Best k={best_k}')

ax.set_xlabel('k (Number of Neighbors)', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('k-NN Performance vs. k', fontweight='bold', fontsize=14)
ax.set_xticks(ks)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_PATH / 'knn_k_selection.png', dpi=150)
plt.show()

## 9) Reproducibility & Good Practices

### Why Record-Wise Splitting is Required for MIT-BIH

The MIT-BIH database contains **48 half-hour ECG recordings** from 47 patients.

If we split by beats instead of records:
- Beats from the **same patient** could appear in both train and test
- These beats are **highly correlated** (same heart, same morphology)
- Model learns patient-specific patterns, not generalizable arrhythmia features
- **Result:** Inflated metrics that don't reflect real-world performance

**Solution:** Split by `record_id` to ensure no patient appears in both sets.

### Diagnostic Benchmarking vs. Deployment

This notebook is for **benchmarking** — understanding baseline k-NN performance.

For **clinical deployment**, you would need:
- External validation dataset (not MIT-BIH)
- Confidence calibration
- Real-time processing pipeline
- Regulatory approval (FDA, CE marking)
- Continuous monitoring and drift detection

## 10) Final Summary

In [None]:
# ============================================================
# FINAL SUMMARY
# ============================================================

print('\n' + '=' * 70)
print('🎯 k-NN ABNORMAL ECG CLASSIFICATION - DIAGNOSTIC SUMMARY')
print('=' * 70)

print(f'''
DATASET:
  Total abnormal beats: {len(df):,}
  Train beats: {len(df_train):,} ({100*len(df_train)/len(df):.1f}%)
  Test beats:  {len(df_test):,} ({100*len(df_test)/len(df):.1f}%)
  Records (train/test): {len(train_records)}/{len(test_records)}
  Normal (N) excluded: ✓

PREPROCESSING:
  Original features: {len(feature_cols)} (including RR-heavy features for S)
  PCA components: {n_components} (explaining {100*explained_var:.1f}% variance)
  
BEST MODEL:
  k = {best_k}
  Macro F1: {best_result['macro_f1']:.4f}
  Accuracy: {best_result['accuracy']:.4f}

PER-CLASS F1 SCORES:
  S (Supraventricular): {best_result['f1_S']:.4f} {'⚠️ RR-dependent' if best_result['f1_S'] < 0.5 else ''}
  V (Ventricular):      {best_result['f1_V']:.4f} {'✓ Strong' if best_result['f1_V'] > 0.8 else ''}
  F (Fusion):           {best_result['f1_F']:.4f} {'❌ Hardest class' if best_result['f1_F'] < 0.3 else ''}
  Q (Unknown/Paced):    {best_result['f1_Q']:.4f}

CONFIDENCE ANALYSIS:
  Low-confidence predictions (<50%): {best_result['low_conf_pct']:.1f}%
  High-confidence accuracy: {best_result['high_conf_acc']:.4f}

STRENGTHS:
  • V class: distinct morphology → high F1
  • Q class: pacing artifacts → moderate F1
  • Simple, interpretable baseline

WEAKNESSES:
  • F class: rare, ambiguous → very low F1
  • S class: depends on RR features, morphology similar to N
  • k-NN doesn't handle imbalance well

KEY MIT-BIH LIMITATIONS:
  • Small dataset (48 recordings, ~100k beats)
  • F class extremely rare (<1%)
  • Annotations sometimes ambiguous (inter-rater disagreement)
  • 1980s recordings — different from modern ECGs
  • Record-wise split reduces effective training data

CONCRETE NEXT STEPS:
  1. RR-feature enrichment for S (more prematurity measures)
  2. Confidence thresholding for F (flag, don't force prediction)
  3. Compare raw features vs CNN embeddings as k-NN input
  4. Decide: merge F into V for product, or keep separate?
  5. Validate on INCART or PTB-XL (external datasets)
  6. Try ensemble (k-NN + RF + XGBoost) for robustness

⚠️ REMINDER: This notebook is DIAGNOSTIC.
   k-NN is a baseline, not a production model.
   Do not oversell these results.

✅ Diagnostic Analysis Complete!
''')

In [None]:
# ============================================================
# SAVE RESULTS
# ============================================================

import json

# Save results summary
results_summary = {
    'best_k': int(best_k),
    'macro_f1': float(best_result['macro_f1']),
    'accuracy': float(best_result['accuracy']),
    'per_class_f1': {
        'S': float(best_result['f1_S']),
        'V': float(best_result['f1_V']),
        'F': float(best_result['f1_F']),
        'Q': float(best_result['f1_Q'])
    },
    'all_results': [{
        'k': int(r['k']),
        'macro_f1': float(r['macro_f1']),
        'accuracy': float(r['accuracy'])
    } for r in results],
    'preprocessing': {
        'n_features_original': len(feature_cols),
        'n_pca_components': int(n_components),
        'pca_variance_explained': float(explained_var)
    },
    'dataset': {
        'total_beats': int(len(df)),
        'train_beats': int(len(df_train)),
        'test_beats': int(len(df_test))
    }
}

with open(OUTPUT_PATH / 'knn_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print('\n✅ Results saved to outputs_knn/')
print(f'   - knn_results.json')
print(f'   - knn_confusion_matrix.png')
print(f'   - knn_pca_projection.png')
print(f'   - knn_k_selection.png')