In [1]:
import os
import wfdb
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter

def get_record_list(data_dir):
    """
    Read the RECORDS file to get the list of record numbers.
    
    Parameters:
    data_dir (str): Path to the directory containing the MIT-BIH database files
    
    Returns:
    list: List of record numbers as strings
    """
    records_file = os.path.join(data_dir, 'RECORDS')
    with open(records_file, 'r') as f:
        return [line.strip() for line in f]

def load_mit_bih_records(data_dir):
    """
    Load all MIT-BIH Arrhythmia Database records and annotations from the specified directory.
    
    Parameters:
    data_dir (str): Path to the directory containing the MIT-BIH database files
    
    Returns:
    dict: Dictionary containing record information, signals, and annotations
    """
    # Get record numbers from RECORDS file
    record_numbers = get_record_list(data_dir)
    
    # Dictionary to store all records
    database = {
        'records': {},
        'metadata': {
            'sampling_frequency': None,
            'total_records': 0,
            'signal_length': None
        }
    }
    
    # Load each record
    for record_name in tqdm(record_numbers, desc="Loading ECG records"):
        record_path = os.path.join(data_dir, record_name)
        
        try:
            # Read the record
            record = wfdb.rdrecord(record_path)
            
            # Read the annotations
            try:
                ann = wfdb.rdann(record_path, 'atr')
                annotations = {
                    'sample': ann.sample,
                    'symbol': ann.symbol,
                    'subtype': ann.subtype,
                    'chan': ann.chan,
                    'num': ann.num,
                    'aux_note': ann.aux_note,
                    'fs': ann.fs
                }
                
                # Calculate annotation statistics
                symbol_counts = Counter(ann.symbol)
                
            except Exception as e:
                print(f"Error loading annotations for record {record_name}: {str(e)}")
                annotations = None
                symbol_counts = None
            
            # Store record information
            database['records'][record_name] = {
                'signals': record.p_signal,
                'channels': record.sig_name,
                'units': record.units,
                'fs': record.fs,
                'baseline': record.baseline,
                'comments': record.comments,
                'annotations': annotations,
                'annotation_counts': symbol_counts
            }
            
            # Update metadata
            if database['metadata']['sampling_frequency'] is None:
                database['metadata']['sampling_frequency'] = record.fs
            if database['metadata']['signal_length'] is None:
                database['metadata']['signal_length'] = len(record.p_signal)
            database['metadata']['total_records'] += 1
            
        except Exception as e:
            print(f"Error loading record {record_name}: {str(e)}")
            continue
    
    return database

def get_record_summary(database):
    """
    Generate a summary of the loaded records including annotation statistics.
    
    Parameters:
    database (dict): The database dictionary returned by load_mit_bih_records
    
    Returns:
    pd.DataFrame: Summary statistics for each record
    """
    summaries = []
    
    for record_name, record_data in database['records'].items():
        signals = record_data['signals']
        annotations = record_data['annotations']
        
        summary = {
            'record_name': record_name,
            'duration_seconds': len(signals) / record_data['fs'],
            'num_channels': signals.shape[1],
            'mean_ch1': np.mean(signals[:, 0]),
            'std_ch1': np.std(signals[:, 0]),
            'mean_ch2': np.mean(signals[:, 1]),
            'std_ch2': np.std(signals[:, 1]),
            'fs': record_data['fs'],
            'total_annotations': len(annotations['sample']) if annotations else 0
        }
        
        # Add annotation type counts if available
        if record_data['annotation_counts']:
            for symbol, count in record_data['annotation_counts'].items():
                summary[f'annotation_{symbol}'] = count
                
        summaries.append(summary)
    
    return pd.DataFrame(summaries)

def get_annotations_as_dataframe(record_data):
    """
    Convert annotations for a single record into a pandas DataFrame.
    
    Parameters:
    record_data (dict): Record data dictionary containing annotations
    
    Returns:
    pd.DataFrame: DataFrame containing all annotations with time information
    """
    if not record_data['annotations']:
        return None
        
    ann = record_data['annotations']
    df = pd.DataFrame({
        'sample': ann['sample'],
        'time': ann['sample'] / ann['fs'],
        'symbol': ann['symbol'],
        'subtype': ann['subtype'],
        'channel': ann['chan'],
        'aux_note': ann['aux_note']
    })
    
    return df

def get_global_annotation_counts(database):
    """
    Calculate total counts for each annotation symbol across all records.
    
    Parameters:
    database (dict): The database dictionary returned by load_mit_bih_records
    
    Returns:
    pd.DataFrame: DataFrame with symbol counts, percentages, and descriptions
    """
    # Initialize a Counter for all symbols
    global_counts = Counter()
    
    # Count symbols across all records
    for record_data in database['records'].values():
        if record_data['annotation_counts']:
            global_counts.update(record_data['annotation_counts'])
    
    # Create DataFrame with counts and percentages
    total_annotations = sum(global_counts.values())
    df = pd.DataFrame([
        {
            'symbol': symbol,
            'count': count,
            'percentage': (count / total_annotations * 100),
            'description': get_symbol_description(symbol)
        }
        for symbol, count in global_counts.most_common()
    ])
    
    # Format percentage column
    df['percentage'] = df['percentage'].round(2)
    
    return df

def get_symbol_description(symbol):
    """
    Get the description for each annotation symbol.
    
    Parameters:
    symbol (str): The annotation symbol
    
    Returns:
    str: Description of the symbol
    """
    descriptions = {
        'N': 'Normal beat',
        'L': 'Left bundle branch block beat',
        'R': 'Right bundle branch block beat',
        'B': 'Bundle branch block beat (unspecified)',
        'A': 'Atrial premature beat',
        'a': 'Aberrated atrial premature beat',
        'J': 'Nodal (junctional) premature beat',
        'S': 'Supraventricular premature or ectopic beat',
        'V': 'Premature ventricular contraction',
        'r': 'R-on-T premature ventricular contraction',
        'F': 'Fusion of ventricular and normal beat',
        'e': 'Atrial escape beat',
        'j': 'Nodal (junctional) escape beat',
        'n': 'Supraventricular escape beat (atrial or nodal)',
        'E': 'Ventricular escape beat',
        '/': 'Paced beat',
        'f': 'Fusion of paced and normal beat',
        'Q': 'Unclassifiable beat',
        '?': 'Beat not classified during learning',
        '[': 'Start of ventricular flutter/fibrillation',
        ']': 'End of ventricular flutter/fibrillation',
        '!': 'Ventricular flutter wave',
        'x': 'Non-conducted P-wave (blocked APC)',
        '(': 'Waveform onset',
        ')': 'Waveform end',
        'p': 'Peak of P-wave',
        't': 'Peak of T-wave',
        'u': 'Peak of U-wave',
        '`': 'PQ junction',
        '\'': 'J-point',
        '^': 'Non-captured pacemaker artifact',
        '|': 'Isolated QRS-like artifact',
        '~': 'Change in signal quality',
        '+': 'Rhythm change',
        's': 'ST segment change',
        'T': 'T-wave change',
        '*': 'Systole',
        'D': 'Diastole',
        '=': 'Measurement annotation',
        '"': 'Comment annotation',
        '@': 'Link to external data'
    }
    return descriptions.get(symbol, 'Unknown annotation type')

if __name__ == "__main__":
    # Set data directory
    data_dir = '../data/mit-bih-arrhythmia-database-1.0.0'
    
    # Load all records
    print("Loading ECG database...")
    database = load_mit_bih_records(data_dir)
    
    # Print basic information
    print(f"\nLoaded {database['metadata']['total_records']} records from RECORDS file")
    
    # Get global annotation statistics
    print("\nGlobal Annotation Statistics:")
    annotation_stats = get_global_annotation_counts(database)
    pd.set_option('display.max_rows', None)  # Show all rows
    print(annotation_stats.to_string(index=False))

Loading ECG database...


Loading ECG records: 100%|█████████████████████████████████████████████████████████████| 48/48 [00:02<00:00, 17.53it/s]


Loaded 48 records from RECORDS file

Global Annotation Statistics:
symbol  count  percentage                                description
     N  75052       66.63                                Normal beat
     L   8075        7.17              Left bundle branch block beat
     R   7259        6.44             Right bundle branch block beat
     V   7130        6.33          Premature ventricular contraction
     /   7028        6.24                                 Paced beat
     A   2546        2.26                      Atrial premature beat
     +   1291        1.15                              Rhythm change
     f    982        0.87            Fusion of paced and normal beat
     F    803        0.71      Fusion of ventricular and normal beat
     ~    616        0.55                   Change in signal quality
     !    472        0.42                   Ventricular flutter wave
     "    437        0.39                         Comment annotation
     j    229        0.20          




In [2]:
from scipy import signal
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

class ECGSpectrogramDataset(Dataset):
    def __init__(self, data_dir, record_numbers, window_size=1024, overlap=512):
        self.data_dir = data_dir
        self.record_numbers = record_numbers
        self.window_size = window_size
        self.overlap = overlap
        self.spectrograms = []
        self.labels = []
        
        self._load_data()
        
    def _load_data(self):
        """Load ECG data and create spectrograms"""
        for record_num in tqdm(self.record_numbers, desc="Loading data"):
            # Load record and annotations
            record_path = os.path.join(self.data_dir, str(record_num))
            record = wfdb.rdrecord(record_path)
            annotations = wfdb.rdann(record_path, 'atr')
            
            # Get signal from first channel
            signal_data = record.p_signal[:, 0]
            
            # Create spectrogram
            frequencies, times, Sxx = signal.spectrogram(
                signal_data,
                fs=record.fs,
                window='hann',
                nperseg=self.window_size,
                noverlap=self.overlap,
                detrend='constant'
            )
            
            # Log scale spectrogram
            Sxx = np.log1p(Sxx)
            
            # Normalize spectrogram
            Sxx = (Sxx - Sxx.mean()) / (Sxx.std() + 1e-8)
            
            # Split spectrogram into segments and match with annotations
            segment_duration = times[1] - times[0]
            num_freq_bins = Sxx.shape[0]
            
            for i in range(len(times)):
                time_point = times[i]
                
                # Find the closest annotation
                ann_idx = np.searchsorted(annotations.sample / record.fs, time_point)
                if ann_idx < len(annotations.symbol):
                    label = annotations.symbol[ann_idx]
                    
                    # Store spectrogram segment and label
                    self.spectrograms.append(Sxx[:, i].reshape(num_freq_bins, 1))
                    self.labels.append(label)
        
        # Convert to numpy arrays
        self.spectrograms = np.array(self.spectrograms)
        
        # Encode labels
        self.label_encoder = LabelEncoder()
        self.labels = self.label_encoder.fit_transform(self.labels)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Shape: [num_freq_bins, 1]
        spectrogram = torch.FloatTensor(self.spectrograms[idx])
        label = torch.LongTensor([self.labels[idx]])[0]
        return spectrogram.unsqueeze(0), label  # Add channel dimension [1, num_freq_bins, 1]

class SimpleCNN(nn.Module):
    def __init__(self, num_classes, input_channels=1):
        super(SimpleCNN, self).__init__()
        
        self.features = nn.Sequential(
            # First conv block - careful with padding to handle the 1D-like input
            nn.Conv2d(input_channels, 16, kernel_size=(7, 3), padding=(3, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d((2, 1)),
            
            # Second conv block
            nn.Conv2d(16, 32, kernel_size=(5, 3), padding=(2, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d((2, 1)),
            
            # Third conv block
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d((2, 1)),
        )
        
        # Adaptive pooling to handle variable input sizes
        self.adaptive_pool = nn.AdaptiveAvgPool2d((8, 1))
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 1, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.adaptive_pool(x)
        x = self.classifier(x)
        return x

def train_model(model, train_loader, val_loader, num_epochs=10, device='cuda'):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    model = model.to(device)
    best_val_acc = 0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (spectrograms, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(spectrograms)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()
            
            if batch_idx == 0:  # Print shapes for first batch to verify dimensions
                print(f"Batch shapes - Input: {spectrograms.shape}, Output: {outputs.shape}")
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for spectrograms, labels in val_loader:
                spectrograms, labels = spectrograms.to(device), labels.to(device)
                outputs = model(spectrograms)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        # Print epoch statistics
        print(f'\nEpoch {epoch+1}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, '
              f'Train Acc: {100.*train_correct/train_total:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}, '
              f'Val Acc: {100.*val_correct/val_total:.2f}%')
        
        # Save best model
        val_acc = 100. * val_correct / val_total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')

if __name__ == "__main__":
    # Set data directory
    data_dir = '../data/mit-bih-arrhythmia-database-1.0.0'
    
    # Read record numbers from RECORDS file
    with open(os.path.join(data_dir, 'RECORDS'), 'r') as f:
        record_numbers = [line.strip() for line in f]
    
    # Split records into train and validation sets
    train_records, val_records = train_test_split(record_numbers, test_size=0.2, random_state=42)
    
    # Create datasets
    train_dataset = ECGSpectrogramDataset(data_dir, train_records)
    val_dataset = ECGSpectrogramDataset(data_dir, val_records)
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    
    # Initialize model
    num_classes = len(train_dataset.label_encoder.classes_)
    model = SimpleCNN(num_classes)
    
    # Print model summary
    sample_data, _ = train_dataset[0]
    print(f"\nInput shape: {sample_data.shape}")
    
    # Train model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_model(model, train_loader, val_loader, num_epochs=10, device=device)

Loading data: 100%|████████████████████████████████████████████████████████████████████| 38/38 [00:05<00:00,  7.02it/s]
Loading data: 100%|████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  7.22it/s]


Training samples: 48184
Validation samples: 12680

Input shape: torch.Size([1, 513, 1])


Epoch 1/10:   0%|▏                                                                    | 5/1506 [00:00<01:17, 19.43it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 1/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [00:46<00:00, 32.45it/s]



Epoch 1:
Train Loss: 0.4385, Train Acc: 87.55%
Val Loss: 8.7253, Val Acc: 0.23%


Epoch 2/10:   0%|                                                                     | 2/1506 [00:00<01:23, 18.04it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 2/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [00:48<00:00, 30.78it/s]



Epoch 2:
Train Loss: 0.2660, Train Acc: 92.65%
Val Loss: 8.5691, Val Acc: 0.19%


Epoch 3/10:   0%|▏                                                                    | 3/1506 [00:00<00:54, 27.37it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 3/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [00:47<00:00, 31.53it/s]



Epoch 3:
Train Loss: 0.2363, Train Acc: 93.38%
Val Loss: 10.3035, Val Acc: 0.36%


Epoch 4/10:   0%|▏                                                                    | 3/1506 [00:00<00:51, 28.95it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 4/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [00:45<00:00, 33.22it/s]



Epoch 4:
Train Loss: 0.2175, Train Acc: 93.86%
Val Loss: 11.1067, Val Acc: 0.17%


Epoch 5/10:   0%|▏                                                                    | 4/1506 [00:00<00:44, 33.68it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 5/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [00:40<00:00, 37.59it/s]



Epoch 5:
Train Loss: 0.2065, Train Acc: 94.22%
Val Loss: 9.8123, Val Acc: 0.21%


Epoch 6/10:   0%|▏                                                                    | 4/1506 [00:00<00:42, 35.21it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 6/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [00:41<00:00, 36.17it/s]



Epoch 6:
Train Loss: 0.1978, Train Acc: 94.42%
Val Loss: 10.8939, Val Acc: 0.14%


Epoch 7/10:   0%|▏                                                                    | 4/1506 [00:00<00:42, 35.15it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 7/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [00:41<00:00, 35.90it/s]



Epoch 7:
Train Loss: 0.1858, Train Acc: 94.66%
Val Loss: 9.3641, Val Acc: 0.32%


Epoch 8/10:   0%|▏                                                                    | 4/1506 [00:00<00:42, 35.27it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 8/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [00:40<00:00, 37.30it/s]



Epoch 8:
Train Loss: 0.1805, Train Acc: 94.76%
Val Loss: 11.1980, Val Acc: 0.24%


Epoch 9/10:   0%|▏                                                                    | 4/1506 [00:00<00:41, 36.50it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 9/10: 100%|██████████████████████████████████████████████████████████████████| 1506/1506 [00:40<00:00, 37.29it/s]



Epoch 9:
Train Loss: 0.1759, Train Acc: 95.06%
Val Loss: 10.2643, Val Acc: 0.29%


Epoch 10/10:   0%|▏                                                                   | 4/1506 [00:00<00:42, 34.99it/s]

Batch shapes - Input: torch.Size([32, 1, 513, 1]), Output: torch.Size([32, 22])


Epoch 10/10: 100%|█████████████████████████████████████████████████████████████████| 1506/1506 [00:40<00:00, 36.90it/s]



Epoch 10:
Train Loss: 0.1711, Train Acc: 95.12%
Val Loss: 10.5236, Val Acc: 0.22%
