diff --git a/lazy_dataset/core.py b/lazy_dataset/core.py index ed8a59f..046e46b 100644 --- a/lazy_dataset/core.py +++ b/lazy_dataset/core.py @@ -3351,7 +3351,13 @@ def is_completed(self): def assess(self, example): seq_len = self.len_key(example) - return self.lower_bound <= seq_len <= self.upper_bound + return ( + (self.lower_bound <= seq_len <= self.upper_bound) + and ( + (self.max_total_size is None) + or ((len(self.data) + 1) * max(self.max_len, seq_len) <= self.max_total_size) + ) + ) def _append(self, example): super()._append(example) diff --git a/tests/test_bucket.py b/tests/test_bucket.py index b457f8f..c8d5764 100644 --- a/tests/test_bucket.py +++ b/tests/test_bucket.py @@ -19,3 +19,16 @@ def test_bucket(): assert dynamic_batched_buckets == [ [10, 5], [7, 8], [1, 2], [4, 3], [6, 9], [20], [1] ] + + +def test_max_total_size(): + examples = [6, 7, 9, 5, 6, 3, 7, 4] + examples = {str(j): i for j, i in enumerate(examples)} + ds = lazy_dataset.new(examples) + + dynamic_batched_buckets = list(ds.batch_dynamic_time_series_bucket( + batch_size=3, len_key=lambda x: x, max_padding_rate=0.9, max_total_size=21, + )) + assert dynamic_batched_buckets == [ + [6, 7, 5], [9, 6], [3, 7, 4] + ]