diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 8ae06d9cfeefe..76d64d1e6c390 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -775,9 +775,9 @@ def __iter__(self): def __len__(self): # Will raise an error if the underlying dataset is not sized. if self.drop_last: - return len(self.dataset) // self.num_processes + return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size else: - return math.ceil(len(self.dataset) / self.num_processes) + return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size # In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index 80096742868a5..b6314818f3f77 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -355,6 +355,34 @@ def test_iterable_dataset_shard(self): self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42) self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=3, epoch=42) + def test_iterable_dataset_shard_with_length(self): + sampler_shards = [ + IterableDatasetShard(list(range(100)), batch_size=4, drop_last=True, num_processes=2, process_index=i) + for i in range(2) + ] + + # Build expected shards: each process will have batches of size 4 until there is not enough elements to + # form two full batches (so we stop at 96 = (100 // (4 * 2)) * 4) + expected_shards = [[], []] + current_shard = 0 + for i in range(0, 96, 4): + expected_shards[current_shard].extend(list(range(i, i + 4))) + current_shard = 1 - current_shard + + self.assertListEqual([list(shard) for shard in sampler_shards], expected_shards) + self.assertListEqual([len(shard) for shard in sampler_shards], [len(shard) for shard in expected_shards]) + + sampler_shards = [ + IterableDatasetShard(list(range(100)), batch_size=4, drop_last=False, num_processes=2, process_index=i) + for i in range(2) + ] + # When drop_last=False, we get two last full batches by looping back to the beginning. + expected_shards[0].extend(list(range(96, 100))) + expected_shards[1].extend(list(range(0, 4))) + + self.assertListEqual([list(shard) for shard in sampler_shards], expected_shards) + self.assertListEqual([len(shard) for shard in sampler_shards], [len(shard) for shard in expected_shards]) + def check_shard_sampler(self, dataset, batch_size, drop_last, num_processes=2): shards = [ ShardSampler( diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 2d2a9af25cc86..bd55c43b36d6b 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -281,6 +281,7 @@ def create_reverse_dependency_map(): "test_trainer_distributed.py", "test_trainer_tpu.py", ], + "train_pt_utils.py": "test_trainer_utils.py", "utils/versions.py": "test_versions_utils.py", }