In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
import random
from tqdm import tqdm
from scipy.signal import resample
import torch
from transformers import AutoTokenizer, AutoModel
import ast
import scipy.io as sio
from torch.utils.data import random_split
from torch.nn.functional import cosine_similarity

In [None]:
sys.path.append('C:/Users/navme/Desktop/ECG_Project/PyFiles')

In [None]:
from helper_functions import *

# Updated PhysioNetDataset Class 

- Update the ```PhysioNetDataset``` class such that instead of header_info --> return header_info['Dx']. Then convert the Dx code to string input for TextEncoder(). 

In [None]:
class PhysioNetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_path, train=False):
        self.dataset_path = dataset_path
        self.dataset_path = [path for path in self.dataset_path if "index.html" not in path]
        self.train = train
        self.file_list = os.listdir(dataset_path)
        self._hea_files = []
        self._mat_files = []
        self._indices_files = []
        self._hea_files_path = []
        self._mat_files_path = []

        self.file_PATHS = []  # Directory to main database folders
        self.data_files = []  # Directory to data files

        # Validation Case: PTB Databases only
        if self.train == False:
            validation_datasets = ['ptb', 'ptb-xl']
            for file in os.listdir(dataset_path):
                if file in validation_datasets:
                    file_path = os.path.join(dataset_path, file)
                    file_path = file_path.replace('\\', '/')
                    self.file_PATHS.append(file_path)

        # Training Case: All Databases excluding PTB
        else:
            validation_datasets = ['ptb', 'ptb-xl']
            for file in os.listdir(dataset_path):
                if file not in validation_datasets:
                    file_path = os.path.join(dataset_path, file)
                    file_path = file_path.replace('\\', '/')
                    self.file_PATHS.append(file_path)

        for path in self.file_PATHS:
            if os.path.isdir(path):
                for sub_folder in os.listdir(path):
                    sub_folder_path = os.path.join(path, sub_folder)
                    sub_folder_path = sub_folder_path.replace('\\', '/')
                    
                    # Ignore index.html files
                    if sub_folder_path.endswith('index.html'):
                        self._indices_files.append(sub_folder_path)
                    else:
                        if os.path.isdir(sub_folder_path):
                            for file in os.listdir(sub_folder_path):
                                # Get all .hea files
                                if file.endswith('.hea'):
                                    file_path = os.path.join(sub_folder_path, file)
                                    file_path = file_path.replace('\\', '/')
                                    self._hea_files.append(file_path)
                                    self._hea_files_path.append(file_path)
                                # Get all .mat files
                                elif file.endswith('.mat'):
                                    file_path = os.path.join(sub_folder_path, file)
                                    file_path = file_path.replace('\\', '/')
                                    self._mat_files.append(file_path)
                                    self._mat_files_path.append(file_path)

    def resample_ecg(self, data, old_freq, new_freq=128):
        # Calculate the duration of the signal
        duration = len(data) / old_freq

        # Calculate the number of points in the resampled signal
        num_points = int(np.round(duration * new_freq))

        # Resample the signal
        resampled_data = resample(data, num_points)

        return resampled_data

    def __getitem__(self, index):
        if isinstance(index, slice):
            start, stop, step = index.indices(len(self))
            return [self[i] for i in range(start, stop, step)]
        # 1. Get .hea file
        hea_file_path = self._hea_files[index]
        with open(hea_file_path, 'r') as f:
            lines = f.readlines()
            
        # Parse header information
        # Initialize header information
        header_info = {
            'recording_number': lines[0].split()[0],
            'recording_file': lines[0].split()[0] + '.mat',
            'num_leads': int(lines[0].split()[1]),
            'sampling_frequency': int(lines[0].split()[2]),
            'num_samples': int(lines[0].split()[3]),
            'leads_info': [],
            'age': None,
            'sex': None,
            'dx': None,
            'rx': None,
            'hx': None,
            'sx': None,
        }

        # Parse header information
        for line in lines:
            if line.startswith('# Age:'):
                age_str = line.split(':')[1].strip()
                header_info['age'] = int(age_str) if age_str != 'NaN' else None
            elif line.startswith('# Sex:'):
                header_info['sex'] = line.split(':')[1].strip()
            elif line.startswith('# Dx:'):
                header_info['dx'] = line.split(':')[1].strip().split(',')
            elif line.startswith('# Rx:'):
                header_info['rx'] = line.split(':')[1].strip()
            elif line.startswith('# Hx:'):
                header_info['hx'] = line.split(':')[1].strip()
            elif line.startswith('# Sx:'):
                header_info['sx'] = line.split(':')[1].strip()

        for line in lines[1:header_info['num_leads']+1]:
            adc_gain = line.split()[2].split('/')[0]
            adc_gain = float(adc_gain.replace('(0)', ''))  # Remove '(0)' and convert to float
            lead_info = {
                'file': line.split()[0],
                'adc_gain': adc_gain,
                'units': line.split()[2].split('/')[1],
                'adc_resolution': int(line.split()[3]),
                'adc_zero': int(line.split()[4]),
                'initial_value': int(line.split()[5]),
                'checksum': int(line.split()[6]),
                'lead_name': line.split()[7],
            }
            header_info['leads_info'].append(lead_info)

        # 2. Get .mat file
        twelve_lead_ecg = None
        if index < len(self._mat_files):
            mat_file_path = self._mat_files[index]
            twelve_lead_ecg = sio.loadmat(mat_file_path)
            
            # Resample the ECG to 128 Hz
            for lead in twelve_lead_ecg:
                twelve_lead_ecg[lead] = self.resample_ecg(twelve_lead_ecg[lead], old_freq=header_info['sampling_frequency'])
        else:
            print(f"MAT file for index {index} does not exist.")
        
        return header_info, twelve_lead_ecg

    def plot_record(self, index):
        mat_file_path = self._mat_files[index]
        data = sio.loadmat(mat_file_path)
        fig, axs = plt.subplots(nrows=3, ncols=4, figsize=(20, 15))

        for i, ax in enumerate(axs.flat):
            ax.plot(data['val'][i], linewidth=0.5)
            ax.set_xlabel('Sample')
            ax.set_ylabel('Amplitude')
            ax.set_title(f'Lead {i+1}')

        plt.tight_layout()
        plt.show()

    def __len__(self):
        return len(self._hea_files)