diff --git a/lhotse/dataset/sampling/__init__.py b/lhotse/dataset/sampling/__init__.py index 2644d7fe7..c7c16293b 100644 --- a/lhotse/dataset/sampling/__init__.py +++ b/lhotse/dataset/sampling/__init__.py @@ -12,6 +12,7 @@ from .simple import SimpleCutSampler from .stateless import StatelessSampler from .utils import find_pessimistic_batches, report_padding_ratio_estimate +from .weighted_simple import WeightedSimpleCutSampler from .zip import ZipSampler __all__ = [ @@ -25,6 +26,7 @@ "DynamicBucketingSampler", "RoundRobinSampler", "SimpleCutSampler", + "WeightedSimpleCutSampler", "StatelessSampler", "ZipSampler", "find_pessimistic_batches", diff --git a/lhotse/dataset/sampling/data_source.py b/lhotse/dataset/sampling/data_source.py index 154ae1c1f..c6ba73926 100644 --- a/lhotse/dataset/sampling/data_source.py +++ b/lhotse/dataset/sampling/data_source.py @@ -1,6 +1,8 @@ import random from collections import deque -from typing import Optional +from typing import List, Optional + +import numpy as np from lhotse import CutSet from lhotse.cut import Cut @@ -98,3 +100,76 @@ def __next__(self) -> Cut: def __len__(self) -> int: return len(self._shuffled_items) + + +class WeightedDataSource(DataSource): + """ + An iterator wrapper over CutSet that helps with the sampling process: + it allows for deterministic re-shuffling of elements and "returning" + sampled elements to be yielded again. + + Every cut has a sampling weight. At the beginning of each epoch, we + pre-compute the indexes by sampling from multi-nomial distribution without + replacement. The data source will be exhausted if the number of drawn cuts + exceed num_samples + """ + + def __init__(self, items: CutSet, weights: List, num_samples: int): + """The constructor of the weighted data source + + Args: + items (CutSet): The cutset itself + weights (List): A list of values representing the weight of each cut. All values must be positive + num_samples (int): The number of samples to be drawn. Must smaller than the total number of cuts + """ + super().__init__(items=items) + assert len(items) == len(weights), "The length should match" + assert num_samples < len( + weights + ), "The number of samples to be drawn should not exceed the dataset size" + + # normalize the weight + weights = np.array(weights) + weights = weights / weights.sum() + + self.weights = weights + self.num_samples = num_samples + self.sampled_indexes = None + + def reset(self) -> None: + """Reset the iterable state of DataSource.""" + self._iter = None + self.sampled_indexes = None + self._reusable.clear() + self._remaining_duration = self._total_duration + self.remaining_cuts = self._total_cuts + + def fast_forward(self, steps: int) -> None: + """Advance the data source by ``steps`` amount of steps.""" + assert steps >= 0 + iter(self) + for i in range(steps): + next(self.sampled_indexes) + + def __iter__(self) -> "WeightedDataSource": + self.reset() + self._iter = iter(self._shuffled_items) + self.sampled_indexes = np.random.choice( + len(self.weights), + self.num_samples, + p=self.weights, + replace=False, + ) + self.sampled_indexes = iter(self.sampled_indexes) + return self + + def __next__(self) -> Cut: + if self._reusable: + next_cut = self._reusable.popleft() + else: + next_cut = self._orig_items[next(self.sampled_indexes)] + + if not self.is_lazy: + self._remaining_duration -= next_cut.duration + self.remaining_cuts -= 1 + return next_cut diff --git a/lhotse/dataset/sampling/weighted_simple.py b/lhotse/dataset/sampling/weighted_simple.py new file mode 100644 index 000000000..7c3f76034 --- /dev/null +++ b/lhotse/dataset/sampling/weighted_simple.py @@ -0,0 +1,147 @@ +import warnings +from typing import Any, Dict, List, Optional + +from lhotse import CutSet, Seconds +from lhotse.dataset.sampling.base import TimeConstraint +from lhotse.dataset.sampling.data_source import WeightedDataSource +from lhotse.dataset.sampling.simple import SimpleCutSampler + + +class WeightedSimpleCutSampler(SimpleCutSampler): + """ + Samples cuts from a CutSet, where the sampling prob is given by a list. + To enable global sampling, cuts must be in eager mode. + + When performing sampling, it avoids having duplicated cuts in the same batch. + The sampler terminates if the number of sampled cuts reach :attr:`num_samples` + + When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, + the batch size is dynamic. + + Example usage: + + >>> dataset = K2SpeechRecognitionDataset(cuts) + >>> weights = get_weights(cuts) + >>> sampler = WeightedSimpleCutSampler(cuts, weights, num_samples=100, max_duration=200.0) + >>> loader = DataLoader(dataset, sampler=sampler, batch_size=None) + >>> for epoch in range(start_epoch, n_epochs): + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__( + self, + cuts: CutSet, + cuts_weight: List, + num_samples: int, + max_duration: Seconds = None, + max_cuts: Optional[int] = None, + shuffle: bool = False, + drop_last: bool = False, + world_size: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, + ): + """ + WeightedSimpleCutSampler's constructor + + :param cuts: the ``CutSet`` to sample data from. + :param cuts_weight: the weight of each cut for sampling. + :param num_samples: the number of samples to be drawn. + :param max_duration: The maximum total recording duration from ``cuts``. + :param max_cuts: The maximum number of cuts sampled to form a mini-batch. + By default, this constraint is off. + :param shuffle: When ``True``, the cuts will be shuffled at the start of iteration. + Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: + `for epoch in range(10): for batch in dataset: ...` as every epoch will see a + different cuts order. + :param drop_last: When ``True``, the last batch is dropped if it's incomplete. + :param world_size: Total number of distributed nodes. We will try to infer it by default. + :param rank: Index of distributed node. We will try to infer it by default. + :param seed: Random seed used to consistently shuffle the dataset across different processes. + """ + super().__init__( + cuts=cuts, + drop_last=drop_last, + shuffle=shuffle, + world_size=world_size, + rank=rank, + max_duration=max_duration, + max_cuts=max_cuts, + seed=seed, + ) + assert not cuts.is_lazy, "This sampler does not support lazy mode!" + self.data_source = WeightedDataSource( + cuts, weights=cuts_weight, num_samples=num_samples + ) + + self.weights = cuts_weight + self.num_samples = num_samples + + def state_dict(self) -> Dict[str, Any]: + """ + Return the current state of the sampler in a state_dict. + Together with ``load_state_dict()``, this can be used to restore the + training loop's state to the one stored in the state_dict. + """ + state_dict = super().state_dict() + state_dict.update( + { + "time_constraint": self.time_constraint.state_dict(), + "weights": self.weights, + "num_samples": self.num_samples, + } + ) + return state_dict + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Restore the state of the sampler that is described in a state_dict. + This will result in the sampler yielding batches from where the previous training left it off. + + .. caution:: + The samplers are expected to be initialized with the same CutSets, + but this is not explicitly checked anywhere. + + .. caution:: + The input ``state_dict`` is being mutated: we remove each consumed key, and expect + it to be empty at the end of loading. If you don't want this behavior, pass a copy + inside of this function (e.g., using ``import deepcopy``). + + .. note:: + For implementers of sub-classes of CutSampler: the flag ``self._just_restored_state`` has to be + handled in ``__iter__`` to make it avoid resetting the just-restored state (only once). + """ + time_constraint = TimeConstraint(**state_dict.pop("time_constraint")) + if self.time_constraint != time_constraint: + warnings.warn( + "SimpleCutSampler.load_state_dict(): Inconsistent time_constraint:\n" + f"expected {self.time_constraint}\n" + f"received {time_constraint}\n" + f"We will overwrite the settings with the received state_dict." + ) + self.time_constraint = time_constraint + + super().load_state_dict(state_dict) + + # Restore the data source's state + self.data_source.fast_forward(self.diagnostics.current_epoch_stats.total_cuts) + + self.weights = state_dict.pop("weights") + self.num_samples = state_dict.pop("num_samples") + + def __iter__(self) -> "WeightedSimpleCutSampler": + """ + Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested. + """ + # Restored state with load_state_dict()? Skip resetting only this once. + if self._just_restored_state: + return self + # Why reset the current epoch? + # Either we are iterating the epoch for the first time and it's a no-op, + # or we are iterating the same epoch again, in which case setting more steps + # than are actually available per epoch would have broken the checkpoint restoration. + self.diagnostics.reset_current_epoch() + # Reset the state to the beginning of the epoch. + iter(self.data_source) + return self diff --git a/test/dataset/sampling/test_sampling.py b/test/dataset/sampling/test_sampling.py index 820cd2ad1..74736794e 100644 --- a/test/dataset/sampling/test_sampling.py +++ b/test/dataset/sampling/test_sampling.py @@ -25,6 +25,7 @@ BucketingSampler, CutPairsSampler, SimpleCutSampler, + WeightedSimpleCutSampler, ZipSampler, ) from lhotse.dataset.sampling.base import SamplingDiagnostics, TimeConstraint @@ -1024,6 +1025,54 @@ def test_cut_pairs_sampler_lazy_shuffle(sampler_cls): assert [c.id for c in sampled_src_cuts] != [c.id for c in lazy_cuts] +def test_weighted_sampler_num_samples(): + cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) + weight = [random.random() for i in range(100)] + num_samples = 32 + + sampler = WeightedSimpleCutSampler( + cut_set, + weight, + num_samples=num_samples, + max_duration=10.0, + drop_last=True, + ) + + sampled_cuts = [] + num_cuts = 0 + for batch in sampler: + sampled_cuts.extend(batch) + num_cuts += len(batch) + + assert num_cuts <= num_samples + + +def test_weighted_sampler_across_epochs(): + cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) + weight = [random.random() for i in range(100)] + num_samples = 32 + + sampler = WeightedSimpleCutSampler( + cut_set, + weight, + num_samples=num_samples, + max_duration=10.0, + drop_last=True, + ) + + # 1st epoch + sampler.set_epoch(1) + batch = next(iter(sampler)) + cut_ids1 = [c.id for c in batch] + + # 2st epoch + sampler.set_epoch(2) + batch = next(iter(sampler)) + cut_ids2 = [c.id for c in batch] + + assert set(cut_ids1) != set(cut_ids2) + + @pytest.mark.parametrize("datasize", [10, 1000, 20000]) @pytest.mark.parametrize("bufsize", [100, 1000, 10000]) def test_streaming_shuffle(datasize, bufsize):