Skip to content

Commit

Permalink
added GPTNeoXForTokenClassification (huggingface#23002)
Browse files Browse the repository at this point in the history
* initial commit

* added GPTNeoXForTokenClassification

* typo

* doc
fixed extra comma that turned into a tuple

* unifying variable names
fixing forward call

* classifier_dropout is in config

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people authored and novice03 committed Jun 23, 2023
1 parent 51e3556 commit 4b08c9c
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 5 deletions.
7 changes: 6 additions & 1 deletion docs/source/en/model_doc/gpt_neox.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,9 @@ The `generate()` method can be used to generate text using GPT Neo model.
## GPTNeoXForSequenceClassification
[[autodoc]] GPTNeoXForSequenceClassification
- forward
- forward
## GPTNeoXForTokenClassification
[[autodoc]] GPTNeoXForTokenClassification
- forward
2 changes: 1 addition & 1 deletion docs/source/en/tasks/token_classification.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit

<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->

[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MRA](../model_doc/mra), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT NeoX](../model_doc/gpt_neox), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MRA](../model_doc/mra), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)

<!--End of the generated tip-->

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,6 +1698,7 @@
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoXForCausalLM",
"GPTNeoXForSequenceClassification",
"GPTNeoXForTokenClassification",
"GPTNeoXLayer",
"GPTNeoXModel",
"GPTNeoXPreTrainedModel",
Expand Down Expand Up @@ -5245,6 +5246,7 @@
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXLayer,
GPTNeoXModel,
GPTNeoXPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@
("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
("gpt_neox", "GPTNeoXForTokenClassification"),
("ibert", "IBertForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"),
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gpt_neox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoXForCausalLM",
"GPTNeoXForSequenceClassification",
"GPTNeoXForTokenClassification",
"GPTNeoXLayer",
"GPTNeoXModel",
"GPTNeoXPreTrainedModel",
Expand Down Expand Up @@ -64,6 +65,7 @@
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXLayer,
GPTNeoXModel,
GPTNeoXPreTrainedModel,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/gpt_neox/configuration_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class GPTNeoXConfig(PretrainedConfig):
percentage of hidden dimensions to allocate to rotary embeddings
rotary_emb_base (`int`, *optional*, defaults to 10000)
base for computing rotary embeddings frequency
classifier_dropout (`float`, *optional*, defaults to 0.1):
Argument used when doing token classification, used in the model [`GPTNeoXForTokenClassification`].
The dropout ratio for the hidden layer.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
Expand Down Expand Up @@ -95,6 +99,7 @@ def __init__(
hidden_act="gelu",
rotary_pct=0.25,
rotary_emb_base=10000,
classifier_dropout=0.1,
max_position_embeddings=2048,
initializer_range=0.02,
layer_norm_eps=1e-5,
Expand All @@ -115,6 +120,7 @@ def __init__(
self.hidden_act = hidden_act
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.classifier_dropout = classifier_dropout
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
Expand Down
84 changes: 83 additions & 1 deletion src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_gpt_neox import GPTNeoXConfig
Expand Down Expand Up @@ -873,3 +878,80 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class GPTNeoXForTokenClassification(GPTNeoXPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.gpt_neox = GPTNeoXModel(config)
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint="LarsJonasson/pythia-410m-deduped-sft-swedish",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_loss=0.25,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.gpt_neox(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)

loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3308,6 +3308,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class GPTNeoXForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class GPTNeoXLayer(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
24 changes: 22 additions & 2 deletions tests/models/gpt_neox/test_modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
if is_torch_available():
import torch

from transformers import GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXModel
from transformers import (
GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification,
GPTNeoXForTokenClassification,
GPTNeoXModel,
)


class GPTNeoXModelTester:
Expand Down Expand Up @@ -153,6 +158,14 @@ def create_and_check_for_sequence_classification(self, config, input_ids, input_
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_for_token_classification(self, config, input_ids, input_mask, token_labels):
config.num_labels = self.num_labels
model = GPTNeoXForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))

def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
config.is_decoder = True
model = GPTNeoXForCausalLM(config=config)
Expand Down Expand Up @@ -200,13 +213,16 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification) if is_torch_available() else ()
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": GPTNeoXModel,
"text-classification": GPTNeoXForSequenceClassification,
"token-classification": GPTNeoXForTokenClassification,
"text-generation": GPTNeoXForCausalLM,
"zero-shot": GPTNeoXForSequenceClassification,
}
Expand Down Expand Up @@ -253,6 +269,10 @@ def test_model_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)

def test_model_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)

@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass
Expand Down

0 comments on commit 4b08c9c

Please sign in to comment.