From 1283bf023a9ea0f3dafa7a550b3abd7c68044a31 Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Wed, 27 Mar 2024 11:27:44 +0800 Subject: [PATCH 1/2] update bloom model --- colossalai/shardformer/modeling/bloom.py | 39 +++++++++--------------- colossalai/shardformer/policies/bloom.py | 4 +-- 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index d94c30d29e71..5c4e8d1cb703 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -21,7 +21,7 @@ BloomModel, ) from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig @@ -205,12 +205,13 @@ def bloom_model_forward( alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = self._prepare_attn_mask( + causal_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, past_key_values_length=past_key_values_length, ) - + causal_mask = causal_mask.bool() # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: @@ -226,21 +227,15 @@ def bloom_model_forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( @@ -1000,11 +995,13 @@ def forward( alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - causal_mask = self._prepare_attn_mask( + causal_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, past_key_values_length=past_key_values_length, ) + causal_mask = causal_mask.bool() # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward( @@ -1016,21 +1013,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index eddfafdcbcdc..741dea32539b 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -27,8 +27,8 @@ def __init__(self) -> None: from packaging.version import Version assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The Bloom model should run on a transformers version not greater than 4.33.0." + "4.36.0" + ), "The Bloom model should run on a transformers version not greater than 4.36.0." def config_sanity_check(self): pass From 0562c97027284f417201e43e87913227bfdfe17d Mon Sep 17 00:00:00 2001 From: Wang Binluo <2538539015@qq.com> Date: Wed, 27 Mar 2024 16:38:38 +0800 Subject: [PATCH 2/2] remove the version restriction --- colossalai/shardformer/policies/bloom.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 741dea32539b..4fb03c83051f 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -23,12 +23,6 @@ class BloomPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.36.0" - ), "The Bloom model should run on a transformers version not greater than 4.36.0." def config_sanity_check(self): pass