<a href="https://colab.research.google.com/github/mewanDimalsha/e19-4yp-LowComplexity-Algorithms-For-EnergyEfficient-Arrhythmia-Classification-In-Wearable-Devices/blob/main/Copy_of_SNN_revised2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
from google.colab import drive
drive.mount('/content/drive')

# Install necessary libraries
!pip install wfdb neurokit2 imblearn snntorch torch numpy scipy matplotlib seaborn --quiet

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [11]:
import os
import wfdb

# Define data directory in Google Drive
data_dir = '/content/drive/MyDrive/ecg_snn_project/data/mitdb'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    wfdb.dl_database('mitdb', data_dir)  # Downloads MIT-BIH dataset

In [12]:
# Function to load an ECG record
def load_ecg(record_id, data_dir):
    record = wfdb.rdrecord(f'{data_dir}/{record_id}')
    annotation = wfdb.rdann(f'{data_dir}/{record_id}', 'atr')
    signal = record.p_signal[:, 0]
    fs = record.fs
    return signal, annotation.sample, fs, annotation

In [13]:
import numpy as np
from scipy import signal
from imblearn.over_sampling import SMOTE

# Denoising
def bandpass_filter(signal_arr, fs, lowcut=0.5, highcut=40):
    nyquist = fs / 2
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = signal.butter(2, [low, high], btype='band')
    return signal.filtfilt(b, a, signal_arr)

def notch_filter(signal_arr, fs, freq=50, Q=30):
    b, a = signal.iirnotch(freq, Q, fs)
    return signal.filtfilt(b, a, signal_arr)

def remove_baseline(signal_arr, fs, window_size=0.2):
    window_samples = int(window_size * fs)
    baseline = signal.savgol_filter(signal_arr, window_samples, 2)
    return signal_arr - baseline

In [14]:
#segmentation
def extract_heartbeats(signal_arr, fs, annotation_rpeaks, fixed_length=250):
    beats = []
    valid_rpeaks = []
    half_length = fixed_length // 2
    for rpeak in annotation_rpeaks:
        start = rpeak - half_length
        end = rpeak + half_length
        if start >= 0 and end <= len(signal_arr):
            beat = signal_arr[start:end]
            beats.append(beat)
            valid_rpeaks.append(rpeak)
    return np.array(beats), np.array(valid_rpeaks)

In [15]:
# Class Balancing
def balance_classes(X, y):
    smote = SMOTE(random_state=42)
    X_balanced, y_balanced = smote.fit_resample(X, y)
    return X_balanced, y_balanced

In [16]:
# Normalization
def normalize_beats(beats):
    min_val = beats.min(axis=1, keepdims=True)
    max_val = beats.max(axis=1, keepdims=True)
    return (beats - min_val) / (max_val - min_val + 1e-8)

In [17]:
# Label Creation
AAMI_classes = {
    0: ['N', 'L', 'R', 'e', 'j'],  # Normal
    1: ['A', 'a', 'J', 'S', 'V', 'E', 'F', 'P', '/', 'f', 'u']  # Non-Normal
}

def get_class_from_symbol(symbol):
    for class_id, symbols in AAMI_classes.items():
        if symbol in symbols:
            return class_id
    return None

def create_labels(rpeaks, annotation):
    labels = []
    for rpeak in rpeaks:
        idx = np.where(annotation.sample == rpeak)[0]
        if len(idx) > 0:
            symbol = annotation.symbol[idx[0]]
            label = get_class_from_symbol(symbol)
            if label is not None:
                labels.append(label)
    return np.array(labels)

In [19]:
# SNN definition
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
class SNN(nn.Module):
    def __init__(self, num_inputs=250, num_hidden=128, num_outputs=2, num_steps=25, beta=0.9):
        super().__init__()
        self.num_steps = num_steps
        spike_grad = surrogate.fast_sigmoid(slope=25)

        self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=1, padding=2)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.pool1 = nn.MaxPool1d(kernel_size=2)

        self.conv2 = nn.Conv1d(16, 32, kernel_size=5, stride=1, padding=2)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.pool2 = nn.MaxPool1d(kernel_size=2)

        self.fc1 = nn.Linear(32 * (num_inputs // 4), num_hidden)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif4 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):
        batch_size = x.size(0)
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()

        spk_rec = []
        for _ in range(self.num_steps):
            cur = self.conv1(x)
            spk1, mem1 = self.lif1(self.pool1(cur), mem1)
            cur = self.conv2(spk1)
            spk2, mem2 = self.lif2(self.pool2(cur), mem2)
            cur = spk2.view(batch_size, -1)
            cur = self.fc1(cur)
            spk3, mem3 = self.lif3(cur, mem3)
            cur = self.fc2(spk3)
            spk4, mem4 = self.lif4(cur, mem4)
            spk_rec.append(spk4)

        return torch.stack(spk_rec, dim=0).sum(dim=0)

In [20]:
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix # Import the required metrics
import os
import seaborn as sns

# Training
def train_model(X_train, y_train, X_val, y_val, X_test, y_test, batch_size=64, num_epochs=10, device='cuda'):
    model = SNN(num_inputs=X_train.shape[1]).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    X_train_tensor = torch.FloatTensor(X_train).permute(0, 2, 1).to(device)
    y_train_tensor = torch.LongTensor(y_train).to(device)
    X_val_tensor = torch.FloatTensor(X_val).permute(0, 2, 1).to(device)
    y_val_tensor = torch.LongTensor(y_val).to(device)
    X_test_tensor = torch.FloatTensor(X_test).permute(0, 2, 1).to(device)
    y_test_tensor = torch.LongTensor(y_test).to(device)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'test_loss': [], 'test_acc': []
    }
    checkpoint_dir = 'checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_paths = []

    # Training phase: Train for all epochs and save checkpoints
    print("Training phase...")
    for epoch in range(num_epochs):
        model.train()
        running_train_loss = 0
        correct_train = 0
        total_train = 0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = running_train_loss / total_train
        train_acc = correct_train / total_train
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)

        # Save model checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pt')
        torch.save(model.state_dict(), checkpoint_path)
        checkpoint_paths.append(checkpoint_path)

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

    # Plot training metrics
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 1, 1)
    plt.plot(history['train_loss'], label='Train Loss', color='#1f77b4')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()

    plt.subplot(2, 1, 2)
    plt.plot(history['train_acc'], label='Train Accuracy', color='#1f77b4')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig('training_metrics_partial.png')
    plt.close()

    # Validation phase: Evaluate each checkpoint on the validation set
    print("\nEvaluating saved models on validation set...")
    for epoch, checkpoint_path in enumerate(checkpoint_paths):
        model.load_state_dict(torch.load(checkpoint_path))
        model.eval()
        running_val_loss = 0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_loss = running_val_loss / total_val
        val_acc = correct_val / total_val
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        print(f"Validation Epoch {epoch+1}/{num_epochs}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Test phase: Evaluate each checkpoint on the test set
    print("\nEvaluating saved models on test set...")
    for epoch, checkpoint_path in enumerate(checkpoint_paths):
        model.load_state_dict(torch.load(checkpoint_path))
        model.eval()
        running_test_loss = 0
        correct_test = 0
        total_test = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_test_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total_test += labels.size(0)
                correct_test += (predicted == labels).sum().item()

        test_loss = running_test_loss / total_test
        test_acc = correct_test / total_test
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        print(f"Test Epoch {epoch+1}/{num_epochs}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

    # Plot combined metrics
    plot_metrics(history)

    # Clean up checkpoints
    for checkpoint_path in checkpoint_paths:
        os.remove(checkpoint_path)
    os.rmdir(checkpoint_dir)

    return model, history

def evaluate_model(model, X_val, y_val, X_test, y_test, device='cuda'):
    def compute_metrics(X, y, dataset_name):
        X_tensor = torch.FloatTensor(X).permute(0, 2, 1).to(device)
        y_tensor = torch.LongTensor(y).to(device)
        model.eval()
        with torch.no_grad():
            outputs = model(X_tensor)
            loss = nn.CrossEntropyLoss()(outputs, y_tensor).item()
            _, predicted = torch.max(outputs, 1)
            y_np = y_tensor.cpu().numpy()
            predicted_np = predicted.cpu().numpy()

            accuracy = (predicted == y_tensor).float().mean().item()
            precision = precision_score(y_np, predicted_np, average='binary')
            recall = recall_score(y_np, predicted_np, average='binary')
            f1 = f1_score(y_np, predicted_np, average='binary')
            cm = confusion_matrix(y_np, predicted_np)

            print(f"\n{dataset_name} Metrics:")
            print(f"  Loss: {loss:.4f}")
            print(f"  Accuracy: {accuracy:.4f}")
            print(f"  Precision: {precision:.4f}")
            print(f"  Recall: {recall:.4f}")
            print(f"  F1-Score: {f1:.4f}")
            print(f"  Confusion Matrix:")
            print(f"    True Negative (Normal correct): {cm[0,0]}")
            print(f"    False Positive (Normal as Non-Normal): {cm[0,1]}")
            print(f"    False Negative (Non-Normal as Normal): {cm[1,0]}")
            print(f"    True Positive (Non-Normal correct): {cm[1,1]}")

            # Plot precision, recall, F1-score
            metrics = {'Precision': precision, 'Recall': recall, 'F1-Score': f1}
            plt.figure(figsize=(8, 6))
            plt.bar(metrics.keys(), metrics.values(), color=['#1f77b4', '#ff7f0e', '#2ca02c'])
            plt.ylim(0, 1)
            plt.xlabel('Metrics')
            plt.ylabel('Score')
            plt.title(f'{dataset_name} Metrics: Precision, Recall, F1-Score')
            for i, v in enumerate(metrics.values()):
                plt.text(i, v + 0.02, f'{v:.4f}', ha='center')
            plt.savefig(f'{dataset_name.lower()}_metrics.png')
            plt.close()

            # Plot confusion matrix heatmap
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                        xticklabels=['Normal', 'Non-Normal'],
                        yticklabels=['Normal', 'Non-Normal'])
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title(f'{dataset_name} Confusion Matrix')
            plt.savefig(f'{dataset_name.lower()}_confusion_matrix.png')
            plt.close()

            # Matplotlib fallback for confusion matrix if seaborn is unavailable:
            """
            plt.figure(figsize=(8, 6))
            plt.imshow(cm, interpolation='nearest', cmap='Blues')
            plt.colorbar()
            plt.xticks([0, 1], ['Normal', 'Non-Normal'])
            plt.yticks([0, 1], ['Normal', 'Non-Normal'])
            for i in range(cm.shape[0]):
                for j in range(cm.shape[1]):
                    plt.text(j, i, cm[i, j], ha='center', va='center')
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title(f'{dataset_name} Confusion Matrix')
            plt.savefig(f'{dataset_name.lower()}_confusion_matrix.png')
            plt.close()
            """

    # Compute metrics for validation and test sets
    compute_metrics(X_val, y_val, 'Validation')
    compute_metrics(X_test, y_test, 'Test')

def plot_metrics(history):
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 1, 1)
    plt.plot(history['train_loss'], label='Train Loss', color='#1f77b4')
    plt.plot(history['val_loss'], label='Validation Loss', color='#ff7f0e')
    plt.plot(history['test_loss'], label='Test Loss', color='#2ca02c')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training, Validation, and Test Loss')
    plt.legend()

    plt.subplot(2, 1, 2)
    plt.plot(history['train_acc'], label='Train Accuracy', color='#1f77b4')
    plt.plot(history['val_acc'], label='Validation Accuracy', color='#ff7f0e')
    plt.plot(history['test_acc'], label='Test Accuracy', color='#2ca02c')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training, Validation, and Test Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig('metrics.png')
    plt.close()

In [21]:
def process_record(record_id, data_dir):
    signal, rpeaks, fs, ann = load_ecg(record_id, data_dir)
    print(f"Record {record_id}: Total annotations: {len(ann.sample)}")

    signal_filtered = bandpass_filter(signal, fs)
    signal_filtered = notch_filter(signal_filtered, fs)
    signal_filtered = remove_baseline(signal_filtered, fs)

    beats, valid_rpeaks = extract_heartbeats(signal_filtered, fs, ann.sample)
    print(f"Extracted {len(beats)} valid beats")

    if len(beats) == 0:
        print(f"No valid beats extracted for record {record_id}.")
        return np.array([]), np.array([])

    beats = normalize_beats(beats)
    labels = create_labels(valid_rpeaks, ann)

    if len(labels) != len(beats):
        print(f"Warning: Number of labels ({len(labels)}) does not match number of beats ({len(beats)}) for record {record_id}.")
        labeled_beats = []
        labeled_valid_rpeaks = []
        labeled_labels = []
        for i, rpeak in enumerate(valid_rpeaks):
            idx = np.where(ann.sample == rpeak)[0]
            if len(idx) > 0:
                symbol = ann.symbol[idx[0]]
                label = get_class_from_symbol(symbol)
                if label is not None:
                    labeled_beats.append(beats[i])
                    labeled_valid_rpeaks.append(rpeak)
                    labeled_labels.append(label)
        beats = np.array(labeled_beats)
        valid_rpeaks = np.array(labeled_valid_rpeaks)
        labels = np.array(labeled_labels)
        print(f"After filtering for labels: {len(beats)} beats, {len(labels)} labels.")

    if len(beats) == 0:
        print(f"No beats with valid labels extracted for record {record_id}.")
        return np.array([]), np.array([])

    beats = beats.reshape(-1, beats.shape[1], 1)
    return beats, labels

def load_dataset(record_ids, data_dir, balance_training=False):
    all_beats = []
    all_labels = []
    for record_id in record_ids:
        X, y = process_record(str(record_id), data_dir)
        if X.shape[0] > 0:
            all_beats.append(X)
            all_labels.append(y)
        else:
            print(f"Skipping record {record_id} due to processing issues or no valid beats.")

    if all_beats:
        X_all = np.concatenate(all_beats, axis=0)
        y_all = np.concatenate(all_labels, axis=0)
        print(f"Loaded {len(record_ids)} records: total samples = {X_all.shape[0]}")

        if balance_training:
            beats_flat = X_all.reshape(X_all.shape[0], -1)
            unique_classes = np.unique(y_all)
            if len(unique_classes) > 1:
                try:
                    smote = SMOTE(random_state=42)
                    X_balanced, y_balanced = smote.fit_resample(beats_flat, y_all)
                    print(f"Balanced training set. Original: {len(X_all)}, Balanced: {len(X_balanced)}")
                    X_balanced = X_balanced.reshape(-1, X_all.shape[1], 1)
                    return X_balanced, y_balanced
                except ValueError as e:
                    print(f"Could not balance training set due to error: {e}")
                    print("Using original unbalanced training data.")
                    return X_all, y_all
            else:
                print(f"Only one class ({unique_classes[0]}) in training set, skipping balancing.")
                return X_all, y_all
        else:
            return X_all, y_all
    else:
        print(f"No valid data loaded from {len(record_ids)} records.")
        return np.array([]), np.array([])


# Define datasets
DS1_train = [101, 106, 108, 109, 112, 114, 115, 116, 118, 119, 122, 124, 201, 203, 205]
DS1_val = [207, 208, 209, 215, 220]
DS2 = [100, 103, 105, 111, 113, 117, 121, 123, 200, 202, 210, 212, 213, 214, 219, 221, 222, 228, 231, 232, 233, 234]

# Run the pipeline
data_dir = '/content/drive/MyDrive/ecg_snn_project/data/mitdb'

if not os.path.exists(os.path.join(data_dir, '100.dat')):
    print(f"Data directory or files not found. Please ensure {data_dir} is correct and contains MIT-BIH files.")
else:
    X_train, y_train = load_dataset(DS1_train, data_dir, balance_training=True)
    X_val, y_val = load_dataset(DS1_val, data_dir, balance_training=False)
    X_test, y_test = load_dataset(DS2, data_dir, balance_training=False)

    if X_train.shape[0] > 0 and X_val.shape[0] > 0 and X_test.shape[0] > 0:
        print(f"Training samples: {X_train.shape[0]}, Validation samples: {X_val.shape[0]}, Test samples: {X_test.shape[0]}")
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        model, history = train_model(X_train, y_train, X_val, y_val, X_test, y_test, batch_size=64, num_epochs=10, device=device)
        evaluate_model(model, X_val, y_val, X_test, y_test, device=device)
    else:
        print("Insufficient data loaded for training, validation, or testing. Cannot proceed with model training.")

Record 101: Total annotations: 1874
Extracted 1872 valid beats
After filtering for labels: 1862 beats, 1862 labels.
Record 106: Total annotations: 2098
Extracted 2097 valid beats
After filtering for labels: 2027 beats, 2027 labels.
Record 108: Total annotations: 1824
Extracted 1822 valid beats
After filtering for labels: 1762 beats, 1762 labels.
Record 109: Total annotations: 2535
Extracted 2532 valid beats
After filtering for labels: 2530 beats, 2530 labels.
Record 112: Total annotations: 2550
Extracted 2547 valid beats
After filtering for labels: 2537 beats, 2537 labels.
Record 114: Total annotations: 1890
Extracted 1889 valid beats
After filtering for labels: 1879 beats, 1879 labels.
Record 115: Total annotations: 1962
Extracted 1960 valid beats
After filtering for labels: 1952 beats, 1952 labels.
Record 116: Total annotations: 2421
Extracted 2420 valid beats
After filtering for labels: 2411 beats, 2411 labels.
Record 118: Total annotations: 2301
Extracted 2299 valid beats
After fil