In [1]:
import pandas as pd

# Path to the ConditionNames_SNOMED-CT.csv file
mapping_csv_path = "/workspaces/BME3053C_ECG_Project/ECGTeam_Data/ecg-arrhythmia/ConditionNames_SNOMED-CT.csv"

# Load the CSV into a DataFrame
mapping_df = pd.read_csv(mapping_csv_path)

# Create a dictionary for mapping SNOMED CT codes to condition names
label_to_snomed = dict(zip(mapping_df['Snomed_CT'].astype(str), mapping_df['Full Name']))

# Create a mapping from SNOMED CT codes to numerical labels
snomed_to_label = {snomed: idx for idx, snomed in enumerate(mapping_df['Snomed_CT'].astype(str))}
num_labels = len(snomed_to_label)


In [2]:
# Correct path to the mapping CSV file
mapping_csv_path = "ECGTeam_Data/ecg-arrhythmia/ConditionNames_SNOMED-CT.csv"
mapping_df = pd.read_csv(mapping_csv_path)

# Create dictionary for label to condition mapping
label_to_condition = dict(zip(mapping_df['Acronym Name'], mapping_df['Full Name']))

In [3]:
import os
import wfdb
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

# Seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Bandpass filter
def bandpass_filter(data, lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype="band")
    return lfilter(b, a, data)

# Filter parameters
fs = 360
lowcut = 0.5
highcut = 50

num_labels = len(snomed_to_label)

# Initialize DistilBERT
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=num_labels)

# Load and preprocess ECG signals
folder_path = "WFDBRecords"
signals = []
labels = []
file_limit = 100
file_count = 0

for root, _, files in os.walk(folder_path):
    for file in files:
        if file.endswith(".hea") and file_count < file_limit:
            file_count += 1
            record_name = os.path.splitext(file)[0]
            record_path = os.path.join(root, record_name)
            try:
                record = wfdb.rdrecord(record_path)
                signal = record.p_signal[:, 0]
                filtered_signal = bandpass_filter(signal, lowcut, highcut, fs)
                
                # Convert to string and tokenize
                encoded = tokenizer(
                    " ".join(map(str, filtered_signal)),
                    truncation=True,
                    padding="max_length",
                    max_length=512,
                    return_tensors="pt"
                )
                signals.append(encoded["input_ids"].squeeze(0).tolist())

                if any("Dx" in comment for comment in record.comments):  # Check if 'Dx' exists in comments
                    for comment in record.comments:
                        if "Dx" in comment:
                            diagnoses = comment.split(": ")[1].split(",")  # Split multiple Dx codes by comma
                            mapped_labels = []
                            for diagnosis in diagnoses:
                                diagnosis = diagnosis.strip()  # Remove extra spaces
                                if diagnosis in snomed_to_label:  # Check if diagnosis is in the mapping
                                    mapped_labels.append(snomed_to_label[diagnosis])
                            if mapped_labels:
                                # For simplicity, use the first valid label (or implement multi-label logic here)
                                labels.append(mapped_labels[0])  # Use the first valid label
                            else:
                                labels.append(-1)  # Use -1 if no valid labels are found
                else:
                    labels.append(-1)  # Use -1 for missing labels
            except Exception as e:
                print(f"Error processing {record_name}: {e}")
                
# Filter out invalid labels
valid_indices = [i for i, label in enumerate(labels) if 0 <= label < num_labels]
signals = [signals[i] for i in valid_indices]
labels = [labels[i] for i in valid_indices]

# Debugging: Check if the dataset is empty
if len(signals) == 0 or len(labels) == 0:
    raise ValueError("No valid data found after filtering. Check the label extraction logic.")

# Convert lists to arrays
signals = np.array(signals, dtype=object)
labels = np.array(labels)

# Initialize DistilBERT with the correct number of labels
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=num_labels)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Train/val/test split
X_train_val, X_test, y_train_val, y_test = train_test_split(signals, labels, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.25, random_state=42)

# PyTorch Dataset
class ECGDataset(Dataset):
    def __init__(self, signals, labels):
        self.signals = signals
        self.labels = labels

    def __len__(self):
        return len(self.signals)

    def __getitem__(self, idx):
        input_array = np.array(self.signals[idx], dtype=np.int64)
        input_ids = torch.tensor(input_array, dtype=torch.long)
        attention_mask = (input_ids != tokenizer.pad_token_id).long()
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "label": label
        }

# Create dataset and dataloader
train_dataset = ECGDataset(X_train, y_train)
val_dataset = ECGDataset(X_val, y_val)
test_dataset = ECGDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)
loss_fn = CrossEntropyLoss()

# Training loop
for epoch in range(1):
    model.train()
    total_train_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
        if labels.max() >= num_labels or labels.min() < 0:
            raise ValueError(f"Invalid label in batch: {labels}")
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    model.eval()
    total_val_loss = 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            preds = torch.argmax(outputs.logits, dim=1)
            total_val_loss += loss.item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    print(f"Epoch {epoch + 1}")
    print(f"Train Loss: {total_train_loss / len(train_loader):.4f}")
    print(f"Val Loss: {total_val_loss / len(val_loader):.4f}")
    print("Val Metrics:")
    print(classification_report(all_labels, all_preds))

# Final Evaluation
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print("Test Set Evaluation:")
print(classification_report(all_labels, all_preds))

# Map predictions to SNOMED CT codes and compare with actual Dx values
predicted_snomed = [list(label_to_snomed.values())[pred] for pred in all_preds]
actual_snomed = [list(label_to_snomed.values())[label] for label in all_labels]

# Compare predictions with actual Dx values
correct_matches = 0
for pred, actual in zip(predicted_snomed, actual_snomed):
    if pred == actual:
        correct_matches += 1

# Calculate accuracy of matching SNOMED CT codes
accuracy = correct_matches / len(all_labels) * 100
print(f"SNOMED CT Matching Accuracy: {accuracy:.2f}%")

# Print a few examples of mismatches for debugging
for i, (pred, actual) in enumerate(zip(predicted_snomed, actual_snomed)):
    if pred != actual:
        print(f"Mismatch {i + 1}: Predicted = {pred}, Actual = {actual}")
        if i >= 4:  # Limit to 5 examples
            break

  from .autonotebook import tqdm as notebook_tqdm
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


: 

In [None]:
# Function to process and predict diagnoses for a given subdirectory
def predict_diagnoses(subdirectory, file_limit=100):
    signals = []
    labels = []
    file_names = []

    # Process files in the subdirectory
    for root, _, files in os.walk(subdirectory):
        for file in files:
            if file.endswith(".hea") and len(signals) < file_limit:  # Limit to 250 files
                record_name = os.path.splitext(file)[0]
                record_path = os.path.join(root, record_name)
                try:
                    record = wfdb.rdrecord(record_path)
                    signal = record.p_signal[:, 0]
                    
                    # Bandpass filter (optional, based on your earlier code)
                    fs = 360
                    lowcut = 0.5
                    highcut = 50
                    nyquist = 0.5 * fs
                    low = lowcut / nyquist
                    high = highcut / nyquist
                    from scipy.signal import butter, lfilter
                    b, a = butter(4, [low, high], btype="band")
                    filtered_signal = lfilter(b, a, signal)

                    # Tokenize the signal
                    encoded = tokenizer(
                        " ".join(map(str, filtered_signal)),
                        truncation=True,
                        padding="max_length",
                        max_length=512,
                        return_tensors="pt"
                    )
                    signals.append(encoded["input_ids"].squeeze(0).tolist())
                    file_names.append(record_name)

                    # Extract labels (if available)
                    if any("Dx" in comment for comment in record.comments):
                        for comment in record.comments:
                            if "Dx" in comment:
                                diagnoses = comment.split(": ")[1].split(",")
                                for diagnosis in diagnoses:
                                    diagnosis = diagnosis.strip()
                                    if diagnosis in snomed_to_label:
                                        labels.append(snomed_to_label[diagnosis])
                                    else:
                                        labels.append(-1)  # Unknown label
                    else:
                        labels.append(-1)  # No diagnosis found
                except Exception as e:
                    print(f"Error processing {record_name}: {e}")

    # Convert to tensors
    signals = torch.tensor(signals, dtype=torch.long)
    attention_masks = (signals != tokenizer.pad_token_id).long()

    # Predict using the model
    model.eval()
    predictions = []
    with torch.no_grad():
        for i in range(signals.size(0)):
            input_ids = signals[i].unsqueeze(0).to(device)
            attention_mask = attention_masks[i].unsqueeze(0).to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1).cpu().numpy()
            predictions.extend(preds)

    # Map predictions back to SNOMED CT codes
    predicted_snomed = [list(snomed_to_label.keys())[pred] for pred in predictions]

    # Print classification report if ground truth labels are available
    if all(label != -1 for label in labels):  # Ensure all labels are valid
        print("\nClassification Report:")
        print(classification_report(labels, predictions, target_names=list(label_to_snomed.values())))
        print(f"Accuracy: {accuracy_score(labels, predictions):.2f}")
        print(f"Precision: {precision_score(labels, predictions, average='weighted'):.2f}")
        print(f"Recall: {recall_score(labels, predictions, average='weighted'):.2f}")
        print(f"F1 Score: {f1_score(labels, predictions, average='weighted'):.2f}")
    else:
        print("\nGround truth labels are not available for all files. Only predictions are shown.")

    # Print predictions
    print("\nPredictions:")
    for file_name, snomed in zip(file_names, predicted_snomed):
        print(f"{file_name}: {snomed} ({label_to_snomed[snomed]})")

# User input for subdirectory
subdirectory = input("Enter the subdirectory path containing .hea files: ")
predict_diagnoses(subdirectory, file_limit=250)