diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index a1e85e5b90f6..85ac9eb48598 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -333,12 +333,23 @@ def backward(self, loss, retain_graph=False): self.zero_grad() def backward_by_grad(self, tensor, grad): - # in lower stage which grad is transfered by higher stage - # we need to pass the optim state down. + assert not(self._partition_grads and not self.require_grad_sync), \ + "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) torch.autograd.backward(tensor, grad) + if not self.require_grad_sync: + return + self._reduce_grad(self._partition_grads) + + # clear reduced grads + if self._overlap_communication: + torch.cuda.synchronize() + + self.zero_grad() + def zero_grad(self, set_to_none=True): """ Set parameter gradients to zero. If set_to_none = True, gradient diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 61881a1f90e7..0855e2248710 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -163,6 +163,15 @@ def run_bert_test(test_config): 'enable_all_optimization': False, 'use_lazy_init': False, 'precision': 'fp32', + }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, 'initial_scale': 1, }, ]) diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index f7ab94bc9aae..c9ee690c86dc 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -165,6 +165,16 @@ def run_bloom_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_bloom_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index c5a3e68f7b55..05ca05dea4d6 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -165,6 +165,16 @@ def run_chatglm_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_chatglm_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 44914721c40e..563084ed0f09 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -183,6 +183,16 @@ def run_gpt2_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) @clear_cache_before_run() def run_gpt2_3d_test(test_config): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index c9d5d3d08305..a60150e3cd72 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -185,6 +185,16 @@ def run_llama_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_llama_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 8c0432b37425..25b1eefc6016 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -174,6 +174,16 @@ def run_opt_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_opt_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 29367031e820..768cae0a6734 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -170,6 +170,16 @@ def run_t5_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp16', + 'zero_stage': 1, + 'initial_scale': 1, + }, ]) def run_t5_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 2980c6eeafba..15db63bfd9da 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -176,6 +176,15 @@ def run_vit_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, ]) def run_vit_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index a55753018300..d0c04c98f80a 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check weights if test_config['precision'] == 'fp32': - atol, rtol = 5e-4, 5e-4 + atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): @@ -195,6 +195,15 @@ def run_whisper_test(test_config): 'precision': 'fp32', 'initial_scale': 1, }, + { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': False, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, + }, ]) def run_whisper_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')