diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index d0e267eacd25..095c8c715f84 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -1,6 +1,5 @@ import random from typing import List, Optional, Tuple, Union - import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( @@ -16,7 +15,7 @@ OPTModel, ) from transformers.utils import logging - +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from colossalai.pipeline.stage_manager import PipelineStageManager @@ -25,33 +24,7 @@ class OPTPipelineForwards: This class serves as a micro library for forward function substitution of OPT models under pipeline setting. """ - - @staticmethod - def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - from transformers.models.opt.modeling_opt import _make_causal_mask - - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - _dtype, - device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to( - device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - + @staticmethod def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ @@ -120,6 +93,7 @@ def opt_model_forward( inputs_embeds = decoder.project_in(inputs_embeds) device = input_ids.device if input_ids is not None else inputs_embeds.device _dtype = inputs_embeds.dtype + hidden_states = inputs_embeds else: if hidden_states is None: @@ -133,17 +107,26 @@ def opt_model_forward( # required mask seq length can be calculated via length of past mask_seq_length = past_key_values_length + seq_length # embed positions - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=device) - elif attention_mask.shape[1] != mask_seq_length: - raise ValueError( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{mask_seq_length} (sum of the lengths of current and past inputs)" + if self.decoder._use_flash_attention_2: + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask + ) + else: + # 4d mask is passed through the layers + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length ) - - causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask( - attention_mask, input_shape, _dtype, device, past_key_values_length - ) if stage_manager.is_first_stage(): pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) @@ -202,20 +185,14 @@ def opt_model_forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if decoder.gradient_checkpointing and decoder.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, None, + output_attentions, + use_cache, ) else: layer_outputs = decoder_layer( diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index a542808ba794..c7b9853e5c37 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -24,12 +24,6 @@ class OPTPolicy(Policy): def __init__(self) -> None: super().__init__() - import transformers - from packaging.version import Version - - assert Version(transformers.__version__) <= Version( - "4.33.0" - ), "The OPT model should run on a transformers version not greater than 4.33.0." def config_sanity_check(self): pass