In [7]:
import os
from typing import List, Optional

import numpy as np
import wfdb
from tqdm import tqdm
import warnings

warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

DB_ALIASES = {
    "AFDB": "../mit-bih-atrial-fibrillation-database-1.0.0/files",
    "LTAFDB": "../long-term-af-database-1.0.0/files",
    "NSRDB": "../mit-bih-normal-sinus-rhythm-database-1.0.0/files",
}

directory = "./segments_bins"


def clean_annotation(annotation):
    sample, aux_note = annotation.sample, annotation.aux_note
    non_empty_indices = [i for i, note in enumerate(aux_note) if note != ""]
    clean_aux_note = [aux_note[i] for i in non_empty_indices]
    clean_sample = [sample[i] for i in non_empty_indices]
    return clean_sample, clean_aux_note


def get_ranges_afib(record_path: str, signal_len: int) -> List[List[int]]:
    annotation = wfdb.rdann(record_path, "atr")
    sample, aux_note = clean_annotation(annotation)
    ranges_interest = []
    for i, label in enumerate(aux_note):
        if label == "(AFIB":
            afib_start = sample[i]
            last_notation = len(sample) == (i + 1)
            afib_end = signal_len if last_notation else sample[i + 1] - 1
            ranges_interest.append([afib_start, afib_end])
    return ranges_interest


def cut_array(array_rri: List[int], segment_len: int) -> np.ndarray:
    num_segments = len(array_rri) // segment_len

    if num_segments <= 0:
        return np.empty((0, 0))

    segments = np.array(
        [
            array_rri[i : i + segment_len]
            for i in range(0, num_segments * segment_len, segment_len)
        ]
    )
    return segments


def get_random_samples_by_label(
    base: np.ndarray,
    label: int,
    n_samples: Optional[int] = 0,
) -> np.ndarray:
    samples = base[base[:, -1] == label]

    if n_samples:
        indexes = np.random.choice(samples.shape[0], n_samples, replace=False)
        return samples[indexes]

    return samples


def get_random_samples(base: np.ndarray, qtd_segments: int) -> np.ndarray:
    qtd_segments = int(qtd_segments / 2)
    negatives = get_random_samples_by_label(base=base, label=0, n_samples=qtd_segments)
    positives = get_random_samples_by_label(base=base, label=1, n_samples=qtd_segments)
    base_ready = np.vstack((positives, negatives))
    return base_ready


def get_record_ids(db_alias: str) -> List[str]:
    with open(f"{DB_ALIASES[db_alias]}/RECORDS") as f:
        lines = f.readlines()
        record_ids = [line.strip() for line in lines]

    if db_alias == "AFDB":
        record_ids.remove("00735")
        record_ids.remove("03665")

    return record_ids


def extract_segments_bins(params) -> np.ndarray:
    segment_size = params.get("segment_size")
    m_bins = params.get("m_bins")
    range_hist = params.get("range_hist")
    qtd_segments = params.get("qtd_segments")

    if not os.path.exists(directory):
        os.makedirs(directory)

    filename = f"mb{m_bins}_ss{segment_size}_rh{range_hist[0]}-{range_hist[1]}.npy"

    path = os.path.join(directory, filename)

    if os.path.isfile(path):
        segment_bins = np.load(path)
        return get_random_samples(segment_bins, qtd_segments)

    base = np.empty((0, m_bins + 1))

    for db_alias in DB_ALIASES.keys():
        label = 0 if db_alias == "NSRDB" else 1
        extension_signal = "atr" if db_alias == "NSRDB" else "qrs"

        record_ids = get_record_ids(db_alias)

        for record_id in tqdm(record_ids):
            record_path = os.path.join(DB_ALIASES[db_alias], record_id)

            _, ecg_metadata = wfdb.rdsamp(record_path)
            signal_len = ecg_metadata["sig_len"]

            extract_intervals = (
                get_ranges_afib(record_path, signal_len)
                if db_alias in ["AFDB", "LTAFDB"]
                else [[0, signal_len - 1]]
            )

            for start_index, end_index in extract_intervals:
                print(record_path, start_index, end_index, extension_signal)
                ann = wfdb.rdann(
                    record_path,
                    sampfrom=start_index,
                    sampto=end_index,
                    extension=extension_signal,
                )

                ann_ms = (ann.sample / ann.fs) * 1000

                rr_interval = wfdb.processing.calc_rr(ann_ms, fs=ann.fs)

                segments = cut_array(rr_interval, segment_size)

                for segment in segments:
                    hist, _ = np.histogram(a=segment, range=range_hist, bins=m_bins)
                    x = hist
                    y = label
                    row = np.append(x, y)
                    row = row.reshape((1, row.shape[0]))
                    base = np.vstack((base, row))

    num_rows = base.shape[0]
    permutation = np.random.permutation(num_rows)
    shuffled_base = base[permutation].copy()

    np.save(path, shuffled_base)

    return get_random_samples(shuffled_base, qtd_segments)

In [8]:
params = {
    "m_bins": 40,
    "qtd_segments": 1000,
    "segment_size": 50,
    "range_hist": (0, 2500),
}
base = extract_segments_bins(params)

  0%|          | 0/23 [00:01<?, ?it/s]

../mit-bih-atrial-fibrillation-database-1.0.0/files/04015 102584 119603 ('qrs',)





TypeError: can only concatenate str (not "tuple") to str