diff --git a/apex/transformer/microbatches.py b/apex/transformer/microbatches.py index f51cc4f8afcd..69673bc1423c 100644 --- a/apex/transformer/microbatches.py +++ b/apex/transformer/microbatches.py @@ -60,8 +60,7 @@ def build_num_microbatches_calculator( global_batch_size, batch_size_increment, ramup_samples, - ), - flush=True, + ) ) num_microbatches_calculator = RampupBatchsizeNumMicroBatches( start_batch_size, diff --git a/tests/L0/run_transformer/test_microbatches.py b/tests/L0/run_transformer/test_microbatches.py new file mode 100644 index 000000000000..fcbef68298ab --- /dev/null +++ b/tests/L0/run_transformer/test_microbatches.py @@ -0,0 +1,80 @@ +import logging +from typing import List, Optional + +from torch.testing._internal import common_utils + +logging.getLogger("torch").setLevel(logging.WARNING) + +from apex.transformer import parallel_state +from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator, + get_micro_batch_size, + get_num_microbatches, + get_current_global_batch_size, + update_num_microbatches, +) +from apex.transformer.testing.distributed_test_base import DistributedTestBase + +logging.getLogger("apex").setLevel(logging.WARNING) + + +class MicrobatchCalculatorTest(DistributedTestBase): + + GLOBAL_BATCH_SIZE: int = 1024 + MICRO_BATCH_SIZE: int = 1 + + def _test(self, rampup_batch_size: Optional[List[int]]) -> None: + for data_parallel_size in range(1, self.world_size + 1): + + expected_global_batch_size = MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE + expected_micro_batch_size = MicrobatchCalculatorTest.MICRO_BATCH_SIZE + if rampup_batch_size: + expected_global_batch_size = rampup_batch_size[0] + num_consumed_samples = 0 + step_of_global_batch_size = rampup_batch_size[1] + threshold = rampup_batch_size[2] + + if data_parallel_size > 1 and data_parallel_size % 2 != 0: + continue + if self.world_size % data_parallel_size != 0: + continue + with self.subTest(data_parallel_size=data_parallel_size): + parallel_state.initialize_model_parallel( + tensor_model_parallel_size_=self.world_size // data_parallel_size, + pipeline_model_parallel_size_=1, + ) + self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size()) + + _reconfigure_microbatch_calculator( + self.rank, + rampup_batch_size, + MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE, + MicrobatchCalculatorTest.MICRO_BATCH_SIZE, + data_parallel_size, + ) + + self.assertEqual(get_micro_batch_size(), expected_micro_batch_size) + self.assertEqual(get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size) + current_global_batch_size = get_current_global_batch_size() + self.assertEqual(current_global_batch_size, expected_global_batch_size) + + # Make sure `global_batch_size` equals to the final global batch size after + # certain number of updates. + if rampup_batch_size: + update_num_microbatches(current_global_batch_size) + for i in range(100): + current_global_batch_size = get_current_global_batch_size() + update_num_microbatches(current_global_batch_size) + current_global_batch_size = get_current_global_batch_size() + self.assertEqual(get_current_global_batch_size(), MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE) + parallel_state.destroy_model_parallel() + + def test_constant_microbatch_calculator(self): + self._test(rampup_batch_size=None) + + def test_dynamic_microbatch_calculator(self): + self._test(rampup_batch_size=[256, 128, 500]) + + +if __name__ == "__main__": + common_utils.run_tests()