-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new sampler: weighted sampler #1344
Add new sampler: weighted sampler #1344
Conversation
…ght; do not allow duplicated sample within the same epoch
…ed cut in the same batch
Thanks @marcoyang1998, I appreciate your work. I think you could achieve a similar outcome by splitting the cutset into subset cutsets for each class, and then using mux to get a cutset to be passed to any of the existing samplers. It would also work with lazy manifests and bucketing. class_cutsets = [cuts_class0, cuts_class1, ...]
class_weights = [w_class0, w_class1, ...]
cuts = CutSet.mux(*class_cutsets, weights=class_weights) |
Hi Piotr, thanks for the |
… weighted_sampler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a valid point, I didn't think of this use-case. In that case LGTM, I just left one comment that may help us reduce code duplication.
from lhotse.dataset.sampling.data_source import WeightedDataSource | ||
|
||
|
||
class WeightedSimpleCutSampler(CutSampler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether this class can inherit SimpleCutSampler
and only override the necessary parts (e.g. using WeightedDataSource
in __init__
); otherwise most of this code looks very similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! I just changed the inheritance, please have a look
drop_last=drop_last, | ||
shuffle=shuffle, | ||
world_size=world_size, | ||
rank=rank, | ||
max_duration=max_duration, | ||
max_cuts=max_cuts, | ||
seed=seed, | ||
) | ||
assert cuts.is_lazy == False, "This sampler does not support lazy mode!" | ||
assert ( | ||
shuffle == False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this assertion? let's just ignore the value of shuffle
in this case.
@@ -181,63 +152,3 @@ def __iter__(self) -> "WeightedSimpleCutSampler": | |||
self.data_source.shuffle(self.seed + self.epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is if self.shuffle
branch needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made the changes I requested myself as I'd like to release a new version and include this. Thanks @marcoyang1998!
Add a weighted sampler, where each cut's sampling probability is proportional to its weight. This is useful for unbalanced dataset, where some classes have very few data. The weight for each cut should be computed by the user and passed to the sampler. It's similar to pytorch's
WeightedRandomSampler
(see here)This sampler only works with eager manifest since we need to perform sampling globally.