diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index fe70376e144d..2b2bf89a06e2 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -21,7 +21,7 @@ BloomModel, ) from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig @@ -205,12 +205,13 @@ def bloom_model_forward( alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = self._prepare_attn_mask( + causal_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, past_key_values_length=past_key_values_length, ) - + causal_mask = causal_mask.bool() # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config and shard_config.enable_sequence_parallelism: @@ -227,21 +228,15 @@ def bloom_model_forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( @@ -1002,11 +997,13 @@ def forward( alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - causal_mask = self._prepare_attn_mask( + causal_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, past_key_values_length=past_key_values_length, ) + causal_mask = causal_mask.bool() # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] hidden_states = split_forward_gather_backward( @@ -1018,21 +1015,15 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + outputs = self._gradient_checkpointing_func( + block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], + use_cache, + output_attentions, ) else: outputs = block( diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 953592abc16a..4894bda35bfc 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -24,12 +24,6 @@ class BloomPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The Bloom model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass