In [14]:
import os
from typing import List
from tqdm import tqdm
import numpy as np
import wfdb
from operator import attrgetter

DB_ALIASES = {
    "AFDB": "./mit-bih-atrial-fibrillation-database-1.0.0/files",
    "LTAFDB": "./long-term-af-database-1.0.0/files",
    "NSRDB": "./mit-bih-normal-sinus-rhythm-database-1.0.0/files",
}

def get_ranges_interest(
    record_path, code
) -> List[List[int]]:
    """
    Get the intervals of atrial fibrillation (AFIB) from a list of sample values and corresponding annotations.

    Args:
    - sample (List[int]): A list of ECG sample values.
    - aux_note (List[str]): A list of annotation labels for each sample.

    Returns:
    - ranges_interest (List[List[int]]): A list of start and end indices for each interval of AFDB.
    """
    
    _, ecg_metadata = wfdb.rdsamp(record_path)
    signal_len = ecg_metadata['sig_len']

    sample, aux_note = attrgetter("sample", "aux_note")(
        wfdb.rdann(record_path, "atr")
    )
    
    ranges_interest = []
    for i, label in enumerate(aux_note):
        if label == code:
            afib_start = sample[i]
            last_notation = len(sample) == (i + 1)
            afib_end = signal_len if last_notation else sample[i + 1] - 1
            ranges_interest.append([afib_start, afib_end])
    return ranges_interest


def extract_rri(db: str) -> np.ndarray:
    """Extract intervals RR from ECG recordings in a PhysioNet database

    Args:
        db (str): The name of the PhysioNet database from which the RR intervals were extracted.

    Returns:
        np.ndarray: A 1-dimensional numpy array containing the concatenated RRI signals in milliseconds.
    """
    path_db = DB_ALIASES[db]

    record_ids = []

    with open(f"{path_db}/RECORDS") as f:
        lines = f.readlines()
        record_ids = list(map(lambda line: line.strip(), lines))

    if db == "AFDB":
        record_ids.remove("00735")
        record_ids.remove("03665")

    rri_output = np.array([])

    code = '(AFIB' if db in ["AFDB", "LTAFB"] else '(N'

    for record_id in tqdm(record_ids):
        record_path = os.path.join(path_db, record_id)
        extract_intervals = get_ranges_interest(record_path, code)
        for start_index, end_index in extract_intervals:
            ann = wfdb.rdann(record_path, sampfrom=start_index, sampto=end_index, extension='qrs')
            rr_interval = wfdb.processing.calc_rr(ann.sample, fs=ann.fs)
            rri_output = np.concatenate((rri_output, rr_interval))
        
    rri_output = rri_output.astype(int)
    globals().update(locals())
    return rri_output
    

In [10]:
directory = "./output"
if not os.path.exists(directory):
    os.makedirs(directory)

for db_alias in DB_ALIASES.keys():
    rri_output = extract_rri(db_alias)
    np.save(os.path.join(directory, f"{db_alias}.npy"), rri_output)
    break

100%|██████████| 23/23 [01:05<00:00,  2.83s/it]
