In [2]:
import os
from typing import List
import numpy as np
import wfdb
from wfdb import processing
import numpy as np

from tqdm import tqdm
import warnings
import numpy as np
from scipy.signal import filtfilt
from scipy import signal
import itertools

warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

DB_ALIASES = {
    "AFDB": "../mit-bih-atrial-fibrillation-database-1.0.0/files",
    "LTAFDB": "../long-term-af-database-1.0.0/files",
    "NSRDB": "../mit-bih-normal-sinus-rhythm-database-1.0.0/files",
}


def clean_annotation(annotation):
    sample, aux_note = annotation.sample, annotation.aux_note
    non_empty_indices = [i for i, note in enumerate(aux_note) if note != ""]
    clean_aux_note = [aux_note[i] for i in non_empty_indices]
    clean_sample = [sample[i] for i in non_empty_indices]
    return clean_sample, clean_aux_note


def get_ranges_afib(record_path: str, signal_len: int) -> List[List[int]]:
    annotation = wfdb.rdann(record_path, "atr")
    sample, aux_note = clean_annotation(annotation)
    ranges_interest = []
    for i, label in enumerate(aux_note):
        if label == "(AFIB":
            afib_start = sample[i]
            last_notation = len(sample) == (i + 1)
            afib_end = signal_len if last_notation else sample[i + 1] - 1
            ranges_interest.append([afib_start, afib_end])
    return ranges_interest


def cut_array(array_rri: List[int], segment_len: int) -> np.ndarray:
    num_segments = len(array_rri) // segment_len

    if num_segments <= 0:
        return np.empty((0, 0))

    segments = np.array(
        [
            array_rri[i : i + segment_len]
            for i in range(0, num_segments * segment_len, segment_len)
        ]
    )
    return segments


def get_record_ids(db_alias: str) -> List[str]:
    with open(f"{DB_ALIASES[db_alias]}/RECORDS") as f:
        lines = f.readlines()
        record_ids = [line.strip() for line in lines]

    if db_alias == "AFDB":
        record_ids.remove("00735")
        record_ids.remove("03665")

    return record_ids


def butter_bandpass(lowcut, highcut, fs, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = signal.butter(order, [low, high], btype="band")
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)
    return y


def filter_signal(ecg1):
    lowcut = 0.05
    highcut = 40.0
    fs = 250
    order = 4

    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist

    b, a = signal.butter(order, [low, high], btype="band")

    filtered_signal = signal.lfilter(b, a, ecg1)

    return filtered_signal

In [5]:
segment_size = 50
m_bins = 50
range_hist = (0, 2500)

base_pos = np.empty((0, 2, m_bins))
base_neg = np.empty((0, 2, m_bins))

for db_alias in DB_ALIASES.keys():
    label = 0 if db_alias == "NSRDB" else 1
    extension_signal = "atr" if db_alias == "NSRDB" else "qrs"

    record_ids = get_record_ids(db_alias)

    for record_id in tqdm(record_ids):
        record_path = os.path.join(DB_ALIASES[db_alias], record_id)

        _, ecg_metadata = wfdb.rdsamp(record_path)
        signal_len = ecg_metadata["sig_len"]

        extract_intervals = (
            get_ranges_afib(record_path, signal_len)
            if db_alias in ["AFDB", "LTAFDB"]
            else [[0, signal_len - 1]]
        )

        rec = wfdb.rdrecord(record_name=record_path)

        for start_index, end_index in extract_intervals:
            ann = wfdb.rdann(
                record_path,
                sampfrom=start_index,
                sampto=end_index,
                extension=extension_signal,
            )
            
            ann_ms = (ann.sample / ann.fs) * 1000

            rr_interval = wfdb.processing.calc_rr(ann_ms, fs=ann.fs)

            num_segments = len(rr_interval) // segment_size
            if num_segments <= 0:
                pass

            last_segment = num_segments * segment_size

            rr_histogram = np.array([])
            
            num_segments = len(ann.sample) // segment_size

            if num_segments <= 0:
                pass

            last_segment = num_segments * segment_size

            tq_histogram = np.array([])

            for i in range(0, last_segment, segment_size):
                start_segment_rr = i
                end_segment_rr = i + segment_size
                rr_segment = rr_interval[start_segment_rr:end_segment_rr]

                # Calc Histogram RR
                rr_histogram, _ = np.histogram(a=rr_segment, range=range_hist, bins=m_bins)
                
                tq_segments = []

                r_peaks_segment = [i for i in ann.sample][i : i + segment_size]

                r_peaks_size = len(r_peaks_segment)

                for j, r_peak in enumerate(r_peaks_segment):
                    if j in [0, r_peaks_size - 1]:
                        continue

                    tq_start = r_peak + int(0.1 * ann.fs)
                    tq_end = tq_start + int(0.25 * ann.fs)

                    if tq_end > r_peaks_segment[j + 1]:
                        continue

                    tq_segment = rec.p_signal[:, 0][tq_start:tq_end]

                    tq_segment = filter_signal(tq_segment)

                    tq_segments.append(tq_segment)

                tq_segments = list(itertools.chain.from_iterable(tq_segments))

                tq_histogram, _ = np.histogram(a=tq_segments, range=(-2, 2), bins=50)

                if rr_histogram.shape[0] == 0 or tq_histogram.shape[0] == 0:
                    continue

                sample = np.vstack((rr_histogram, tq_histogram))

                sample = sample.reshape((1, sample.shape[0], sample.shape[1]))

                if label:
                    base_pos = np.vstack((base_pos, sample))
                else:
                    base_neg = np.vstack((base_neg, sample))

directory = "./samples"
np.save(os.path.join(directory, "base_pos"), base_pos)
np.save(os.path.join(directory, "base_neg"), base_neg)

  0%|          | 0/23 [00:00<?, ?it/s]

100%|██████████| 23/23 [05:14<00:00, 13.69s/it]
100%|██████████| 84/84 [1:35:05<00:00, 67.92s/it]   
100%|██████████| 18/18 [12:14<00:00, 40.81s/it]


In [7]:
base_pos.shape

(70269, 2, 50)