Skip to content

Commit

Permalink
Rename ImageGPT (#14526)
Browse files Browse the repository at this point in the history
* Rename

* Add MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
  • Loading branch information
NielsRogge committed Nov 29, 2021
1 parent 4ee0b75 commit 25156eb
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 19 deletions.
4 changes: 2 additions & 2 deletions docs/source/model_doc/imagegpt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ ImageGPTModel
:members: forward


ImageGPTForCausalLM
ImageGPTForCausalImageModeling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.ImageGPTForCausalLM
.. autoclass:: transformers.ImageGPTForCausalImageModeling
:members: forward


Expand Down
6 changes: 4 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@
_import_structure["models.auto"].extend(
[
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
Expand Down Expand Up @@ -977,7 +978,7 @@
_import_structure["models.imagegpt"].extend(
[
"IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ImageGPTForCausalLM",
"ImageGPTForCausalImageModeling",
"ImageGPTForImageClassification",
"ImageGPTModel",
"ImageGPTPreTrainedModel",
Expand Down Expand Up @@ -2521,6 +2522,7 @@
)
from .models.auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
Expand Down Expand Up @@ -2823,7 +2825,7 @@
)
from .models.imagegpt import (
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
ImageGPTForCausalLM,
ImageGPTForCausalImageModeling,
ImageGPTForImageClassification,
ImageGPTModel,
ImageGPTPreTrainedModel,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
if is_torch_available():
_import_structure["modeling_auto"] = [
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
Expand Down Expand Up @@ -137,6 +138,7 @@
if is_torch_available():
from .modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("imagegpt", "ImageGPTForCausalLM"),
("qdqbert", "QDQBertForMaskedLM"),
("fnet", "FNetForMaskedLM"),
("gptj", "GPTJForCausalLM"),
Expand Down Expand Up @@ -199,7 +198,6 @@
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("imagegpt", "ImageGPTForCausalLM"),
("qdqbert", "QDQBertLMHeadModel"),
("trocr", "TrOCRForCausalLM"),
("gptj", "GPTJForCausalLM"),
Expand Down Expand Up @@ -233,6 +231,13 @@
]
)

MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
# Model for Causal Image Modeling mapping
[
("imagegpt", "ImageGPTForCausalImageModeling"),
]
)

MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image Classification mapping
Expand Down Expand Up @@ -524,6 +529,9 @@
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/imagegpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
if is_torch_available():
_import_structure["modeling_imagegpt"] = [
"IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ImageGPTForCausalLM",
"ImageGPTForCausalImageModeling",
"ImageGPTForImageClassification",
"ImageGPTModel",
"ImageGPTPreTrainedModel",
Expand All @@ -48,7 +48,7 @@
if is_torch_available():
from .modeling_imagegpt import (
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
ImageGPTForCausalLM,
ImageGPTForCausalImageModeling,
ImageGPTForImageClassification,
ImageGPTModel,
ImageGPTPreTrainedModel,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/imagegpt/modeling_imagegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def custom_forward(*inputs):
""",
IMAGEGPT_START_DOCSTRING,
)
class ImageGPTForCausalLM(ImageGPTPreTrainedModel):
class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]

def __init__(self, config):
Expand Down Expand Up @@ -958,13 +958,13 @@ def forward(
Examples::
>>> from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalLM
>>> from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalImageModeling
>>> import torch
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> feature_extractor = ImageGPTFeatureExtractor.from_pretrained('openai/imagegpt-small')
>>> model = ImageGPTForCausalLM.from_pretrained('openai/imagegpt-small')
>>> model = ImageGPTForCausalImageModeling.from_pretrained('openai/imagegpt-small')
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> model.to(device)
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def load_tf_weights_in_albert(*args, **kwargs):
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = None


MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = None


MODEL_FOR_CAUSAL_LM_MAPPING = None


Expand Down Expand Up @@ -2661,7 +2664,7 @@ def forward(self, *args, **kwargs):
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST = None


class ImageGPTForCausalLM:
class ImageGPTForCausalImageModeling:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

Expand Down
2 changes: 2 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

from transformers import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
Expand Down Expand Up @@ -150,6 +151,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
elif model_class in [
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
*get_values(MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING),
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
]:
Expand Down
14 changes: 7 additions & 7 deletions tests/test_modeling_imagegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from transformers import (
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST,
ImageGPTForCausalLM,
ImageGPTForCausalImageModeling,
ImageGPTForImageClassification,
ImageGPTModel,
)
Expand Down Expand Up @@ -207,14 +207,14 @@ def create_and_check_imagegpt_model(self, config, pixel_values, input_mask, head
self.parent.assertEqual(len(result.past_key_values), config.n_layer)

def create_and_check_lm_head_model(self, config, pixel_values, input_mask, head_mask, token_type_ids, *args):
model = ImageGPTForCausalLM(config)
model = ImageGPTForCausalImageModeling(config)
model.to(torch_device)
model.eval()

labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1)
result = model(pixel_values, token_type_ids=token_type_ids, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
# ImageGPTForCausalLM doens't have tied input- and output embeddings
# ImageGPTForCausalImageModeling doens't have tied input- and output embeddings
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size - 1))

def create_and_check_imagegpt_for_image_classification(
Expand Down Expand Up @@ -255,9 +255,9 @@ def prepare_config_and_inputs_for_common(self):
class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):

all_model_classes = (
(ImageGPTForCausalLM, ImageGPTForImageClassification, ImageGPTModel) if is_torch_available() else ()
(ImageGPTForCausalImageModeling, ImageGPTForImageClassification, ImageGPTModel) if is_torch_available() else ()
)
all_generative_model_classes = (ImageGPTForCausalLM,) if is_torch_available() else ()
all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else ()
test_missing_keys = False
input_name = "pixel_values"

Expand All @@ -273,7 +273,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

return inputs_dict

# we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausalLM doesn't have tied input- and output embeddings
# we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausalImageModeling doesn't have tied input- and output embeddings
def _check_scores(self, batch_size, scores, length, config):
expected_shape = (batch_size, config.vocab_size - 1)
self.assertIsInstance(scores, tuple)
Expand Down Expand Up @@ -519,7 +519,7 @@ def default_feature_extractor(self):

@slow
def test_inference_causal_lm_head(self):
model = ImageGPTForCausalLM.from_pretrained("openai/imagegpt-small").to(torch_device)
model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-small").to(torch_device)

feature_extractor = self.default_feature_extractor
image = prepare_img()
Expand Down

0 comments on commit 25156eb

Please sign in to comment.