Skip to content

Commit

Permalink
[shardformer] update bloom model (#5518)
Browse files Browse the repository at this point in the history
* update bloom model

* remove the version restriction
  • Loading branch information
wangbluo committed Apr 18, 2024
1 parent e2ff589 commit 7227e81
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 30 deletions.
39 changes: 15 additions & 24 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
BloomModel,
)
from transformers.utils import logging

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
Expand Down Expand Up @@ -205,12 +205,13 @@ def bloom_model_forward(
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)

# causal_mask is constructed every stage and its input is passed through different stages
causal_mask = self._prepare_attn_mask(
causal_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)

causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config and shard_config.enable_sequence_parallelism:
Expand All @@ -227,21 +228,15 @@ def bloom_model_forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)

return custom_forward

outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
Expand Down Expand Up @@ -1002,11 +997,13 @@ def forward(

alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)

causal_mask = self._prepare_attn_mask(
causal_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)
causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
Expand All @@ -1018,21 +1015,15 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)

return custom_forward

outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
Expand Down
6 changes: 0 additions & 6 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@
class BloomPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
from packaging.version import Version

assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Bloom model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass
Expand Down

0 comments on commit 7227e81

Please sign in to comment.