## Downsampling occurs with rhythm_data_generator() function in datasets/icentia11k

You can treat SampleGenerator as the downsampled PatientGenerator

**Workflow of `rhythm_data_generator` Function:**

1. **Initialize Target Labels and Mapping**
   - Extract target labels from class map.
   - Create a mapping from Icentia labels to HeartRhythmMap, then convert them to class map labels (-1 indicates not in class map).

2. **Process Patient Segments**
    - For each patient:
        * Obtain segments and metadata (input_size, sampling_rate, target_rate)

3. **Group Rhythms by Type**
    - Organize rhythm information into pt_tgt_seg_map based on segment type (start index, end index etc.)

4. **Handle Samples Per Patient Input** 
    - Calculate samples_per_tgt list based on input samples_per_patient.

5. **Shuffle Segment Information**
    - Randomly shuffle elements in pt_tgt_seg_map.

6. **Yield Selected Samples for Patient**
    - For every patient:
       * Iterate through segmented information.
       * Downsample per class using samples_per_tgt.
       * Randomly pick the start and the coressponding end index for every sample
       * Yield x (data) and label.

In [None]:
import h5py
import numpy.typing as npt
from typing import Callable, Generator
import numpy as np
from enum import IntEnum
import random

PatientGenerator = Generator[tuple[int, h5py.Group | None], None, None]
SampleGenerator = Generator[tuple[npt.NDArray, npt.NDArray], None, None]
Preprocessor = Callable[[npt.NDArray], npt.NDArray]
from heartkit.defines import (
    HKDemoParams, HeartBeat, HeartRate, HeartRhythm, HeartSegment
)
class IcentiaRhythm(IntEnum):
    """Icentia rhythm labels"""
    noise = 0
    normal = 1
    afib = 2
    aflut = 3
    end = 4

HeartRhythmMap = {
    IcentiaRhythm.noise: HeartRhythm.noise,
    IcentiaRhythm.normal: HeartRhythm.normal,
    IcentiaRhythm.afib: HeartRhythm.afib,
    IcentiaRhythm.aflut: HeartRhythm.aflut,
    IcentiaRhythm.end: HeartRhythm.noise,
}

from collections.abc import Iterable


def rhythm_data_generator(
        self,
        patient_generator: PatientGenerator,
        samples_per_patient: int | list[int] = 1,
    ) -> SampleGenerator:
        """Generate frames w/ rhythm labels (e.g. afib) using patient generator.

        Args:
            patient_generator (PatientGenerator): Patient Generator
            samples_per_patient (int | list[int], optional): # samples per patient. Defaults to 1.

        Returns:
            SampleGenerator: Sample generator

        Yields:
            Iterator[SampleGenerator]
        """
        # Target labels and mapping
        tgt_labels = list(set(self.class_map.values()))

        # Convert Icentia labels -> HK labels -> class map labels (-1 indicates not in class map)
        tgt_map = {k: self.class_map.get(v, -1) for (k, v) in HeartRhythmMap.items()}
        num_classes = len(tgt_labels)

        # If samples_per_patient is a list, then it must be the same length as nclasses
        if isinstance(samples_per_patient, Iterable):
            # now samples_per_tgt will match params.samples_per_patient which is [25, 200] as an example
            samples_per_tgt = samples_per_patient 
        else:
            num_per_tgt = int(max(1, samples_per_patient / num_classes))
            samples_per_tgt = num_per_tgt * [num_classes] # if not a list, every class will share the same sample numbers

        input_size = int(np.round((self.sampling_rate / self.target_rate) * self.frame_size))

        # Group patient rhythms by type (segment, start, stop, delta)
        for _, segments in patient_generator: # this for loop is traversing through all segments under every patient
            seg_map: list[str] = list(segments.keys())

            pt_tgt_seg_map = [[] for _ in tgt_labels]
            for seg_idx, seg_key in enumerate(seg_map): # from s00~s49
                # Grab rhythm labels
                rlabels = segments[seg_key]["rlabels"][:]
                # Skip if no rhythm labels
                if not rlabels.shape[0]:
                    continue
                rlabels = rlabels[np.where(rlabels[:, 1] != IcentiaRhythm.noise.value)[0]]
                # Skip if only noise
                if not rlabels.shape[0]:
                    continue
                # Unpack every start index, end index, and label
                xs, xe, xl = rlabels[0::2, 0], rlabels[1::2, 0], rlabels[0::2, 1]

                # Map labels to target labels
                xl = np.vectorize(tgt_map.get, otypes=[int])(xl)

                for tgt_idx, tgt_class in enumerate(tgt_labels): # this is extracting all the potential samples for every rlabel
                    idxs = np.where((xe - xs >= input_size) & (xl == tgt_class)) # we want to extract all samples of current segment which does contain the current label class
                    seg_vals = np.vstack((seg_idx * np.ones_like(idxs), xs[idxs], xe[idxs])).T # keep track of seg_idx, start, end index
                    pt_tgt_seg_map[tgt_idx] += seg_vals.tolist()
                # END FOR
            # END FOR
            pt_tgt_seg_map = [np.array(b) for b in pt_tgt_seg_map] # pt_tgt_seg_map will be grouped by rlabel, containing all segments information

            seg_samples: list[tuple[int, int, int, int]] = []
            for tgt_idx, tgt_class in enumerate(tgt_labels): # downsamples per class
                tgt_samples = pt_tgt_seg_map[tgt_idx] # segments that contain the label
                if not tgt_samples.shape[0]:
                    continue
                tgt_seg_indices: list[int] = random.choices(
                    np.arange(tgt_samples.shape[0]),
                    weights=tgt_samples[:, 2] - tgt_samples[:, 1], # drawing possibility proportional to the length of every sample
                    k=samples_per_tgt[tgt_idx], # draw 25 times for Normal and 200 times for AFib/AFlut
                ) # a number of indices equal to the number of samples for the current target label.

                for tgt_seg_idx in tgt_seg_indices:
                    seg_idx, rhy_start, rhy_end = tgt_samples[tgt_seg_idx] # it should look something like this: [0     989   27860]
                    frame_start = np.random.randint(rhy_start, rhy_end - input_size + 1) # randomly select a start position for current sample
                    frame_end = frame_start + input_size
                    seg_samples.append((seg_idx, frame_start, frame_end, tgt_class))
                # END FOR
            # END FOR

            # Shuffle segments
            random.shuffle(seg_samples)

            # Yield selected samples for patient
            for seg_idx, frame_start, frame_end, label in seg_samples:
                x: npt.NDArray = segments[seg_map[seg_idx]]["data"][frame_start:frame_end].astype(np.float32)
                if self.sampling_rate != self.target_rate:
                    x = pk.signal.resample_signal(x, self.sampling_rate, self.target_rate, axis=0)
                yield x, label
            # END FOR
        # END FOR

## Below is the key downsampling

In [None]:
def downsamples_per_segment(
    pt_tgt_seg_map: list[list[tuple[int, int, int]]],
    samples_per_tgt: list[int], # samples_per_patient 
    tgt_labels: list[str],
    input_size, int
) -> list[tuple[int, int, int, str]]:
    
    seg_samples: list[tuple[int, int, int,str]] = []
    
    for tgt_idx, tgt_class in enumerate(tgt_labels): # downsamples per class
        tgt_segments = pt_tgt_seg_map[tgt_idx]
        if not tgt_segments.shape[0]:
            continue
        tgt_seg_indices: list[int] = random.choices(
            np.arange(tgt_segments.shape[0]),
            weights=tgt_segments[:, 2] - tgt_segments[:, 1], 
            k=samples_per_tgt[tgt_idx]
        )
        
    for tgt_seg_idx in tgt_seg_indices:
                seg_idx, rhy_start, rhy_end = tgt_segments[tgt_seg_idx] # it should look something like this: [0     989   27860]
                frame_start = np.random.randint(rhy_start, rhy_end - input_size + 1) # randomly select a frame within the segment sample
                frame_end = frame_start + input_size
                seg_samples.append((seg_idx, frame_start, frame_end, tgt_class))
                
    return seg_samples