Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] update gptj model #5503

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions colossalai/shardformer/modeling/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,9 @@ def gptj_model_forward(
head_mask = self.get_head_mask(head_mask, self.config.n_layer)

# position id to be assigned not just for the first stage for attn input
if position_ids is not None:
position_ids = position_ids.view(-1, seq_length)
else:
if position_ids is None:
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_ids = position_ids.unsqueeze(0)
if stage_manager.is_first_stage():
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
Expand Down Expand Up @@ -172,21 +170,15 @@ def gptj_model_forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)

return custom_forward

outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
Expand Down Expand Up @@ -603,7 +595,9 @@ def forward(
value = torch.cat((past_value, value), dim=1)

if use_cache is True:
present = (key, value)
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
present = (key.to(hidden_states.dtype), value)
else:
present = None

Expand Down
Loading