Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 7, 2024
1 parent d531ba6 commit e104b79
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
18 changes: 10 additions & 8 deletions colossalai/shardformer/modeling/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -193,9 +194,9 @@ def gptj_model_forward(

# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
if 'attention_mask' in attention_mask:
if isinstance(attention_mask['attention_mask'], torch.Tensor):
attention_mask['attention_mask'] = attention_mask['attention_mask'].to(hidden_states.device)
if "attention_mask" in attention_mask:
if isinstance(attention_mask["attention_mask"], torch.Tensor):
attention_mask["attention_mask"] = attention_mask["attention_mask"].to(hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
Expand Down Expand Up @@ -646,9 +647,11 @@ def forward(
# ensure qkv have the same dtype, hidden states' dtype
query = query.to(value.dtype)
key = key.to(value.dtype)
if 'attention_mask' in attention_mask:
if isinstance(attention_mask['attention_mask'], torch.Tensor):
attn_output = ColoAttention.attention(query, key, value, attention_mask['attention_mask'], dropout_p=dropout_p)
if "attention_mask" in attention_mask:
if isinstance(attention_mask["attention_mask"], torch.Tensor):
attn_output = ColoAttention.attention(
query, key, value, attention_mask["attention_mask"], dropout_p=dropout_p
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
Expand Down Expand Up @@ -1024,7 +1027,6 @@ def custom_forward(*inputs):


def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):

def forward(
self: GPTJForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -1106,4 +1108,4 @@ def forward(
attentions=transformer_outputs.attentions,
)

return forward
return forward
2 changes: 1 addition & 1 deletion colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def forward(
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
Expand Down
10 changes: 3 additions & 7 deletions colossalai/shardformer/policies/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from ..modeling.gptj import (
GPTJPipelineForwards,
get_gptj_flash_attention_forward,
get_lm_forward_with_dist_cross_entropy,
gptj_model_forward_for_flash_attention,
get_lm_forward_with_dist_cross_entropy
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

Expand Down Expand Up @@ -275,13 +275,9 @@ def module_policy(self):
)
}
if self.shard_config.parallel_output:
method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
self.append_or_create_method_replacement(
description=method_replacement,
policy=policy,
target_key=GPTJForCausalLM
description=method_replacement, policy=policy, target_key=GPTJForCausalLM
)

else:
Expand Down

0 comments on commit e104b79

Please sign in to comment.