In [20]:
import os
import wfdb
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

def get_record_list(data_dir):
    """
    Read the RECORDS file to get the list of record numbers.
    
    Parameters:
    data_dir (str): Path to the directory containing the MIT-BIH database files
    
    Returns:
    list: List of record numbers as strings
    """
    records_file = os.path.join(data_dir, 'RECORDS')
    with open(records_file, 'r') as f:
        return [line.strip() for line in f]

def load_mit_bih_records(data_dir):
    """
    Load all MIT-BIH Arrhythmia Database records and annotations from the specified directory.
    
    Parameters:
    data_dir (str): Path to the directory containing the MIT-BIH database files
    
    Returns:
    dict: Dictionary containing record information, signals, and annotations
    """
    # Get record numbers from RECORDS file
    record_numbers = get_record_list(data_dir)
    
    # Dictionary to store all records
    database = {
        'records': {},
        'metadata': {
            'sampling_frequency': None,
            'total_records': 0,
            'signal_length': None
        }
    }
    
    # Load each record
    for record_name in tqdm(record_numbers, desc="Loading ECG records"):
        record_path = os.path.join(data_dir, record_name)
        
        try:
            # Read the record
            record = wfdb.rdrecord(record_path)
            
            # Read the annotations
            try:
                ann = wfdb.rdann(record_path, 'atr')
                annotations = {
                    'sample': ann.sample,
                    'symbol': ann.symbol,
                    'subtype': ann.subtype,
                    'chan': ann.chan,
                    'num': ann.num,
                    'aux_note': ann.aux_note,
                    'fs': ann.fs
                }
                
                # Calculate annotation statistics
                symbol_counts = Counter(ann.symbol)
                
            except Exception as e:
                print(f"Error loading annotations for record {record_name}: {str(e)}")
                annotations = None
                symbol_counts = None
            
            # Store record information
            database['records'][record_name] = {
                'signals': record.p_signal,
                'channels': record.sig_name,
                'units': record.units,
                'fs': record.fs,
                'baseline': record.baseline,
                'comments': record.comments,
                'annotations': annotations,
                'annotation_counts': symbol_counts
            }
            
            # Update metadata
            if database['metadata']['sampling_frequency'] is None:
                database['metadata']['sampling_frequency'] = record.fs
            if database['metadata']['signal_length'] is None:
                database['metadata']['signal_length'] = len(record.p_signal)
            database['metadata']['total_records'] += 1
            
        except Exception as e:
            print(f"Error loading record {record_name}: {str(e)}")
            continue
    
    return database

def get_record_summary(database):
    """
    Generate a summary of the loaded records including annotation statistics.
    
    Parameters:
    database (dict): The database dictionary returned by load_mit_bih_records
    
    Returns:
    pd.DataFrame: Summary statistics for each record
    """
    summaries = []
    
    for record_name, record_data in database['records'].items():
        signals = record_data['signals']
        annotations = record_data['annotations']
        
        summary = {
            'record_name': record_name,
            'duration_seconds': len(signals) / record_data['fs'],
            'num_channels': signals.shape[1],
            'mean_ch1': np.mean(signals[:, 0]),
            'std_ch1': np.std(signals[:, 0]),
            'mean_ch2': np.mean(signals[:, 1]),
            'std_ch2': np.std(signals[:, 1]),
            'fs': record_data['fs'],
            'total_annotations': len(annotations['sample']) if annotations else 0
        }
        
        # Add annotation type counts if available
        if record_data['annotation_counts']:
            for symbol, count in record_data['annotation_counts'].items():
                summary[f'annotation_{symbol}'] = count
                
        summaries.append(summary)
    
    return pd.DataFrame(summaries)

def get_annotations_as_dataframe(record_data):
    """
    Convert annotations for a single record into a pandas DataFrame.
    
    Parameters:
    record_data (dict): Record data dictionary containing annotations
    
    Returns:
    pd.DataFrame: DataFrame containing all annotations with time information
    """
    if not record_data['annotations']:
        return None
        
    ann = record_data['annotations']
    df = pd.DataFrame({
        'sample': ann['sample'],
        'time': ann['sample'] / ann['fs'],
        'symbol': ann['symbol'],
        'subtype': ann['subtype'],
        'channel': ann['chan'],
        'aux_note': ann['aux_note']
    })
    
    return df

def get_global_annotation_counts(database):
    """
    Calculate total counts for each annotation symbol across all records.
    
    Parameters:
    database (dict): The database dictionary returned by load_mit_bih_records
    
    Returns:
    pd.DataFrame: DataFrame with symbol counts, percentages, and descriptions
    """
    # Initialize a Counter for all symbols
    global_counts = Counter()
    
    # Count symbols across all records
    for record_data in database['records'].values():
        if record_data['annotation_counts']:
            global_counts.update(record_data['annotation_counts'])
    
    # Create DataFrame with counts and percentages
    total_annotations = sum(global_counts.values())
    df = pd.DataFrame([
        {
            'symbol': symbol,
            'count': count,
            'percentage': (count / total_annotations * 100),
            'description': get_symbol_description(symbol)
        }
        for symbol, count in global_counts.most_common()
    ])
    
    # Format percentage column
    df['percentage'] = df['percentage'].round(2)
    
    return df

def get_symbol_description(symbol):
    """
    Get the description for each annotation symbol.
    
    Parameters:
    symbol (str): The annotation symbol
    
    Returns:
    str: Description of the symbol
    """
    descriptions = {
        'N': 'Normal beat',
        'L': 'Left bundle branch block beat',
        'R': 'Right bundle branch block beat',
        'B': 'Bundle branch block beat (unspecified)',
        'A': 'Atrial premature beat',
        'a': 'Aberrated atrial premature beat',
        'J': 'Nodal (junctional) premature beat',
        'S': 'Supraventricular premature or ectopic beat',
        'V': 'Premature ventricular contraction',
        'r': 'R-on-T premature ventricular contraction',
        'F': 'Fusion of ventricular and normal beat',
        'e': 'Atrial escape beat',
        'j': 'Nodal (junctional) escape beat',
        'n': 'Supraventricular escape beat (atrial or nodal)',
        'E': 'Ventricular escape beat',
        '/': 'Paced beat',
        'f': 'Fusion of paced and normal beat',
        'Q': 'Unclassifiable beat',
        '?': 'Beat not classified during learning',
        '[': 'Start of ventricular flutter/fibrillation',
        ']': 'End of ventricular flutter/fibrillation',
        '!': 'Ventricular flutter wave',
        'x': 'Non-conducted P-wave (blocked APC)',
        '(': 'Waveform onset',
        ')': 'Waveform end',
        'p': 'Peak of P-wave',
        't': 'Peak of T-wave',
        'u': 'Peak of U-wave',
        '`': 'PQ junction',
        '\'': 'J-point',
        '^': 'Non-captured pacemaker artifact',
        '|': 'Isolated QRS-like artifact',
        '~': 'Change in signal quality',
        '+': 'Rhythm change',
        's': 'ST segment change',
        'T': 'T-wave change',
        '*': 'Systole',
        'D': 'Diastole',
        '=': 'Measurement annotation',
        '"': 'Comment annotation',
        '@': 'Link to external data'
    }
    return descriptions.get(symbol, 'Unknown annotation type')

if __name__ == "__main__":
    # Set data directory
    data_dir = 'src/data/mit-bih-arrhythmia-database-1.0.0'
    
    # Load all records
    print("Loading ECG database...")
    database = load_mit_bih_records(data_dir)
    
    # Print basic information
    print(f"\nLoaded {database['metadata']['total_records']} records from RECORDS file")
    
    # Get global annotation statistics
    print("\nGlobal Annotation Statistics:")
    annotation_stats = get_global_annotation_counts(database)
    pd.set_option('display.max_rows', None)  # Show all rows
    print(annotation_stats.to_string(index=False))

Loading ECG database...


Loading ECG records: 100%|█████████████████████████████████████████████████████████████| 48/48 [00:05<00:00,  8.47it/s]



Loaded 48 records from RECORDS file

Global Annotation Statistics:
symbol  count  percentage                                description
     N  75052       66.63                                Normal beat
     L   8075        7.17              Left bundle branch block beat
     R   7259        6.44             Right bundle branch block beat
     V   7130        6.33          Premature ventricular contraction
     /   7028        6.24                                 Paced beat
     A   2546        2.26                      Atrial premature beat
     +   1291        1.15                              Rhythm change
     f    982        0.87            Fusion of paced and normal beat
     F    803        0.71      Fusion of ventricular and normal beat
     ~    616        0.55                   Change in signal quality
     !    472        0.42                   Ventricular flutter wave
     "    437        0.39                         Comment annotation
     j    229        0.20          

In [21]:
from scipy import signal
import os
import numpy as np
import pandas as pd
from collections import Counter
import wfdb
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

def prepare_data(data_dir, top_n_labels=5, max_samples_per_class=1000):
    """Prepare balanced ECG datasets with consistent labels across splits"""
    
    # First, collect all samples across all records
    all_samples = []
    all_labels = []
    
    # Read record numbers
    with open(os.path.join(data_dir, 'RECORDS'), 'r') as f:
        record_numbers = [line.strip() for line in f]
    
    # First pass: collect all labels to find most common
    print("Analyzing label distribution...")
    label_counter = Counter()
    for record_num in tqdm(record_numbers):
        record_path = os.path.join(data_dir, str(record_num))
        try:
            annotations = wfdb.rdann(record_path, 'atr')
            label_counter.update(annotations.symbol)
        except:
            continue
    
    # Get top N most common labels
    common_labels = set([label for label, _ in label_counter.most_common(top_n_labels)])
    print(f"\nMost common labels: {common_labels}")
    print("Label frequencies:", 
          [(label, count) for label, count in label_counter.most_common(top_n_labels)])
    
    # Second pass: collect samples for common labels
    print("\nCollecting samples...")
    for record_num in tqdm(record_numbers):
        record_path = os.path.join(data_dir, str(record_num))
        try:
            record = wfdb.rdrecord(record_path)
            annotations = wfdb.rdann(record_path, 'atr')
            
            # Get signal from first channel
            signal_data = record.p_signal[:, 0]
            
            # Create spectrogram
            frequencies, times, Sxx = signal.spectrogram(
                signal_data,
                fs=record.fs,
                window='hann',
                nperseg=1024,
                noverlap=512,
                detrend='constant'
            )
            
            # Log scale and normalize spectrogram
            Sxx = np.log1p(Sxx)
            Sxx = (Sxx - Sxx.mean()) / (Sxx.std() + 1e-8)
            
            # Collect samples for common labels
            for i in range(len(times)):
                time_point = times[i]
                ann_idx = np.searchsorted(annotations.sample / record.fs, time_point)
                if ann_idx < len(annotations.symbol):
                    label = annotations.symbol[ann_idx]
                    if label in common_labels:
                        all_samples.append(Sxx[:, i].reshape(Sxx.shape[0], 1))
                        all_labels.append(label)
                        
        except Exception as e:
            print(f"Error processing record {record_num}: {str(e)}")
            continue
    
    # Convert to numpy arrays
    all_samples = np.array(all_samples)
    all_labels = np.array(all_labels)
    
    # Balance classes
    balanced_samples = []
    balanced_labels = []
    
    for label in common_labels:
        mask = all_labels == label
        label_samples = all_samples[mask]
        
        # Limit samples per class
        n_samples = min(len(label_samples), max_samples_per_class)
        indices = np.random.choice(len(label_samples), n_samples, replace=False)
        
        balanced_samples.append(label_samples[indices])
        balanced_labels.extend([label] * n_samples)
    
    balanced_samples = np.concatenate(balanced_samples, axis=0)
    balanced_labels = np.array(balanced_labels)
    
    # Split into train and validation
    indices = np.arange(len(balanced_labels))
    train_idx, val_idx = train_test_split(indices, test_size=0.2, stratify=balanced_labels)
    
    # Create datasets
    train_dataset = ECGDataset(balanced_samples[train_idx], balanced_labels[train_idx])
    val_dataset = ECGDataset(balanced_samples[val_idx], balanced_labels[val_idx])
    
    return train_dataset, val_dataset

class ECGDataset(Dataset):
    def __init__(self, samples, labels):
        self.samples = samples
        self.label_encoder = LabelEncoder()
        self.labels = self.label_encoder.fit_transform(labels)
        
        # Print distribution
        label_counts = Counter(self.labels)
        print("\nDataset label distribution:")
        for label_idx, count in label_counts.items():
            label_name = self.label_encoder.inverse_transform([label_idx])[0]
            print(f"Label {label_name}: {count} samples")
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        sample = torch.FloatTensor(self.samples[idx])
        label = torch.LongTensor([self.labels[idx]])[0]
        return sample.unsqueeze(0), label


class SimpleCNN(nn.Module):
    def __init__(self, num_classes, input_channels=1):
        super(SimpleCNN, self).__init__()
        
        self.features = nn.Sequential(
            # First conv block - careful with padding to handle the 1D-like input
            nn.Conv2d(input_channels, 16, kernel_size=(7, 3), padding=(3, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d((2, 1)),
            
            # Second conv block
            nn.Conv2d(16, 32, kernel_size=(5, 3), padding=(2, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d((2, 1)),
            
            # Third conv block
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d((2, 1)),
        )
        
        # Adaptive pooling to handle variable input sizes
        self.adaptive_pool = nn.AdaptiveAvgPool2d((8, 1))
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 1, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.adaptive_pool(x)
        x = self.classifier(x)
        return x

def compute_class_weights(labels):
    """Compute class weights based on the inverse frequency of each class."""
    class_counts = Counter(labels)
    total_samples = len(labels)
    num_classes = len(class_counts)
    
    weights = {cls: total_samples / (num_classes * count) for cls, count in class_counts.items()}
    
    # Convert to tensor
    class_weights = torch.tensor([weights[i] for i in range(num_classes)], dtype=torch.float)
    return class_weights


def train_model(model, train_loader, val_loader, num_epochs=10, device='cuda'):
    # Compute class weights from the dataset
    all_labels = [label.item() for _, label in train_loader.dataset]
    class_weights = compute_class_weights(all_labels).to(device)
    
    # Use weighted loss function
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    model = model.to(device)
    best_val_acc = 0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (spectrograms, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(spectrograms)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for spectrograms, labels in val_loader:
                spectrograms, labels = spectrograms.to(device), labels.to(device)
                outputs = model(spectrograms)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        # Print epoch statistics
        print(f'\nEpoch {epoch+1}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {100.*train_correct/train_total:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {100.*val_correct/val_total:.2f}%')
        
        # Save best model
        val_acc = 100. * val_correct / val_total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')


if __name__ == "__main__":
    # Set data directory
    data_dir = 'src/data/mit-bih-arrhythmia-database-1.0.0'
    
    # Prepare datasets with top 5 most common labels
    train_dataset, val_dataset = prepare_data(data_dir, top_n_labels=5)
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    
    # Initialize model
    num_classes = len(train_dataset.label_encoder.classes_)
    model = SimpleCNN(num_classes)
    
    # Print model summary
    sample_data, _ = train_dataset[0]
    print(f"\nInput shape: {sample_data.shape}")
    
    # Train model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_model(model, train_loader, val_loader, num_epochs=10, device=device)

Analyzing label distribution...


100%|██████████████████████████████████████████████████████████████████████████████████| 48/48 [00:03<00:00, 15.24it/s]



Most common labels: {'R', 'V', 'N', '/', 'L'}
Label frequencies: [('N', 75052), ('L', 8075), ('R', 7259), ('V', 7130), ('/', 7028)]

Collecting samples...


100%|██████████████████████████████████████████████████████████████████████████████████| 48/48 [00:11<00:00,  4.20it/s]



Dataset label distribution:
Label /: 800 samples
Label L: 800 samples
Label N: 800 samples
Label V: 800 samples
Label R: 800 samples

Dataset label distribution:
Label L: 200 samples
Label R: 200 samples
Label V: 200 samples
Label /: 200 samples
Label N: 200 samples
Training samples: 4000
Validation samples: 1000

Input shape: torch.Size([1, 513, 1])


Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████| 125/125 [00:07<00:00, 17.43it/s]



Epoch 1:
Train Loss: 0.7110, Train Acc: 72.97%
Val Loss: 0.3538, Val Acc: 88.90%


Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████| 125/125 [00:07<00:00, 17.58it/s]



Epoch 2:
Train Loss: 0.3308, Train Acc: 88.70%
Val Loss: 0.3983, Val Acc: 87.10%


Epoch 3/10: 100%|████████████████████████████████████████████████████████████████████| 125/125 [00:06<00:00, 18.17it/s]



Epoch 3:
Train Loss: 0.2631, Train Acc: 90.60%
Val Loss: 0.2169, Val Acc: 92.00%


Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████| 125/125 [00:06<00:00, 18.08it/s]



Epoch 4:
Train Loss: 0.2104, Train Acc: 93.12%
Val Loss: 0.2322, Val Acc: 91.80%


Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████| 125/125 [00:07<00:00, 17.56it/s]



Epoch 5:
Train Loss: 0.1887, Train Acc: 93.50%
Val Loss: 0.1829, Val Acc: 93.50%


Epoch 6/10: 100%|████████████████████████████████████████████████████████████████████| 125/125 [00:07<00:00, 17.79it/s]



Epoch 6:
Train Loss: 0.1764, Train Acc: 93.83%
Val Loss: 0.2825, Val Acc: 91.50%


Epoch 7/10: 100%|████████████████████████████████████████████████████████████████████| 125/125 [00:07<00:00, 17.51it/s]



Epoch 7:
Train Loss: 0.1388, Train Acc: 95.10%
Val Loss: 0.2093, Val Acc: 93.40%


Epoch 8/10: 100%|████████████████████████████████████████████████████████████████████| 125/125 [00:07<00:00, 17.67it/s]



Epoch 8:
Train Loss: 0.1261, Train Acc: 95.55%
Val Loss: 0.1772, Val Acc: 93.90%


Epoch 9/10: 100%|████████████████████████████████████████████████████████████████████| 125/125 [00:06<00:00, 17.95it/s]



Epoch 9:
Train Loss: 0.1302, Train Acc: 95.38%
Val Loss: 0.2541, Val Acc: 91.60%


Epoch 10/10: 100%|███████████████████████████████████████████████████████████████████| 125/125 [00:07<00:00, 17.63it/s]



Epoch 10:
Train Loss: 0.1210, Train Acc: 95.62%
Val Loss: 0.2066, Val Acc: 94.40%


In [22]:
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm

def analyze_model_predictions(model, data_loader, device, label_encoder):
    """
    Analyze model predictions and generate detailed diagnostics.
    """
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    incorrect_samples = []
    
    with torch.no_grad():
        for batch_idx, (spectrograms, labels) in enumerate(tqdm(data_loader, desc="Analyzing predictions")):
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            outputs = model(spectrograms)
            probabilities = torch.softmax(outputs, dim=1)
            
            preds = outputs.argmax(dim=1)
            
            # Store predictions and labels
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probabilities.cpu().numpy())
            
            # Store incorrect predictions for analysis
            incorrect_mask = preds != labels
            if incorrect_mask.any():
                incorrect_indices = torch.where(incorrect_mask)[0]
                for idx in incorrect_indices:
                    incorrect_samples.append({
                        'true_label': labels[idx].item(),
                        'predicted_label': preds[idx].item(),
                        'confidence': probabilities[idx][preds[idx]].item(),
                        'batch_idx': batch_idx,
                        'sample_idx': idx.item()
                    })
    
    return analyze_results(all_preds, all_labels, all_probs, incorrect_samples, label_encoder)

def analyze_results(all_preds, all_labels, all_probs, incorrect_samples, label_encoder):
    """
    Generate comprehensive analysis of model predictions.
    """
    # Convert numerical labels to original classes and ensure they're regular Python strings
    class_names = [str(name) for name in label_encoder.classes_]
    true_classes = [class_names[label] for label in all_labels]
    pred_classes = [class_names[pred] for pred in all_preds]
    
    # 1. Generate confusion matrix
    cm = confusion_matrix(true_classes, pred_classes)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    # 2. Generate classification report
    report = classification_report(true_classes, pred_classes, output_dict=True)
    report_df = pd.DataFrame(report).transpose()
    
    # 3. Analyze prediction confidence
    confidence_analysis = {
        'correct_conf': [],
        'incorrect_conf': []
    }
    
    for i in range(len(all_preds)):
        conf = all_probs[i][all_preds[i]]
        if all_preds[i] == all_labels[i]:
            confidence_analysis['correct_conf'].append(conf)
        else:
            confidence_analysis['incorrect_conf'].append(conf)
    
    # Plot confidence distributions
    plt.figure(figsize=(10, 6))
    plt.hist(confidence_analysis['correct_conf'], alpha=0.5, label='Correct', bins=50)
    plt.hist(confidence_analysis['incorrect_conf'], alpha=0.5, label='Incorrect', bins=50)
    plt.title('Prediction Confidence Distribution')
    plt.xlabel('Confidence')
    plt.ylabel('Count')
    plt.legend()
    plt.tight_layout()
    plt.savefig('confidence_distribution.png')
    plt.close()
    
    # 4. Analyze most common confusion pairs
    confusion_pairs = []
    for sample in incorrect_samples:
        true_class = class_names[sample['true_label']]
        pred_class = class_names[sample['predicted_label']]
        confusion_pairs.append((true_class, pred_class))
    
    # Handle confusion pairs
    if confusion_pairs:  # Only process if there are incorrect predictions
        confusion_series = pd.Series(confusion_pairs)
        confusion_values = confusion_series.value_counts()
        confusion_counts = pd.DataFrame([
            {'True Class': str(true), 'Predicted Class': str(pred), 'Count': count}
            for (true, pred), count in confusion_values.items()
        ])
    else:
        confusion_counts = pd.DataFrame(columns=['True Class', 'Predicted Class', 'Count'])
    
    # 5. Generate summary statistics
    class_performance = {}
    for class_name in class_names:
        str_class_name = str(class_name)
        if str_class_name in report:
            class_performance[str_class_name] = {
                'precision': report[str_class_name]['precision'],
                'recall': report[str_class_name]['recall'],
                'f1-score': report[str_class_name]['f1-score'],
                'support': report[str_class_name]['support']
            }
    
    summary_stats = {
        'Overall Accuracy': report['accuracy'],
        'Most Confused Pairs': confusion_counts.head(5).to_dict('records'),
        'Per Class Performance': class_performance,
        'Average Confidence': {
            'Correct Predictions': np.mean(confidence_analysis['correct_conf']) if confidence_analysis['correct_conf'] else 0.0,
            'Incorrect Predictions': np.mean(confidence_analysis['incorrect_conf']) if confidence_analysis['incorrect_conf'] else 0.0
        }
    }
    
    return {
        'confusion_matrix': cm,
        'classification_report': report_df,
        'confidence_analysis': confidence_analysis,
        'confusion_pairs': confusion_counts,
        'summary_stats': summary_stats,
        'incorrect_samples': incorrect_samples
    }

def print_analysis_results(results):
    """
    Print formatted analysis results.
    """
    print("\n=== ECG Classification Model Analysis ===\n")
    
    # Overall Performance
    print("Overall Performance:")
    print(f"Accuracy: {results['summary_stats']['Overall Accuracy']:.4f}")
    print("\n=== Per-Class Performance ===")
    print(results['classification_report'])
    
    # Most Common Confusion Pairs
    print("\nTop 5 Most Common Confusion Pairs:")
    for pair in results['summary_stats']['Most Confused Pairs']:
        print(f"True: {pair['True Class']} → Predicted: {pair['Predicted Class']} "
              f"(Count: {pair['Count']})")
    
    # Confidence Analysis
    print("\nConfidence Analysis:")
    print(f"Average confidence for correct predictions: "
          f"{results['summary_stats']['Average Confidence']['Correct Predictions']:.4f}")
    print(f"Average confidence for incorrect predictions: "
          f"{results['summary_stats']['Average Confidence']['Incorrect Predictions']:.4f}")
    
    # Examples of High-Confidence Mistakes
    print("\nHigh-Confidence Mistakes:")
    high_conf_mistakes = sorted(
        [x for x in results['incorrect_samples'] if x['confidence'] > 0.8],
        key=lambda x: x['confidence'],
        reverse=True
    )[:5]
    
    for mistake in high_conf_mistakes:
        true_label = results['classification_report'].index[mistake['true_label']]
        pred_label = results['classification_report'].index[mistake['predicted_label']]
        print(f"True: {true_label} → Predicted: {pred_label} "
              f"(Confidence: {mistake['confidence']:.4f})")

# Usage example
if __name__ == "__main__":
    # Assuming model, data_loader, and label_encoder are defined
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    results = analyze_model_predictions(model, val_loader, device, train_dataset.label_encoder)
    print_analysis_results(results)

Analyzing predictions: 100%|███████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 41.51it/s]



=== ECG Classification Model Analysis ===

Overall Performance:
Accuracy: 0.9440

=== Per-Class Performance ===
              precision  recall  f1-score   support
/              0.980392   1.000  0.990099   200.000
L              0.979275   0.945  0.961832   200.000
N              0.912371   0.885  0.898477   200.000
R              0.964286   0.945  0.954545   200.000
V              0.887324   0.945  0.915254   200.000
accuracy       0.944000   0.944  0.944000     0.944
macro avg      0.944730   0.944  0.944042  1000.000
weighted avg   0.944730   0.944  0.944042  1000.000

Top 5 Most Common Confusion Pairs:
True: N → Predicted: V (Count: 14)
True: V → Predicted: N (Count: 10)
True: N → Predicted: R (Count: 6)
True: R → Predicted: V (Count: 6)
True: R → Predicted: N (Count: 5)

Confidence Analysis:
Average confidence for correct predictions: 0.9751
Average confidence for incorrect predictions: 0.7657

High-Confidence Mistakes:
True: V → Predicted: N (Confidence: 1.0000)
True: V → Pred