# Main Thesis Topic: “Zero-shot classification of ECG signals using CLIP-like model”.

**For example: Train on PBT-XL:**

- Text Encoder: ClinicalBERT (trained on diagnoses of ECG signal to obtain corresponding embeddings)
- Image Encoder: 1D-CNN (used to encode ECG signal to obtain signal embeddings)

- Experiment A): Baseline: We can take only the name of the class. For example, take “Myocardial Infarction” as a text. We should exclude some classes from training and after training is completed, the CLIP-like model can be tested on these excluded classes.
    - Next, we get embeddings of text from ClinicalBERT and train the ECG encoder with contrastive loss.

- Experiment B): Same as Experiment A but instead of testing on the same dataset/classes, we would test on other datasets containing different classes.

**Evaluation metrics:**
- Main: AUC-ROC, average_precison_score,
- Optional: Specificity, Sensitivity, F1-score

**Outcome:**
- It’s possible to train CLIP-like models with freezed (or unchanged/not fine tuned for downstream tasks) text encoder
- Training ECG encoders that are viable for representing different domains (within ECG modality) and previously unseen classes.
- Training a CLIP-like model on ECGs has little novelty.

## Experiment A: PTB & PTB-XL ONLY

-  Baseline: We can take only the name of the class. For example, take “Myocardial Infarction” as a text. We should exclude some classes from training and after training is completed, the CLIP-like model can be tested on these excluded classes.
- Next, we get embeddings of text from ClinicalBERT and train the ECG encoder with contrastive loss.

In [8]:
import os
import sys
import random
import ast
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.signal import resample
import scipy.io as sio
from scipy import signal
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score

import torch
from torch.utils.data import random_split
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torch.nn.functional import cosine_similarity
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [3]:
sys.path.append('C:/Users/navme/Desktop/ECG_Project/PyFiles')

In [4]:
from helper_functions import *

Using the ```PhysioNet_PATH```, we can create separate datasets for training, testing & validation.

# Stage 1: Data Preprocessing

- train_set (train & validation data)
- test_set (test data)

First, let's load the SNOWMED-CT mappings:

In [5]:
smowmed_mappings_path = convert_to_forward_slashes(r'C:\Users\navme\Desktop\ECG_Project\Data\SNOWMED-CT Codes\combined_mappings.csv')

# Load the SNOMED-CT mappings
smowmed_mappings = pd.read_csv(smowmed_mappings_path)
smowmed_mappings.head(2)

Unnamed: 0,Dx,SNOMEDCTCode,Abbreviation,CPSC,CPSC_Extra,StPetersburg,PTB,PTB_XL,Georgia,Chapman_Shaoxing,Ningbo,Total,Notes
0,atrial fibrillation,164889003,AF,1221,153,2,15,1514,570,1780,0,5255,
1,atrial flutter,164890007,AFL,0,54,0,1,73,186,445,7615,8374,


In [6]:
# Select the 'Dx' and 'SNOMEDCTCode' columns
codes = smowmed_mappings[['Dx', 'SNOMEDCTCode']]

# Set 'SNOWMEDCTCode' as the index
codes.set_index('SNOMEDCTCode', inplace=True)

# Convert the DataFrame into a dictionary
codes_dict = codes['Dx'].to_dict()

In [8]:
list(codes_dict.items())[:5]

[(164889003, 'atrial fibrillation'),
 (164890007, 'atrial flutter'),
 (6374002, 'bundle branch block'),
 (426627000, 'bradycardia'),
 (733534002, 'complete left bundle branch block')]

# Updated PhysioNetDataset Class 

- Update the ```PhysioNetDataset``` class such that instead of header_info --> return header_info['Dx']. Then convert the Dx code to string input for TextEncoder(). 

In [7]:
class PhysioNetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, train=False):
        self.dataset_path = dataset_path
        self.dataset_path = [path for path in self.dataset_path if "index.html" not in path]
        self.train = train
        self.file_list = os.listdir(dataset_path)
        self._hea_files = []
        self._mat_files = []
        self._indices_files = []
        self._hea_files_path = []
        self._mat_files_path = []

        self.file_PATHS = []  # Directory to main database folders
        self.data_files = []  # Directory to data files

        # Validation Case: PTB Databases only
        if self.train == False:
            validation_datasets = ['ptb', 'ptb-xl']
            for file in os.listdir(dataset_path):
                if file in validation_datasets:
                    file_path = os.path.join(dataset_path, file)
                    file_path = file_path.replace('\\', '/')
                    self.file_PATHS.append(file_path)

        # Training Case: All Databases excluding PTB
        else:
            validation_datasets = ['ptb', 'ptb-xl']
            for file in os.listdir(dataset_path):
                if file not in validation_datasets:
                    file_path = os.path.join(dataset_path, file)
                    file_path = file_path.replace('\\', '/')
                    self.file_PATHS.append(file_path)

        for path in self.file_PATHS:
            if os.path.isdir(path):
                for sub_folder in os.listdir(path):
                    sub_folder_path = os.path.join(path, sub_folder)
                    sub_folder_path = sub_folder_path.replace('\\', '/')

                    # Ignore index.html files
                    if sub_folder_path.endswith('index.html'):
                        self._indices_files.append(sub_folder_path)
                    else:
                        if os.path.isdir(sub_folder_path):
                            for file in os.listdir(sub_folder_path):
                                # Get all .hea files
                                if file.endswith('.hea'):
                                    file_path = os.path.join(sub_folder_path, file)
                                    file_path = file_path.replace('\\', '/')
                                    self._hea_files.append(file_path)
                                    self._hea_files_path.append(file_path)
                                # Get all .mat files
                                elif file.endswith('.mat'):
                                    file_path = os.path.join(sub_folder_path, file)
                                    file_path = file_path.replace('\\', '/')
                                    self._mat_files.append(file_path)
                                    self._mat_files_path.append(file_path)

    def resample_ecg(self, ecg, new_length=1280):
        # Get the current length of the ECG
        current_length = ecg.shape[1]

        # Resample the ECG
        resampled_ecg = resample(ecg, new_length, axis=1)

        return resampled_ecg

    def __getitem__(self, index):
        if isinstance(index, slice):
            start, stop, step = index.indices(len(self))
            return [self[i] for i in range(start, stop, step)]
        # 1. Get .hea file
        hea_file_path = self._hea_files[index]
        with open(hea_file_path, 'r') as f:
            lines = f.readlines()

        dx_modalities = None
        # Parse header information
        for line in lines:
            if line.startswith('# Dx:'):
                dx_codes = line.split(':')[1].strip().split(',')
                dx_modalities = [codes_dict.get(int(code.strip()), code.strip()) for code in dx_codes]

        # 2. Get .mat file
        twelve_lead_ecg = None
        if index < len(self._mat_files):
            mat_file_path = self._mat_files[index]
            mat_data = sio.loadmat(mat_file_path)

            # Extract the ECG data
            twelve_lead_ecg = mat_data['val']

            # Resample the ECG if it is not None
            twelve_lead_ecg = self.resample_ecg(twelve_lead_ecg)
            
        else:
            print(f"MAT file for index {index} does not exist.")

        # Return list of diagnoses and the np array of the 12-lead ECG
        return dx_modalities, twelve_lead_ecg


    def plot_record(self, index):
        mat_file_path = self._mat_files[index]
        data = sio.loadmat(mat_file_path)
        fig, axs = plt.subplots(nrows=3, ncols=4, figsize=(20, 15))

        for i, ax in enumerate(axs.flat):
            ax.plot(data['val'][i], linewidth=0.5)
            ax.set_xlabel('Sample')
            ax.set_ylabel('Amplitude')
            ax.set_title(f'Lead {i+1}')

        plt.tight_layout()
        plt.show()

    def __len__(self):
        return len(self._hea_files)

In [9]:
# Path to training folder within PhysioNet dataset
PhysioNet_PATH = f'C:/Users/navme/Desktop/ECG_Thesis_Local/PhysioNet-2021-Challenge/physionet.org/files/challenge-2021/1.0.3/training'
PhysioNet_PATH

'C:/Users/navme/Desktop/ECG_Thesis_Local/PhysioNet-2021-Challenge/physionet.org/files/challenge-2021/1.0.3/training'

We have to set `train = False` to load the `PTB` and `PTB-XL` datasets.

In [10]:
expA_dataset = PhysioNetDataset(PhysioNet_PATH, train=False)
len(expA_dataset)  # Should be 22353 records in total

22352

In [11]:
expA_dataset[0][0]

['myocardial infarction']

In [12]:
expA_dataset[0][1]

array([[ -45.93443974, -257.76735841, -204.38884819, ...,  207.06125242,
         229.91846005,  241.06771077],
       [   2.39774601, -258.27448051, -213.23767774, ...,  236.74978869,
         163.10069387,  266.09279668],
       [  48.77130997,   -0.55561737,   -8.93666839, ...,   30.38791226,
         -66.5753722 ,   25.43929824],
       ...,
       [   6.93659186,  136.56480991,   83.44371554, ..., -109.00235325,
         -82.10284084,  -65.32595022],
       [  40.00654719,  231.07221173,  161.53286103, ..., -171.11330299,
        -169.96584404, -174.2212556 ],
       [  18.56012414,  231.41053703,  158.34203376, ..., -215.72758534,
        -208.84775154, -226.52649213]])

### Step 3: Define the classes to exclude

In [13]:
classes_to_exclude = [
    'myocardial infarction',
    'left bundle branch block',
    't wave abnormal',
    'sinus bradycardia',
    'inferior ischaemia'
]

In [14]:
# Create functions to exclude/include records containing certain classes
def exclude_classes(record, classes_to_exclude):
    return not any([cls in record[0] for cls in classes_to_exclude])

def include_classes(record, classes_to_exclude):
    return any([cls in record[0] for cls in classes_to_exclude])

In [15]:
# Apply the function to the dataset
filtered_dataset = [record for record in expA_dataset if exclude_classes(record, classes_to_exclude)]

MAT file for index 22349 does not exist.
MAT file for index 22350 does not exist.
MAT file for index 22351 does not exist.


In [16]:
filtered_dataset[1000][0]

['abnormal QRS',
 'left axis deviation',
 'sinus rhythm',
 'supraventricular premature beats']

The ```train_set``` can be split into ```current_train``` (85%) and ```current_val``` (15%).

In [19]:
# Set the seed for the random number generator
torch.manual_seed(0)

# Get the length of the train_set
length = len(filtered_dataset)

# Calculate the lengths of the splits
train_length = int(0.85 * length)
val_length = length - train_length

# Split the dataset
current_train, current_val = random_split(filtered_dataset, [train_length, val_length])

In [20]:
len(current_train), len(current_val)

(11534, 2036)

In [21]:
assert len(current_train) + len(current_val) == len(filtered_dataset)

In [22]:
# Diagnoses
current_train[0][0]

['sinus rhythm']

In [23]:
# 12-lead ECG (np array)
current_train[0][1]

array([[ -12.75095365,    8.26146432,   25.21036212, ...,  -46.78649712,
         -42.33615862,  -49.77185234],
       [ -36.08955916,  -91.16647942,  -56.88710675, ...,   37.87149406,
          30.50770108,   43.76508402],
       [ -23.33349795,  -99.43040908,  -82.11161079, ...,   84.65291698,
          72.849304  ,   93.53135053],
       ...,
       [ -78.20672582, -123.04287893, -110.69050357, ...,  -11.37945503,
         -20.05333544,   -6.15773454],
       [ -70.24501569,  -98.42570387, -102.21803626, ...,  -22.16972905,
         -28.9483897 ,  -18.20635664],
       [ -60.17367165,  -78.23841155,  -73.83297861, ...,  -33.70700198,
         -36.87081883,  -31.60160604]])

In [24]:
test_dataset = [record for record in expA_dataset if include_classes(record, classes_to_exclude)]
len(test_dataset)

MAT file for index 22349 does not exist.
MAT file for index 22350 does not exist.
MAT file for index 22351 does not exist.


8782

In [25]:
assert len(test_dataset) == len(expA_dataset) - len(filtered_dataset)

# Stage 2: ECG Classification Model Pipeline

Now that our data is preprocessed, we can begin working on the Model Pipeline itself. The ECG Classification Model Pipeline will consist of three components:

1. `TextEncoder()` class

2. `ECGEncoder()` class

3. `InstanceSelector()` class

4. `CLIPModel()` class

An overview and outline of each of these components can be found below in their respective subsections.

In [29]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## TextEncoder()

Create a class, ```TextEncoder()``` that is used to convert the description of the (dx_modality) diagnosis class into embeddings using the ClinicalBERT model.

- Input should be a concatenated using comma or blank space string of diagnoses/dx_modality per ECG signal.

In [27]:
class TextEncoder(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        self.model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(self.device)

    def encode(self, text_list):
        # Check if text_list is a string representation of a list
        if isinstance(text_list, str):
            text_list = ast.literal_eval(text_list)
        # Convert list of strings to a single string
        text = ', '.join(text_list)
        # Tokenize text
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
        # Move inputs to the correct device
        inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()}
        # Get embeddings from ClinicalBERT model
        with torch.no_grad():
            embeddings = self.model(**inputs).last_hidden_state
        # Average the embeddings to get single vector per each input
        embeddings = torch.mean(embeddings, dim=1)
        return embeddings.to(self.device)

In [30]:
text_encoder = TextEncoder(device).to(device)
# Example
embeddings = text_encoder.encode(current_train[0][0])

In [31]:
print(embeddings.size())
print(type(embeddings))

torch.Size([1, 768])
<class 'torch.Tensor'>


## ECGEncoder()

- Input is ECG signal, output will be embeddings of ECG signal
- This is going to be model in model.py
- Model weights are updated iteratively
- optimizer = torch.optim.Adam(clip_model.ECGEncoder.parameters())

**Current Tasks:**

- Change number of layers of fully connected layer 2 to 768 out_channels
- Alternatively look into getting rid of the second fully connected layer. 
- Shifting from 128 --> 768 can cause loss of information. 
- Update all layers after layer 7 as follows:  
    - ```self.conv7 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)```
    - ```self.bn7 = nn.BatchNorm1d(512)```


No. CNN layers: 4
No. CNN kernels: [32, 128, 64, 64] (each layer, accordingly)
CNN kernel size: [7, 5, 7, 5] (each layer, accordingly)
No. dense layers: 2
No. dense layer units: [156, 140] (each layer, accordingly)
Batch size: 128
Learning rate: 0.0092
Activation function: [Gelu, Selu, Selu, Selu] (each layer, accordingly)
Optimizer: Adam
	

I see, you want to extrapolate the parameters for a 4-layer CNN to a 10-layer CNN. Here's how you can do it:

First, let's extrapolate the parameters for the CNN layers:

- No. CNN layers: 10
- No. CNN kernels: [32, 128, 64, 64, 32, 128, 64, 64, 32, 128] (each layer, accordingly)
- CNN kernel size: [7, 5, 7, 5, 7, 5, 7, 5, 7, 5] (each layer, accordingly)
- Activation function: [Gelu, Selu, Selu, Selu, Gelu, Selu, Selu, Selu, Gelu, Selu] (each layer, accordingly)

For the dense layers, you can keep them as they are:

- No. dense layers: 2
- No. dense layer units: [156, 140] (each layer, accordingly)

The batch size, learning rate, and optimizer can also remain the same:

- Batch size: 128
- Learning rate: 0.0092
- Optimizer: Adam

In [19]:
class ECGEncoder(nn.Module):
    def __init__(self):
        super(ECGEncoder, self).__init__()

        # Layer 1
        self.conv1 = nn.Conv1d(in_channels=12, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm1d(16)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 2
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm1d(32)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 3
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm1d(64)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 4
        self.conv4 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm1d(128)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 5
        self.conv5 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm1d(256)
        self.relu5 = nn.ReLU()
        self.pool5 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 6
        self.conv6 = nn.Conv1d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm1d(512)
        self.relu6 = nn.ReLU()
        self.pool6 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 7
        self.conv7 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm1d(512)
        self.relu7 = nn.ReLU()
        self.pool7 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 8
        self.conv8 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.bn8 = nn.BatchNorm1d(512)
        self.relu8 = nn.ReLU()
        self.pool8 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 9
        self.conv9 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.bn9 = nn.BatchNorm1d(512)
        self.relu9 = nn.ReLU()
        self.pool9 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 10
        self.conv10 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1)
        self.bn10 = nn.BatchNorm1d(512)
        self.relu10 = nn.ReLU()
        self.pool10 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Fully Connected Layer 1
        # self.fc1 = nn.Linear(512*4, 768)
        self.relu11 = nn.ReLU()

    def forward(self, x):
        # Layer 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)

        # Layer 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)

        # Layer 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)

        # Layer 4
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu4(x)
        x = self.pool4(x)

        # Layer 5
        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu5(x)
        x = self.pool5(x)

        # Layer 6
        x = self.conv6(x)
        x = self.bn6(x)
        x = self.relu6(x)
        x = self.pool6(x)

        # Layer 7
        x = self.conv7(x)
        x = self.bn7(x)
        x = self.relu7(x)
        x = self.pool7(x)

        # Layer 8
        x = self.conv8(x)
        x = self.bn8(x)
        x = self.relu8(x)
        x = self.pool8(x)

        # Layer 9
        x = self.conv9(x)
        x = self.bn9(x)
        x = self.relu9(x)
        x = self.pool9(x)

        # Layer 10
        x = self.conv10(x)
        x = self.bn10(x)
        x = self.relu10(x)
        x = self.pool10(x)

        # Flatten the output of the convolutional layers
        print(x.size())
        x = x.view(x.size(0), -1)

        # Initialize self.fc1 here, using the size of x
        if not hasattr(self, 'fc1'):
            self.fc1 = nn.Linear(x.size(1), self.embed_dim).to(x.device)
        x = self.fc1(x)
        x = self.relu11(x) 

        return x 

In [40]:
pip install --upgrade torch

Collecting torch
  Obtaining dependency information for torch from https://files.pythonhosted.org/packages/98/b7/699b44593fa0d7373191226cb0c2ac3182b5aa0fb1fbb26a87589f655585/torch-2.1.1-cp38-cp38-win_amd64.whl.metadata
  Downloading torch-2.1.1-cp38-cp38-win_amd64.whl.metadata (26 kB)
Downloading torch-2.1.1-cp38-cp38-win_amd64.whl (192.3 MB)
   ---------------------------------------- 0.0/192.3 MB ? eta -:--:--
   ---------------------------------------- 0.2/192.3 MB 6.7 MB/s eta 0:00:29
   ---------------------------------------- 0.6/192.3 MB 5.9 MB/s eta 0:00:33
   ---------------------------------------- 1.5/192.3 MB 12.0 MB/s eta 0:00:16
    --------------------------------------- 3.2/192.3 MB 18.6 MB/s eta 0:00:11
   - -------------------------------------- 4.9/192.3 MB 20.9 MB/s eta 0:00:09
   - -------------------------------------- 6.5/192.3 MB 23.1 MB/s eta 0:00:09
   - -------------------------------------- 8.2/192.3 MB 24.9 MB/s eta 0:00:08
   -- ---------------------------

ERROR: Could not install packages due to an OSError: [WinError 5] Access is denied: 'C:\\Users\\navme\\AppData\\Local\\Programs\\Python\\Python38\\Lib\\site-packages\\~orch\\lib\\asmjit.dll'
Consider using the `--user` option or check the permissions.


[notice] A new release of pip is available: 23.2.1 -> 23.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [45]:
import torch.nn as nn

class ECGEncoder(nn.Module):
    def __init__(self):
        super(ECGEncoder, self).__init__()

        # Define CNN layers
        self.conv1 = nn.Conv1d(in_channels=12, out_channels=32, kernel_size=7, stride=1, padding=3)
        self.bn1 = nn.BatchNorm1d(32)
        self.act1 = nn.GELU()
        self.pool1 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv1d(in_channels=32, out_channels=128, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm1d(128)
        self.act2 = nn.SELU()
        self.pool2 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=7, stride=1, padding=3)
        self.bn3 = nn.BatchNorm1d(64)
        self.act3 = nn.SELU()
        self.pool3 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=2)
        self.bn4 = nn.BatchNorm1d(64)
        self.act4 = nn.SELU()
        self.pool4 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv5 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3)
        self.bn5 = nn.BatchNorm1d(32)
        self.act5 = nn.GELU()
        self.pool5 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv6 = nn.Conv1d(in_channels=32, out_channels=128, kernel_size=5, stride=1, padding=2)
        self.bn6 = nn.BatchNorm1d(128)
        self.act6 = nn.SELU()
        self.pool6 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv7 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=7, stride=1, padding=3)
        self.bn7 = nn.BatchNorm1d(64)
        self.act7 = nn.SELU()
        self.pool7 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv8 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=2)
        self.bn8 = nn.BatchNorm1d(64)
        self.act8 = nn.SELU()
        self.pool8 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv9 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3)
        self.bn9 = nn.BatchNorm1d(32)
        self.act9 = nn.GELU()
        self.pool9 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv10 = nn.Conv1d(in_channels=32, out_channels=128, kernel_size=5, stride=1, padding=2)
        self.bn10 = nn.BatchNorm1d(128)
        self.act10 = nn.SELU()
        self.pool10 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Define dense layers
        self.fc1 = nn.Linear(128, 156)  # Adjust the input size based on the output of the last CNN layer
        self.act11 = nn.ReLU()
        self.fc2 = nn.Linear(156, 140)
        self.act12 = nn.ReLU()

        # Final layer to match the output size with the text encoder
        self.fc3 = nn.Linear(140, 768)

    def forward(self, x):
        # Pass through CNN layers
        x = self.pool1(self.act1(self.bn1(self.conv1(x))))
        x = self.pool2(self.act2(self.bn2(self.conv2(x))))
        x = self.pool3(self.act3(self.bn3(self.conv3(x))))
        x = self.pool4(self.act4(self.bn4(self.conv4(x))))
        x = self.pool5(self.act5(self.bn5(self.conv5(x))))
        x = self.pool6(self.act6(self.bn6(self.conv6(x))))
        x = self.pool7(self.act7(self.bn7(self.conv7(x))))
        x = self.pool8(self.act8(self.bn8(self.conv8(x))))
        x = self.pool9(self.act9(self.bn9(self.conv9(x))))
        x = self.pool10(self.act10(self.bn10(self.conv10(x))))

        # Flatten the output of the convolutional layers
        x = x.view(x.size(0), -1)

        # Pass through dense layers
        x = self.act11(self.fc1(x))
        x = self.act12(self.fc2(x))

        # Pass through final layer
        x = self.fc3(x)

        return x

In [54]:
# Instantiate the model
model = ECGEncoder()

# Convert the numpy array to a PyTorch tensor
input_data = torch.from_numpy(current_train[0][1]).float()

# Add an extra dimension to the tensor to represent the batch size
input_data = input_data.unsqueeze(0)

# Pass the tensor through the model
output = model(input_data)

# Print the output
print(output)

tensor([[-3.0817e-02, -8.9155e-02, -2.5868e-03, -1.3259e-02,  5.3304e-02,
          3.8886e-02, -6.3233e-02,  5.7398e-02, -3.1007e-02,  4.0455e-02,
         -5.2742e-02,  6.4431e-02, -8.4995e-02, -5.8559e-02,  1.4057e-02,
         -9.4203e-02, -3.3626e-02,  6.7180e-02,  7.8477e-02,  9.4205e-02,
          2.2929e-03,  8.0179e-02,  6.5239e-02,  6.9107e-02,  9.3195e-02,
          4.6036e-02, -6.1541e-02, -8.0510e-02, -8.0613e-02, -1.1387e-04,
         -4.0552e-02,  5.9260e-02,  7.9667e-02, -5.3490e-02,  5.3286e-02,
         -5.8330e-02, -2.4385e-02, -5.9090e-02,  2.1884e-02,  3.4458e-02,
         -2.7244e-02,  5.7623e-02, -6.3806e-02, -9.2269e-02,  7.6795e-02,
         -6.6634e-02, -4.8256e-03,  8.8345e-03, -4.9382e-02,  5.6963e-02,
          5.7996e-02, -2.6941e-02, -5.7656e-02,  7.1037e-02,  4.9748e-02,
         -1.1936e-01,  1.2834e-01, -2.0449e-02, -3.6006e-02, -7.8571e-02,
         -1.6021e-02, -5.9872e-02, -7.3079e-02, -1.2094e-02, -6.4410e-02,
         -3.8267e-02,  5.3923e-02,  1.

In [48]:
# Set the model to evaluation mode
model.eval()

# Pass the data through the model
output = model(input_data)

print(type(output))
print(output.size())

<class 'torch.Tensor'>
torch.Size([1, 768])


In [49]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [50]:
# Count the parameters
num_params = count_parameters(model)
print(f'The model has {num_params} trainable parameters')

The model has 401048 trainable parameters


## CLIPModel

The final component of the Model Pipeline is to create a `ClIPModel` class which takes `TextEncoder` and `ECGEncoder` to train the final model with contrastive loss. 

```
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2
Figure 3. Numpy-like pseudocode for the core of an implementation of CLIP
```

In [24]:
class CLIPModel(nn.Module):
    def __init__(self, text_encoder, ecg_encoder, temperature):
        super(CLIPModel, self).__init__()
        self.text_encoder = text_encoder
        self.ecg_encoder = ecg_encoder
        self.temperature = temperature
        self.W_i = nn.Linear(self.ecg_encoder.out_features, self.ecg_encoder.out_features)
        self.W_t = nn.Linear(self.text_encoder.out_features, self.text_encoder.out_features)

    def forward(self, ecgs, texts):
        # Extract feature representations of each modality
        I_f = self.ecg_encoder(ecgs)  # [n, d_i]
        T_f = torch.stack([self.text_encoder.encode(text).squeeze() for text in texts])  # [n, d_t]

        # Joint multimodal embedding [n, d_e]
        I_e = F.normalize(self.W_i(I_f), dim=1)
        T_e = F.normalize(self.W_t(T_f), dim=1)

        # Scaled pairwise cosine similarities [n, n]
        logits = torch.matmul(I_e, T_e.t()) / self.temperature

        # Symmetric loss function
        labels = torch.arange(len(ecgs)).to(ecgs.device)
        loss_i = F.cross_entropy(logits, labels)
        loss_t = F.cross_entropy(logits.t(), labels)
        loss = (loss_i + loss_t) / 2

        return loss

Let's break down this `CLIPModel` class:

- `__init__` method: This is the constructor for the `CLIPModel` class. It initializes the text and ECG encoders, the temperature for the softmax function, and two linear layers (`W_i` and `W_t`). The linear layers are used to transform the embeddings produced by the encoders to a common embedding space. The `embed_dim` parameter determines the dimensionality of this common space.

- `forward` method: This is the method that is called when you pass data through the model.

    - `I_f = self.ecg_encoder(ecgs)`: This line passes the ECG data through the ECG encoder to get a feature representation (`I_f`).

    - `T_f = torch.stack([self.text_encoder.encode(text).squeeze() for text in texts])`: This line encodes each text in the `texts` list using the text encoder, squeezes the output to remove any unnecessary dimensions, and stacks the results into a tensor (`T_f`).

    - `I_e = F.normalize(self.W_i(I_f), dim=1)` and `T_e = F.normalize(self.W_t(T_f), dim=1)`: These lines pass the feature representations through the linear layers (`W_i` and `W_t`), and then normalize the output. The result (`I_e` and `T_e`) are the embeddings in the common space.

    - `logits = torch.matmul(I_e, T_e.t()) / self.temperature`: This line computes the cosine similarity between the ECG and text embeddings, scales it by the temperature, and stores the result in `logits`.

    - The last few lines compute the cross-entropy loss between the logits and the labels (which are just the indices of the data points). The loss is computed twice: once for the ECG-to-text direction (`loss_i`), and once for the text-to-ECG direction (`loss_t`). The final loss is the average of these two losses.

In summary, `W_i` and `W_t` are used to transform the feature representations produced by the encoders into a common embedding space. The model then computes the similarity between the ECG and text embeddings in this common space, and uses this to compute the loss.

In [32]:
# Instantiate the encoders
text_encoder = TextEncoder()
ecg_encoder = ECGEncoder()

# Instantiate the encoders and the model
CLIP_model = CLIPModel(text_encoder, ecg_encoder, embed_dim=512, temperature=1)

In [None]:
def resample_ecg(ecg, new_length=1280):
    # Get the current length of the ECG
    current_length = ecg.shape[1]

    # Resample the ECG
    resampled_ecg = signal.resample(ecg, new_length, axis=1)

    return resampled_ecg

In [None]:
# Generate 50 random indices
indices = random.sample(range(len(current_train)), 750)

# Extract the samples at these indices
texts = [current_train[i][0] for i in indices]
ecgs = [current_train[i][1] for i in indices]

# Convert the list of numpy arrays to a tensor
ecgs_tensor = torch.stack([torch.from_numpy(ecg).float() for ecg in ecgs])

In [None]:
len(ecgs_tensor), len(ecgs), len(ecgs[0])

In [None]:
current_train[0][1]

In [None]:
# Pass the tensor and texts through the model
loss = CLIP_model(ecgs_tensor, texts)

# Print the loss
print(loss)

# Model Training

## DataLoaders + Setup

In [1]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

c:\Users\navme\AppData\Local\Programs\Python\Python38\lib\site-packages\numpy\.libs\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll
c:\Users\navme\AppData\Local\Programs\Python\Python38\lib\site-packages\numpy\.libs\libopenblas64__v0.3.21-gcc_10_3_0.dll


In [None]:
# Instantiate the encoders
text_encoder = TextEncoder(device).to(torch.device("cuda"))
ecg_encoder = ECGEncoder().to(torch.device("cuda"))

# Instantiate the encoders and the model
CLIP_model = CLIPModel(text_encoder, ecg_encoder, embed_dim=512, temperature=1)

In [None]:
# Parameters
lr = 0.001
weight_decay = 0.01
optimizer = torch.optim.Adam(CLIP_model.parameters(), lr=lr, weight_decay=weight_decay)
num_epochs = 3

## Training CLIP model on ```current_train```

In [None]:
def collate_fn(batch):
    texts = [item[0] for item in batch]
    ecgs = [item[1] for item in batch]
    # Pad the sequences
    ecgs_padded = pad_sequence([torch.from_numpy(ecg) for ecg in ecgs], batch_first=True)
    return texts, ecgs_padded

# In your DataLoader, specify the collate_fn
current_train_loader = DataLoader(current_train, batch_size=64, collate_fn=collate_fn)

In [None]:
def train_model(model, optimizer, num_epochs, device):
    # Move model to the device
    CLIP_model.to(device)

    # Store the losses for each epoch
    losses = []

    for epoch in range(num_epochs):
        CLIP_model.train()  # Set the model to training mode
        running_loss = 0.0
        pbar = tqdm(enumerate(current_train_loader, 0), total=len(current_train_loader), leave=False)
        
        # Then in your training loop
        for i, data in pbar:
            # Extract the samples
            texts = data[0]  # Move texts to the device
            ecgs = data[1].float().to(device)  # ecgs is already a tensor, just move it to the device

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Pass the tensor and texts through the model
            loss = CLIP_model(ecgs, texts)  # Pass the tensor directly

            # Backward + optimize
            loss.backward()
            optimizer.step()

            # Print statistics
            running_loss += loss.item()
            pbar.set_description(f"Epoch {epoch + 1} Loss: {running_loss/(i+1):.4f}")
        epoch_loss = running_loss / len(current_train)
        losses.append(epoch_loss)
        print(f'Epoch {epoch + 1}, Loss: {epoch_loss}')

    return CLIP_model, losses

In [None]:
train_model(CLIP_model, optimizer, num_epochs, device)

## Evaluating CLIP model on ```current_val```

## Testing Trained CLIP model on ```test_set```

In [None]:
# Pseudo Code
def predict_diagnosis(CLIP_model, ecg_sample, all_diagnoses):
    # Generate the ECG embedding
    ecg_embedding = CLIP_model.ecg_encoder(ecg_sample)

    # Generate embeddings for all possible diagnoses
    diagnosis_embeddings = [CLIP_model.text_encoder.encode(diagnosis) for diagnosis in all_diagnoses]

    # Calculate the similarity between the ECG embedding and each diagnosis embedding
    similarities = [torch.dot(ecg_embedding, diagnosis_embedding) for diagnosis_embedding in diagnosis_embeddings]

    # The predicted diagnosis is the one with the highest similarity
    predicted_diagnosis = all_diagnoses[torch.argmax(similarities)]

    return predicted_diagnosis