Skip to content

Commit

Permalink
[transformer] add microbatches test (pytorch#1349)
Browse files Browse the repository at this point in the history
* add test

* destroy model parallel was missing
  • Loading branch information
crcrpar committed Apr 7, 2022
1 parent 23cfb57 commit 7d90387
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 2 deletions.
3 changes: 1 addition & 2 deletions apex/transformer/microbatches.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def build_num_microbatches_calculator(
global_batch_size,
batch_size_increment,
ramup_samples,
),
flush=True,
)
)
num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
start_batch_size,
Expand Down
80 changes: 80 additions & 0 deletions tests/L0/run_transformer/test_microbatches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging
from typing import List, Optional

from torch.testing._internal import common_utils

logging.getLogger("torch").setLevel(logging.WARNING)

from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import (
_reconfigure_microbatch_calculator,
get_micro_batch_size,
get_num_microbatches,
get_current_global_batch_size,
update_num_microbatches,
)
from apex.transformer.testing.distributed_test_base import DistributedTestBase

logging.getLogger("apex").setLevel(logging.WARNING)


class MicrobatchCalculatorTest(DistributedTestBase):

GLOBAL_BATCH_SIZE: int = 1024
MICRO_BATCH_SIZE: int = 1

def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
for data_parallel_size in range(1, self.world_size + 1):

expected_global_batch_size = MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE
expected_micro_batch_size = MicrobatchCalculatorTest.MICRO_BATCH_SIZE
if rampup_batch_size:
expected_global_batch_size = rampup_batch_size[0]
num_consumed_samples = 0
step_of_global_batch_size = rampup_batch_size[1]
threshold = rampup_batch_size[2]

if data_parallel_size > 1 and data_parallel_size % 2 != 0:
continue
if self.world_size % data_parallel_size != 0:
continue
with self.subTest(data_parallel_size=data_parallel_size):
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=self.world_size // data_parallel_size,
pipeline_model_parallel_size_=1,
)
self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size())

_reconfigure_microbatch_calculator(
self.rank,
rampup_batch_size,
MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE,
MicrobatchCalculatorTest.MICRO_BATCH_SIZE,
data_parallel_size,
)

self.assertEqual(get_micro_batch_size(), expected_micro_batch_size)
self.assertEqual(get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size)
current_global_batch_size = get_current_global_batch_size()
self.assertEqual(current_global_batch_size, expected_global_batch_size)

# Make sure `global_batch_size` equals to the final global batch size after
# certain number of updates.
if rampup_batch_size:
update_num_microbatches(current_global_batch_size)
for i in range(100):
current_global_batch_size = get_current_global_batch_size()
update_num_microbatches(current_global_batch_size)
current_global_batch_size = get_current_global_batch_size()
self.assertEqual(get_current_global_batch_size(), MicrobatchCalculatorTest.GLOBAL_BATCH_SIZE)
parallel_state.destroy_model_parallel()

def test_constant_microbatch_calculator(self):
self._test(rampup_batch_size=None)

def test_dynamic_microbatch_calculator(self):
self._test(rampup_batch_size=[256, 128, 500])


if __name__ == "__main__":
common_utils.run_tests()

0 comments on commit 7d90387

Please sign in to comment.