diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 389ead8dec..9b8dfae586 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -155,6 +155,15 @@ CLIPVisionModelWithProjection, ) from .models.cohere2 import Cohere2ForCausalLM, Cohere2Model, Cohere2PreTrainedModel +from .models.convbert import ( + ConvBertForMaskedLM, + ConvBertForMultipleChoice, + ConvBertForQuestionAnswering, + ConvBertForSequenceClassification, + ConvBertForTokenClassification, + ConvBertLayer, + ConvBertModel, +) from .models.deberta import ( DebertaForMaskedLM, DebertaForQuestionAnswering, diff --git a/mindone/transformers/generation/utils.py b/mindone/transformers/generation/utils.py index f2019addf9..f53134527d 100644 --- a/mindone/transformers/generation/utils.py +++ b/mindone/transformers/generation/utils.py @@ -1351,7 +1351,7 @@ def compute_transition_scores( ```python >>> from transformers import GPT2Tokenizer - >>> from mindway.transformers import AutoModelForCausalLM + >>> from mindone.transformers import AutoModelForCausalLM >>> import numpy as np >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index 47000df1b1..ddbf9a7ccf 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -31,6 +31,7 @@ camembert, clap, clip, + convbert, dpt, fuyu, gemma, diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index bfd857e585..fd1bed30d5 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -61,6 +61,7 @@ ("helium", "HeliumConfig"), ("hiera", "HieraConfig"), ("camembert", "CamembertConfig"), + ("convbert", "ConvBertConfig"), ("idefics", "IdeficsConfig"), ("idefics2", "Idefics2Config"), ("idefics3", "Idefics3Config"), @@ -184,6 +185,7 @@ ("umt5", "UMT5"), ("wav2vec2", "Wav2Vec2"), ("whisper", "Whisper"), + ("convbert", "ConvBERT"), ("xlm-roberta", "XLM-RoBERTa"), ("xlm-roberta-xl", "XLM-RoBERTa-XL"), ("cohere2", "Cohere2"), diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index 9a57da0357..c1ed353310 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -39,6 +39,7 @@ ("bart", "BartModel"), ("camembert", "CamembertModel"), ("mvp", "MvpModel"), + ("convbert", "ConvBertModel"), ("bit", "BitModel"), ("blip", "BlipModel"), ("blip-2", "Blip2Model"), @@ -143,6 +144,7 @@ ("bert", "BertForMaskedLM"), ("deberta", "DebertaForMaskedLM"), ("deberta-v2", "DebertaV2ForMaskedLM"), + ("convbert", "ConvBertForMaskedLM"), ("gpt2", "GPT2LMHeadModel"), ("led", "LEDForConditionalGeneration"), ("camembert", "CamembertForMaskedLM"), @@ -285,6 +287,7 @@ ("mvp", "MvpForConditionalGeneration"), ("albert", "AlbertForMaskedLM"), ("bart", "BartForConditionalGeneration"), + ("convbert", "ConvBertForMaskedLM"), ("bert", "BertForMaskedLM"), ("roberta", "RobertaForMaskedLM"), ("camembert", "CamembertForMaskedLM"), @@ -371,6 +374,7 @@ ("llama", "LlamaForSequenceClassification"), ("persimmon", "PersimmonForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"), + ("convbert", "ConvBertForSequenceClassification"), ("mt5", "MT5ForSequenceClassification"), ("megatron-bert", "MegatronBertForSequenceClassification"), ("mistral", "MistralForSequenceClassification"), @@ -398,6 +402,7 @@ ("deberta", "DebertaForQuestionAnswering"), ("deberta-v2", "DebertaV2ForQuestionAnswering"), ("led", "LEDForQuestionAnswering"), + ("convbert", "ConvBertForQuestionAnswering"), ("llama", "LlamaForQuestionAnswering"), ("mobilebert", "MobileBertForQuestionAnswering"), ("megatron-bert", "MegatronBertForQuestionAnswering"), @@ -446,6 +451,7 @@ ("qwen2", "Qwen2ForTokenClassification"), ("roberta", "RobertaForTokenClassification"), ("rembert", "RemBertForTokenClassification"), + ("convbert", "ConvBertForTokenClassification"), ("t5", "T5ForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm-roberta", "XLMRobertaForTokenClassification"), @@ -458,6 +464,7 @@ # Model for Multiple Choice mapping ("camembert", "CamembertForMultipleChoice"), ("albert", "AlbertForMultipleChoice"), + ("convbert", "ConvBertForMultipleChoice"), ("bert", "BertForMultipleChoice"), ("roberta", "RobertaForMultipleChoice"), ("deberta-v2", "DebertaV2ForMultipleChoice"), diff --git a/mindone/transformers/models/convbert/__init__.py b/mindone/transformers/models/convbert/__init__.py new file mode 100644 index 0000000000..2c99563208 --- /dev/null +++ b/mindone/transformers/models/convbert/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_convbert import ( + ConvBertForMaskedLM, + ConvBertForMultipleChoice, + ConvBertForQuestionAnswering, + ConvBertForSequenceClassification, + ConvBertForTokenClassification, + ConvBertLayer, + ConvBertModel, +) diff --git a/mindone/transformers/models/convbert/modeling_convbert.py b/mindone/transformers/models/convbert/modeling_convbert.py new file mode 100644 index 0000000000..19fa39f041 --- /dev/null +++ b/mindone/transformers/models/convbert/modeling_convbert.py @@ -0,0 +1,1354 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MindSpore ConvBERT model.""" + +import math +import os +from operator import attrgetter +from typing import Optional, Tuple, Union + +from transformers.models.convbert.configuration_convbert import ConvBertConfig +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + +import mindspore as ms +from mindspore import mint, nn, ops + +from ...activations import ACT2FN, get_activation +from ...mindspore_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_outputs import ( + BaseModelOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel, SequenceSummary + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "YituTech/conv-bert-base" +_CONFIG_FOR_DOC = "ConvBertConfig" + + +def load_tf_weights_in_convbert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a MindSpore model.""" + try: + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in MindSpore, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + tf_data = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + tf_data[name] = array + + param_mapping = { + "embeddings.word_embeddings.weight": "electra/embeddings/word_embeddings", + "embeddings.position_embeddings.weight": "electra/embeddings/position_embeddings", + "embeddings.token_type_embeddings.weight": "electra/embeddings/token_type_embeddings", + "embeddings.LayerNorm.weight": "electra/embeddings/LayerNorm/gamma", + "embeddings.LayerNorm.bias": "electra/embeddings/LayerNorm/beta", + "embeddings_project.weight": "electra/embeddings_project/kernel", + "embeddings_project.bias": "electra/embeddings_project/bias", + } + if config.num_groups > 1: + group_dense_name = "g_dense" + else: + group_dense_name = "dense" + + for j in range(config.num_hidden_layers): + param_mapping[ + f"encoder.layer.{j}.attention.self.query.weight" + ] = f"electra/encoder/layer_{j}/attention/self/query/kernel" + param_mapping[ + f"encoder.layer.{j}.attention.self.query.bias" + ] = f"electra/encoder/layer_{j}/attention/self/query/bias" + param_mapping[ + f"encoder.layer.{j}.attention.self.key.weight" + ] = f"electra/encoder/layer_{j}/attention/self/key/kernel" + param_mapping[ + f"encoder.layer.{j}.attention.self.key.bias" + ] = f"electra/encoder/layer_{j}/attention/self/key/bias" + param_mapping[ + f"encoder.layer.{j}.attention.self.value.weight" + ] = f"electra/encoder/layer_{j}/attention/self/value/kernel" + param_mapping[ + f"encoder.layer.{j}.attention.self.value.bias" + ] = f"electra/encoder/layer_{j}/attention/self/value/bias" + param_mapping[ + f"encoder.layer.{j}.attention.self.key_conv_attn_layer.depthwise.weight" + ] = f"electra/encoder/layer_{j}/attention/self/conv_attn_key/depthwise_kernel" + param_mapping[ + f"encoder.layer.{j}.attention.self.key_conv_attn_layer.pointwise.weight" + ] = f"electra/encoder/layer_{j}/attention/self/conv_attn_key/pointwise_kernel" + param_mapping[ + f"encoder.layer.{j}.attention.self.key_conv_attn_layer.bias" + ] = f"electra/encoder/layer_{j}/attention/self/conv_attn_key/bias" + param_mapping[ + f"encoder.layer.{j}.attention.self.conv_kernel_layer.weight" + ] = f"electra/encoder/layer_{j}/attention/self/conv_attn_kernel/kernel" + param_mapping[ + f"encoder.layer.{j}.attention.self.conv_kernel_layer.bias" + ] = f"electra/encoder/layer_{j}/attention/self/conv_attn_kernel/bias" + param_mapping[ + f"encoder.layer.{j}.attention.self.conv_out_layer.weight" + ] = f"electra/encoder/layer_{j}/attention/self/conv_attn_point/kernel" + param_mapping[ + f"encoder.layer.{j}.attention.self.conv_out_layer.bias" + ] = f"electra/encoder/layer_{j}/attention/self/conv_attn_point/bias" + param_mapping[ + f"encoder.layer.{j}.attention.output.dense.weight" + ] = f"electra/encoder/layer_{j}/attention/output/dense/kernel" + param_mapping[ + f"encoder.layer.{j}.attention.output.LayerNorm.weight" + ] = f"electra/encoder/layer_{j}/attention/output/LayerNorm/gamma" + param_mapping[ + f"encoder.layer.{j}.attention.output.dense.bias" + ] = f"electra/encoder/layer_{j}/attention/output/dense/bias" + param_mapping[ + f"encoder.layer.{j}.attention.output.LayerNorm.bias" + ] = f"electra/encoder/layer_{j}/attention/output/LayerNorm/beta" + param_mapping[ + f"encoder.layer.{j}.intermediate.dense.weight" + ] = f"electra/encoder/layer_{j}/intermediate/{group_dense_name}/kernel" + param_mapping[ + f"encoder.layer.{j}.intermediate.dense.bias" + ] = f"electra/encoder/layer_{j}/intermediate/{group_dense_name}/bias" + param_mapping[ + f"encoder.layer.{j}.output.dense.weight" + ] = f"electra/encoder/layer_{j}/output/{group_dense_name}/kernel" + param_mapping[ + f"encoder.layer.{j}.output.dense.bias" + ] = f"electra/encoder/layer_{j}/output/{group_dense_name}/bias" + param_mapping[ + f"encoder.layer.{j}.output.LayerNorm.weight" + ] = f"electra/encoder/layer_{j}/output/LayerNorm/gamma" + param_mapping[f"encoder.layer.{j}.output.LayerNorm.bias"] = f"electra/encoder/layer_{j}/output/LayerNorm/beta" + + for param in model.named_parameters(): + param_name = param[0] + retriever = attrgetter(param_name) + result = retriever(model) + tf_name = param_mapping[param_name] + value = ms.from_numpy(tf_data[tf_name]) + logger.info(f"TF: {tf_name}, PT: {param_name} ") + if tf_name.endswith("/kernel"): + if not tf_name.endswith("/intermediate/g_dense/kernel"): + if not tf_name.endswith("/output/g_dense/kernel"): + value = value.T + if tf_name.endswith("/depthwise_kernel"): + value = value.permute(1, 2, 0) # 2, 0, 1 + if tf_name.endswith("/pointwise_kernel"): + value = value.permute(2, 1, 0) # 2, 1, 0 + if tf_name.endswith("/conv_attn_key/bias"): + value = value.unsqueeze(-1) + result.data = value + return model + + +class ConvBertEmbeddings(nn.Cell): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = mint.nn.Embedding( + config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = mint.nn.Embedding(config.max_position_embeddings, config.embedding_size) + self.token_type_embeddings = mint.nn.Embedding(config.type_vocab_size, config.embedding_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = mint.nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) + self.dropout = mint.nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_ids = mint.arange(config.max_position_embeddings).expand((1, -1)) + self.token_type_ids = mint.zeros(self.position_ids.shape, dtype=ms.int64) + # self.register_buffer( + # "position_ids", mint.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + # ) + # self.register_buffer( + # "token_type_ids", mint.zeros(self.position_ids.shape, dtype=ms.int64), persistent=False + # ) + + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + token_type_ids: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + if input_ids is not None: + input_shape = input_ids.shape + else: + input_shape = inputs_embeds.shape[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = mint.zeros(input_shape, dtype=ms.int64) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class ConvBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ConvBertConfig + load_tf_weights = load_tf_weights_in_convbert + base_model_prefix = "convbert" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/MindSpore/MindSpore/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class SeparableConv1D(nn.Cell): + """This class implements separable convolution, i.e. a depthwise and a pointwise layer""" + + def __init__(self, config, input_filters, output_filters, kernel_size, **kwargs): + super().__init__() + self.depthwise = nn.Conv1d( + input_filters, + input_filters, + kernel_size=kernel_size, + group=input_filters, + padding=kernel_size // 2, + has_bias=False, + pad_mode="pad", + ) + self.pointwise = nn.Conv1d( + input_filters, output_filters, kernel_size=1, has_bias=False, pad_mode="pad", padding=0 + ) + self.bias = ms.Parameter(mint.zeros((output_filters, 1))) + + self.depthwise.weight.data.normal_(mean=0.0, std=config.initializer_range) + self.pointwise.weight.data.normal_(mean=0.0, std=config.initializer_range) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + x = self.depthwise(hidden_states) + x = self.pointwise(x) + x += self.bias + return x + + +class ConvBertSelfAttention(nn.Cell): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + new_num_attention_heads = config.num_attention_heads // config.head_ratio + if new_num_attention_heads < 1: + self.head_ratio = config.num_attention_heads + self.num_attention_heads = 1 + else: + self.num_attention_heads = new_num_attention_heads + self.head_ratio = config.head_ratio + + self.conv_kernel_size = config.conv_kernel_size + if config.hidden_size % self.num_attention_heads != 0: + raise ValueError("hidden_size should be divisible by num_attention_heads") + + self.attention_head_size = (config.hidden_size // self.num_attention_heads) // 2 + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = mint.nn.Linear(config.hidden_size, self.all_head_size) + self.key = mint.nn.Linear(config.hidden_size, self.all_head_size) + self.value = mint.nn.Linear(config.hidden_size, self.all_head_size) + + self.key_conv_attn_layer = SeparableConv1D( + config, config.hidden_size, self.all_head_size, self.conv_kernel_size + ) + self.conv_kernel_layer = mint.nn.Linear(self.all_head_size, self.num_attention_heads * self.conv_kernel_size) + self.conv_out_layer = mint.nn.Linear(config.hidden_size, self.all_head_size) + + self.unfold = mint.nn.Unfold( + kernel_size=[self.conv_kernel_size, 1], padding=[int((self.conv_kernel_size - 1) / 2), 0] + ) + + self.dropout = mint.nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.shape[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[ms.Tensor, Optional[ms.Tensor]]: + mixed_query_layer = self.query(hidden_states) + batch_size = hidden_states.shape[0] + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + if encoder_hidden_states is not None: + mixed_key_layer = self.key(encoder_hidden_states) + mixed_value_layer = self.value(encoder_hidden_states) + else: + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + mixed_key_conv_attn_layer = self.key_conv_attn_layer(hidden_states.transpose(1, 2)) + mixed_key_conv_attn_layer = mixed_key_conv_attn_layer.transpose(1, 2) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + conv_attn_layer = ops.multiply(mixed_key_conv_attn_layer, mixed_query_layer) + + conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer) + conv_kernel_layer = mint.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1]) + conv_kernel_layer = mint.softmax(conv_kernel_layer, dim=1) + + conv_out_layer = self.conv_out_layer(hidden_states) + conv_out_layer = mint.reshape(conv_out_layer, [batch_size, -1, self.all_head_size]) + conv_out_layer = conv_out_layer.transpose(1, 2).contiguous().unsqueeze(-1) + conv_out_layer = mint.nn.functional.unfold( + conv_out_layer, + kernel_size=[self.conv_kernel_size, 1], + dilation=1, + padding=[(self.conv_kernel_size - 1) // 2, 0], + stride=1, + ) + conv_out_layer = conv_out_layer.transpose(1, 2).reshape( + batch_size, -1, self.all_head_size, self.conv_kernel_size + ) + conv_out_layer = mint.reshape(conv_out_layer, [-1, self.attention_head_size, self.conv_kernel_size]) + conv_out_layer = mint.matmul(conv_out_layer, conv_kernel_layer) + conv_out_layer = mint.reshape(conv_out_layer, [-1, self.all_head_size]) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = mint.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in ConvBertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = mint.nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = mint.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + conv_out = mint.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size]) + context_layer = mint.cat([context_layer, conv_out], 2) + + # conv and context + new_context_layer_shape = context_layer.shape[:-2] + (self.num_attention_heads * self.attention_head_size * 2,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + +class ConvBertSelfOutput(nn.Cell): + def __init__(self, config): + super().__init__() + self.dense = mint.nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = mint.nn.Dropout(config.hidden_dropout_prob) + + def construct(self, hidden_states: ms.Tensor, input_tensor: ms.Tensor) -> ms.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class ConvBertAttention(nn.Cell): + def __init__(self, config): + super().__init__() + self.self = ConvBertSelfAttention(config) + self.output = ConvBertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[ms.Tensor, Optional[ms.Tensor]]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class GroupedLinearLayer(nn.Cell): + def __init__(self, input_size, output_size, num_groups): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.num_groups = num_groups + self.group_in_dim = self.input_size // self.num_groups + self.group_out_dim = self.output_size // self.num_groups + self.weight = ms.Parameter(mint.empty(self.num_groups, self.group_in_dim, self.group_out_dim)) + self.bias = ms.Parameter(mint.empty(output_size)) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + batch_size = list(hidden_states.shape)[0] + x = mint.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]) + x = x.permute(1, 0, 2) + x = mint.matmul(x, self.weight) + x = x.permute(1, 0, 2) + x = mint.reshape(x, [batch_size, -1, self.output_size]) + x = x + self.bias + return x + + +class ConvBertIntermediate(nn.Cell): + def __init__(self, config): + super().__init__() + if config.num_groups == 1: + self.dense = mint.nn.Linear(config.hidden_size, config.intermediate_size) + else: + self.dense = GroupedLinearLayer( + input_size=config.hidden_size, output_size=config.intermediate_size, num_groups=config.num_groups + ) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class ConvBertOutput(nn.Cell): + def __init__(self, config): + super().__init__() + if config.num_groups == 1: + self.dense = mint.nn.Linear(config.intermediate_size, config.hidden_size) + else: + self.dense = GroupedLinearLayer( + input_size=config.intermediate_size, output_size=config.hidden_size, num_groups=config.num_groups + ) + self.LayerNorm = mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = mint.nn.Dropout(config.hidden_dropout_prob) + + def construct(self, hidden_states: ms.Tensor, input_tensor: ms.Tensor) -> ms.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class ConvBertLayer(nn.Cell): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ConvBertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise TypeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = ConvBertAttention(config) + self.intermediate = ConvBertIntermediate(config) + self.output = ConvBertOutput(config) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[ms.Tensor, Optional[ms.Tensor]]: + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise AttributeError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + cross_attention_outputs = self.crossattention( + attention_output, + encoder_attention_mask, + head_mask, + encoder_hidden_states, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class ConvBertEncoder(nn.Cell): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.CellList([ConvBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class ConvBertPredictionHeadTransform(nn.Cell): + def __init__(self, config): + super().__init__() + self.dense = mint.nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +CONVBERT_START_DOCSTRING = r""" + This model is a MindSpore [nn.Cell](https://www.mindspore.cn/docs/en/r2.5.0/api_python/nn/mindspore.nn.Cell.html) sub-class. Use + it as a regular MindSpore Module and refer to the MindSpore documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ConvBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CONVBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`ms.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`ms.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`ms.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`ms.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`ms.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`ms.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ConvBERT Model transformer outputting raw hidden-states without any specific head on top.", + CONVBERT_START_DOCSTRING, +) +class ConvBertModel(ConvBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.embeddings = ConvBertEmbeddings(config) + + if config.embedding_size != config.hidden_size: + self.embeddings_project = mint.nn.Linear(config.embedding_size, config.hidden_size) + + self.encoder = ConvBertEncoder(config) + self.config = config + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + token_type_ids: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.shape[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if attention_mask is None: + attention_mask = mint.ones(input_shape) + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = mint.zeros(input_shape, dtype=ms.int64) + + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + hidden_states = self.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + + if hasattr(self, "embeddings_project"): + hidden_states = self.embeddings_project(hidden_states) + + hidden_states = self.encoder( + hidden_states, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return hidden_states + + +class ConvBertGeneratorPredictions(nn.Cell): + """Prediction module for the generator, made up of two dense layers.""" + + def __init__(self, config): + super().__init__() + + self.activation = get_activation("gelu") + self.LayerNorm = mint.nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) + self.dense = mint.nn.Linear(config.hidden_size, config.embedding_size) + + def construct(self, generator_hidden_states: ms.Tensor) -> ms.Tensor: + hidden_states = self.dense(generator_hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + + return hidden_states + + +@add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING) +class ConvBertForMaskedLM(ConvBertPreTrainedModel): + _tied_weights_keys = ["generator.lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + self.convbert = ConvBertModel(config) + self.generator_predictions = ConvBertGeneratorPredictions(config) + + self.generator_lm_head = mint.nn.Linear(config.embedding_size, config.vocab_size) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.generator_lm_head + + def set_output_embeddings(self, word_embeddings): + self.generator_lm_head = word_embeddings + + @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + token_type_ids: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + generator_hidden_states = self.convbert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict, + ) + generator_sequence_output = generator_hidden_states[0] + + prediction_scores = self.generator_predictions(generator_sequence_output) + prediction_scores = self.generator_lm_head(prediction_scores) + + loss = None + # Masked language modeling softmax layer + if labels is not None: + loss_fct = mint.nn.CrossEntropyLoss() # -100 index = padding token + loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + generator_hidden_states[1:] + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=generator_hidden_states.hidden_states, + attentions=generator_hidden_states.attentions, + ) + + +class ConvBertClassificationHead(nn.Cell): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = mint.nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = mint.nn.Dropout(classifier_dropout) + self.out_proj = mint.nn.Linear(config.hidden_size, config.num_labels) + + self.config = config + + def construct(self, hidden_states: ms.Tensor, **kwargs) -> ms.Tensor: + x = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ + ConvBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + CONVBERT_START_DOCSTRING, +) +class ConvBertForSequenceClassification(ConvBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.convbert = ConvBertModel(config) + self.classifier = ConvBertClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + token_type_ids: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`ms.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.convbert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == ms.int64 or labels.dtype == ms.int32): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = mint.nn.MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = mint.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = mint.nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + ConvBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + CONVBERT_START_DOCSTRING, +) +class ConvBertForMultipleChoice(ConvBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.convbert = ConvBertModel(config) + self.sequence_summary = SequenceSummary(config) + self.classifier = mint.nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + token_type_ids: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MultipleChoiceModelOutput]: + r""" + labels (`ms.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.shape[-1]) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1]) + if inputs_embeds is not None + else None + ) + + outputs = self.convbert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + pooled_output = self.sequence_summary(sequence_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = mint.nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + ConvBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + CONVBERT_START_DOCSTRING, +) +class ConvBertForTokenClassification(ConvBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.convbert = ConvBertModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = mint.nn.Dropout(classifier_dropout) + self.classifier = mint.nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + token_type_ids: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.convbert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = mint.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + ConvBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + CONVBERT_START_DOCSTRING, +) +class ConvBertForQuestionAnswering(ConvBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.convbert = ConvBertModel(config) + self.qa_outputs = mint.nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + token_type_ids: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + head_mask: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + start_positions: Optional[ms.Tensor] = None, + end_positions: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`ms.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`ms.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.convbert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.shape) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.shape) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.shape[1] + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = mint.nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "ConvBertForMaskedLM", + "ConvBertForMultipleChoice", + "ConvBertForQuestionAnswering", + "ConvBertForSequenceClassification", + "ConvBertForTokenClassification", + "ConvBertLayer", + "ConvBertModel", + "ConvBertPreTrainedModel", + "load_tf_weights_in_convbert", +] diff --git a/tests/transformers_tests/models/convbert/__init__.py b/tests/transformers_tests/models/convbert/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/convbert/test_modeling_convbert.py b/tests/transformers_tests/models/convbert/test_modeling_convbert.py new file mode 100644 index 0000000000..8ebcfff7bf --- /dev/null +++ b/tests/transformers_tests/models/convbert/test_modeling_convbert.py @@ -0,0 +1,311 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + +import numpy as np +import pytest +import torch +from transformers import ConvBertConfig + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import ids_numpy + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] + + +class ConvBertModelTester: + def __init__( + self, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=2, + num_choices=4, + scope=None, + ): + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def random_attention_mask(self, shape, rng=None, name=None): + attn_mask = ids_numpy(shape, vocab_size=2, rng=None, name=None) + # make sure that at least one token is attended to for each batch + # we choose the 1st token so this property of `at least one being non-zero` still holds after applying causal mask + attn_mask[:, 0] = 1 + return attn_mask + + def prepare_config_and_inputs(self): + input_ids = ids_numpy([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = self.random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_numpy([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_numpy([self.batch_size], self.type_sequence_label_size) + token_labels = ids_numpy([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_numpy([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return ConvBertConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + +model_tester = ConvBertModelTester() +( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, +) = model_tester.prepare_config_and_inputs() + + +LLAMA_CASES = [ + [ + "ConvBertForMaskedLM", + "transformers.ConvBertForMaskedLM", + "mindone.transformers.ConvBertForMaskedLM", + (config,), + {}, + (input_ids,), + { + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "labels": token_labels, + }, + { + "logits": 0, + }, + ], + [ + "ConvBertForMultipleChoice", + "transformers.ConvBertForMultipleChoice", + "mindone.transformers.ConvBertForMultipleChoice", + (config,), + {}, + (np.repeat(np.expand_dims(input_ids, 1), model_tester.num_choices, 1),), + { + "attention_mask": np.repeat(np.expand_dims(input_mask, 1), model_tester.num_choices, 1), + "token_type_ids": np.repeat(np.expand_dims(token_type_ids, 1), model_tester.num_choices, 1), + "labels": choice_labels, + }, + { + "logits": 0, + }, + ], + [ + "ConvBertForQuestionAnswering", + "transformers.ConvBertForQuestionAnswering", + "mindone.transformers.ConvBertForQuestionAnswering", + (config,), + {}, + (input_ids,), + { + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "start_positions": sequence_labels, + "end_positions": sequence_labels, + }, + { + "start_logits": 0, + "end_logits": 1, + }, + ], + [ + "ConvBertForSequenceClassification", + "transformers.ConvBertForSequenceClassification", + "mindone.transformers.ConvBertForSequenceClassification", + (config,), + {}, + (input_ids,), + { + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "labels": sequence_labels, + }, + { + "logits": 0, # key: torch attribute, value: mindspore idx + }, + ], + [ + "ConvBertForTokenClassification", + "transformers.ConvBertForTokenClassification", + "mindone.transformers.ConvBertForTokenClassification", + (config,), + {}, + (input_ids,), + { + "attention_mask": input_mask, + "token_type_ids": token_type_ids, + "labels": token_labels, + }, + { + "logits": 0, # key: torch attribute, value: mindspore idx + }, + ], + [ + "ConvBertModel", + "transformers.ConvBertModel", + "mindone.transformers.ConvBertModel", + (config,), + {}, + (input_ids,), + {"attention_mask": input_mask, "token_type_ids": token_type_ids}, + { + "last_hidden_state": 0, + }, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in LLAMA_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode, jit_syntax_level=ms.STRICT) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[pt_key] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + )