In [5]:
import os
import wfdb
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

def get_record_list(data_dir):
    """
    Read the RECORDS file to get the list of record numbers.
    
    Parameters:
    data_dir (str): Path to the directory containing the MIT-BIH database files
    
    Returns:
    list: List of record numbers as strings
    """
    records_file = os.path.join(data_dir, 'RECORDS')
    with open(records_file, 'r') as f:
        return [line.strip() for line in f]

def load_mit_bih_records(data_dir):
    """
    Load all MIT-BIH Arrhythmia Database records and annotations from the specified directory.
    
    Parameters:
    data_dir (str): Path to the directory containing the MIT-BIH database files
    
    Returns:
    dict: Dictionary containing record information, signals, and annotations
    """
    # Get record numbers from RECORDS file
    record_numbers = get_record_list(data_dir)
    
    # Dictionary to store all records
    database = {
        'records': {},
        'metadata': {
            'sampling_frequency': None,
            'total_records': 0,
            'signal_length': None
        }
    }
    
    # Load each record
    for record_name in tqdm(record_numbers, desc="Loading ECG records"):
        record_path = os.path.join(data_dir, record_name)
        
        try:
            # Read the record
            record = wfdb.rdrecord(record_path)
            
            # Read the annotations
            try:
                ann = wfdb.rdann(record_path, 'atr')
                annotations = {
                    'sample': ann.sample,
                    'symbol': ann.symbol,
                    'subtype': ann.subtype,
                    'chan': ann.chan,
                    'num': ann.num,
                    'aux_note': ann.aux_note,
                    'fs': ann.fs
                }
                
                # Calculate annotation statistics
                symbol_counts = Counter(ann.symbol)
                
            except Exception as e:
                print(f"Error loading annotations for record {record_name}: {str(e)}")
                annotations = None
                symbol_counts = None
            
            # Store record information
            database['records'][record_name] = {
                'signals': record.p_signal,
                'channels': record.sig_name,
                'units': record.units,
                'fs': record.fs,
                'baseline': record.baseline,
                'comments': record.comments,
                'annotations': annotations,
                'annotation_counts': symbol_counts
            }
            
            # Update metadata
            if database['metadata']['sampling_frequency'] is None:
                database['metadata']['sampling_frequency'] = record.fs
            if database['metadata']['signal_length'] is None:
                database['metadata']['signal_length'] = len(record.p_signal)
            database['metadata']['total_records'] += 1
            
        except Exception as e:
            print(f"Error loading record {record_name}: {str(e)}")
            continue
    
    return database

def get_record_summary(database):
    """
    Generate a summary of the loaded records including annotation statistics.
    
    Parameters:
    database (dict): The database dictionary returned by load_mit_bih_records
    
    Returns:
    pd.DataFrame: Summary statistics for each record
    """
    summaries = []
    
    for record_name, record_data in database['records'].items():
        signals = record_data['signals']
        annotations = record_data['annotations']
        
        summary = {
            'record_name': record_name,
            'duration_seconds': len(signals) / record_data['fs'],
            'num_channels': signals.shape[1],
            'mean_ch1': np.mean(signals[:, 0]),
            'std_ch1': np.std(signals[:, 0]),
            'mean_ch2': np.mean(signals[:, 1]),
            'std_ch2': np.std(signals[:, 1]),
            'fs': record_data['fs'],
            'total_annotations': len(annotations['sample']) if annotations else 0
        }
        
        # Add annotation type counts if available
        if record_data['annotation_counts']:
            for symbol, count in record_data['annotation_counts'].items():
                summary[f'annotation_{symbol}'] = count
                
        summaries.append(summary)
    
    return pd.DataFrame(summaries)

def get_annotations_as_dataframe(record_data):
    """
    Convert annotations for a single record into a pandas DataFrame.
    
    Parameters:
    record_data (dict): Record data dictionary containing annotations
    
    Returns:
    pd.DataFrame: DataFrame containing all annotations with time information
    """
    if not record_data['annotations']:
        return None
        
    ann = record_data['annotations']
    df = pd.DataFrame({
        'sample': ann['sample'],
        'time': ann['sample'] / ann['fs'],
        'symbol': ann['symbol'],
        'subtype': ann['subtype'],
        'channel': ann['chan'],
        'aux_note': ann['aux_note']
    })
    
    return df

def get_global_annotation_counts(database):
    """
    Calculate total counts for each annotation symbol across all records.
    
    Parameters:
    database (dict): The database dictionary returned by load_mit_bih_records
    
    Returns:
    pd.DataFrame: DataFrame with symbol counts, percentages, and descriptions
    """
    # Initialize a Counter for all symbols
    global_counts = Counter()
    
    # Count symbols across all records
    for record_data in database['records'].values():
        if record_data['annotation_counts']:
            global_counts.update(record_data['annotation_counts'])
    
    # Create DataFrame with counts and percentages
    total_annotations = sum(global_counts.values())
    df = pd.DataFrame([
        {
            'symbol': symbol,
            'count': count,
            'percentage': (count / total_annotations * 100),
            'description': get_symbol_description(symbol)
        }
        for symbol, count in global_counts.most_common()
    ])
    
    # Format percentage column
    df['percentage'] = df['percentage'].round(2)
    
    return df

def get_symbol_description(symbol):
    """
    Get the description for each annotation symbol.
    
    Parameters:
    symbol (str): The annotation symbol
    
    Returns:
    str: Description of the symbol
    """
    descriptions = {
        'N': 'Normal beat',
        'L': 'Left bundle branch block beat',
        'R': 'Right bundle branch block beat',
        'B': 'Bundle branch block beat (unspecified)',
        'A': 'Atrial premature beat',
        'a': 'Aberrated atrial premature beat',
        'J': 'Nodal (junctional) premature beat',
        'S': 'Supraventricular premature or ectopic beat',
        'V': 'Premature ventricular contraction',
        'r': 'R-on-T premature ventricular contraction',
        'F': 'Fusion of ventricular and normal beat',
        'e': 'Atrial escape beat',
        'j': 'Nodal (junctional) escape beat',
        'n': 'Supraventricular escape beat (atrial or nodal)',
        'E': 'Ventricular escape beat',
        '/': 'Paced beat',
        'f': 'Fusion of paced and normal beat',
        'Q': 'Unclassifiable beat',
        '?': 'Beat not classified during learning',
        '[': 'Start of ventricular flutter/fibrillation',
        ']': 'End of ventricular flutter/fibrillation',
        '!': 'Ventricular flutter wave',
        'x': 'Non-conducted P-wave (blocked APC)',
        '(': 'Waveform onset',
        ')': 'Waveform end',
        'p': 'Peak of P-wave',
        't': 'Peak of T-wave',
        'u': 'Peak of U-wave',
        '`': 'PQ junction',
        '\'': 'J-point',
        '^': 'Non-captured pacemaker artifact',
        '|': 'Isolated QRS-like artifact',
        '~': 'Change in signal quality',
        '+': 'Rhythm change',
        's': 'ST segment change',
        'T': 'T-wave change',
        '*': 'Systole',
        'D': 'Diastole',
        '=': 'Measurement annotation',
        '"': 'Comment annotation',
        '@': 'Link to external data'
    }
    return descriptions.get(symbol, 'Unknown annotation type')

if __name__ == "__main__":
    # Set data directory
    data_dir = '../data/mit-bih-arrhythmia-database-1.0.0'
    
    # Load all records
    print("Loading ECG database...")
    database = load_mit_bih_records(data_dir)
    
    # Print basic information
    print(f"\nLoaded {database['metadata']['total_records']} records from RECORDS file")
    
    # Get global annotation statistics
    print("\nGlobal Annotation Statistics:")
    annotation_stats = get_global_annotation_counts(database)
    pd.set_option('display.max_rows', None)  # Show all rows
    print(annotation_stats.to_string(index=False))

Loading ECG database...


Loading ECG records: 100%|█████████████████████████████████████████████████████████████| 48/48 [00:07<00:00,  6.56it/s]


Loaded 48 records from RECORDS file

Global Annotation Statistics:
symbol  count  percentage                                description
     N  75052       66.63                                Normal beat
     L   8075        7.17              Left bundle branch block beat
     R   7259        6.44             Right bundle branch block beat
     V   7130        6.33          Premature ventricular contraction
     /   7028        6.24                                 Paced beat
     A   2546        2.26                      Atrial premature beat
     +   1291        1.15                              Rhythm change
     f    982        0.87            Fusion of paced and normal beat
     F    803        0.71      Fusion of ventricular and normal beat
     ~    616        0.55                   Change in signal quality
     !    472        0.42                   Ventricular flutter wave
     "    437        0.39                         Comment annotation
     j    229        0.20          




In [7]:
from scipy import signal
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

class ECGSpectrogramDataset(Dataset):
    def __init__(self, data_dir, record_numbers, window_size=1024, overlap=512):
        self.data_dir = data_dir
        self.record_numbers = record_numbers
        self.window_size = window_size
        self.overlap = overlap
        self.spectrograms = []
        self.labels = []
        
        self._load_data()
        
    def _load_data(self):
        """Load ECG data and create spectrograms"""
        for record_num in tqdm(self.record_numbers, desc="Loading data"):
            # Load record and annotations
            record_path = os.path.join(self.data_dir, str(record_num))
            record = wfdb.rdrecord(record_path)
            annotations = wfdb.rdann(record_path, 'atr')
            
            # Get signal from first channel
            signal_data = record.p_signal[:, 0]
            
            # Create spectrogram
            frequencies, times, Sxx = signal.spectrogram(
                signal_data,
                fs=record.fs,
                window='hann',
                nperseg=self.window_size,
                noverlap=self.overlap,
                detrend='constant'
            )
            
            # Log scale spectrogram
            Sxx = np.log1p(Sxx)
            
            # Normalize spectrogram
            Sxx = (Sxx - Sxx.mean()) / (Sxx.std() + 1e-8)
            
            # Split spectrogram into segments and match with annotations
            segment_duration = times[1] - times[0]
            num_freq_bins = Sxx.shape[0]
            
            for i in range(len(times)):
                time_point = times[i]
                
                # Find the closest annotation
                ann_idx = np.searchsorted(annotations.sample / record.fs, time_point)
                if ann_idx < len(annotations.symbol):
                    label = annotations.symbol[ann_idx]
                    
                    # Store spectrogram segment and label
                    self.spectrograms.append(Sxx[:, i].reshape(num_freq_bins, 1))
                    self.labels.append(label)
        
        # Convert to numpy arrays
        self.spectrograms = np.array(self.spectrograms)
        
        # Encode labels
        self.label_encoder = LabelEncoder()
        self.labels = self.label_encoder.fit_transform(self.labels)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Shape: [num_freq_bins, 1]
        spectrogram = torch.FloatTensor(self.spectrograms[idx])
        label = torch.LongTensor([self.labels[idx]])[0]
        return spectrogram.unsqueeze(0), label  # Add channel dimension [1, num_freq_bins, 1]

class SimpleCNN(nn.Module):
    def __init__(self, num_classes, input_channels=1):
        super(SimpleCNN, self).__init__()
        
        self.features = nn.Sequential(
            # First conv block - careful with padding to handle the 1D-like input
            nn.Conv2d(input_channels, 16, kernel_size=(7, 3), padding=(3, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d((2, 1)),
            
            # Second conv block
            nn.Conv2d(16, 32, kernel_size=(5, 3), padding=(2, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d((2, 1)),
            
            # Third conv block
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d((2, 1)),
        )
        
        # Adaptive pooling to handle variable input sizes
        self.adaptive_pool = nn.AdaptiveAvgPool2d((8, 1))
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 1, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.adaptive_pool(x)
        x = self.classifier(x)
        return x

def train_model(model, train_loader, val_loader, num_epochs=10, device='cuda'):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    model = model.to(device)
    best_val_acc = 0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (spectrograms, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(spectrograms)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
            
            if batch_idx == 0:  # Print shapes for first batch to verify dimensions
                print(f"Batch shapes - Input: {spectrograms.shape}, Output: {outputs.shape}")
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for spectrograms, labels in val_loader:
                spectrograms, labels = spectrograms.to(device), labels.to(device)
                outputs = model(spectrograms)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        # Print epoch statistics
        print(f'\nEpoch {epoch+1}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, '
              f'Train Acc: {100.*train_correct/train_total:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, '
              f'Val Acc: {100.*val_correct/val_total:.2f}%')
        
        # Save best model
        val_acc = 100. * val_correct / val_total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')

if __name__ == "__main__":
    # Set data directory
    data_dir = '../data/mit-bih-arrhythmia-database-1.0.0'
    
    # Read record numbers from RECORDS file
    with open(os.path.join(data_dir, 'RECORDS'), 'r') as f:
        record_numbers = [line.strip() for line in f]

    # Count the frequency of annotations in each record
    record_frequencies = {}
    for record_num in record_numbers:
        record_path = os.path.join(data_dir, str(record_num))
        try:
            annotations = wfdb.rdann(record_path, 'atr')
            record_frequencies[record_num] = len(annotations.symbol)  # Count of annotations
        except:
            continue  # Skip if annotation file is missing

    # Select top 5 most frequent records
    top_records = [record for record, _ in Counter(record_frequencies).most_common(5)]

    # Split records into train and validation sets
    train_records, val_records = train_test_split(top_records, test_size=0.2, random_state=42)
        
    # Create datasets
    train_dataset = ECGSpectrogramDataset(data_dir, train_records)
    val_dataset = ECGSpectrogramDataset(data_dir, val_records)
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    
    # Initialize model
    num_classes = len(train_dataset.label_encoder.classes_)
    model = SimpleCNN(num_classes)
    
    # Print model summary
    sample_data, _ = train_dataset[0]
    print(f"\nInput shape: {sample_data.shape}")
    
    # Train model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_model(model, train_loader, val_loader, num_epochs=10, device=device)

Loading data: 100%|████████████████████████████████████████████████████████████████████| 38/38 [00:11<00:00,  3.27it/s]
Loading data: 100%|████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  3.95it/s]


Training samples: 48184
Validation samples: 12680

Input shape: torch.Size([1, 513, 1])


Epoch 1/10:   0%|▏                                                                    | 3/1506 [00:00<01:56, 12.93it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 1/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [01:32<00:00, 16.27it/s]



Epoch 1:
Train Loss: 0.4309, Train Acc: 87.76%
Val Loss: 8.5396, Val Acc: 0.09%


Epoch 2/10:   0%|                                                                     | 2/1506 [00:00<01:32, 16.19it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 2/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [01:31<00:00, 16.51it/s]



Epoch 2:
Train Loss: 0.2620, Train Acc: 92.73%
Val Loss: 10.7438, Val Acc: 0.32%


Epoch 3/10:   0%|                                                                     | 2/1506 [00:00<01:37, 15.49it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 3/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [01:32<00:00, 16.36it/s]



Epoch 3:
Train Loss: 0.2353, Train Acc: 93.38%
Val Loss: 11.0026, Val Acc: 0.11%


Epoch 4/10:   0%|                                                                     | 2/1506 [00:00<01:30, 16.59it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 4/10:   4%|██▍                                                                 | 55/1506 [00:03<01:26, 16.84it/s]


KeyboardInterrupt: 

In [8]:
from scipy import signal
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

class EnhancedECGSpectrogramDataset(Dataset):
    def __init__(self, data_dir, record_numbers, window_size=2048, overlap=1024):
        self.data_dir = data_dir
        self.record_numbers = record_numbers
        self.window_size = window_size
        self.overlap = overlap
        self.spectrograms = []
        self.labels = []
        
    def _load_data(self):
        for record_num in tqdm(self.record_numbers, desc="Loading data"):
            record_path = os.path.join(self.data_dir, str(record_num))
            record = wfdb.rdrecord(record_path)
            annotations = wfdb.rdann(record_path, 'atr')
            
            # Use both channels instead of just first channel
            for channel in range(record.p_signal.shape[1]):
                signal_data = record.p_signal[:, channel]
                
                # Apply bandpass filter to remove noise
                nyquist = record.fs / 2
                low = 0.5 / nyquist
                high = 50.0 / nyquist
                b, a = signal.butter(4, [low, high], btype='band')
                filtered_signal = signal.filtfilt(b, a, signal_data)
                
                # Create spectrogram with modified parameters
                frequencies, times, Sxx = signal.spectrogram(
                    filtered_signal,
                    fs=record.fs,
                    window='hamming',  # Changed from hann
                    nperseg=self.window_size,
                    noverlap=self.overlap,
                    detrend='linear',  # Changed from constant
                    scaling='density'  # Added scaling parameter
                )
                
                # Apply more sophisticated normalization
                Sxx = np.log1p(Sxx)  # Log scale
                Sxx = (Sxx - np.percentile(Sxx, 1)) / (np.percentile(Sxx, 99) - np.percentile(Sxx, 1))
                Sxx = np.clip(Sxx, 0, 1)  # Clip outliers
                
                # Include context windows
                context_size = 3
                for i in range(context_size, len(times) - context_size):
                    context_window = Sxx[:, i-context_size:i+context_size+1]
                    
                    # Find the closest annotation
                    time_point = times[i]
                    ann_idx = np.searchsorted(annotations.sample / record.fs, time_point)
                    if ann_idx < len(annotations.symbol):
                        label = annotations.symbol[ann_idx]
                        self.spectrograms.append(context_window)
                        self.labels.append(label)

class ImprovedCNN(nn.Module):
    def __init__(self, num_classes, input_channels=1):
        super(ImprovedCNN, self).__init__()
        
        # Residual block definition
        def residual_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels)
            )
        
        self.features = nn.ModuleList([
            # Initial convolution
            nn.Sequential(
                nn.Conv2d(input_channels, 32, kernel_size=7, padding=3),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.MaxPool2d(2)
            ),
            
            # Residual blocks with increasing channels
            residual_block(32, 32),
            residual_block(32, 64),
            residual_block(64, 128),
            
            # SE block for channel attention
            nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(128, 64, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(64, 128, kernel_size=1),
                nn.Sigmoid()
            )
        ])
        
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        # Apply feature extraction with residual connections
        identity = x
        for i, layer in enumerate(self.features[:-1]):  # Exclude SE block
            x = layer(x)
            if i > 0 and i % 2 == 0:  # Add residual every 2 blocks
                if identity.shape == x.shape:
                    x += identity
                identity = x
        
        # Apply SE block
        se_weights = self.features[-1](x)
        x = x * se_weights
        
        # Global pooling and classification
        x = self.global_pool(x)
        x = self.classifier(x)
        return x

def train_model_with_scheduler(model, train_loader, val_loader, num_epochs=20, device='cuda'):
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Added label smoothing
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    
    # Cosine annealing scheduler with warm restarts
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=5, T_mult=2, eta_min=1e-6
    )
    
    # Initialize mixup
    alpha = 0.2
    
    model = model.to(device)
    best_val_acc = 0
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        
        for spectrograms, labels in train_loader:
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            
            # Apply mixup
            if np.random.random() > 0.5:
                lam = np.random.beta(alpha, alpha)
                index = torch.randperm(spectrograms.size(0)).to(device)
                mixed_spectrograms = lam * spectrograms + (1 - lam) * spectrograms[index]
                
                outputs = model(mixed_spectrograms)
                loss = lam * criterion(outputs, labels) + (1 - lam) * criterion(outputs, labels[index])
            else:
                outputs = model(spectrograms)
                loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            train_loss += loss.item()
        
        scheduler.step()
        
        # Validation logic remains the same...
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for spectrograms, labels in val_loader:
                spectrograms, labels = spectrograms.to(device), labels.to(device)
                outputs = model(spectrograms)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        # Print epoch statistics
        print(f'\nEpoch {epoch+1}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, '
              f'Train Acc: {100.*train_correct/train_total:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, '
              f'Val Acc: {100.*val_correct/val_total:.2f}%')
        
        # Save best model
        val_acc = 100. * val_correct / val_total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')

if __name__ == "__main__":
    # Set data directory
    data_dir = '../data/mit-bih-arrhythmia-database-1.0.0'
    
    # Read record numbers from RECORDS file
    with open(os.path.join(data_dir, 'RECORDS'), 'r') as f:
        record_numbers = [line.strip() for line in f]

    # Count the frequency of annotations in each record
    record_frequencies = {}
    for record_num in record_numbers:
        record_path = os.path.join(data_dir, str(record_num))
        try:
            annotations = wfdb.rdann(record_path, 'atr')
            record_frequencies[record_num] = len(annotations.symbol)  # Count of annotations
        except:
            continue  # Skip if annotation file is missing

    # Select top 5 most frequent records
    top_records = [record for record, _ in Counter(record_frequencies).most_common(5)]

    # Split records into train and validation sets
    train_records, val_records = train_test_split(top_records, test_size=0.2, random_state=42)
        
    # Create datasets
    train_dataset = ECGSpectrogramDataset(data_dir, train_records)
    val_dataset = ECGSpectrogramDataset(data_dir, val_records)
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    
    # Initialize model
    num_classes = len(train_dataset.label_encoder.classes_)
    model = SimpleCNN(num_classes)
    
    # Print model summary
    sample_data, _ = train_dataset[0]
    print(f"\nInput shape: {sample_data.shape}")
    
    # Train model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_model(model, train_loader, val_loader, num_epochs=10, device=device)

Loading data: 100%|██████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.67it/s]
Loading data: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.31it/s]


Training samples: 5072
Validation samples: 1268

Input shape: torch.Size([1, 513, 1])


Epoch 1/10:   1%|▉                                                                     | 2/159 [00:00<00:11, 14.21it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 9])


Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████| 159/159 [00:09<00:00, 16.67it/s]



Epoch 1:
Train Loss: 0.3674, Train Acc: 90.20%
Val Loss: 1.8043, Val Acc: 81.39%


Epoch 2/10:   1%|▉                                                                     | 2/159 [00:00<00:08, 17.62it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 9])


Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████| 159/159 [00:09<00:00, 16.36it/s]



Epoch 2:
Train Loss: 0.2421, Train Acc: 92.59%
Val Loss: 1.6483, Val Acc: 81.55%


Epoch 3/10:   1%|▉                                                                     | 2/159 [00:00<00:09, 16.33it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 9])


Epoch 3/10: 100%|████████████████████████████████████████████████████████████████████| 159/159 [00:09<00:00, 16.45it/s]



Epoch 3:
Train Loss: 0.2208, Train Acc: 93.77%
Val Loss: 1.6065, Val Acc: 81.47%


Epoch 4/10:   1%|▉                                                                     | 2/159 [00:00<00:10, 15.49it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 9])


Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████| 159/159 [00:09<00:00, 17.16it/s]



Epoch 4:
Train Loss: 0.1994, Train Acc: 93.95%
Val Loss: 1.8520, Val Acc: 81.47%


Epoch 5/10:   1%|▉                                                                     | 2/159 [00:00<00:09, 16.46it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 9])


Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████| 159/159 [00:09<00:00, 16.65it/s]



Epoch 5:
Train Loss: 0.1997, Train Acc: 94.34%
Val Loss: 1.4827, Val Acc: 79.97%


Epoch 6/10:   1%|▉                                                                     | 2/159 [00:00<00:08, 17.64it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 9])


Epoch 6/10: 100%|████████████████████████████████████████████████████████████████████| 159/159 [00:09<00:00, 16.76it/s]



Epoch 6:
Train Loss: 0.1892, Train Acc: 94.72%
Val Loss: 1.7665, Val Acc: 80.76%


Epoch 7/10:   1%|▉                                                                     | 2/159 [00:00<00:10, 15.57it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 9])


Epoch 7/10: 100%|████████████████████████████████████████████████████████████████████| 159/159 [00:09<00:00, 16.71it/s]



Epoch 7:
Train Loss: 0.1821, Train Acc: 94.72%
Val Loss: 1.2518, Val Acc: 71.69%


Epoch 8/10:   1%|▉                                                                     | 2/159 [00:00<00:08, 18.40it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 9])


Epoch 8/10: 100%|████████████████████████████████████████████████████████████████████| 159/159 [00:09<00:00, 16.50it/s]



Epoch 8:
Train Loss: 0.1659, Train Acc: 94.85%
Val Loss: 1.4601, Val Acc: 78.63%


Epoch 9/10:   1%|▉                                                                     | 2/159 [00:00<00:09, 16.55it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 9])


Epoch 9/10: 100%|████████████████████████████████████████████████████████████████████| 159/159 [00:09<00:00, 17.02it/s]



Epoch 9:
Train Loss: 0.1680, Train Acc: 95.11%
Val Loss: 1.8065, Val Acc: 80.28%


Epoch 10/10:   1%|▊                                                                    | 2/159 [00:00<00:09, 16.82it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 9])


Epoch 10/10: 100%|███████████████████████████████████████████████████████████████████| 159/159 [00:09<00:00, 16.74it/s]



Epoch 10:
Train Loss: 0.1569, Train Acc: 95.70%
Val Loss: 1.7002, Val Acc: 74.05%


In [9]:
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm

def analyze_model_predictions(model, data_loader, device, label_encoder):
    """
    Analyze model predictions and generate detailed diagnostics.
    """
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    incorrect_samples = []
    
    with torch.no_grad():
        for batch_idx, (spectrograms, labels) in enumerate(tqdm(data_loader, desc="Analyzing predictions")):
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            outputs = model(spectrograms)
            probabilities = torch.softmax(outputs, dim=1)
            
            preds = outputs.argmax(dim=1)
            
            # Store predictions and labels
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probabilities.cpu().numpy())
            
            # Store incorrect predictions for analysis
            incorrect_mask = preds != labels
            if incorrect_mask.any():
                incorrect_indices = torch.where(incorrect_mask)[0]
                for idx in incorrect_indices:
                    incorrect_samples.append({
                        'true_label': labels[idx].item(),
                        'predicted_label': preds[idx].item(),
                        'confidence': probabilities[idx][preds[idx]].item(),
                        'batch_idx': batch_idx,
                        'sample_idx': idx.item()
                    })
    
    return analyze_results(all_preds, all_labels, all_probs, incorrect_samples, label_encoder)

def analyze_results(all_preds, all_labels, all_probs, incorrect_samples, label_encoder):
    """
    Generate comprehensive analysis of model predictions.
    """
    # Convert numerical labels to original classes
    class_names = label_encoder.classes_
    true_classes = [class_names[label] for label in all_labels]
    pred_classes = [class_names[pred] for pred in all_preds]
    
    # 1. Generate confusion matrix
    cm = confusion_matrix(true_classes, pred_classes)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    # 2. Generate classification report
    report = classification_report(true_classes, pred_classes, output_dict=True)
    report_df = pd.DataFrame(report).transpose()
    
    # 3. Analyze prediction confidence
    confidence_analysis = {
        'correct_conf': [],
        'incorrect_conf': []
    }
    
    for i in range(len(all_preds)):
        conf = all_probs[i][all_preds[i]]
        if all_preds[i] == all_labels[i]:
            confidence_analysis['correct_conf'].append(conf)
        else:
            confidence_analysis['incorrect_conf'].append(conf)
    
    # Plot confidence distributions
    plt.figure(figsize=(10, 6))
    plt.hist(confidence_analysis['correct_conf'], alpha=0.5, label='Correct', bins=50)
    plt.hist(confidence_analysis['incorrect_conf'], alpha=0.5, label='Incorrect', bins=50)
    plt.title('Prediction Confidence Distribution')
    plt.xlabel('Confidence')
    plt.ylabel('Count')
    plt.legend()
    plt.tight_layout()
    plt.savefig('confidence_distribution.png')
    plt.close()
    
    # 4. Analyze most common confusion pairs
    confusion_pairs = []
    for sample in incorrect_samples:
        true_class = class_names[sample['true_label']]
        pred_class = class_names[sample['predicted_label']]
        confusion_pairs.append((true_class, pred_class))
    
    confusion_counts = pd.DataFrame(
        pd.Series(confusion_pairs).value_counts(),
        columns=['count']
    ).reset_index()
    confusion_counts.columns = ['True Class', 'Predicted Class', 'Count']
    
    # 5. Generate summary statistics
    summary_stats = {
        'Overall Accuracy': report['accuracy'],
        'Most Confused Pairs': confusion_counts.head(5).to_dict('records'),
        'Per Class Performance': {
            class_name: {
                'precision': report[class_name]['precision'],
                'recall': report[class_name]['recall'],
                'f1-score': report[class_name]['f1-score'],
                'support': report[class_name]['support']
            }
            for class_name in class_names
        },
        'Average Confidence': {
            'Correct Predictions': np.mean(confidence_analysis['correct_conf']),
            'Incorrect Predictions': np.mean(confidence_analysis['incorrect_conf'])
        }
    }
    
    return {
        'confusion_matrix': cm,
        'classification_report': report_df,
        'confidence_analysis': confidence_analysis,
        'confusion_pairs': confusion_counts,
        'summary_stats': summary_stats,
        'incorrect_samples': incorrect_samples
    }

def print_analysis_results(results):
    """
    Print formatted analysis results.
    """
    print("\n=== ECG Classification Model Analysis ===\n")
    
    # Overall Performance
    print("Overall Performance:")
    print(f"Accuracy: {results['summary_stats']['Overall Accuracy']:.4f}")
    print("\n=== Per-Class Performance ===")
    print(results['classification_report'])
    
    # Most Common Confusion Pairs
    print("\nTop 5 Most Common Confusion Pairs:")
    for pair in results['summary_stats']['Most Confused Pairs']:
        print(f"True: {pair['True Class']} → Predicted: {pair['Predicted Class']} "
              f"(Count: {pair['Count']})")
    
    # Confidence Analysis
    print("\nConfidence Analysis:")
    print(f"Average confidence for correct predictions: "
          f"{results['summary_stats']['Average Confidence']['Correct Predictions']:.4f}")
    print(f"Average confidence for incorrect predictions: "
          f"{results['summary_stats']['Average Confidence']['Incorrect Predictions']:.4f}")
    
    # Examples of High-Confidence Mistakes
    print("\nHigh-Confidence Mistakes:")
    high_conf_mistakes = sorted(
        [x for x in results['incorrect_samples'] if x['confidence'] > 0.8],
        key=lambda x: x['confidence'],
        reverse=True
    )[:5]
    
    for mistake in high_conf_mistakes:
        print(f"True: {mistake['true_label']} → Predicted: {mistake['predicted_label']} "
              f"(Confidence: {mistake['confidence']:.4f})")

# Usage example
if __name__ == "__main__":
    # Assuming model, data_loader, and label_encoder are defined
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    results = analyze_model_predictions(model, val_loader, device, train_dataset.label_encoder)
    print_analysis_results(results)

Analyzing predictions: 100%|███████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 37.08it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


ValueError: Length mismatch: Expected axis has 2 elements, new values have 3 elements