diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index d5f00fc9f565..93da71abb4a2 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -16,7 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention +from ..layer import ColoAttention, cross_entropy_1d logger = logging.get_logger(__name__) @@ -270,11 +270,21 @@ def mistral_for_causal_lm_forward( shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + else: + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] @@ -609,3 +619,100 @@ def forward( return attn_output, None, past_key_value return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import MistralForCausalLM + + def forward( + self: MistralForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 81521c30b1a2..5282e2eaac22 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -22,6 +22,8 @@ from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig +from ..layer import cross_entropy_1d + logger = logging.get_logger(__name__) @@ -336,8 +338,22 @@ def opt_for_causal_lm_forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + else: + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -844,3 +860,146 @@ def forward( return outputs return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + def forward( + self: OPTForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 984b71646318..621982f29058 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -18,6 +18,7 @@ from ..modeling.mistral import ( MistralForwards, + get_lm_forward_with_dist_cross_entropy, get_mistral_flash_attention_forward, get_mistral_model_forward_for_flash_attn, ) @@ -275,14 +276,18 @@ def module_policy(self): SubModuleReplacementDescription( suffix="lm_head", target_module=VocabParallelLMHead1D, - kwargs=dict( - gather_output=True, - make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, - ), + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ] ) } + if self.shard_config.parallel_output: + new_item[MistralForCausalLM].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } else: new_item = { MistralForCausalLM: ModulePolicyDescription( diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9619b3d41b8a..524d2b8cd0c3 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -21,6 +21,7 @@ from ..modeling.opt import ( OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, + get_lm_forward_with_dist_cross_entropy, get_opt_decoder_forward_for_flash_attention, get_opt_flash_attention_forward, ) @@ -269,12 +270,18 @@ def module_policy(self): suffix="lm_head", target_module=VocabParallelLMHead1D, kwargs=dict( - gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + gather_output=not self.shard_config.parallel_output, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, ), ), policy=policy, target_key=OPTForCausalLM, ) + if self.shard_config.parallel_output: + method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=OPTForCausalLM + ) else: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(