From 9efc79ef24c48882e6a553fe329d5755b6382112 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 30 Apr 2024 08:10:20 +0000 Subject: [PATCH 1/7] add parallel output for mistral model --- colossalai/shardformer/modeling/mistral.py | 119 ++++++++++++++++++++- colossalai/shardformer/policies/mistral.py | 14 ++- 2 files changed, 126 insertions(+), 7 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index d5f00fc9f565..642fa3b40192 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -16,7 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention +from ..layer import ColoAttention, cross_entropy_1d logger = logging.get_logger(__name__) @@ -270,11 +270,22 @@ def mistral_for_causal_lm_forward( shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) + #shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + else: + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] @@ -609,3 +620,105 @@ def forward( return attn_output, None, past_key_value return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import MistralForCausalLM + + def forward( + self: MistralForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 984b71646318..362fd11e5b66 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -20,6 +20,7 @@ MistralForwards, get_mistral_flash_attention_forward, get_mistral_model_forward_for_flash_attn, + get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -275,14 +276,19 @@ def module_policy(self): SubModuleReplacementDescription( suffix="lm_head", target_module=VocabParallelLMHead1D, - kwargs=dict( - gather_output=True, - make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, - ), + kwargs={ + #gather_output=True, + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + }, ) ] ) } + if self.shard_config.parallel_output: + new_item[MistralForCausalLM].method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } else: new_item = { MistralForCausalLM: ModulePolicyDescription( From 26329163291e2f8142494bece2db5548aef469c9 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 1 May 2024 09:23:43 +0000 Subject: [PATCH 2/7] remove useless code --- colossalai/shardformer/modeling/mistral.py | 1 - colossalai/shardformer/policies/mistral.py | 1 - 2 files changed, 2 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 642fa3b40192..796aeca51a57 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -270,7 +270,6 @@ def mistral_for_causal_lm_forward( shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() - #shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 362fd11e5b66..936fd2d249de 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -277,7 +277,6 @@ def module_policy(self): suffix="lm_head", target_module=VocabParallelLMHead1D, kwargs={ - #gather_output=True, "gather_output": not self.shard_config.parallel_output, "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, }, From 88f057ce7c527b20c9edd798f49b3ff92d7dc252 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 May 2024 07:03:46 +0000 Subject: [PATCH 3/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/policies/mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index 936fd2d249de..621982f29058 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -18,9 +18,9 @@ from ..modeling.mistral import ( MistralForwards, + get_lm_forward_with_dist_cross_entropy, get_mistral_flash_attention_forward, get_mistral_model_forward_for_flash_attn, - get_lm_forward_with_dist_cross_entropy, ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription From 108ddfb795c595992265b400fc30eb6c0543e73d Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 3 May 2024 08:58:00 +0000 Subject: [PATCH 4/7] add parallel_output for the opt model --- colossalai/shardformer/modeling/opt.py | 165 ++++++++++++++++++++++++- colossalai/shardformer/policies/opt.py | 13 +- 2 files changed, 174 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 81521c30b1a2..227042480bc5 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -21,7 +21,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig - +from ..layer import cross_entropy_1d logger = logging.get_logger(__name__) @@ -336,8 +336,22 @@ def opt_for_causal_lm_forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if shard_config.enable_tensor_parallelism and shard_config.parallel_output: + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + shift_labels = shift_labels.view(-1) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + else: + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -844,3 +858,148 @@ def forward( return outputs return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + def forward( + self: OPTForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + new_vocab_size = logits.shape[-1] + shift_logits = shift_logits.view(-1, new_vocab_size) + loss = cross_entropy_1d( + shift_logits, + shift_labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=self.lm_head.out_features, + ) + #loss_fct = CrossEntropyLoss() + #loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return forward \ No newline at end of file diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 9619b3d41b8a..bb094d25aa9f 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -23,6 +23,7 @@ get_jit_fused_opt_decoder_layer_forward, get_opt_decoder_forward_for_flash_attention, get_opt_flash_attention_forward, + get_lm_forward_with_dist_cross_entropy ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -269,12 +270,22 @@ def module_policy(self): suffix="lm_head", target_module=VocabParallelLMHead1D, kwargs=dict( - gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + gather_output=not self.shard_config.parallel_output, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by ), ), policy=policy, target_key=OPTForCausalLM, ) + if self.shard_config.parallel_output: + method_replacement = { + "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) + } + self.append_or_create_method_replacement( + description=method_replacement, + policy=policy, + target_key=OPTForCausalLM + ) else: self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( From ca56b93d8352cc493722626b9a44a8ad3d9f2b18 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 May 2024 07:07:07 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/opt.py | 10 ++++++---- colossalai/shardformer/policies/opt.py | 14 +++++--------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 227042480bc5..1cde61914f36 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -21,7 +21,9 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.shard import ShardConfig + from ..layer import cross_entropy_1d + logger = logging.get_logger(__name__) @@ -351,7 +353,7 @@ def opt_for_causal_lm_forward( loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) - + if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -987,8 +989,8 @@ def forward( process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, ) - #loss_fct = CrossEntropyLoss() - #loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + # loss_fct = CrossEntropyLoss() + # loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] @@ -1002,4 +1004,4 @@ def forward( attentions=outputs.attentions, ) - return forward \ No newline at end of file + return forward diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index bb094d25aa9f..524d2b8cd0c3 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -21,9 +21,9 @@ from ..modeling.opt import ( OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, + get_lm_forward_with_dist_cross_entropy, get_opt_decoder_forward_for_flash_attention, get_opt_flash_attention_forward, - get_lm_forward_with_dist_cross_entropy ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -270,21 +270,17 @@ def module_policy(self): suffix="lm_head", target_module=VocabParallelLMHead1D, kwargs=dict( - gather_output=not self.shard_config.parallel_output, - make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by + gather_output=not self.shard_config.parallel_output, + make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by, ), ), policy=policy, target_key=OPTForCausalLM, ) if self.shard_config.parallel_output: - method_replacement = { - "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) - } + method_replacement = {"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} self.append_or_create_method_replacement( - description=method_replacement, - policy=policy, - target_key=OPTForCausalLM + description=method_replacement, policy=policy, target_key=OPTForCausalLM ) else: self.append_or_create_submodule_replacement( From a8408b4d314cc4d13ee3fbbc125c166a38518d78 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 7 May 2024 07:08:00 +0000 Subject: [PATCH 6/7] remove comment code --- colossalai/shardformer/modeling/opt.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 1cde61914f36..5282e2eaac22 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -989,8 +989,6 @@ def forward( process_group=shard_config.tensor_parallel_process_group, vocab_size=self.lm_head.out_features, ) - # loss_fct = CrossEntropyLoss() - # loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] From 4e50cce26bc5d7aa6c14419c2394bcbc9cc863bf Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 7 May 2024 09:17:56 +0000 Subject: [PATCH 7/7] fix the mistral model --- colossalai/shardformer/modeling/mistral.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 796aeca51a57..93da71abb4a2 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -683,12 +683,7 @@ def forward( ) hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) + logits = self.lm_head(hidden_states) logits = logits.float() loss = None