From 4b7b83a0860e24eb17b57221a3be09ca83e7393c Mon Sep 17 00:00:00 2001 From: linsj20 Date: Wed, 20 Mar 2024 15:45:25 +0800 Subject: [PATCH] sequence parallel: inside text split (#6) --- .../booster/plugin/hybrid_parallel_plugin.py | 6 +---- colossalai/shardformer/layer/_operation.py | 12 ++++----- colossalai/shardformer/modeling/gpt2.py | 19 +++++--------- colossalai/shardformer/modeling/llama.py | 26 ++++++++----------- colossalai/shardformer/policies/gpt2.py | 4 ++- colossalai/shardformer/policies/llama.py | 2 +- colossalai/shardformer/shard/shard_config.py | 1 + colossalai/zero/low_level/low_level_optim.py | 5 +--- tests/test_shardformer/test_model/_utils.py | 13 +--------- 9 files changed, 31 insertions(+), 57 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index f0029ab5dd6e..5c1c363d2593 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -660,10 +660,6 @@ def __init__( self.dp_pg = dp_process_group self.tp_pg = tp_process_group self.pp_pg = pp_process_group - self.use_all_to_all_sequence_parallel = ( - self.model.shard_config.enable_sequence_parallelism - and self.model.shard_config.sequence_parallelism_mode == "all_to_all" - ) if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__( @@ -684,7 +680,6 @@ def __init__( cpu_offload=cpu_offload, dp_process_group=dp_process_group, forced_dtype=forced_dtype, - enable_sequence_parallel=self.use_all_to_all_sequence_parallel, ) def sync_dp_grads(self): @@ -1098,6 +1093,7 @@ def __init__( enable_sequence_parallelism=enable_sequence_parallelism, sequence_parallelism_mode=sequence_parallelism_mode, enable_sequence_overlap=enable_sequence_overlap, + zero_stage=zero_stage, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7c9e71318cc9..4dfa5a4131b5 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -756,7 +756,7 @@ def backward(ctx, grad_output): grad_output = grad_output * dist.get_world_size(ctx.process_group) elif ctx.grad_scale == "down": grad_output = grad_output / dist.get_world_size(ctx.process_group) - return _gather(grad_output, ctx.dim, ctx.process_group), None, None + return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None class _ReduceForward(torch.autograd.Function): @@ -819,7 +819,7 @@ def backward(ctx, grad_output): grad_output = grad_output * dist.get_world_size(ctx.process_group) elif ctx.grad_scale == "down": grad_output = grad_output / dist.get_world_size(ctx.process_group) - return _split(grad_output, ctx.dim, ctx.process_group), None, None + return _split(grad_output, ctx.dim, ctx.process_group), None, None, None class _AllToAll(torch.autograd.Function): @@ -1020,12 +1020,12 @@ def matmul_gather_forward_reducescatter_backward( ) -def gather_forward_split_backward(input_, dim, process_group): - return _GatherForwardSplitBackward.apply(input_, dim, process_group) +def gather_forward_split_backward(input_, dim, process_group, grad_scale=None): + return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale) -def split_forward_gather_backward(input_, dim, process_group): - return _SplitForwardGatherBackward.apply(input_, dim, process_group) +def split_forward_gather_backward(input_, dim, process_group, grad_scale=None): + return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale) def reduce_forward(input_, process_group): diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 47c8a10a2c0a..5342274d7937 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -853,8 +853,6 @@ def forward( # use variable seq_len to replace input_shape[-1] seq_len = input_shape[-1] - if sp_mode in ["ring", "all_to_all"]: - seq_len *= sp_size if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, seq_len) @@ -866,8 +864,6 @@ def forward( past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) - if sp_mode in ["ring", "all_to_all"]: - past_length *= sp_size if position_ids is None: position_ids = torch.arange(past_length, seq_len + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, seq_len) @@ -876,9 +872,6 @@ def forward( if sp_mode in ["ring", "all_to_all"]: position_ids = torch.chunk(position_ids.clone(), sp_size, dim=1)[dist.get_rank(sp_group)] - if sp_mode in ["ring", "all_to_all"]: - attention_mask = _gather(attention_mask, 1, sp_group) - # GPT2Attention mask. if attention_mask is not None: if batch_size <= 0: @@ -917,12 +910,12 @@ def forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: - if sp_mode in ["ring"]: - input_ids = _gather(input_ids, 1, sp_group) - inputs_embeds = self.wte(input_ids) - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) - else: - inputs_embeds = self.wte(input_ids) + inputs_embeds = self.wte(input_ids) + if sp_mode == "ring": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down') + position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 6ad3f4befd14..d9ec4a3474fd 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -694,7 +694,7 @@ def forward( return forward -def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group): +def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, zero_stage=0): logger = logging.get_logger(__name__) # Copied from transformers.models.bart.modeling_bart._make_causal_mask @@ -804,10 +804,6 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # sp: modify seq_length when using sequence parallel - if sp_mode in ["ring", "all_to_all"]: - seq_length *= sp_size - seq_length_with_past = seq_length past_key_values_length = 0 @@ -827,13 +823,12 @@ def forward( position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: - if sp_mode == "ring": - input_ids = _gather(input_ids, 1, sp_group) - inputs_embeds = self.embed_tokens(input_ids) - input_ids = input_ids.chunk(sp_size, dim=1)[torch.distributed.get_rank(sp_group)] - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) - else: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input_ids) + + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 'down') # TODO use_distributed_mask use_distributed_mask = True if sp_mode in ["ring", "all_to_all"] else False @@ -864,8 +859,6 @@ def forward( attention_mask = _gather(attention_mask, 1, sp_group) hidden_states = inputs_embeds - if sp_mode == "split_gather": - hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: if use_cache: @@ -922,7 +915,10 @@ def custom_forward(*inputs): hidden_states = self.norm(hidden_states) # Todo: Maybe this line can be optimized - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up") + if sp_mode == "ring" or sp_mode == "split_gather" or (sp_mode == "all_to_all" and zero_stage == 0): + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all" and zero_stage in [1, 2]: + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale="up") # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index dcfe5889f36b..d1437430a320 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -109,7 +109,9 @@ def module_policy(self): SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, + kwargs={ + "seq_parallel_mode": sp_mode, + }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 1e6e568bba21..36d4d4f31383 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -121,7 +121,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) self.append_or_create_method_replacement( description={ - "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group), + "forward": get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group, self.shard_config.zero_stage), }, policy=policy, target_key=LlamaModel, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index eef1092280c4..982b6edf7e0e 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -37,6 +37,7 @@ class ShardConfig: enable_jit_fused: bool = False enable_sequence_parallelism: bool = False sequence_parallelism_mode: str = None + zero_stage: int = 0 enable_sequence_overlap: bool = False extra_kwargs: Dict[str, Any] = field(default_factory=dict) # pipeline_parallel_size: int diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 70f366dd716e..bbbaf13b53ef 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -77,10 +77,8 @@ def __init__( forced_dtype: Optional[torch.dtype] = None, moe_extra_dp_process_group: Optional[ProcessGroup] = None, master_weights: bool = True, # master weights - enable_sequence_parallel: bool = False, ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) - self._enable_sequence_parallel = enable_sequence_parallel self._dtype = self.optim.param_groups[0]["params"][0].dtype self._logger = get_dist_logger() @@ -300,8 +298,7 @@ def _run_reduction(self): if self.moe_extra_dp_pg is None: flat_grads = self._bucket_store.get_flatten_grad() - if not self._enable_sequence_parallel: - flat_grads /= self._world_size + flat_grads /= self._world_size else: # record moe and non moe param moe_list = [] diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index a5b47e3e52fb..ff08fed3b90f 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -159,18 +159,7 @@ def _criterion(outputs, inputs): shard_test_data = {} for k, v in data.items(): - if k not in ["input_ids", "attention_mask"]: - shard_test_data[k] = data[k].clone() - else: - # todo: check the correctness of using dim=-1: to be compatible with date_gen_for_double_heads() - shard_test_data[k] = ( - torch.chunk(data[k].clone(), booster.plugin.shard_config.sequence_parallel_size, dim=-1)[ - dist.get_rank(booster.plugin.shard_config.sequence_parallel_process_group) - ] - if booster.plugin.shard_config.enable_sequence_parallelism - and booster.plugin.shard_config.sequence_parallelism_mode in ["ring", "all_to_all"] - else data[k].clone() - ) + shard_test_data[k] = data[k].clone() unshard_test_data = {} for k, v in data.items(): unshard_test_data[k] = data[k].clone()