In [1]:
import os
from operator import attrgetter
from typing import List
from tqdm import tqdm
import neurokit2 as nk
import numpy as np
import wfdb

DB_ALIASES = {
    "AFIB": "./mit-bih-atrial-fibrillation-database-1.0.0/files",
    "LTAFDB": "./long-term-af-database-1.0.0/files",
    "NSRDB": "./mit-bih-normal-sinus-rhythm-database-1.0.0/files",
}

def get_records_id(path: str) -> List[str]:
    """method to get all available record ids from AFDB

    Args:
        path (str): the path of MIT-BIH AFDB files

    Returns:
        List[str]: a list of string containing all records id's
    """
    with open(f"{path}/RECORDS") as f:
        lines = f.readlines()
    return list(map(lambda line: line.strip(), lines))


def extract_rri_signal(ecg_r_peaks: np.ndarray, signal_lead_size: int) -> List[float]:
    """method to get IRR intervals from signal lead and founded R-peaks indices

    Args:
        ecg_r_peaks (np.ndarray): numpy array containing found R-peaks indices
        signal_lead_size (int): a size of signal recording coming from an ECG lead

    Returns:
        List[float]: a list of floats, is the RRIs extracted from signal
    """
    rri_signal = []
    for i in range(ecg_r_peaks.size - 1):
        rri_beat = int(ecg_r_peaks[i + 1] - ecg_r_peaks[i])
        rri_signal.append(rri_beat)
    return rri_signal


def get_intervals_afib(
    sample: List[int], aux_note: List[str], signal_len: int
) -> List[List[int]]:
    """
    Get the intervals of atrial fibrillation (AFIB) from a list of sample values and corresponding annotations.

    Args:
    - sample (List[int]): A list of ECG sample values.
    - aux_note (List[str]): A list of annotation labels for each sample.

    Returns:
    - afib_intervals (List[List[int]]): A list of start and end indices for each interval of AFIB.
    """
    afib_intervals = []
    for i, label in enumerate(aux_note):
        if label == "(AFIB":
            afib_start = sample[i]
            last_notation = len(sample) == (i + 1)
            afib_end = signal_len if last_notation else sample[i + 1] - 1
            afib_intervals.append([afib_start, afib_end])
    return afib_intervals


def resample_ms(rri_signal: List[float], freq: float) -> List[float]:
    """Resample an RRI signal to a specific frequency in milliseconds.

    Args:
        rri_signal (List[float]): List of RRI values to be resampled.
        freq (float): Frequency in Hz to which the RRI signal will be resampled.

    Returns:
        List[float]: List of resampled RRI values in milliseconds.
    """
    MILLISECONDS = 1000
    return [(MILLISECONDS / freq) * rri for rri in rri_signal]


def extract_rri(db: str) -> np.ndarray:
    """Extract RRI signals from ECG recordings in a PhysioNet database, 
    resample the signals to a specific frequency in milliseconds, and 
    concatenate them into a single numpy array.

    Args:
        db (str): The name of the PhysioNet database to extract RRI 
        signals from. Valid values are "LTAFDB", "AFIB", and "NSRDB".

    Returns:
        np.ndarray: A 1-dimensional numpy array containing the 
        concatenated RRI signals in milliseconds.
    """
    path_db = DB_ALIASES[db]

    record_ids = get_records_id(path_db)

    if db == "AFIB":
        record_ids.remove("00735")
        record_ids.remove("03665")

    rri_output = np.array([])

    print(path_db)
    for record_id in tqdm(record_ids):
        record_path = os.path.join(path_db, record_id)

        ecg_signal, ecg_metadata = wfdb.rdsamp(record_path)
        signal_len = ecg_metadata['sig_len']
        sampling_rate = ecg_metadata['fs']

        sample, aux_note = attrgetter("sample", "aux_note")(
            wfdb.rdann(record_path, "atr")
        )

        extract_intervals = list()
        if db in ["AFIB", "LTAFB"]:
            extract_intervals = get_intervals_afib(sample, aux_note, signal_len)
        else:
            extract_intervals = [[0, signal_len - 1]]

        lead_signal = ecg_signal[:, 1]

        for start_index, end_index in extract_intervals:
            signal = lead_signal[start_index:end_index]
            _, rpeaks = nk.ecg_peaks(signal, sampling_rate=sampling_rate)
            rri_signal = extract_rri_signal(rpeaks["ECG_R_Peaks"], signal_len)
            rri_signal_ms = resample_ms(rri_signal, sampling_rate)
            rri_output = np.concatenate((rri_output, rri_signal_ms))

    globals().update(locals())
    print(f'results: {rri_output.shape[0]} RRIs')
    print('\n\n---\n\n')

    return rri_output

def save_output(result: np.ndarray, db: str) -> None:
    """Save a numpy array to a file in the "./output" directory with the 
    RRI extracted and ".npy" extension.

    Args:
        result (np.ndarray): A numpy array containing the output to be saved.
        db (str): The name of the PhysioNet database to extract RRI 
        signals from. Valid values are "LTAFDB", "AFIB", and "NSRDB".

    Returns:
        None
    """
    directory = "./output"
    if not os.path.exists(directory):
        os.makedirs(directory)
    np.save(os.path.join(directory, f"{db}.npy"), result)

In [2]:
for db_alias in DB_ALIASES.keys():
    rri_output = extract_rri(db_alias)
    save_output(rri_output, db_alias)

./mit-bih-atrial-fibrillation-database-1.0.0/files


100%|██████████| 23/23 [00:34<00:00,  1.48s/it]


results: 512257 RRIs


---


./long-term-af-database-1.0.0/files


100%|██████████| 84/84 [06:13<00:00,  4.45s/it]


results: 8692884 RRIs


---


./mit-bih-normal-sinus-rhythm-database-1.0.0/files


100%|██████████| 18/18 [01:26<00:00,  4.79s/it]

results: 2050611 RRIs


---





