Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] update opt model #5522

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 26 additions & 49 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random
from typing import List, Optional, Tuple, Union

import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
Expand All @@ -16,7 +15,7 @@
OPTModel,
)
from transformers.utils import logging

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from colossalai.pipeline.stage_manager import PipelineStageManager


Expand All @@ -25,33 +24,7 @@ class OPTPipelineForwards:
This class serves as a micro library for forward function substitution of OPT models
under pipeline setting.
"""

@staticmethod
def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
from transformers.models.opt.modeling_opt import _make_causal_mask

combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
_dtype,
device,
past_key_values_length=past_key_values_length,
)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to(
device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)

return combined_attention_mask


@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expand Down Expand Up @@ -120,6 +93,7 @@ def opt_model_forward(
inputs_embeds = decoder.project_in(inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device
_dtype = inputs_embeds.dtype
hidden_states = inputs_embeds

else:
if hidden_states is None:
Expand All @@ -133,17 +107,26 @@ def opt_model_forward(
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
if self.decoder._use_flash_attention_2:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if attention_mask is None
else attention_mask
)
else:
# 4d mask is passed through the layers
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, hidden_states, past_key_values_length
)

causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(
attention_mask, input_shape, _dtype, device, past_key_values_length
)

if stage_manager.is_first_stage():
pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
Expand Down Expand Up @@ -202,20 +185,14 @@ def opt_model_forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None

if decoder.gradient_checkpointing and decoder.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
Expand Down
6 changes: 0 additions & 6 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@
class OPTPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
from packaging.version import Version

assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The OPT model should run on a transformers version not greater than 4.33.0."

def config_sanity_check(self):
pass
Expand Down
Loading