In [1]:
import numpy.typing as npt
from heartkit.datasets.defines import PatientGenerator, Preprocessor, SampleGenerator
import tensorflow as tf

def default_preprocess(x: npt.NDArray) -> npt.NDArray:
    """Default identity preprocessing."""
    return x


def _dataset_sample_generator(
    self,
    patient_ids: npt.NDArray,
    samples_per_patient: int | list[int] = 100,
    repeat: bool = True,
    preprocess: Preprocessor | None = None,
) -> SampleGenerator:
    """Internal sample generator for task.

    Args:
        patient_ids (npt.NDArray): Patient IDs
        samples_per_patient (int | list[int], optional): Samples per patient. Defaults to 100.
        repeat (bool, optional): Repeat. Defaults to True.

    Returns:
        SampleGenerator: Task sample generator
    """
    patient_generator = self.uniform_patient_generator(patient_ids, repeat=repeat)
    data_generator = self.task_data_generator(
        patient_generator,
        samples_per_patient=samples_per_patient,
    )
    preprocess_fn = preprocess if preprocess else default_preprocess
    num_classes = len(set(self.class_map.values()))
    feat_shape = tuple(self.spec[0].shape)

    data_generator = map(
        lambda x_y: (
            preprocess_fn(x_y[0]).reshape(feat_shape),
            x_y[0] if num_classes <= 1 else tf.one_hot(x_y[1], num_classes), # x_y[1] should be the corresponding label
        ),
        data_generator,
    )

    return data_generator

2024-03-14 20:04:44.544139: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-14 20:04:45.316739: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-03-14 20:04:46.709498: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-14 20:04:46.709550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-14 20:04:46.922623: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to

In [None]:
def rhythm_data_generator(
    self,
    patient_generator: PatientGenerator,
    samples_per_patient: int | list[int] = 1,
) -> SampleGenerator:
    """Generate frames w/ rhythm labels (e.g. afib) using patient generator.

    Args:
        patient_generator (PatientGenerator): Patient Generator
        samples_per_patient (int | list[int], optional): # samples per patient. Defaults to 1.

    Returns:
        SampleGenerator: Sample generator

    Yields:
        Iterator[SampleGenerator]
    """
    # Target labels and mapping
    tgt_labels = list(set(self.class_map.values()))

    # Convert Icentia labels -> HK labels -> class map labels (-1 indicates not in class map)
    tgt_map = {k: self.class_map.get(v, -1) for (k, v) in HeartRhythmMap.items()}
    num_classes = len(tgt_labels)

    # If samples_per_patient is a list, then it must be the same length as nclasses
    if isinstance(samples_per_patient, Iterable):
        samples_per_tgt = samples_per_patient
    else:
        num_per_tgt = int(max(1, samples_per_patient / num_classes))
        samples_per_tgt = num_per_tgt * [num_classes]

    input_size = int(np.round((self.sampling_rate / self.target_rate) * self.frame_size))

    # Group patient rhythms by type (segment, start, stop, delta)
    for _, segments in patient_generator:
        # This maps segment index to segment key
        seg_map: list[str] = list(segments.keys())

        pt_tgt_seg_map = [[] for _ in tgt_labels]
        for seg_idx, seg_key in enumerate(seg_map):
            # Grab rhythm labels
            rlabels = segments[seg_key]["rlabels"][:]

            # Skip if no rhythm labels
            if not rlabels.shape[0]:
                continue
            rlabels = rlabels[np.where(rlabels[:, 1] != IcentiaRhythm.noise.value)[0]]
            if not rlabels.shape[0]:
                continue

            # Unpack start, end, and label
            xs, xe, xl = rlabels[0::2, 0], rlabels[1::2, 0], rlabels[0::2, 1]

            # Map labels to target labels
            xl = np.vectorize(tgt_map.get, otypes=[int])(xl)

            # Capture segment, start, and end for each target label
            for tgt_idx, tgt_class in enumerate(tgt_labels):
                idxs = np.where((xe - xs >= input_size) & (xl == tgt_class))
                seg_vals = np.vstack((seg_idx * np.ones_like(idxs), xs[idxs], xe[idxs])).T
                pt_tgt_seg_map[tgt_idx] += seg_vals.tolist()
            # END FOR
        # END FOR
        pt_tgt_seg_map = [np.array(b) for b in pt_tgt_seg_map]

        # Grab target segments
        seg_samples: list[tuple[int, int, int, int]] = []
        for tgt_idx, tgt_class in enumerate(tgt_labels):
            tgt_segments = pt_tgt_seg_map[tgt_idx]
            if not tgt_segments.shape[0]:
                continue
            tgt_seg_indices: list[int] = random.choices(
                np.arange(tgt_segments.shape[0]),
                weights=tgt_segments[:, 2] - tgt_segments[:, 1],
                k=samples_per_tgt[tgt_idx],
            )
            for tgt_seg_idx in tgt_seg_indices:
                seg_idx, rhy_start, rhy_end = tgt_segments[tgt_seg_idx]
                frame_start = np.random.randint(rhy_start, rhy_end - input_size + 1)
                frame_end = frame_start + input_size
                seg_samples.append((seg_idx, frame_start, frame_end, tgt_class))
            # END FOR
        # END FOR

        # Shuffle segments
        random.shuffle(seg_samples)

        # Yield selected samples for patient
        for seg_idx, frame_start, frame_end, label in seg_samples:
            x: npt.NDArray = segments[seg_map[seg_idx]]["data"][frame_start:frame_end].astype(np.float32)
            if self.sampling_rate != self.target_rate:
                x = pk.signal.resample_signal(x, self.sampling_rate, self.target_rate, axis=0)
            yield x, label
        # END FOR
    # END FOR

## Check it here: heartkit/datasets/dataset.py