In [1]:
import os
import tensorflow as tf

from heartkit.tasks import TaskFactory
from typing import Type, TypeVar
from argdantic import ArgField, ArgParser
from pydantic import BaseModel
from heartkit.utils import env_flag, set_random_seed, setup_logger

from heartkit.tasks.AFIB_Ident.utils import (
    create_model,
    load_datasets,
    load_test_datasets,
    load_train_datasets,
    prepare,
)

from heartkit.defines import (
    HKDemoParams
)
from heartkit.tasks.AFIB_Ident.defines import (
    get_class_mapping,
    get_class_names,
    get_class_shape,
    get_classes,
    get_feat_shape,
)

cli = ArgParser()
B = TypeVar("B", bound=BaseModel)


def parse_content(cls: Type[B], content: str) -> B:
    """Parse file or raw content into Pydantic model.

    Args:
        cls (B): Pydantic model subclasss
        content (str): File path or raw content

    Returns:
        B: Pydantic model subclass instance
    """
    if os.path.isfile(content):
        with open(content, "r", encoding="utf-8") as f:
            content = f.read()

    return cls.model_validate_json(json_data=content)


config = 'configs/arrhythmia-100class-2.json'
params = parse_content(HKDemoParams, config)


params.seed = set_random_seed(params.seed)
params.data_parallelism = 8

class_names = get_class_names(params.num_classes)
class_map = get_class_mapping(params.num_classes)
input_spec = (
    tf.TensorSpec(shape=get_feat_shape(params.frame_size), dtype=tf.float32),
    tf.TensorSpec(shape=get_class_shape(params.frame_size, params.num_classes), dtype=tf.int32),
)

datasets = load_datasets(
    ds_path=params.ds_path,
    frame_size=params.frame_size,
    sampling_rate=params.sampling_rate,
    class_map=class_map,
    spec=input_spec,
    datasets=params.datasets,
)

# this is where they get the test signal and the label
test_x, test_y = load_test_datasets(datasets=datasets, params=params)

2024-03-15 16:22:22.065647: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-15 16:22:22.069261: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-03-15 16:22:22.106207: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-15 16:22:22.106235: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-15 16:22:22.107529: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to

In [None]:
patient_ids = datasets[0].get_test_patient_ids()
signal_label = next(datasets[0].signal_label_generator(datasets[0].uniform_patient_generator(patient_ids=patient_ids, repeat=False)))
x = signal_label[0]
x.shape[0]

400

In [3]:
import numpy as np

patient_ids = datasets[0].get_test_patient_ids()
pat_gen = datasets[0].uniform_patient_generator(patient_ids=patient_ids, repeat=False)
first_pat = next(pat_gen)

segment = first_pat[1][np.random.choice(list(first_pat[1].keys()))]
# rlabels are the rhythm type
segment["rlabels"]

<HDF5 dataset "rlabels": shape (20, 2), type "<i4">

In [None]:
rlabels = segment["rlabels"][:]
rlabels
xs, xe, xl = rlabels[0::2, 0], rlabels[1::2, 0], rlabels[0::2, 1]

In [12]:
from enum import IntEnum
from heartkit.defines import (
    HKDemoParams, HeartBeat, HeartRate, HeartRhythm, HeartSegment
)

from heartkit.tasks.AFIB_Ident.defines import (
    get_class_mapping,
    get_class_names,
    get_class_shape,
    get_classes,
    get_feat_shape,
)

class IcentiaRhythm(IntEnum):
    """Icentia rhythm labels"""
    noise = 0
    normal = 1
    afib = 2
    aflut = 3
    end = 4

HeartRhythmMap = {
    IcentiaRhythm.noise: HeartRhythm.noise,
    IcentiaRhythm.normal: HeartRhythm.normal,
    IcentiaRhythm.afib: HeartRhythm.afib,
    IcentiaRhythm.aflut: HeartRhythm.aflut,
    IcentiaRhythm.end: HeartRhythm.noise,
}


tgt_labels = list(set(class_map.values()))
seg_map: list[str] = list(segment.keys()) # blabel, data, rlabel
pt_tgt_seg_map = [[] for _ in tgt_labels]
pt_tgt_seg_map

class_map = get_class_mapping(2)
tgt_map = {k: class_map.get(v, -1) for (k, v) in HeartRhythmMap.items()}
input_size = 400
# Grab rhythm labels
rlabels = segment["rlabels"][:]

# Skip if no rhythm labels
if not rlabels.shape[0]:
    print("No rlabel")
rlabels = rlabels[np.where(rlabels[:, 1] != IcentiaRhythm.noise.value)[0]]
if not rlabels.shape[0]:
    print("Only noise")

# Unpack start, end, and label
xs, xe, xl = rlabels[0::2, 0], rlabels[1::2, 0], rlabels[0::2, 1]

print(xs, xe)
# Map labels to target labels
xl = np.vectorize(tgt_map.get, otypes=[int])(xl)

# Capture segment, start, and end for each target label
for tgt_idx, tgt_class in enumerate(tgt_labels):
    idxs = np.where((xe - xs >= input_size) & (xl == tgt_class))
    seg_vals = np.vstack((0 * np.ones_like(idxs), xs[idxs], xe[idxs])).T
    pt_tgt_seg_map[tgt_idx] += seg_vals.tolist()
# END FOR


[    120  358662  367313  384920  832166  872517  929152  942302  980922
 1004133] [ 357063  365728  383353  830248  870657  927440  940660  961648  998137
 1048429]


In [14]:
num_classes = 2
samples_per_patient = [25, 200][0]
num_per_tgt = int(max(1, samples_per_patient / num_classes))
samples_per_tgt = num_per_tgt * [num_classes]

In [15]:
import random

pt_tgt_seg_map = [np.array(b) for b in pt_tgt_seg_map]

# Grab target segments
seg_samples: list[tuple[int, int, int, int]] = []
for tgt_idx, tgt_class in enumerate(tgt_labels):
    tgt_segments = pt_tgt_seg_map[tgt_idx]
    if not tgt_segments.shape[0]:
        continue
    tgt_seg_indices: list[int] = random.choices(
        np.arange(tgt_segments.shape[0]),
        weights=tgt_segments[:, 2] - tgt_segments[:, 1],
        k=samples_per_tgt[tgt_idx],
    )
    for tgt_seg_idx in tgt_seg_indices:
        seg_idx, rhy_start, rhy_end = tgt_segments[tgt_seg_idx]
        frame_start = np.random.randint(rhy_start, rhy_end - input_size + 1)
        frame_end = frame_start + input_size
        seg_samples.append((seg_idx, frame_start, frame_end, tgt_class))

In [16]:
seg_samples

[(0, 1022, 1422, 0), (0, 86916, 87316, 0)]

In [17]:
tgt_labels

[0, 1]

In [None]:
random.shuffle(seg_samples)

# Yield selected samples for patient
for seg_idx, frame_start, frame_end, label in seg_samples:
    x: npt.NDArray = segments[seg_map[seg_idx]]["data"][frame_start:frame_end].astype(np.float32)
    if self.sampling_rate != self.target_rate:
        x = pk.signal.resample_signal(x, self.sampling_rate, self.target_rate, axis=0)
    yield x, label

In [None]:
# Group patient rhythms by type (segment, start, stop, delta)
for _, segments in patient_generator:
    # This maps segment index to segment key
    seg_map: list[str] = list(segments.keys())

    pt_tgt_seg_map = [[] for _ in tgt_labels]
    for seg_idx, seg_key in enumerate(seg_map):
        # Grab rhythm labels
        rlabels = segments[seg_key]["rlabels"][:]
        # Unpack start, end, and label
        xs, xe, xl = rlabels[0::2, 0], rlabels[1::2, 0], rlabels[0::2, 1]

        # Capture segment, start, and end for each target label
        for tgt_idx, tgt_class in enumerate(tgt_labels):
            idxs = np.where((xe - xs >= input_size) & (xl == tgt_class))
            seg_vals = np.vstack((seg_idx * np.ones_like(idxs), xs[idxs], xe[idxs])).T
            pt_tgt_seg_map[tgt_idx] += seg_vals.tolist()
        # END FOR
    # END FOR
    pt_tgt_seg_map = [np.array(b) for b in pt_tgt_seg_map]

    # Grab target segments
    seg_samples: list[tuple[int, int, int, int]] = []
    for tgt_idx, tgt_class in enumerate(tgt_labels):
        tgt_segments = pt_tgt_seg_map[tgt_idx]
        if not tgt_segments.shape[0]:
            continue
        tgt_seg_indices: list[int] = random.choices(
            np.arange(tgt_segments.shape[0]),
            weights=tgt_segments[:, 2] - tgt_segments[:, 1],
            k=samples_per_tgt[tgt_idx],
        )
        for tgt_seg_idx in tgt_seg_indices:
            seg_idx, rhy_start, rhy_end = tgt_segments[tgt_seg_idx]
            frame_start = np.random.randint(rhy_start, rhy_end - input_size + 1)
            frame_end = frame_start + input_size
            seg_samples.append((seg_idx, frame_start, frame_end, tgt_class))
        # END FOR
    # END FOR

    # Shuffle segments
    random.shuffle(seg_samples)

    # Yield selected samples for patient
    for seg_idx, frame_start, frame_end, label in seg_samples:
        x: npt.NDArray = segments[seg_map[seg_idx]]["data"][frame_start:frame_end].astype(np.float32)
        if self.sampling_rate != self.target_rate:
            x = pk.signal.resample_signal(x, self.sampling_rate, self.target_rate, axis=0)
        yield x, label
    # END FOR
# END FOR