# PhysioNet 2021 Challenge

The training data contains twelve-lead ECGs. The validation and test data contains twelve-lead, six-lead, four-lead, three-lead, and two-lead ECGs:

1. Twelve leads: I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6
2. Six leads: I, II, III, aVR, aVL, aVF
3. Four leads: I, II, III, V2
4. Three leads: I, II, V2
5. Two leads: I, II

Each ECG recording has one or more labels that describe cardiac abnormalities (and/or a normal sinus rhythm).

The Challenge data include annotated twelve-lead ECG recordings from six sources in four countries across three continents. These databases include over 100,000 twelve-lead ECG recordings with over 88,000 ECGs shared publicly as training data.

For example, a header file A0001.hea may have the following contents:

```
    A0001 12 500 7500
    A0001.mat 16+24 1000/mV 16 0 28 -1716 0 I
    A0001.mat 16+24 1000/mV 16 0 7 2029 0 II
    A0001.mat 16+24 1000/mV 16 0 -21 3745 0 III
    A0001.mat 16+24 1000/mV 16 0 -17 3680 0 aVR
    A0001.mat 16+24 1000/mV 16 0 24 -2664 0 aVL
    A0001.mat 16+24 1000/mV 16 0 -7 -1499 0 aVF
    A0001.mat 16+24 1000/mV 16 0 -290 390 0 V1
    A0001.mat 16+24 1000/mV 16 0 -204 157 0 V2
    A0001.mat 16+24 1000/mV 16 0 -96 -2555 0 V3
    A0001.mat 16+24 1000/mV 16 0 -112 49 0 V4
    A0001.mat 16+24 1000/mV 16 0 -596 -321 0 V5
    A0001.mat 16+24 1000/mV 16 0 -16 -3112 0 V6
    #Age: 74
    #Sex: Male
    #Dx: 426783006
    #Rx: Unknown
    #Hx: Unknown
    #Sx: Unknown
```

From the first line of the file:
- We see that the recording number is A0001, and the recording file is A0001.mat. 
- The recording has 12 leads, each recorded at a 500 Hz sampling frequency, and contains 7500 samples. 
- From the next 12 lines of the file (one for each lead), we see that each signal:
    - Was written at 16 bits with an offset of 24 bits
    - The floating point number (analog-to-digital converter (ADC) units per physical unit) is 1000/mV 
    - The resolution of the analog-to-digital converter (ADC) used to digitize the signal is 16 bits, and the baseline value corresponding to 0 physical units is 0. 
    - The first value of the signal (-1716, etc.), the checksum (0, etc.), and the lead name (I, etc.) are the last three entries of each of these lines. 
- From the final 6 lines, we see that the patient is:
    - A 74-year-old male 
    - With a diagnosis (Dx) of 426783006, which is the **SNOMED-CT code** for sinus rhythm. 
    - The medical prescription (Rx), history (Hx), and symptom or surgery (Sx) are unknown. 

- Please visit WFDB header format for more information on the header file and variables.

In [135]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
from tqdm import tqdm
from scipy.signal import resample
import torch
from transformers import AutoTokenizer, AutoModel
import ast
import scipy.io as sio

In [2]:
sys.path.append('C:/Users/navme/Desktop/ECG_Project/PyFiles')

In [3]:
from helper_functions import *

In [4]:
from dataset import PhysioNetDataset

In [5]:
PhysioNet_PATH = f'C:/Users/navme/Desktop/ECG_Thesis_Local/PhysioNet-2021-Challenge/physionet.org/files/challenge-2021/1.0.3/training'
PhysioNet_PATH

'C:/Users/navme/Desktop/ECG_Thesis_Local/PhysioNet-2021-Challenge/physionet.org/files/challenge-2021/1.0.3/training'

## Train/Val/Test PhysioNet Datasets 

In [6]:
# Train
train_set = PhysioNetDataset(PhysioNet_PATH, train=True)

# Val
val_set = PhysioNetDataset(PhysioNet_PATH, train=False)

In [None]:
# Example of ECG header data
train_set[0][0]

In [None]:
# Example of ECG signal
train_set[0][1]

## Loading Processed Patient Data

### PhysioNet 2021

In [60]:
processed_train_df_path = r'C:\Users\navme\Desktop\ECG_Project\Data\PhysioNet\processed_train_set_records.csv'
processed_val_df_path = r'C:\Users\navme\Desktop\ECG_Project\Data\PhysioNet\processed_val_set_records.csv'

# Fix URL formatting
processed_train_df_path = convert_to_forward_slashes(processed_train_df_path)
processed_val_df_path = convert_to_forward_slashes(processed_val_df_path)

In [61]:
processed_train_df = pd.read_csv(processed_train_df_path)
processed_val_df = pd.read_csv(processed_val_df_path)

In [None]:
processed_train_df.head()

### CODE-15

In [62]:
CODE15_df_path = r'C:\Users\navme\Desktop\ECG_Project\Data\CODE-15\exams.csv'

# Fix URL formatting
CODE15_df_path = convert_to_forward_slashes(CODE15_df_path)

In [63]:
CODE15_df = pd.read_csv(CODE15_df_path)

In [64]:
CODE15_df.head(5)

Unnamed: 0,exam_id,age,is_male,nn_predicted_age,1dAVb,RBBB,LBBB,SB,ST,AF,patient_id,death,timey,normal_ecg,trace_file
0,1169160,38,True,40.160484,False,False,False,False,False,False,523632,False,2.098628,True,exams_part13.hdf5
1,2873686,73,True,67.05944,False,False,False,False,False,False,1724173,False,6.657529,False,exams_part13.hdf5
2,168405,67,True,79.62174,False,False,False,False,False,True,51421,False,4.282188,False,exams_part13.hdf5
3,271011,41,True,69.75026,False,False,False,False,False,False,1737282,False,4.038353,True,exams_part13.hdf5
4,384368,73,True,78.87346,False,False,False,False,False,False,331652,False,3.786298,False,exams_part13.hdf5


## TextEncoder()

Create a class, ```TextEncoder()``` that is used to convert the description of the (dx_modality) diagnosis class into embeddings using the ClinicalBERT model.

- Input should be a concatenated using comma or blank space string of diagnoses/dx_modality per ECG signal.
- Use processed CSV files (dx_modality vs dx_modality, age, etc together)
- Frozen weights (since it's already pretrained)

## PhysioNet Data

### Case 1: dx_modality only

In [71]:
class TextEncoder:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        self.model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

    def encode(self, text_list):
        # Check if text_list is a string representation of a list
        if isinstance(text_list, str):
            text_list = ast.literal_eval(text_list)
        # Convert list of strings to a single string
        text = ', '.join(text_list)
        # Tokenize text
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
        # Get embeddings from ClinicalBERT model
        with torch.no_grad():
            embeddings = self.model(**inputs).last_hidden_state
        # Average the embeddings to get single vector per each input
        embeddings = torch.mean(embeddings, dim=1)
        return embeddings

In [72]:
if isinstance(processed_train_df['dx_modality'][4], str):
    print('yes')
else:
    print('no')

yes


In [None]:
# Example of TextEncoder
encoder = TextEncoder()
embeddings = encoder.encode(processed_train_df['dx_modality'][0])

In [None]:
# Check size of the embeddings
print(embeddings.size())

### Case 2: dx_modality plus age, sex, etc

In [None]:
class TextEncoder:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        self.model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

    def encode(self, series):
        text = f"{series['age']}, {series['sex']}, {series['dx_modality']}"
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
        with torch.no_grad():
            embeddings = self.model(**inputs).last_hidden_state
        embeddings = torch.mean(embeddings, dim=1)
        return embeddings

In [None]:
encoder = TextEncoder()
embeddings = encoder.encode(processed_train_df.iloc[0])

In [None]:
print(embeddings.size())

### CODE-15 Data

## ECGEncoder() 

- Input is ECG signal, output will be embeddings of ECG signal
- This is going to be model in model.py 
- Model weights are updated iteratively
- optimizer = torch.optim.Adam(clip_model.ECGEncoder.parameters())

In [7]:
from model import OneDimCNN

In [73]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class OneDimCNN(nn.Module):
    def __init__(self, num_classes):
        super(OneDimCNN, self).__init__()

        # Layer 1
        self.conv1 = nn.Conv1d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm1d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 2
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm1d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 3
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm1d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Layer 4
        self.conv4 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm1d(256)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.AvgPool1d(kernel_size=2, stride=2)

        # Fully Connected Layer 1
        self.fc1 = nn.Linear(79872, 128)
        self.relu5 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)

        # Fully Connected Layer 2
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        # Layer 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)

        # Layer 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)

        # Layer 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)

        # Layer 4
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu4(x)
        x = self.pool4(x)

        # Flatten the tensor
        x = x.view(x.size(0), -1)
        # print(x.shape)  # Add this line

        # Fully Connected Layer 1
        x = self.fc1(x)
        x = self.relu5(x)
        x = self.dropout1(x)

        # Fully Connected Layer 2
        x = self.fc2(x)

        return x

In [94]:
class ECGEncoder(OneDimCNN):
    def __init__(self, num_classes):
        super(ECGEncoder, self).__init__(num_classes)
        self.fc3 = nn.Linear(126, 768)  # New linear layer

    def encode(self, signal):
        signal = torch.tensor(signal, dtype=torch.float).unsqueeze(0)
        embedding = self.forward(signal)
        return self.fc3(embedding)  # Apply the new linear layer

In [None]:
type(train_set[0][1]['val'])

In [95]:
train_set[0][1]['val']

array([[408.24601882, 408.24601882, 408.24601882, ..., -83.34581329,
        -74.965045  , -63.10339951],
       [-92.07603073, -92.07603073, -92.07603073, ...,  57.20010276,
         54.51591647,  58.88514819],
       [225.08001192, 225.08001192, 225.08001192, ...,  93.39571052,
         97.44912853, 117.96825132]])

In [98]:
# Define the number of classes
num_classes = 126  # Replace with the actual number of classes

# Create an instance of the model
ecg_encoder = ECGEncoder(num_classes)

# Convert the numpy array to a PyTorch tensor
input_data = torch.from_numpy(train_set[60000][1]['val']).float()

# Add an extra dimension for the batch size
input_data = input_data.unsqueeze(0)

# Convert the model's weights to Float
ecg_encoder = ecg_encoder.float()

# Pass the data through the model
output = ecg_encoder(input_data)

print(output)

torch.Size([1, 79872])
tensor([[-1.1187e-01,  2.3097e-01, -1.3781e-01, -1.5903e-01,  3.5233e-01,
          9.1446e-02,  2.3288e-01,  3.6046e-01, -4.6508e-02, -3.3814e-01,
         -6.0480e-01,  1.3641e-01,  2.7077e-01,  2.8557e-02,  4.2393e-01,
         -3.3274e-01,  4.7674e-02, -4.7889e-02, -1.2955e-01, -2.0932e-01,
         -6.5117e-02,  4.4930e-01, -3.0702e-01,  1.4360e-04, -1.7969e-01,
         -6.3248e-02,  1.7239e-01, -1.8230e-02,  7.2395e-02, -3.5511e-01,
          2.4709e-01, -4.2690e-02, -2.2465e-01, -2.5672e-01, -2.0102e-01,
          9.7384e-02,  9.3277e-02,  9.4923e-02, -4.2708e-01,  2.2382e-01,
          1.6069e-01, -1.4382e-01, -1.5387e-01, -1.6791e-01, -1.8936e-01,
         -2.4707e-01, -1.4342e-02, -6.3892e-02, -2.4525e-01,  2.0876e-01,
         -2.9727e-01, -3.7843e-01, -1.7891e-01,  2.3789e-01, -3.1652e-01,
         -3.5186e-01,  1.1949e-01,  2.4470e-01,  3.9233e-01,  3.5625e-01,
          2.4722e-01,  1.1365e-01, -1.3010e-01, -2.0656e-01, -4.8631e-01,
         -2.702

In [102]:
# Convert the model's weights to Float
ecg_encoder = ecg_encoder.float()

# Set the model in evaluation mode
ecg_encoder.eval()

# Pass the data through the model
output = ecg_encoder(input_data)

print(output)

torch.Size([1, 79872])
tensor([[-1.3415e-01,  4.2046e-02, -2.3089e-02, -9.6867e-02,  5.6141e-02,
         -3.6813e-02,  2.3188e-02,  5.6476e-02, -6.9348e-03, -1.4045e-01,
         -1.4861e-01,  5.0855e-02,  1.2150e-01, -3.3249e-02, -5.2528e-02,
         -4.0292e-02, -9.3183e-02,  9.1527e-02, -5.8979e-02, -7.0683e-03,
         -3.0763e-02,  4.0828e-02, -7.4873e-02, -3.9866e-02, -6.1620e-02,
          7.2322e-02, -1.0775e-02, -5.0234e-02,  8.3719e-03, -3.5085e-02,
          5.0136e-03,  7.3070e-02, -3.1885e-02,  5.0698e-02, -4.5149e-02,
         -7.1235e-02,  9.1999e-02,  8.6871e-02, -4.2471e-02,  1.1170e-01,
          6.7944e-02, -3.8681e-02, -8.1415e-02, -6.4703e-02, -7.4312e-02,
          5.1688e-02,  4.9098e-02,  7.3638e-02, -4.9833e-02,  7.8834e-02,
         -6.6390e-02, -2.9900e-02,  4.8231e-03,  7.2578e-02,  7.4812e-03,
         -2.7810e-03,  2.1301e-02,  2.5543e-02,  7.5328e-02,  5.3012e-02,
          5.9585e-02,  2.1722e-02,  5.6241e-02, -6.9900e-02, -1.2700e-01,
         -7.779

## Update TripletLoss() such that:

- positive_instances are where the ECG embedding and dx_modality embedding align (from the same file/reading)
- negative_instances are where these two embeddings do not align
- filter out text embeddings that are the same or equal to the positive_instances

In [66]:
processed_train_df.head(1)

Unnamed: 0,recording_number,recording_file,num_leads,sampling_frequency,num_samples,age,sex,dx,rx,hx,...,lead_10_lead_name,lead_11_file,lead_11_adc_gain,lead_11_units,lead_11_adc_resolution,lead_11_adc_zero,lead_11_initial_value,lead_11_checksum,lead_11_lead_name,dx_modality
0,JS00001,JS00001.mat,12,500,5000,85.0,Male,"['164889003', '59118001', '164934002']",Unknown,Unknown,...,0,JS00001.mat,1000.0,mV,16,0,527,32579,0,"['atrial fibrillation', 'right bundle branch b..."
1,JS00002,JS00002.mat,12,500,5000,59.0,Female,"['426177001', '164934002']",Unknown,Unknown,...,0,JS00002.mat,1000.0,mV,16,0,0,31542,0,"['sinus bradycardia', 't wave abnormal']"


In [68]:
train_set[0][0]

{'recording_number': 'JS00001',
 'recording_file': 'JS00001.mat',
 'num_leads': 12,
 'sampling_frequency': 500,
 'num_samples': 5000,
 'leads_info': [{'file': 'JS00001.mat',
   'adc_gain': 1000.0,
   'units': 'mV',
   'adc_resolution': 16,
   'adc_zero': 0,
   'initial_value': -254,
   'checksum': 21756,
   'lead_name': '0'},
  {'file': 'JS00001.mat',
   'adc_gain': 1000.0,
   'units': 'mV',
   'adc_resolution': 16,
   'adc_zero': 0,
   'initial_value': 264,
   'checksum': -599,
   'lead_name': '0'},
  {'file': 'JS00001.mat',
   'adc_gain': 1000.0,
   'units': 'mV',
   'adc_resolution': 16,
   'adc_zero': 0,
   'initial_value': 517,
   'checksum': -22376,
   'lead_name': '0'},
  {'file': 'JS00001.mat',
   'adc_gain': 1000.0,
   'units': 'mV',
   'adc_resolution': 16,
   'adc_zero': 0,
   'initial_value': -5,
   'checksum': 28232,
   'lead_name': '0'},
  {'file': 'JS00001.mat',
   'adc_gain': 1000.0,
   'units': 'mV',
   'adc_resolution': 16,
   'adc_zero': 0,
   'initial_value': -386,


In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch 

class TripletLoss(torch.nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        positive_distance = F.pairwise_distance(anchor, positive, keepdim=True)
        negative_distance = F.pairwise_distance(anchor, negative, keepdim=True)

        triplet_loss = torch.mean(torch.clamp(positive_distance - negative_distance + self.margin, min=0.0))

        return triplet_loss

Firstly, I have a function that calulcates ClinicalBERT embedding of ECG header, class TextEncoder. 

I use this class to get ClinicalBERT embeddings of the dx_modality column. 

Next, I Have another class, ECGEncoder, which uses moden OneDimCNN to get ECG embeddings from a Physionet object. 

train_set[0][1]['val'] --> RETURNS THE ECG SIGNAL

train_set[0][1] --< RETURSN THE ECG HEADER. 

I need to write a function that does the following: 

- It finds positive_instances where the ECG embedding (train_set[0][1]['val']) and dx_modality (processed_train_df['dx_modality'][0])) align (from the same file/reading)
- negative_instances are where these two embeddings do not align

FINALLY, - filter out text embeddings that are the same or equal to the positive_instances


In [103]:
class InstanceSelector:
    def __init__(self, train_set, processed_train_df, text_encoder, ecg_encoder):
        self.train_set = train_set
        self.processed_train_df = processed_train_df
        self.text_encoder = text_encoder
        self.ecg_encoder = ecg_encoder

    def get_positive_instances(self):
        positive_instances = []
        for i in range(len(self.train_set)):
            # Generate ECG embedding for the current instance in the training set
            ecg_embedding = self.ecg_encoder.encode(self.train_set[i][1]['val'])
            # Generate dx_modality embedding for the current instance in the processed DataFrame
            dx_modality_embedding = self.text_encoder.encode(self.processed_train_df['dx_modality'][i])
            # If the ECG embedding and dx_modality embedding are equal, append them as a positive instance
            if torch.all(torch.eq(ecg_embedding, dx_modality_embedding)):
                positive_instances.append((ecg_embedding, dx_modality_embedding))
        return positive_instances

    def get_negative_instances(self):
        negative_instances = []
        # Get the positive instances
        positive_instances = self.get_positive_instances()
        for i in range(len(self.train_set)):
            # Generate ECG embedding for the current instance in the training set
            ecg_embedding = self.ecg_encoder.encode(self.train_set[i][1]['val'])
            for j in range(len(self.processed_train_df)):
                # Only consider dx_modality embeddings that are not at the same index as the current ECG embedding
                if i != j:
                    # Generate dx_modality embedding for the current instance in the processed DataFrame
                    dx_modality_embedding = self.text_encoder.encode(self.processed_train_df['dx_modality'][j])
                    # If the ECG embedding does not match any of the positive instance embeddings, append it as a negative instance
                    if not any(torch.all(torch.eq(ecg_embedding, pos[1])) for pos in positive_instances):
                        negative_instances.append((ecg_embedding, dx_modality_embedding))
        return negative_instances

In [104]:
text_encoder = TextEncoder()
ecg_encoder = ECGEncoder(num_classes=126)  # Assuming you have this class defined

In [105]:
instance_selector = InstanceSelector(train_set, processed_train_df, text_encoder, ecg_encoder)

In [None]:
positive_instances = instance_selector.get_positive_instances()
negative_instances = instance_selector.get_negative_instances()

## CLIP Model

In [113]:
class CLIPModel(nn.Module):
    def __init__(self, train_set, processed_train_df):
        super(CLIPModel, self).__init__()
        self.text_encoder = TextEncoder()  # Initialize TextEncoder
        self.ecg_encoder = ECGEncoder(num_classes=126)  # Initialize ECGEncoder
        self.instance_selector = InstanceSelector(train_set, processed_train_df, self.text_encoder, self.ecg_encoder)

    def forward(self, ecg_signal, dx_modality):
        ecg_embedding = self.ecg_encoder.encode(ecg_signal)
        dx_modality_embedding = self.text_encoder.encode(dx_modality)
        # Compute similarity between embeddings, e.g., using cosine similarity
        similarity = F.cosine_similarity(ecg_embedding, dx_modality_embedding)
        return similarity

In [126]:
class CLIPModel(nn.Module):
    def __init__(self, train_set, processed_train_df):
        super(CLIPModel, self).__init__()
        self.ecg_encoder = ECGEncoder(num_classes=126)  # Initialize ECGEncoder
        self.text_encoder = TextEncoder()  # Initialize TextEncoder
        self.instance_selector = InstanceSelector(train_set, processed_train_df, self.text_encoder, self.ecg_encoder)

    def forward(self, ecgs, diagnoses):
        ecgs_embeddings = self.ecg_encoder(ecgs)
        diagnoses_embeddings = self.text_encoder.encode(diagnoses)
        positive_instances = self.instance_selector.get_positive_instances()
        negative_instances = self.instance_selector.get_negative_instances()
        # Compute loss based on whether the pair of embeddings is a positive or negative instance
        loss = sum(F.cosine_similarity(ecgs_embeddings[i], diagnoses_embeddings[i]) for i in range(len(ecgs)) if (ecgs_embeddings[i], diagnoses_embeddings[i]) in positive_instances) \
             - sum(F.cosine_similarity(ecgs_embeddings[i], diagnoses_embeddings[i]) for i in range(len(ecgs)) if (ecgs_embeddings[i], diagnoses_embeddings[i]) in negative_instances)
        return loss

In [136]:

class PhysioNetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, train=False):
        self.dataset_path = dataset_path
        self.dataset_path = [path for path in self.dataset_path if "index.html" not in path]
        self.train = train
        self.file_list = os.listdir(dataset_path)
        self._hea_files = []
        self._mat_files = []
        self._indices_files = []
        self._hea_files_path = []
        self._mat_files_path = []

        self.file_PATHS = []  # Directory to main database folders
        self.data_files = []  # Directory to data files

        # Validation Case: PTB Databases only
        if self.train == False:
            validation_datasets = ['ptb', 'ptb-xl']
            for file in os.listdir(dataset_path):
                if file in validation_datasets:
                    file_path = os.path.join(dataset_path, file)
                    file_path = file_path.replace('\\', '/')
                    self.file_PATHS.append(file_path)

        # Training Case: All Databases excluding PTB
        else:
            validation_datasets = ['ptb', 'ptb-xl']
            for file in os.listdir(dataset_path):
                if file not in validation_datasets:
                    file_path = os.path.join(dataset_path, file)
                    file_path = file_path.replace('\\', '/')
                    self.file_PATHS.append(file_path)

        for path in self.file_PATHS:
            if os.path.isdir(path):
                for sub_folder in os.listdir(path):
                    sub_folder_path = os.path.join(path, sub_folder)
                    sub_folder_path = sub_folder_path.replace('\\', '/')
                    
                    # Ignore index.html files
                    if sub_folder_path.endswith('index.html'):
                        self._indices_files.append(sub_folder_path)
                    else:
                        if os.path.isdir(sub_folder_path):
                            for file in os.listdir(sub_folder_path):
                                # Get all .hea files
                                if file.endswith('.hea'):
                                    file_path = os.path.join(sub_folder_path, file)
                                    file_path = file_path.replace('\\', '/')
                                    self._hea_files.append(file_path)
                                    self._hea_files_path.append(file_path)
                                # Get all .mat files
                                elif file.endswith('.mat'):
                                    file_path = os.path.join(sub_folder_path, file)
                                    file_path = file_path.replace('\\', '/')
                                    self._mat_files.append(file_path)
                                    self._mat_files_path.append(file_path)

    def resample_ecg(self, data, old_freq, new_freq=128):
        # Calculate the duration of the signal
        duration = len(data) / old_freq

        # Calculate the number of points in the resampled signal
        num_points = int(np.round(duration * new_freq))

        # Resample the signal
        resampled_data = resample(data, num_points)

        return resampled_data

    def __getitem__(self, index):
        if isinstance(index, slice):
            start, stop, step = index.indices(len(self))
            return [self[i] for i in range(start, stop, step)]
        # 1. Get .hea file
        hea_file_path = self._hea_files[index]
        with open(hea_file_path, 'r') as f:
            lines = f.readlines()
            
        # Parse header information
        # Initialize header information
        header_info = {
            'recording_number': lines[0].split()[0],
            'recording_file': lines[0].split()[0] + '.mat',
            'num_leads': int(lines[0].split()[1]),
            'sampling_frequency': int(lines[0].split()[2]),
            'num_samples': int(lines[0].split()[3]),
            'leads_info': [],
            'age': None,
            'sex': None,
            'dx': None,
            'rx': None,
            'hx': None,
            'sx': None,
        }

        # Parse header information
        for line in lines:
            if line.startswith('# Age:'):
                age_str = line.split(':')[1].strip()
                header_info['age'] = int(age_str) if age_str != 'NaN' else None
            elif line.startswith('# Sex:'):
                header_info['sex'] = line.split(':')[1].strip()
            elif line.startswith('# Dx:'):
                header_info['dx'] = line.split(':')[1].strip().split(',')
            elif line.startswith('# Rx:'):
                header_info['rx'] = line.split(':')[1].strip()
            elif line.startswith('# Hx:'):
                header_info['hx'] = line.split(':')[1].strip()
            elif line.startswith('# Sx:'):
                header_info['sx'] = line.split(':')[1].strip()

        for line in lines[1:header_info['num_leads']+1]:
            adc_gain = line.split()[2].split('/')[0]
            adc_gain = float(adc_gain.replace('(0)', ''))  # Remove '(0)' and convert to float
            lead_info = {
                'file': line.split()[0],
                'adc_gain': adc_gain,
                'units': line.split()[2].split('/')[1],
                'adc_resolution': int(line.split()[3]),
                'adc_zero': int(line.split()[4]),
                'initial_value': int(line.split()[5]),
                'checksum': int(line.split()[6]),
                'lead_name': line.split()[7],
            }
            header_info['leads_info'].append(lead_info)

        # 2. Get .mat file
        twelve_lead_ecg = None
        if index < len(self._mat_files):
            mat_file_path = self._mat_files[index]
            twelve_lead_ecg = sio.loadmat(mat_file_path)
            
            # Resample the ECG to 128 Hz
            for lead in twelve_lead_ecg:
                twelve_lead_ecg[lead] = self.resample_ecg(twelve_lead_ecg[lead], old_freq=header_info['sampling_frequency'])
        else:
            print(f"MAT file for index {index} does not exist.")
        
        return header_info, twelve_lead_ecg

    def plot_record(self, index):
        mat_file_path = self._mat_files[index]
        data = sio.loadmat(mat_file_path)
        fig, axs = plt.subplots(nrows=3, ncols=4, figsize=(20, 15))

        for i, ax in enumerate(axs.flat):
            ax.plot(data['val'][i], linewidth=0.5)
            ax.set_xlabel('Sample')
            ax.set_ylabel('Amplitude')
            ax.set_title(f'Lead {i+1}')

        plt.tight_layout()
        plt.show()

    def __len__(self):
        return len(self._hea_files)

In [137]:
train_set = PhysioNetDataset(PhysioNet_PATH, train=True)

In [138]:
small_train_set = train_set[:100]

In [143]:
small_train_set[-1][0]

{'recording_number': 'JS00104',
 'recording_file': 'JS00104.mat',
 'num_leads': 12,
 'sampling_frequency': 500,
 'num_samples': 5000,
 'leads_info': [{'file': 'JS00104.mat',
   'adc_gain': 1000.0,
   'units': 'mV',
   'adc_resolution': 16,
   'adc_zero': 0,
   'initial_value': -142,
   'checksum': 16634,
   'lead_name': '0'},
  {'file': 'JS00104.mat',
   'adc_gain': 1000.0,
   'units': 'mV',
   'adc_resolution': 16,
   'adc_zero': 0,
   'initial_value': 39,
   'checksum': 12464,
   'lead_name': '0'},
  {'file': 'JS00104.mat',
   'adc_gain': 1000.0,
   'units': 'mV',
   'adc_resolution': 16,
   'adc_zero': 0,
   'initial_value': 181,
   'checksum': -4171,
   'lead_name': '0'},
  {'file': 'JS00104.mat',
   'adc_gain': 1000.0,
   'units': 'mV',
   'adc_resolution': 16,
   'adc_zero': 0,
   'initial_value': 54,
   'checksum': -8466,
   'lead_name': '0'},
  {'file': 'JS00104.mat',
   'adc_gain': 1000.0,
   'units': 'mV',
   'adc_resolution': 16,
   'adc_zero': 0,
   'initial_value': -161,
 

In [145]:
processed_train_df.iloc[99]

recording_number                       JS00104
recording_file                     JS00104.mat
num_leads                                   12
sampling_frequency                         500
num_samples                               5000
                                 ...          
lead_11_adc_zero                             0
lead_11_initial_value                       44
lead_11_checksum                         27373
lead_11_lead_name                            0
dx_modality              ['sinus bradycardia']
Name: 99, Length: 108, dtype: object

In [127]:
# Initialize model, criterion, and optimizer
model = CLIPModel(train_set, processed_train_df)
optimizer = torch.optim.Adam(model.ecg_encoder.parameters())

In [128]:
# Initialize model
model = CLIPModel(train_set, processed_train_df)
optimizer = torch.optim.Adam(model.ecg_encoder.parameters())

In [129]:
from tqdm import tqdm

# Initialize a list to store the loss at each step
losses = []

# Training loop
for epoch in range(2):
    # Add a progress bar for the inner loop
    for i in tqdm(range(len(train_set)), desc=f"Training epoch {epoch+1}/{2}"):
        # Get ECGs and diagnoses from training set
        ecgs = train_set[i][1]['val']
        diagnoses = processed_train_df['dx_modality'][i]

        # Convert ECGs to tensor and add a dimension for batch size
        ecgs = torch.from_numpy(ecgs).float().unsqueeze(0)
        
        # Forward pass
        loss = model(ecgs, diagnoses)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Save the loss to a variable
        losses.append(loss.item())

    print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

Training epoch 1/2:   0%|          | 0/65900 [00:00<?, ?it/s]

torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size([1, 79872])
torch.Size(

Training epoch 1/2:   0%|          | 0/65900 [00:15<?, ?it/s]

torch.Size([1, 79872])





KeyboardInterrupt: 

In [None]:
# Training loop
for epoch in range(num_epochs):
    for i in range(len(train_set)):
        # Get ECG signal and dx_modality from training set
        ecg_signal = train_set[i][1]['val']
        dx_modality = processed_train_df['dx_modality'][i]
        # Get target label from instance selector
        target = 1 if (ecg_signal, dx_modality) in model.instance_selector.get_positive_instances() else 0
        target = torch.tensor([target], dtype=torch.float)

        # Forward pass
        outputs = model(ecg_signal, dx_modality)
        loss = criterion(outputs, target)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))