# Imports

In [7]:
import wfdb
import biosppy.signals.ecg as ecg
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import resample
from scipy.signal import decimate
import re
import os

def trim_signal_segments(signal, rpeaks, segment_length_sec, sampling_rate):
    """
    Trim signal segments around R-peaks and return trimmed signals along with corresponding R-peaks.

    Parameters:
        signal (ndarray): The input signal.
        rpeaks (ndarray): Array containing R-peak indices.
        segment_length_sec (float): Length of the trimmed segments in seconds.
        sampling_rate (float): Sampling rate of the signal (Hz).

    Returns:
        tuple: A tuple containing two lists:
            - List of trimmed signal segments.
            - List of corresponding R-peaks for each trimmed signal segment.
    """
    trimmed_signals = []
    trimmed_rpeaks = []

    # Convert segment length from seconds to samples
    segment_length_samples = int(segment_length_sec * sampling_rate)

    peak_index = 0
    start_point = 0

    while start_point < (len(signal) - segment_length_samples):

        # Calculate the start and end points of the current segment around a peak
        # Start 0.25 sec before an rpeak
        if(peak_index < len(rpeaks)):
            start_point = rpeaks[peak_index] - int(sampling_rate / 4)
        else:
            break
        
        if(start_point<0):
            start_point = 0

        end_point = start_point + segment_length_samples
        
        # Extract the current segment from the signal
        current_signal_segment = signal[start_point:end_point].copy()

        # Extract R-peaks within the current segment
        current_rpeaks = rpeaks[(rpeaks >= start_point) & (rpeaks <= end_point)]
        current_rpeaks = current_rpeaks - start_point

        # If the last R-peak is too close to the end of the segment, trim the segment
        if current_rpeaks[-1] > (sampling_rate * 9.75):
            # 0.25 secs before
            temp_index = end_point - int(sampling_rate / 4)

            # Find the index where the signal is close to zero
            while temp_index >= start_point and abs(current_signal_segment[temp_index - start_point]) > 0.005:
                temp_index -= 1
            
            # Set values to zero after the identified index
            current_signal_segment[temp_index - start_point:] = 0

            # Remove the last R-peak as it may not be reliable
            current_rpeaks = current_rpeaks[:-1]

        # Store the trimmed segment and its corresponding R-peaks
        trimmed_signals.append(current_signal_segment)
        trimmed_rpeaks.append(current_rpeaks)

        # Move to the next peak
        peak_index += len(current_rpeaks)

    return trimmed_signals, trimmed_rpeaks

def save_ecg_recording(file_name, signals, annotations, sample_rate, write_dir, additional_info):
    """
    Save an ECG recording to .mat, .hea, and .atr files.

    Args:
        file_name (str): Base name for the files to create.
        signals (list): List of signal arrays, one for each channel.
        annotations (list): List of annotation samples.
        sample_rate (int): Sampling rate of the signals.
        units (list): List of measurement units for each channel.
        write_dir (str, optional): Directory to save the files. Defaults to current directory.
        additional_info (list): List of additional information lines to append to the .hea file.
    """
    # Define the channel names in the desired order
    channel_names = ['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

    # Ensure signal array is properly shaped (channels x samples)
    signals = np.array(signals).T  # Transpose to make channels as columns
    
    # Create a list of signal formats and units
    fmt = ['16'] * len(channel_names)
    units = ['mV'] * len(channel_names)

    # Check if the write directory exists, if not, create it
    if not os.path.exists(write_dir):
        os.makedirs(write_dir)

    # Save the signal using the wfdb format
    wfdb.wrsamp(file_name, fs=sample_rate, units=units, sig_name=channel_names, p_signal=signals, fmt=fmt, write_dir=write_dir)

    # Append additional information to .hea file
    hea_file = os.path.join(write_dir, f"{file_name}.hea")
    with open(hea_file, 'a') as f:
        for line in additional_info:
            f.write(line + '\n')

    # Save .atr file if annotations exist
    if len(annotations) > 0:
        wfdb.wrann(file_name, extension='atr', sample=annotations, symbol=['N'] * len(annotations), write_dir=write_dir)

# Organize Files

In [8]:
ptb_files_dir = "\\path\\ptb-diagnostic-ecg-database-1.0.0\\"

In [10]:
import os
import shutil

def process_files(directory, search_texts):
    
    # List all files and directories in the given directory
    for root, dirs, files in os.walk(directory):
        # Process each file in the current directory
        for file in files:
            if file.endswith(".hea"):  # Process only files with .hea extension
                file_path = os.path.join(root, file)
                with open(file_path, 'r') as f:
                    lines = f.readlines()
                    for line in lines:
                        if line.startswith("# Reason for admission:"):
                            # Extract the text after the "#" symbol
                            text = line.split(":")[1].strip()
                            # Check if the extracted text is in the search_texts list
                            if text in search_texts:
                                # Extract the file name by removing the .hea extension
                                file_name = file.split(".")[0]
                                # Get the destination directory for the current text
                                destination_folder = os.path.join(directory, text)
                                os.makedirs(destination_folder, exist_ok=True)
                                # Copy the .hea, .dat, and .xyz files to the destination directory
                                for ext in [".hea", ".dat", ".xyz"]:
                                    source_file = os.path.join(root, file_name + ext)
                                    destination_file = os.path.join(destination_folder, file_name + ext)
                                    shutil.copy(source_file, destination_file)
                                break  # Stop searching after finding the text

# Specify the search texts
search_texts = ["Healthy control", "Myocardial infarction", "Hypertrophy"]

# Process files in the root directory and its subdirectories
process_files(ptb_files_dir, search_texts)

# Iterative File Processing

In [None]:
root_write_dir = ptb_files_dir+"ptb_filtered_splitted_100hz\\"

dir_norm = ptb_files_dir+"Healthy control"
dir_mi = ptb_files_dir+"Myocardial infarction"
dir_hyp = ptb_files_dir+"Hypertrophy"

write_dir_norm = root_write_dir+"NORM"
write_dir_mi = root_write_dir+"MI"
write_dir_hyp = root_write_dir+"HYP"

directories = [write_dir_norm, write_dir_mi, write_dir_hyp]

# Function to create directories if they don't exist
def create_directory_if_not_exists(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"Directory '{directory}' created successfully.")

# Apply the function to each directory
for directory in directories:
    create_directory_if_not_exists(directory)
    
# --------------------------------------------------------------------

def trim_and_improve_signal(root_dir, write_dir, label):
    for file_name in os.listdir(root_dir):
        if file_name.endswith(".hea"):
            #------------------------------------------------#
            # GET SIGNAL AND PREPROCESS
            #------------------------------------------------#
            base_file_name = os.path.splitext(file_name)[0]
            print(base_file_name)
            
            record_path = os.path.join(root_dir, base_file_name)

            signals, fields = wfdb.rdsamp(record_path)
            fs = fields['fs']

            #------------------------------------------------#
            # EXTRACT ADDITIONAL INFO
            #------------------------------------------------#
            # Specify the file name
            file_path = f'{record_path}.hea'
            # Read the file
            with open(file_path, "r") as file:
                hea_file_content = file.read()
                
            # Split the content of the HEA file by lines
            hea_lines = hea_file_content.split('\n')
            age = 0
            sex = ''
            # Extract age and sex
            for line in hea_lines:
                if "# age: " in line:
                    age = line.split("# age: ")[1].split()[0].strip()
                if "# sex: " in line:
                    sex = line.split("# sex: ")[1].split()[0].strip()

            if sex == "male":
                sex = "M"
            elif sex == "female":
                sex = "F"

            additional_info = [
                f"# Age: {age}",
                f"# Sex: {sex}",
                f"# Diagnosis: {label}"
            ]

            #------------------------------------------------#
            # ANALYSIS OF ECG SIGNAL
            #------------------------------------------------#

            show_data_info = False
            ecg_analysis_original = []

            for i in range(12):
                ecg_analysis_original.append(ecg.ecg(signal=signals[:, i], sampling_rate=fs, show=show_data_info))

            signals_filtered = [ecg_analysis['filtered'] for ecg_analysis in ecg_analysis_original]

            #------------------------------------------------#
            # UPSAMPLE SIGNAL
            #------------------------------------------------#

            new_sampling_rate = 100
            resize_factor = new_sampling_rate / fs
            new_size = int(len(signals_filtered[0]) * resize_factor)
            resampled_ecg_signals = []
            for signal_filtered in signals_filtered:
                resampled_signal = resample(signal_filtered, new_size)
                resampled_ecg_signals.append(resampled_signal)

            #------------------------------------------------#
            # TRIM SIGNAL TO SHORTER SIGNALS
            #------------------------------------------------#

            analysis = ecg.ecg(signal=resampled_ecg_signals[0], sampling_rate=new_sampling_rate, show=False)
            rpeaks = analysis['rpeaks']

            _, trimmed_signals_rpeaks_all = trim_signal_segments(resampled_ecg_signals[0], rpeaks, 10, new_sampling_rate)

            # Loop through each resampled ECG signal and trim it using the R-peaks
            trimmed_signals_1ch_all = []
            for resampled_signal in resampled_ecg_signals:
                trimmed_signals, trimmed_signals_rpeaks = trim_signal_segments(resampled_signal, rpeaks, 10, new_sampling_rate)
                trimmed_signals_1ch_all.append(trimmed_signals)

            # Combine each channel from the trimmed_signals_1ch_all list into a single 12-channel ECG signal
            # Create a list of 12-channel ECG signals
            trimmed_signals_12ch_all = []
            for i in range(len(trimmed_signals_1ch_all[0])-1):
                signal_12ch = [trimmed_signals_1ch_all[j][i] for j in range(12)]
                trimmed_signals_12ch_all.append(signal_12ch)

            #------------------------------------------------#
            # SAVE TRIMMED ECG SIGNALS TO FILE
            #------------------------------------------------#

            len_arr = len(trimmed_signals_12ch_all)
            for slice_no, signals in enumerate(trimmed_signals_12ch_all):
                if(slice_no < len_arr-1):
                    file_name = f'{base_file_name}_segment{slice_no+1}'
                    file_name = re.sub(r'[^\w\d\-_]', '', file_name)
                    save_ecg_recording(file_name, signals, trimmed_signals_rpeaks_all[slice_no], new_sampling_rate, write_dir, additional_info)
                else:
                    if(len(signals) == new_sampling_rate*10):
                        file_name = f'{base_file_name}_segment{slice_no+1}'
                        file_name = re.sub(r'[^\w\d\-_]', '', file_name)
                        save_ecg_recording(file_name, signals, trimmed_signals_rpeaks_all[slice_no], new_sampling_rate, write_dir, additional_info)

#print("NORM")
#trim_and_improve_signal(dir_norm, write_dir_norm, "NORM")
#print("----------------------------------------------------------")
print("MI")
trim_and_improve_signal(dir_mi, write_dir_mi, "MI")
#print("----------------------------------------------------------")
#print("HYP")
#trim_and_improve_signal(dir_hyp, write_dir_hyp, "HYP")