# Main Thesis Topic: “Zero-shot classification of ECG signals using CLIP-like model”.

**For example: Train on PBT-XL:**

- Text Encoder: ClinicalBERT (trained on diagnoses of ECG signal to obtain corresponding embeddings)
- Image Encoder: 1D-CNN (used to encode ECG signal to obtain signal embeddings)

- Experiment A): Baseline: We can take only the name of the class. For example, take “Myocardial Infarction” as a text. We should exclude some classes from training and after training is completed, the CLIP-like model can be tested on these excluded classes.
    - Next, we get embeddings of text from ClinicalBERT and train the ECG encoder with contrastive loss.

- Experiment B): Same as Experiment A but instead of testing on the same dataset/classes, we would test on other datasets containing different classes.

**Evaluation metrics:**
- Main: AUC-ROC, average_precison_score,
- Optional: Specificity, Sensitivity, F1-score

**Outcome:**
- It’s possible to train CLIP-like models with freezed (or unchanged/not fine tuned for downstream tasks) text encoder
- Training ECG encoders that are viable for representing different domains (within ECG modality) and previously unseen classes.
- Training a CLIP-like model on ECGs has little novelty.

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
import random
from tqdm import tqdm
from scipy.signal import resample
import torch
from transformers import AutoTokenizer, AutoModel
import ast
import scipy.io as sio
from torch.utils.data import random_split
from torch.nn.functional import cosine_similarity

c:\Users\navme\AppData\Local\Programs\Python\Python38\lib\site-packages\numpy\.libs\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll
c:\Users\navme\AppData\Local\Programs\Python\Python38\lib\site-packages\numpy\.libs\libopenblas64__v0.3.21-gcc_10_3_0.dll


In [2]:
sys.path.append('C:/Users/navme/Desktop/ECG_Project/PyFiles')

In [3]:
from helper_functions import *

Using the ```PhysioNet_PATH```, we can create separate datasets for training, testing & validation.

# Stage 1: Data Preprocessing

- train_set (train & validation data)
- test_set (test data)

First, let's load the SNOWMED-CT mappings:

In [5]:
smowmed_mappings_path = convert_to_forward_slashes(r'C:\Users\navme\Desktop\ECG_Project\Data\SNOWMED-CT Codes\combined_mappings.csv')

# Load the SNOMED-CT mappings
smowmed_mappings = pd.read_csv(smowmed_mappings_path)
smowmed_mappings.head(2)

Unnamed: 0,Dx,SNOMEDCTCode,Abbreviation,CPSC,CPSC_Extra,StPetersburg,PTB,PTB_XL,Georgia,Chapman_Shaoxing,Ningbo,Total,Notes
0,atrial fibrillation,164889003,AF,1221,153,2,15,1514,570,1780,0,5255,
1,atrial flutter,164890007,AFL,0,54,0,1,73,186,445,7615,8374,


In [6]:
# Select the 'Dx' and 'SNOMEDCTCode' columns
codes = smowmed_mappings[['Dx', 'SNOMEDCTCode']]

# Set 'SNOWMEDCTCode' as the index
codes.set_index('SNOMEDCTCode', inplace=True)

# Convert the DataFrame into a dictionary
codes_dict = codes['Dx'].to_dict()

In [12]:
list(codes_dict.items())[:5]

[(164889003, 'atrial fibrillation'),
 (164890007, 'atrial flutter'),
 (6374002, 'bundle branch block'),
 (426627000, 'bradycardia'),
 (733534002, 'complete left bundle branch block')]

# Updated PhysioNetDataset Class 

- Update the ```PhysioNetDataset``` class such that instead of header_info --> return header_info['Dx']. Then convert the Dx code to string input for TextEncoder(). 

In [69]:
class PhysioNetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, train=False):
        self.dataset_path = dataset_path
        self.dataset_path = [path for path in self.dataset_path if "index.html" not in path]
        self.train = train
        self.file_list = os.listdir(dataset_path)
        self._hea_files = []
        self._mat_files = []
        self._indices_files = []
        self._hea_files_path = []
        self._mat_files_path = []

        self.file_PATHS = []  # Directory to main database folders
        self.data_files = []  # Directory to data files

        # Validation Case: PTB Databases only
        if self.train == False:
            validation_datasets = ['ptb', 'ptb-xl']
            for file in os.listdir(dataset_path):
                if file in validation_datasets:
                    file_path = os.path.join(dataset_path, file)
                    file_path = file_path.replace('\\', '/')
                    self.file_PATHS.append(file_path)

        # Training Case: All Databases excluding PTB
        else:
            validation_datasets = ['ptb', 'ptb-xl']
            for file in os.listdir(dataset_path):
                if file not in validation_datasets:
                    file_path = os.path.join(dataset_path, file)
                    file_path = file_path.replace('\\', '/')
                    self.file_PATHS.append(file_path)

        for path in self.file_PATHS:
            if os.path.isdir(path):
                for sub_folder in os.listdir(path):
                    sub_folder_path = os.path.join(path, sub_folder)
                    sub_folder_path = sub_folder_path.replace('\\', '/')
                    
                    # Ignore index.html files
                    if sub_folder_path.endswith('index.html'):
                        self._indices_files.append(sub_folder_path)
                    else:
                        if os.path.isdir(sub_folder_path):
                            for file in os.listdir(sub_folder_path):
                                # Get all .hea files
                                if file.endswith('.hea'):
                                    file_path = os.path.join(sub_folder_path, file)
                                    file_path = file_path.replace('\\', '/')
                                    self._hea_files.append(file_path)
                                    self._hea_files_path.append(file_path)
                                # Get all .mat files
                                elif file.endswith('.mat'):
                                    file_path = os.path.join(sub_folder_path, file)
                                    file_path = file_path.replace('\\', '/')
                                    self._mat_files.append(file_path)
                                    self._mat_files_path.append(file_path)

    def resample_ecg(self, data, old_freq, new_freq=128):
        # Calculate the duration of the signal
        duration = len(data) / old_freq

        # Calculate the number of points in the resampled signal
        num_points = int(np.round(duration * new_freq))

        # Resample the signal
        resampled_data = resample(data, num_points)

        return resampled_data

    def __getitem__(self, index):
        if isinstance(index, slice):
            start, stop, step = index.indices(len(self))
            return [self[i] for i in range(start, stop, step)]
        # 1. Get .hea file
        hea_file_path = self._hea_files[index]
        with open(hea_file_path, 'r') as f:
            lines = f.readlines()
            
        # Parse header information
        # Initialize header information
        header_info = {
            'recording_number': lines[0].split()[0],
            'recording_file': lines[0].split()[0] + '.mat',
            'num_leads': int(lines[0].split()[1]),
            'sampling_frequency': int(lines[0].split()[2]),
            'num_samples': int(lines[0].split()[3]),
            'leads_info': [],
            'age': None,
            'sex': None,
            'dx': None,
            'rx': None,
            'hx': None,
            'sx': None,
        }

        # Parse header information
        for line in lines:
            if line.startswith('# Age:'):
                age_str = line.split(':')[1].strip()
                header_info['age'] = int(age_str) if age_str != 'NaN' else None
            elif line.startswith('# Sex:'):
                header_info['sex'] = line.split(':')[1].strip()
            elif line.startswith('# Dx:'):
                dx_codes = line.split(':')[1].strip().split(',')
                dx_modalities = [codes_dict.get(int(code.strip()), code.strip()) for code in dx_codes]
                header_info['dx'] = [codes_dict.get(int(code.strip()), code.strip()) for code in dx_codes]
            elif line.startswith('# Rx:'):
                header_info['rx'] = line.split(':')[1].strip()
            elif line.startswith('# Hx:'):
                header_info['hx'] = line.split(':')[1].strip()
            elif line.startswith('# Sx:'):
                header_info['sx'] = line.split(':')[1].strip()

        for line in lines[1:header_info['num_leads']+1]:
            adc_gain = line.split()[2].split('/')[0]
            adc_gain = float(adc_gain.replace('(0)', ''))  # Remove '(0)' and convert to float
            lead_info = {
                'file': line.split()[0],
                'adc_gain': adc_gain,
                'units': line.split()[2].split('/')[1],
                'adc_resolution': int(line.split()[3]),
                'adc_zero': int(line.split()[4]),
                'initial_value': int(line.split()[5]),
                'checksum': int(line.split()[6]),
                'lead_name': line.split()[7],
            }
            header_info['leads_info'].append(lead_info)

        # 2. Get .mat file
        twelve_lead_ecg = None
        if index < len(self._mat_files):
            mat_file_path = self._mat_files[index]
            twelve_lead_ecg = sio.loadmat(mat_file_path)
            
            # Resample the ECG to 128 Hz
            for lead in twelve_lead_ecg:
                twelve_lead_ecg[lead] = self.resample_ecg(twelve_lead_ecg[lead], old_freq=header_info['sampling_frequency'])
        else:
            print(f"MAT file for index {index} does not exist.")
        
        # Return list of diagnoses and the np array of the 12-lead ECG
        return dx_modalities, twelve_lead_ecg['val']

    def plot_record(self, index):
        mat_file_path = self._mat_files[index]
        data = sio.loadmat(mat_file_path)
        fig, axs = plt.subplots(nrows=3, ncols=4, figsize=(20, 15))

        for i, ax in enumerate(axs.flat):
            ax.plot(data['val'][i], linewidth=0.5)
            ax.set_xlabel('Sample')
            ax.set_ylabel('Amplitude')
            ax.set_title(f'Lead {i+1}')

        plt.tight_layout()
        plt.show()

    def __len__(self):
        return len(self._hea_files)

In [70]:
# Path to training folder within PhysioNet dataset
PhysioNet_PATH = f'C:/Users/navme/Desktop/ECG_Thesis_Local/PhysioNet-2021-Challenge/physionet.org/files/challenge-2021/1.0.3/training'
PhysioNet_PATH

'C:/Users/navme/Desktop/ECG_Thesis_Local/PhysioNet-2021-Challenge/physionet.org/files/challenge-2021/1.0.3/training'

In [71]:
train_set = PhysioNetDataset(PhysioNet_PATH, train=True)
test_set = PhysioNetDataset(PhysioNet_PATH, train=False)
len(train_set), len(test_set)

(65900, 22352)

The ```train_set``` can be split into ```current_train``` (85%) and ```current_val``` (15%).

In [74]:
# Set the seed for the random number generator
torch.manual_seed(0)

# Get the length of the train_set
length = len(train_set)

# Calculate the lengths of the splits
train_length = int(0.85 * length)
val_length = length - train_length

# Split the dataset
current_train, current_val = random_split(train_set, [train_length, val_length])

In [75]:
len(current_train), len(current_val)

(56015, 9885)

In [76]:
# Diagnoses
current_train[0][0]

['sinus tachycardia',
 't wave inversion',
 't wave abnormal',
 'prolonged qt interval']

In [77]:
# 12-lead ECG (np array)
current_train[0][1]

array([[-25.84721708, -25.84721708, -25.84721708, ...,  22.05545073,
         23.6104719 ,  20.12243803],
       [ 15.36559584,  15.36559584,  15.36559584, ..., -13.64070169,
        -14.25074403, -15.14070169],
       [ 29.23162124,  29.23162124,  29.23162124, ..., 107.83525095,
        109.39027212, 110.26826366]])

# Stage 2: ECG Classification Model Pipeline

Now that our data is preprocessed, we can begin working on the Model Pipeline itself. The ECG Classification Model Pipeline will consist of three components:

1. `TextEncoder()` class

2. `ECGEncoder()` class

3. `InstanceSelector()` class

4. `CLIPModel()` class

An overview and outline of each of these components can be found below in their respective subsections.

## TextEncoder()

Create a class, ```TextEncoder()``` that is used to convert the description of the (dx_modality) diagnosis class into embeddings using the ClinicalBERT model.

- Input should be a concatenated using comma or blank space string of diagnoses/dx_modality per ECG signal.
- Use processed CSV files (dx_modality vs dx_modality, age, etc together)
- Frozen weights (since it's already pretrained)

In [None]:
class TextEncoder:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        self.model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

    def encode(self, text_list):
        # Check if text_list is a string representation of a list
        if isinstance(text_list, str):
            text_list = ast.literal_eval(text_list)
        # Convert list of strings to a single string
        text = ', '.join(text_list)
        # Tokenize text
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
        # Get embeddings from ClinicalBERT model
        with torch.no_grad():
            embeddings = self.model(**inputs).last_hidden_state
        # Average the embeddings to get single vector per each input
        embeddings = torch.mean(embeddings, dim=1)
        return embeddings