Skip to content

Commit

Permalink
Add GPTJForQuestionAnswering (#14503)
Browse files Browse the repository at this point in the history
* Add GPTJForQuestionAnswering

* Reformat for GPTJForQuestionAnswering

* Fix isort error

* make style for GPTJForQA

* Add _keys_to_ignore_on_load_missing

* Change the sequence of qa and classification

Co-authored-by: Suraj Patil <surajp815@gmail.com>
  • Loading branch information
tucan9389 and patil-suraj committed Dec 6, 2021
1 parent 1ccc033 commit 0f3f045
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 2 deletions.
7 changes: 7 additions & 0 deletions docs/source/model_doc/gptj.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ GPTJForSequenceClassification
:members: forward


GPTJForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.GPTJForQuestionAnswering
:members: forward


FlaxGPTJModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@
[
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
"GPTJForQuestionAnswering",
"GPTJForSequenceClassification",
"GPTJModel",
"GPTJPreTrainedModel",
Expand Down Expand Up @@ -2833,6 +2834,7 @@
from .models.gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
GPTJForQuestionAnswering,
GPTJForSequenceClassification,
GPTJModel,
GPTJPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@
# Model for Question Answering mapping
("qdqbert", "QDQBertForQuestionAnswering"),
("fnet", "FNetForQuestionAnswering"),
("gptj", "GPTJForQuestionAnswering"),
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
("rembert", "RemBertForQuestionAnswering"),
("canine", "CanineForQuestionAnswering"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gptj/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_import_structure["modeling_gptj"] = [
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
"GPTJForQuestionAnswering",
"GPTJForSequenceClassification",
"GPTJModel",
"GPTJPreTrainedModel",
Expand All @@ -48,6 +49,7 @@
from .modeling_gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
GPTJForQuestionAnswering,
GPTJForSequenceClassification,
GPTJModel,
GPTJPreTrainedModel,
Expand Down
112 changes: 111 additions & 1 deletion src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@

from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ...utils.model_parallel_utils import assert_device_map, get_device_map
Expand Down Expand Up @@ -967,3 +972,108 @@ def forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


@add_start_docstrings(
"""
The GPT-J Model transformer with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
GPTJ_START_DOCSTRING,
)
class GPTJForQuestionAnswering(GPTJPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = GPTJModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

# Model parallel
self.model_parallel = False
self.device_map = None

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]

logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()

total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)

loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2

if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output

return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
12 changes: 12 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,18 @@ def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])


class GPTJForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])


class GPTJForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
Expand Down
7 changes: 6 additions & 1 deletion tests/test_modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
AutoTokenizer,
GPTJForCausalLM,
GPTJForQuestionAnswering,
GPTJForSequenceClassification,
GPTJModel,
)
Expand Down Expand Up @@ -356,7 +357,11 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):

all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else ()
all_model_classes = (
(GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification, GPTJForQuestionAnswering)
if is_torch_available()
else ()
)
all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
test_pruning = False
Expand Down

0 comments on commit 0f3f045

Please sign in to comment.