Skip to content

Commit

Permalink
[shardformer] support pp+tp+zero1 tests (hpcaitech#4531)
Browse files Browse the repository at this point in the history
* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1
  • Loading branch information
flybird11111 committed Sep 3, 2023
1 parent 853ef03 commit 2a293d7
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 3 deletions.
15 changes: 13 additions & 2 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,23 @@ def backward(self, loss, retain_graph=False):
self.zero_grad()

def backward_by_grad(self, tensor, grad):
# in lower stage which grad is transfered by higher stage
# we need to pass the optim state down.
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"

if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)

if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)

# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()

self.zero_grad()

def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ def run_bert_test(test_config):
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def run_bloom_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_bloom_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def run_chatglm_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_chatglm_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ def run_gpt2_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
@clear_cache_before_run()
def run_gpt2_3d_test(test_config):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ def run_llama_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_llama_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ def run_opt_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_opt_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def run_t5_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_t5_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ def run_vit_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_vit_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
Expand Down
11 changes: 10 additions & 1 deletion tests/test_shardformer/test_model/test_shard_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,

# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 5e-4, 5e-4
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
Expand Down Expand Up @@ -195,6 +195,15 @@ def run_whisper_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_whisper_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
Expand Down

0 comments on commit 2a293d7

Please sign in to comment.