In [None]:
!pip install -q wfdb
!pip install -q transformers

In [None]:
import wfdb
import numpy as np
import pandas as pd
import scipy as sp
from scipy import io as sio
from scipy import signal as sps
import matplotlib.pyplot as plt


from bisect import bisect
from collections import defaultdict
import pickle
import json

# 1. Feature Engineering
The dataset is open-source and can be downloaded from PhysioNet: https://physionet.org/content/mitdb/1.0.0/. 

This pre-processing method is inspired by this paper (https://arxiv.org/abs/2207.07089) that uses wfdb library, a library specifically for medical wave data (https://wfdb.readthedocs.io/en/latest/wfdb.html) $\textbf{UNLIKE the paper above, our uniqueness is to not store data but read in RAM so we save time}$. For our project, We tailor codes to get data that contains three columns. 

- patient_id
- patient's heart beat
- doctor's annotation on this heart beat (normal or abnormal)

## 1.1 Data Read-in and Visualization

In [None]:
# patient_ids = pd.read_csv(RECORDS, delimiter="\n", header=None).to_numpy().reshape(-1)
patient_ids = [100,101,102,103,104,105,106,107,108,109,111,112,113,114,115,116,117,118,119,121,122,123,124,200,201,202,203,205,207,208,209,210,212,213,214,215,217,219,220,221,222,223,228,230,231,232,233,234]

### Read the ECG signals: every id has two ECG signals denoting two types of heart rate. 

In [None]:
def get_ecg_signals(patient_ids):
    lead0 = {}
    lead1 = {}
    for id_ in patient_ids:
        signals, info = wfdb.io.rdsamp(str(id_), pn_dir='mitdb')
        lead0[id_] = signals[:, 0]
        lead1[id_] = signals[:, 1]
    return lead0, lead1

In [None]:
lead0, lead1 = get_ecg_signals(patient_ids)

In [None]:
def get_ecg_info(patient_ids):
    _, info = wfdb.io.rdsamp( str(patient_ids[0]), pn_dir='mitdb')
    resolution = 2**11  # Number of possible signal values we can have.
    info["resolution"] = 2**11
    return info

In [None]:
ecg_info = get_ecg_info(patient_ids)
ecg_info

### Read in Annotations dataset. 

Every record in annotation denotes whether this heart rate is Normal (marked with N) or not. 

In [None]:
def get_paced_patients(patient_ids):
    paced = []
    for id_ in patient_ids:
        annotation = wfdb.rdann(str(id_), pn_dir='mitdb', extension='atr')
        labels = np.unique(annotation.symbol)
        if ("/" in labels):
            paced.append(id_)
    return np.array(paced)

In [None]:
paced_patients = get_paced_patients(patient_ids)
paced_patients

In [None]:
def get_all_beat_labels(patient_ids):
    all_labels = []
    for id_ in patient_ids:
        annotation = wfdb.rdann(str(id_), pn_dir='mitdb', extension='atr')
        labels = np.unique(annotation.symbol)
        all_labels.extend(labels)
    return np.unique(all_labels)

In [None]:
all_beat_labels = get_all_beat_labels(patient_ids)
all_beat_labels

In [None]:
def get_rpeaks_and_labels(patient_ids):
    rpeaks = {}
    labels = {}
    for id_ in patient_ids:
        annotation = wfdb.rdann(str(id_), pn_dir='mitdb', extension='atr')
        rpeaks[id_] = annotation.sample
        labels[id_] = np.array(annotation.symbol)
    return rpeaks, labels

In [None]:
rpeaks, labels = get_rpeaks_and_labels(patient_ids)

### Visualize to see heart rate and heart beat.

In [None]:
patient_id = 203
secs = 20
samps = secs * ecg_info["fs"]
upto = bisect(rpeaks[patient_id], samps)
signal, peaks = lead1[patient_id], rpeaks[patient_id]

plt.figure(figsize=(15, 8), dpi=100)
plt.plot(signal[:samps], linewidth=0.1)
plt.plot(peaks[:upto]-2, signal[peaks[:upto]-2], marker="x", linestyle="")
plt.show()

## 1.2 Turn ECG heart rate (signal) to heart beat to match annotation

### Separate ECG signal into its beats and annotate them.
- **Beat classes are different from beat labels. Classes are only one of N, S, V, F, Q.**
https://physionet.org/physiobank/database/html/mitdbdir/intro.htm
https://archive.physionet.org/physiobank/database/html/mitdbdir/tables.htm


In [None]:
def get_normal_beat_labels():
    """
    The MIT-BIH labels that are classified as healthy/normal. Check wfdb.Annotation documentation for description of labels.
    N: {N, L, R, e, j}. 
    """
    return np.array(["N", "L", "R", "e", "j"])

def get_abnormal_beat_labels():
    """
    The MIT-BIH labels that are classified as arrhythmia/abnormal. Check wfdb.Annotation documentation for description of labels.
    S: {S, A, J, a} - V: {V, E} - F: {F} - Q: {Q}
    """
    return np.array(["S", "A", "J", "a", "V", "E", "F", "Q"])

def get_beat_class(label):
    """
    A mapping from labels to classes, based on the rules described in get_normal_beat_labels() and get_abnormal_beat_labels().
    """
    if label in ["N", "L", "R", "e", "j"]:
        return "N"
    elif label in ["S", "A", "J", "a"]:
        return "S"
    elif label in ["V", "E"]:
        return "V"
    elif label == "F" or label == "Q":
        return label
    return None

In [None]:
def get_beats(patient_ids, signals, rpeaks, labels, beat_trio=False, centered=False, lr_offset=0.1, matlab=False):
    """
    For each patient:
    Converts its ECG signal to an array of valid beats, where each rpeak with a valid label is converted to a beat of length 128 by resampling (Fourier-Domain).
    Converts its labels to an array of valid labels, and a valid label is defined in the functions get_normal_beat_labels() and get_abnormal_beat_labels().
    Converts its valid labels to an array of classes, where each valid label is one of 5 classes, (N, S, V, F, Q).

    Parameters
    ----------
    beat_trio: bool, default=False
        If True, generate beats as trios.

    centered: bool, default=False
        Whether the generated beats have their peaks centered.

    lr_offset: float, default=0.1, range=[0, 1]
        A beat is extracted by finding the beats before and after it, and then offsetting by some samples. This parameter controls how many samples are
        offsetted. If the lower beat is L, and the current beat is C, then we offset by `lr_offset * abs(L - C)` samples.

    matlab: bool, default=False
        If True, dictionary keys become strings to be able to save the dictionary as a .mat file.
    """
    
    beat_length = 128
    get_key_name = lambda patient_id: f"patient_{patient_id}" if matlab else patient_id
    
    beat_data = {get_key_name(patient_id):{"beats":[], "class":[], "label":[]} for patient_id in patient_ids}
    
    for j, patient_id in enumerate(patient_ids):
        key_name = get_key_name(patient_id)
        
        # Filter out rpeaks that do not correspond to a valid label.
        valid_labels = np.concatenate((get_normal_beat_labels(), get_abnormal_beat_labels()))
        valid_idx = np.where(np.isin(labels[patient_id], valid_labels))[0]
        valid_rpeaks = rpeaks[patient_id][valid_idx]
        valid_labels = labels[patient_id][valid_idx]
        
        for i in range(1, len(valid_rpeaks) - 1):
            lpeak = valid_rpeaks[i - 1]
            cpeak = valid_rpeaks[i]
            upeak = valid_rpeaks[i + 1]
    
            if beat_trio:
                lpeak = int(lpeak - (lr_offset * abs(cpeak - lpeak)))
                upeak = int(upeak + (lr_offset * abs(cpeak - upeak)))
            else:
                lpeak = int(lpeak + (lr_offset * abs(cpeak - lpeak)))
                upeak = int(upeak - (lr_offset * abs(cpeak - upeak)))
            
            if centered:
                ldiff = abs(lpeak - cpeak)
                udiff = abs(upeak - cpeak)
                diff = min(ldiff, udiff)
                
                # Take same number of samples from the center.
                beat = signals[patient_id][cpeak - diff:cpeak + diff + 1]
            else:
                beat = signals[patient_id][lpeak:upeak]
            
            # Resampling in the frequency domain instead of in the time domain (resample_poly)
            # beat = sp.signal.resample_poly(beat, beat_length, len(beat))
            beat = sp.signal.resample(beat, beat_length)
    
            # detrend the beat and normalize it.
            beat = sps.detrend(beat)
            beat = beat / np.linalg.norm(beat, ord=2)
        
            label = valid_labels[i]
        
            beat_data[key_name]["beats"].append(beat)
            beat_data[key_name]["class"].append(get_beat_class(label))
            beat_data[key_name]["label"].append(label)
        beat_data[key_name]["beats"] = np.stack(beat_data[key_name]["beats"])
        beat_data[key_name]["class"] = np.stack(beat_data[key_name]["class"])
        beat_data[key_name]["label"] = np.stack(beat_data[key_name]["label"])
        
    return beat_data

In [None]:
beat_data = get_beats(patient_ids, lead0, rpeaks, labels, beat_trio=False, centered=False, lr_offset=0.1)

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(beat_data[232]["beats"][beat_data[232]["class"] == "N"].T, "C2")
plt.show()

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(beat_data[100]["beats"][0:50].T, "C0")
plt.plot(beat_data[203]["beats"][0:50].T, "C1")
plt.plot(beat_data[232]["beats"][0:50].T, "C2")
plt.show()

# 2. Model Training


## 2.1 Pytorch with BERT

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, AdamW
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [None]:
class HeartRateDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        beats = self.data[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            beats,
            add_special_tokens=True,
            max_length=512,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"][0],
            "attention_mask": encoding["attention_mask"][0],
            "label": torch.tensor(int(label)),
        }

In [None]:
class HeartRateClassifier(nn.Module):
    def __init__(self, num_labels):
        super(HeartRateClassifier, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.dropout = nn.Dropout(0.1)
        self.linear = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs["pooler_output"]
        pooled_output = self.dropout(pooled_output)
        logits = self.linear(pooled_output)
        return logits

In [None]:
def train(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss, total_correct = 0, 0
    
    progress_bar = tqdm(dataloader, desc="Training", unit="batch")
    
    for batch in progress_bar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = loss_fn(outputs, labels)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        _, preds = torch.max(outputs, dim=1)
        total_correct += torch.sum(preds == labels).item()
        progress_bar.set_postfix({"Loss": total_loss / (total_correct if total_correct != 0 else 1)})
    
    acc = total_correct / len(dataloader.dataset)
    return total_loss, acc

In [None]:
def evaluate(model, dataloader, loss_fn, device):
    model.eval()
    total_loss, total_correct = 0, 0
    
    progress_bar = tqdm(dataloader, desc="Evaluating", unit="batch")
    
    with torch.no_grad():
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            _, preds = torch.max(outputs, dim=1)
            total_correct += torch.sum(preds == labels).item()
            progress_bar.set_postfix({"Loss": total_loss / (total_correct if total_correct != 0 else 1)})
    
    acc = total_correct / len(dataloader.dataset)
    return total_loss, acc

In [None]:
# Save checkpoint function
def save_checkpoint(model, optimizer, epoch, filename):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

### Load Data

In [None]:
# # Load the dataset from csv file
# data = pd.read_csv('data.csv')
# data.drop('label', axis=1, inplace=True)
# data.rename(columns={'class': 'label'}, inplace=True)
# data['label'] = data['label'].apply(lambda x: 0 if x == 'N' else 1)

# #Split the data into training and testing sets
# train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# #Convert numerical data to strings
# train_beats = [','.join(map(str, np.fromstring(x, sep=','))) for x in train_data['beat'].tolist()]
# train_labels = train_data['label'].tolist()

# test_beats = [','.join(map(str, np.fromstring(x, sep=','))) for x in test_data['beat'].tolist()]
# test_labels = test_data['label'].tolist()

In [None]:
data = pd.DataFrame(columns=["beat", "label","id"])
for i in patient_ids: 
  df = pd.DataFrame(zip(beat_data[i]['beats'], beat_data[i]['class']), columns=["beat", "label"])
  df["id"] = i
  data = pd.concat([data, df], ignore_index=True)

In [None]:
data['label'] = data['label'].apply(lambda x: 0 if x == 'N' else 1)

In [None]:
#Split the data into training and testing sets
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

### Pre-trained with BERT

In [None]:
#Convert numerical data to strings
train_beats = [','.join(map(str, x)) for x in train_data['beat'].tolist()]
train_labels = train_data['label'].tolist()

test_beats = [','.join(map(str, x)) for x in test_data['beat'].tolist()]
test_labels = test_data['label'].tolist()

In [None]:
#Create datasets
train_dataset = HeartRateDataset(train_beats, train_labels)
test_dataset = HeartRateDataset(test_beats, test_labels)

#Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

#Initialize the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HeartRateClassifier(num_labels=2).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)

# Train and evaluate the model with checkpoint saving
epochs = 10
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)


In [None]:
for epoch in range(epochs):
    train_loss, train_acc = train(model, train_loader, optimizer, loss_fn, device)
    eval_loss, eval_acc = evaluate(model, test_loader, loss_fn, device)
    print(f"Epoch {epoch+1}, Training Loss: {train_loss/len(train_loader)}, Training Accuracy: {train_acc:.2f}, Validation Loss: {eval_loss/len(test_loader)}, Validation Accuracy: {eval_acc:.2f}")

    # Save the checkpoint after each epoch
    checkpoint_file = f"{checkpoint_dir}/model_epoch_{epoch+1}.pt"
    save_checkpoint(model, optimizer, epoch, checkpoint_file)

## 2.2 Train with Random Forest

In [None]:
train_beats_numerical = train_data['beat'].to_list()
train_labels = train_data['label'].to_list()

test_beats_numerical = test_data['beat'].to_list()
test_labels = test_data['label'].to_list()

In [None]:
from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(train_beats_numerical, train_labels)
y_pred = rf.predict(test_beats_numerical)

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

print("Accuracy:", accuracy_score(test_labels, y_pred))
print("Precision:", precision_score(test_labels, y_pred, average='macro'))
print("Recall:", recall_score(test_labels, y_pred, average='macro'))
print("F1 score:", f1_score(test_labels, y_pred, average='macro'))