Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new sampler: weighted sampler #1344

Merged
merged 17 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lhotse/dataset/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .simple import SimpleCutSampler
from .stateless import StatelessSampler
from .utils import find_pessimistic_batches, report_padding_ratio_estimate
from .weighted_simple import WeightedSimpleCutSampler
from .zip import ZipSampler

__all__ = [
Expand All @@ -25,6 +26,7 @@
"DynamicBucketingSampler",
"RoundRobinSampler",
"SimpleCutSampler",
"WeightedSimpleCutSampler",
"StatelessSampler",
"ZipSampler",
"find_pessimistic_batches",
Expand Down
77 changes: 76 additions & 1 deletion lhotse/dataset/sampling/data_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import random
from collections import deque
from typing import Optional
from typing import List, Optional

import numpy as np

from lhotse import CutSet
from lhotse.cut import Cut
Expand Down Expand Up @@ -98,3 +100,76 @@

def __len__(self) -> int:
return len(self._shuffled_items)


class WeightedDataSource(DataSource):
"""
An iterator wrapper over CutSet that helps with the sampling process:
it allows for deterministic re-shuffling of elements and "returning"
sampled elements to be yielded again.

Every cut has a sampling weight. At the beginning of each epoch, we
pre-compute the indexes by sampling from multi-nomial distribution without
replacement. The data source will be exhausted if the number of drawn cuts
exceed num_samples
"""

def __init__(self, items: CutSet, weights: List, num_samples: int):
"""The constructor of the weighted data source

Args:
items (CutSet): The cutset itself
weights (List): A list of values representing the weight of each cut. All values must be positive
num_samples (int): The number of samples to be drawn. Must smaller than the total number of cuts
"""
super().__init__(items=items)
assert len(items) == len(weights), "The length should match"
assert num_samples < len(
weights
), "The number of samples to be drawn should not exceed the dataset size"

# normalize the weight
weights = np.array(weights)
weights = weights / weights.sum()

self.weights = weights
self.num_samples = num_samples
self.sampled_indexes = None

def reset(self) -> None:
"""Reset the iterable state of DataSource."""
self._iter = None
self.sampled_indexes = None
self._reusable.clear()
self._remaining_duration = self._total_duration
self.remaining_cuts = self._total_cuts

def fast_forward(self, steps: int) -> None:
"""Advance the data source by ``steps`` amount of steps."""
assert steps >= 0
iter(self)
for i in range(steps):
next(self.sampled_indexes)

Check warning on line 152 in lhotse/dataset/sampling/data_source.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/data_source.py#L149-L152

Added lines #L149 - L152 were not covered by tests

def __iter__(self) -> "WeightedDataSource":
self.reset()
self._iter = iter(self._shuffled_items)
self.sampled_indexes = np.random.choice(
len(self.weights),
self.num_samples,
p=self.weights,
replace=False,
)
self.sampled_indexes = iter(self.sampled_indexes)
return self

def __next__(self) -> Cut:
if self._reusable:
next_cut = self._reusable.popleft()
else:
next_cut = self._orig_items[next(self.sampled_indexes)]

if not self.is_lazy:
self._remaining_duration -= next_cut.duration
self.remaining_cuts -= 1
return next_cut
147 changes: 147 additions & 0 deletions lhotse/dataset/sampling/weighted_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import warnings
from typing import Any, Dict, List, Optional

from lhotse import CutSet, Seconds
from lhotse.dataset.sampling.base import TimeConstraint
from lhotse.dataset.sampling.data_source import WeightedDataSource
from lhotse.dataset.sampling.simple import SimpleCutSampler


class WeightedSimpleCutSampler(SimpleCutSampler):
"""
Samples cuts from a CutSet, where the sampling prob is given by a list.
To enable global sampling, cuts must be in eager mode.

When performing sampling, it avoids having duplicated cuts in the same batch.
The sampler terminates if the number of sampled cuts reach :attr:`num_samples`

When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified,
the batch size is dynamic.

Example usage:

>>> dataset = K2SpeechRecognitionDataset(cuts)
>>> weights = get_weights(cuts)
>>> sampler = WeightedSimpleCutSampler(cuts, weights, num_samples=100, max_duration=200.0)
>>> loader = DataLoader(dataset, sampler=sampler, batch_size=None)
>>> for epoch in range(start_epoch, n_epochs):
... sampler.set_epoch(epoch)
... train(loader)
"""

def __init__(
self,
cuts: CutSet,
cuts_weight: List,
num_samples: int,
max_duration: Seconds = None,
max_cuts: Optional[int] = None,
shuffle: bool = False,
drop_last: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
):
"""
WeightedSimpleCutSampler's constructor

:param cuts: the ``CutSet`` to sample data from.
:param cuts_weight: the weight of each cut for sampling.
:param num_samples: the number of samples to be drawn.
:param max_duration: The maximum total recording duration from ``cuts``.
:param max_cuts: The maximum number of cuts sampled to form a mini-batch.
By default, this constraint is off.
:param shuffle: When ``True``, the cuts will be shuffled at the start of iteration.
Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.:
`for epoch in range(10): for batch in dataset: ...` as every epoch will see a
different cuts order.
:param drop_last: When ``True``, the last batch is dropped if it's incomplete.
:param world_size: Total number of distributed nodes. We will try to infer it by default.
:param rank: Index of distributed node. We will try to infer it by default.
:param seed: Random seed used to consistently shuffle the dataset across different processes.
"""
super().__init__(
cuts=cuts,
drop_last=drop_last,
shuffle=shuffle,
world_size=world_size,
rank=rank,
max_duration=max_duration,
max_cuts=max_cuts,
seed=seed,
)
assert not cuts.is_lazy, "This sampler does not support lazy mode!"
self.data_source = WeightedDataSource(
cuts, weights=cuts_weight, num_samples=num_samples
)

self.weights = cuts_weight
self.num_samples = num_samples

def state_dict(self) -> Dict[str, Any]:
"""
Return the current state of the sampler in a state_dict.
Together with ``load_state_dict()``, this can be used to restore the
training loop's state to the one stored in the state_dict.
"""
state_dict = super().state_dict()
state_dict.update(

Check warning on line 88 in lhotse/dataset/sampling/weighted_simple.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/weighted_simple.py#L87-L88

Added lines #L87 - L88 were not covered by tests
{
"time_constraint": self.time_constraint.state_dict(),
"weights": self.weights,
"num_samples": self.num_samples,
}
)
return state_dict

Check warning on line 95 in lhotse/dataset/sampling/weighted_simple.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/weighted_simple.py#L95

Added line #L95 was not covered by tests

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Restore the state of the sampler that is described in a state_dict.
This will result in the sampler yielding batches from where the previous training left it off.

.. caution::
The samplers are expected to be initialized with the same CutSets,
but this is not explicitly checked anywhere.

.. caution::
The input ``state_dict`` is being mutated: we remove each consumed key, and expect
it to be empty at the end of loading. If you don't want this behavior, pass a copy
inside of this function (e.g., using ``import deepcopy``).

.. note::
For implementers of sub-classes of CutSampler: the flag ``self._just_restored_state`` has to be
handled in ``__iter__`` to make it avoid resetting the just-restored state (only once).
"""
time_constraint = TimeConstraint(**state_dict.pop("time_constraint"))
if self.time_constraint != time_constraint:
warnings.warn(

Check warning on line 117 in lhotse/dataset/sampling/weighted_simple.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/weighted_simple.py#L115-L117

Added lines #L115 - L117 were not covered by tests
"SimpleCutSampler.load_state_dict(): Inconsistent time_constraint:\n"
f"expected {self.time_constraint}\n"
f"received {time_constraint}\n"
f"We will overwrite the settings with the received state_dict."
)
self.time_constraint = time_constraint

Check warning on line 123 in lhotse/dataset/sampling/weighted_simple.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/weighted_simple.py#L123

Added line #L123 was not covered by tests

super().load_state_dict(state_dict)

Check warning on line 125 in lhotse/dataset/sampling/weighted_simple.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/weighted_simple.py#L125

Added line #L125 was not covered by tests

# Restore the data source's state
self.data_source.fast_forward(self.diagnostics.current_epoch_stats.total_cuts)

Check warning on line 128 in lhotse/dataset/sampling/weighted_simple.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/weighted_simple.py#L128

Added line #L128 was not covered by tests

self.weights = state_dict.pop("weights")
self.num_samples = state_dict.pop("num_samples")

Check warning on line 131 in lhotse/dataset/sampling/weighted_simple.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/weighted_simple.py#L130-L131

Added lines #L130 - L131 were not covered by tests

def __iter__(self) -> "WeightedSimpleCutSampler":
"""
Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested.
"""
# Restored state with load_state_dict()? Skip resetting only this once.
if self._just_restored_state:
return self

Check warning on line 139 in lhotse/dataset/sampling/weighted_simple.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/sampling/weighted_simple.py#L139

Added line #L139 was not covered by tests
# Why reset the current epoch?
# Either we are iterating the epoch for the first time and it's a no-op,
# or we are iterating the same epoch again, in which case setting more steps
# than are actually available per epoch would have broken the checkpoint restoration.
self.diagnostics.reset_current_epoch()
# Reset the state to the beginning of the epoch.
iter(self.data_source)
return self
49 changes: 49 additions & 0 deletions test/dataset/sampling/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BucketingSampler,
CutPairsSampler,
SimpleCutSampler,
WeightedSimpleCutSampler,
ZipSampler,
)
from lhotse.dataset.sampling.base import SamplingDiagnostics, TimeConstraint
Expand Down Expand Up @@ -1024,6 +1025,54 @@ def test_cut_pairs_sampler_lazy_shuffle(sampler_cls):
assert [c.id for c in sampled_src_cuts] != [c.id for c in lazy_cuts]


def test_weighted_sampler_num_samples():
cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)
weight = [random.random() for i in range(100)]
num_samples = 32

sampler = WeightedSimpleCutSampler(
cut_set,
weight,
num_samples=num_samples,
max_duration=10.0,
drop_last=True,
)

sampled_cuts = []
num_cuts = 0
for batch in sampler:
sampled_cuts.extend(batch)
num_cuts += len(batch)

assert num_cuts <= num_samples


def test_weighted_sampler_across_epochs():
cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)
weight = [random.random() for i in range(100)]
num_samples = 32

sampler = WeightedSimpleCutSampler(
cut_set,
weight,
num_samples=num_samples,
max_duration=10.0,
drop_last=True,
)

# 1st epoch
sampler.set_epoch(1)
batch = next(iter(sampler))
cut_ids1 = [c.id for c in batch]

# 2st epoch
sampler.set_epoch(2)
batch = next(iter(sampler))
cut_ids2 = [c.id for c in batch]

assert set(cut_ids1) != set(cut_ids2)


@pytest.mark.parametrize("datasize", [10, 1000, 20000])
@pytest.mark.parametrize("bufsize", [100, 1000, 10000])
def test_streaming_shuffle(datasize, bufsize):
Expand Down
Loading