Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Shardformer] Add parallel output for shardformer models(bloom, falcon) #5702

Merged
merged 8 commits into from
May 21, 2024
15 changes: 9 additions & 6 deletions colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def forward(
ignore_index: int,
process_group: ProcessGroup,
vocab_size: int,
dtype=torch.float32,
):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
Expand All @@ -34,7 +35,7 @@ def forward(
Args:
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
[batch_size, seq_len, vocab_size]
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
target (:class:`torch.Tensor`): The labels of the vocabulary, shape is
[batch_size, seq_len]

Returns:
Expand Down Expand Up @@ -86,7 +87,7 @@ def forward(
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)

# calculate the loss
Expand All @@ -97,9 +98,10 @@ def forward(
loss = torch.sum(loss).div_(num_non_zero)

# calculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
exp_logits[target == ignore_index] = 0.0
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
ctx.dtype = dtype

return loss

Expand All @@ -114,11 +116,11 @@ def backward(ctx, grad_output):
partion_vocab_size = grad_logits.shape[-1]
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)

update = 1.0 - mask.view(-1).float()
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update

grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None, None, None
return grad_logits, None, None, None, None, None


def cross_entropy_1d(
Expand All @@ -127,5 +129,6 @@ def cross_entropy_1d(
ignore_index: int = -100,
process_group: ProcessGroup = None,
vocab_size: int = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
100 changes: 95 additions & 5 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
Expand All @@ -27,6 +28,8 @@
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -354,7 +357,7 @@ def bloom_for_causal_lm_forward(
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
lm_logits = self.lm_head(hidden_states).contiguous()

loss = None
if labels is not None:
Expand All @@ -365,10 +368,21 @@ def bloom_for_causal_lm_forward(
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels.view(-1))

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down Expand Up @@ -1065,3 +1079,79 @@ def forward(
)

return forward


def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import BloomForCausalLM

def forward(
self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
past_key_values = None
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

return forward
99 changes: 96 additions & 3 deletions colossalai/shardformer/modeling/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
Expand All @@ -31,6 +32,8 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d


def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
def build_falcon_alibi_tensor(
Expand Down Expand Up @@ -437,14 +440,28 @@ def falcon_for_causal_lm_forward(
loss = None
if labels is not None:
# Shift so that tokens < n predict n
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = shift_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length),
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down Expand Up @@ -747,3 +764,79 @@ def falcon_for_question_answering_forward(
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}


def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import FalconForCausalLM

def forward(
self: FalconForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
past_key_values = None
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
new_vocab_size = shift_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

return forward
2 changes: 2 additions & 0 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def gpt2_lmhead_model_forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss = loss_fct(shift_logits, shift_labels)
Expand Down Expand Up @@ -1294,6 +1295,7 @@ def forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)

if not return_dict:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def llama_for_causal_lm_forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
Expand Down Expand Up @@ -768,6 +769,7 @@ def forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)

if not return_dict:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/modeling/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def mistral_for_causal_lm_forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
Expand Down Expand Up @@ -701,6 +702,7 @@ def forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)

if not return_dict:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def opt_for_causal_lm_forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.decoder.dtype,
)
else:
loss_fct = CrossEntropyLoss()
Expand Down Expand Up @@ -988,6 +989,7 @@ def forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.decoder.dtype,
)

if not return_dict:
Expand Down
Loading
Loading