From 0f3f045ebddc526fac7cd7b231ecfcd8342315bc Mon Sep 17 00:00:00 2001 From: tucan9389 <37643248+tucan9389@users.noreply.github.com> Date: Tue, 7 Dec 2021 01:44:10 +0900 Subject: [PATCH] Add GPTJForQuestionAnswering (#14503) * Add GPTJForQuestionAnswering * Reformat for GPTJForQuestionAnswering * Fix isort error * make style for GPTJForQA * Add _keys_to_ignore_on_load_missing * Change the sequence of qa and classification Co-authored-by: Suraj Patil --- docs/source/model_doc/gptj.rst | 7 ++ src/transformers/__init__.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/gptj/__init__.py | 2 + src/transformers/models/gptj/modeling_gptj.py | 112 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 12 ++ tests/test_modeling_gptj.py | 7 +- 7 files changed, 141 insertions(+), 2 deletions(-) diff --git a/docs/source/model_doc/gptj.rst b/docs/source/model_doc/gptj.rst index d02a6e9996fe2..b5760befa044e 100644 --- a/docs/source/model_doc/gptj.rst +++ b/docs/source/model_doc/gptj.rst @@ -121,6 +121,13 @@ GPTJForSequenceClassification :members: forward +GPTJForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GPTJForQuestionAnswering + :members: forward + + FlaxGPTJModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5a6d497879d5f..a3f03bb64ba7a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -951,6 +951,7 @@ [ "GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST", "GPTJForCausalLM", + "GPTJForQuestionAnswering", "GPTJForSequenceClassification", "GPTJModel", "GPTJPreTrainedModel", @@ -2833,6 +2834,7 @@ from .models.gptj import ( GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST, GPTJForCausalLM, + GPTJForQuestionAnswering, GPTJForSequenceClassification, GPTJModel, GPTJPreTrainedModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e262f70e64f00..97b1093c14e7f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -385,6 +385,7 @@ # Model for Question Answering mapping ("qdqbert", "QDQBertForQuestionAnswering"), ("fnet", "FNetForQuestionAnswering"), + ("gptj", "GPTJForQuestionAnswering"), ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), ("rembert", "RemBertForQuestionAnswering"), ("canine", "CanineForQuestionAnswering"), diff --git a/src/transformers/models/gptj/__init__.py b/src/transformers/models/gptj/__init__.py index ae0ff5c6df83f..29f1a2baa891a 100644 --- a/src/transformers/models/gptj/__init__.py +++ b/src/transformers/models/gptj/__init__.py @@ -28,6 +28,7 @@ _import_structure["modeling_gptj"] = [ "GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST", "GPTJForCausalLM", + "GPTJForQuestionAnswering", "GPTJForSequenceClassification", "GPTJModel", "GPTJPreTrainedModel", @@ -48,6 +49,7 @@ from .modeling_gptj import ( GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST, GPTJForCausalLM, + GPTJForQuestionAnswering, GPTJForSequenceClassification, GPTJModel, GPTJPreTrainedModel, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 60705fbb22cee..a046300636061 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -23,7 +23,12 @@ from ...activations import ACT2FN from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, +) from ...modeling_utils import PreTrainedModel from ...utils import logging from ...utils.model_parallel_utils import assert_device_map, get_device_map @@ -967,3 +972,108 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The GPT-J Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPTJ_START_DOCSTRING, +) +class GPTJForQuestionAnswering(GPTJPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPTJModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b36aa7bc203b7..685b5b9944584 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2494,6 +2494,18 @@ def forward(self, *args, **kwargs): requires_backends(self, ["torch"]) +class GPTJForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GPTJForSequenceClassification: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_modeling_gptj.py b/tests/test_modeling_gptj.py index 9bbc9b681f95e..dd743b80d76a8 100644 --- a/tests/test_modeling_gptj.py +++ b/tests/test_modeling_gptj.py @@ -32,6 +32,7 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST, AutoTokenizer, GPTJForCausalLM, + GPTJForQuestionAnswering, GPTJForSequenceClassification, GPTJModel, ) @@ -356,7 +357,11 @@ def prepare_config_and_inputs_for_common(self): @require_torch class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): - all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else () + all_model_classes = ( + (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification, GPTJForQuestionAnswering) + if is_torch_available() + else () + ) all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else () fx_ready_model_classes = all_model_classes test_pruning = False