From 328d73b3c72e76083ce491a9f8ebb84e21b1bad4 Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Mon, 3 Jun 2024 14:39:10 +0000 Subject: [PATCH] Fix bucketing bug --- src/fairseq2/data/data_pipeline.py | 4 +++- tests/unit/data/data_pipeline/test_bucket_by_length.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/data/data_pipeline.py b/src/fairseq2/data/data_pipeline.py index 71667b188..2447ec439 100644 --- a/src/fairseq2/data/data_pipeline.py +++ b/src/fairseq2/data/data_pipeline.py @@ -531,7 +531,7 @@ def create_bucket_sizes( bucket_size = max_num_elements - while seq_len <= max_seq_len: + while seq_len < max_seq_len: if seq_len >= min_seq_len: bucket_sizes.append((bucket_size, seq_len)) @@ -539,6 +539,8 @@ def create_bucket_sizes( seq_len = max_num_elements // bucket_size + bucket_sizes.append((bucket_size, max_seq_len)) + if num_seqs_multiple_of == 1: return bucket_sizes diff --git a/tests/unit/data/data_pipeline/test_bucket_by_length.py b/tests/unit/data/data_pipeline/test_bucket_by_length.py index 2b898e223..ef6069fc0 100644 --- a/tests/unit/data/data_pipeline/test_bucket_by_length.py +++ b/tests/unit/data/data_pipeline/test_bucket_by_length.py @@ -21,3 +21,11 @@ def test_create_bucket_sizes_with_num_seqs_multiple_of() -> None: ) assert bucket_sizes == [(8, 2), (4, 3), (4, 4), (2, 5), (2, 8)] + + +def test_create_bucket_sizes_with_max_seq_len_equals_max_num_elements() -> None: + bucket_sizes = create_bucket_sizes( + max_num_elements=16, max_seq_len=16, min_seq_len=2 + ) + + assert bucket_sizes == [(8, 2), (5, 3), (4, 4), (3, 5), (2, 8), (1, 16)]