In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wfdb
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
import spikingjelly.clock_driven as sj
from spikingjelly.clock_driven import functional as sf

In [9]:
class PTBXL(Dataset):
    def __init__(self, csv_path, records_path, sampling_rate=100, split='train'):
        self.data = pd.read_csv(csv_path)
        self.records_path = records_path
        self.sampling_rate = sampling_rate

        # Filter only diagnostic labels
        self.data.scp_codes = self.data.scp_codes.apply(eval) 
        agg_df = pd.read_csv(records_path.replace('records100', 'scp_statements.csv'), index_col=0)
        agg_df = agg_df[agg_df.diagnostic == 1]
        
        def aggregate(y_dic):
            return list(set([agg_df.loc[k].diagnostic_class for k in y_dic if k in agg_df.index]))
        
        self.data['diagnostic_superclass'] = self.data.scp_codes.apply(aggregate)

        # Keep only samples with diagnostic labels
        self.data = self.data[self.data.diagnostic_superclass.map(lambda d: len(d)) > 0]

        # Filter out classes with only one sample
        class_counts = self.data['diagnostic_superclass'].value_counts()
        self.data = self.data[self.data['diagnostic_superclass'].isin(class_counts[class_counts > 1].index)]

        # Now perform the split
        train, test = train_test_split(self.data, test_size=0.2, random_state=42, stratify=self.data['diagnostic_superclass'])
        self.data = train if split == 'train' else test

        self.mlb = MultiLabelBinarizer()
        self.mlb.fit(self.data.diagnostic_superclass)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        record_path = f"{self.records_path}/{row.filename_lr}"
        if 'records100/records100' in record_path:
            record_path = record_path.replace('records100/records100', 'records100')
        if '13000' in record_path:
            return None

        record_path = record_path.replace('hr', 'lr').replace('.hea', '')
        try:
            signal, _ = wfdb.rdsamp(record_path)
        except:
            return None

        # signal: shape (n_samples, n_channels) → transpose to (channels, length)
        signal = signal.T  # (12, length)

        # normalize
        signal = (signal - signal.mean()) / signal.std()

        # reshape to (channels, height, width) for Conv2d
        signal = np.expand_dims(signal, axis=1)  # (12, 1, length)

        label = self.mlb.transform([row.diagnostic_superclass])[0]

        return torch.tensor(signal, dtype=torch.float32), torch.tensor(label, dtype=torch.long)



In [10]:
def collate_fn(batch):
    # Filter out None values from the batch
    batch = list(filter(lambda x: x is not None, batch))

    # Handle case where no valid samples are present in the batch
    if len(batch) == 0:
        return None, None  # Return None if there are no valid samples in the batch

    # Stack the data and labels
    data, targets = zip(*batch)
    
    # Ensure the batch sizes match
    if len(data) != len(targets):
        raise ValueError(f"Input batch size ({len(data)}) doesn't match target batch size ({len(targets)})")

    data = torch.stack(data)
    targets = torch.stack(targets)

    return data, targets


In [11]:
import spikingjelly.clock_driven.layer as sj_layer
import torch.nn.functional as F
class SpikingCAM(nn.Module):
    def __init__(self, in_channels):
        super(SpikingCAM, self).__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
        )

    def forward(self, x):
        return self.attention(x)

class SpikingTwoStageSNN(nn.Module):
    def __init__(self, input_channels=12, num_classes1=2, num_classes2=4):
        super().__init__()
        # Stage 1: normal vs abnormal
        self.conv1_stage1 = sj_layer.ConvBatchNorm2d(input_channels, 32, kernel_size=5, padding=2)
        self.cam1         = SpikingCAM(32)
        self.conv2_stage1 = sj_layer.ConvBatchNorm2d(32, 64, kernel_size=5, padding=2)
        self.fc1          = nn.Linear(64, num_classes1)

        # Stage 2: detailed abnormal classification
        self.conv1_stage2 = sj_layer.ConvBatchNorm2d(input_channels, 32, kernel_size=5, padding=2)
        self.cam2         = SpikingCAM(32)
        self.conv2_stage2 = sj_layer.ConvBatchNorm2d(32, 64, kernel_size=5, padding=2)
        self.fc2          = nn.Linear(64, num_classes2)

    def forward(self, x):
        # x: (batch_size, 12, 1, length)

        # --- Stage 1 ---
        x1 = F.relu(self.conv1_stage1(x))   # → (batch,32,1,length)
        x1 = self.cam1(x1)                  # → same
        x1 = F.relu(self.conv2_stage1(x1))  # → (batch,64,1,length)
        x1 = x1.mean(dim=[2,3])             # global avg → (batch,64)
        out1 = self.fc1(x1)                 # → (batch,2)

        # --- Stage 2 ---
        x2 = F.relu(self.conv1_stage2(x))   # → (batch,32,1,length)
        x2 = self.cam2(x2)
        x2 = F.relu(self.conv2_stage2(x2))  # → (batch,64,1,length)
        x2 = x2.mean(dim=[2,3])             # → (batch,64)
        out2 = self.fc2(x2)                 # → (batch,4)

        return out1, out2



In [12]:
# Example function for accuracy calculation
def calculate_accuracy(predictions, targets):
    _, predicted = torch.max(predictions, 1)
    correct = (predicted == targets).sum().item()
    accuracy = correct / targets.size(0)
    return accuracy

In [20]:
# custom collate_fn to drop None
def collate_fn(batch):
    batch = [x for x in batch if x is not None]
    data, labels = zip(*batch)
    return torch.stack(data), torch.stack(labels)

train_loader = DataLoader(train_dataset, batch_size=16, collate_fn=collate_fn)

for data, labels in train_loader:
    print(data.shape, labels.shape)
    # expect: (16,12,1,1000), (16,<#classes>)
    break

torch.Size([15, 12, 1, 1000]) torch.Size([15, 5])


In [21]:
def train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, patience, device):
    best_accuracy = 0.0
    early_stopping_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_stage1 = 0
        total = 0

        for data, targets in train_loader:
            # drop empty / bad batches
            if data is None or targets is None:
                continue

            # move to device
            data, targets = data.to(device), targets.to(device)

            optimizer.zero_grad()
            # --- forward ---
            out1, out2 = model(data)
            # out1: (batch, 2)  normal vs abnormal
            # out2: (batch, 4)  detailed abnormal

            # --- compute stage1 targets + loss ---
            # Binary target for normal vs abnormal: if sum of multi-hot labels > 0, it's abnormal
            y1 = (targets.sum(dim=1) > 0).long()  # 0 = normal, 1 = abnormal
            loss1 = criterion(out1, y1)

            # track stage1 accuracy
            preds1 = out1.argmax(dim=1)
            correct_stage1 += (preds1 == y1).sum().item()
            total += y1.size(0)

            # --- compute stage2 targets + loss (only on abnormal samples) ---
            # Stage 2: Detailed classification for abnormal cases (only those that are abnormal)
            idx_abn = (y1 == 1).nonzero(as_tuple=True)[0]  # indices of abnormal samples
            if idx_abn.numel() > 0:
                # select only abnormal samples
                targ2_multi = targets[idx_abn]  # shape (n_abn, num_classes)
                # map multi-hot to a single class 0–3: argmax over the abnormal subset
                # (assuming positions 1–4 in `targets` correspond to the 4 abnormal classes)
                y2 = targ2_multi[:, 1:].argmax(dim=1).long()  # Select class 1–4 for abnormal
                loss2 = criterion(out2[idx_abn], y2)
            else:
                loss2 = torch.tensor(0.0, device=device)

            # --- total loss & backward ---
            loss = loss1 + loss2
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # end epoch training
        avg_loss = running_loss / len(train_loader)
        stage1_acc = 100.0 * correct_stage1 / total if total > 0 else 0.0
        print(f"Epoch {epoch+1}/{num_epochs} — Loss: {avg_loss:.4f}, Stage1 Acc: {stage1_acc:.2f}%")

        # --- validation & early stopping based on stage1 acc ---
        val_acc = evaluate_stage1(model, val_loader, device)  # Evaluate stage 1
        if val_acc > best_accuracy:
            best_accuracy = val_acc
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1

        if early_stopping_counter >= patience:
            print("Early stopping triggered.")
            break

    return model

In [16]:
def evaluate_stage1(model, loader, device):
    """
    Compute the binary (normal vs abnormal) accuracy of stage‐1 head on the validation set.
    """
    model.eval()
    correct = 0
    total   = 0

    with torch.no_grad():
        for data, targets in loader:
            # skip any bad batches
            if data is None or targets is None:
                continue

            data, targets = data.to(device), targets.to(device)
            out1, _ = model(data)                 # ignore stage‐2
            # build binary labels: sum>0 ⇒ abnormal(1), else normal(0)
            y1 = (targets.sum(dim=1) > 0).long()   # shape (batch,)
            preds1 = out1.argmax(dim=1)            # shape (batch,)
            
            correct += (preds1 == y1).sum().item()
            total   += y1.size(0)

    return 100.0 * correct / total if total>0 else 0.0


In [18]:
# Define your dataset and dataloaders
train_dataset = PTBXL(csv_path=r'D:\23020407 Dang Minh Nguyet\Seminar\BTL-Seminar\data\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1\ptbxl_database.csv', 
                      records_path=r'D:\23020407 Dang Minh Nguyet\Seminar\BTL-Seminar\data\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1\records100', 
                      split='train')

val_dataset = PTBXL(csv_path=r'D:\23020407 Dang Minh Nguyet\Seminar\BTL-Seminar\data\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1\ptbxl_database.csv', 
                    records_path=r'D:\23020407 Dang Minh Nguyet\Seminar\BTL-Seminar\data\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.1\records100', 
                    split='test')

train_loader = DataLoader(train_dataset, batch_size=16, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, collate_fn=collate_fn)

In [None]:
for data, labels in train_loader:
    print("raw multi-hot labels:", labels[:10])
    y1 = (labels.sum(dim=1) > 0).long()
    print("binary y1 labels:", y1[:10])
    break

In [22]:
# Define the model, optimizer, and loss function
device = torch.device('cpu')  # Use 'cuda' if available
model = SpikingTwoStageSNN()
model.to(device)

# Use CrossEntropyLoss for multi-class classification if applicable
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()  # Use this for multi-class classification

# Training loop with early stopping and accuracy checking
trained_model = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs=50, patience=5, device=device)


Epoch 1/50 — Loss: 0.9598, Stage1 Acc: 99.81%
