In [None]:
import os
from typing import List

import numpy as np
import wfdb
from tqdm import tqdm
from wfdb import processing

DB_ALIASES = {
    "AFDB": "./mit-bih-atrial-fibrillation-database-1.0.0/files",
    "LTAFDB": "./long-term-af-database-1.0.0/files",
    "NSRDB": "./mit-bih-normal-sinus-rhythm-database-1.0.0/files",
}


def clean_annotations(annotations):
    sample, aux_note = annotations["sample"], annotations["aux_note"]
    non_empty_indices = [i for i, note in enumerate(aux_note) if note != ""]
    clean_aux_note = [aux_note[i] for i in non_empty_indices]
    clean_sample = [sample[i] for i in non_empty_indices]
    return clean_aux_note, clean_sample


def get_ranges_afib(record_path: str, signal_len: int) -> List[List[int]]:
    annotations = wfdb.rdann(record_path, "atr")
    sample, aux_note = clean_annotations(annotations)

    ranges_interest = []
    for i, label in enumerate(aux_note):
        if label == "(AFIB":
            afib_start = sample[i]
            last_notation = len(sample) == (i + 1)
            afib_end = signal_len if last_notation else sample[i + 1] - 1
            ranges_interest.append([afib_start, afib_end])
    return ranges_interest


def extract_rri(record_ids: List[int], db_alias: str) -> np.ndarray:
    rri_output = np.array([])

    for record_id in tqdm(record_ids):
        record_path = os.path.join(DB_ALIASES[db_alias], record_id)

        _, ecg_metadata = wfdb.rdsamp(record_path)
        signal_len = ecg_metadata["sig_len"]

        extract_intervals = (
            get_ranges_afib(record_path, signal_len)
            if db_alias in ["AFDB", "LTAFDB"] else [[0, signal_len - 1]])

        for start_index, end_index in extract_intervals:
            ann = wfdb.rdann(
                record_path,
                sampfrom=start_index,
                sampto=end_index,
                extension="atr" if db_alias == "NSRDB" else "qrs")
            ann_ms = (ann.sample / ann.fs) * 1000
            rr_interval = processing.calc_rr(ann_ms, fs=ann.fs)
            rri_output = np.concatenate((rri_output, rr_interval))

    rri_output = rri_output.astype(int)

    return rri_output


def get_record_ids(db_alias: str) -> List[str]:
    with open(f"{DB_ALIASES[db_alias]}/RECORDS") as f:
        lines = f.readlines()
        record_ids = [line.strip() for line in lines]

    if db_alias == "AFDB":
        record_ids.remove("00735")
        record_ids.remove("03665")

    return record_ids

In [None]:
directory = "./output"
if not os.path.exists(directory):
    os.makedirs(directory)

for db_alias in DB_ALIASES.keys():

    record_ids = get_record_ids(db_alias)

    rri_output = extract_rri(record_ids)

    np.save(os.path.join(directory, f"{db_alias}.npy"), rri_output)