Skip to content

Commit

Permalink
Add TFVisionEncoderDecoderModel (#14148)
Browse files Browse the repository at this point in the history
* Start the work on TFVisionEncoderDecoderModel

* Expose TFVisionEncoderDecoderModel

* fix import

* Add modeling_tf_vision_encoder_decoder to _ignore_modules in get_model_modules()

* reorder

* Apply the fix for checkpoint loading as in #14016

* remove attention_mask + fix VISION_DUMMY_INPUTS

* A minimal change to make TF generate() work for vision models as encoder in encoder-decoder setting

* fix wrong condition: shape_list(input_ids) == 2

* add tests

* use personal TFViTModel checkpoint (for now)

* Add equivalence tests + projection layer

* style

* make sure projection layer can run

* Add examples

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Clean comments (need to work on TODOs for PyTorch models)

* Remove TF -> PT in check_pt_tf_equivalence for TFVisionEncoderDecoderModel

* fixes

* Revert changes in PT code.

* Update tests/test_modeling_tf_vision_encoder_decoder.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add test_inference_coco_en for TF test

* fix quality

* fix name

* build doc

* add main_input_name

* Fix ckpt name in test

* fix diff between master and this PR

* fix doc

* fix style and quality

* fix missing doc

* fix labels handling

* Delete auto.rst

* Add the changes done in #14016

* fix prefix

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make style

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
4 people authored Jan 10, 2022
1 parent 37bc0b4 commit b67fd79
Show file tree
Hide file tree
Showing 14 changed files with 1,654 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ Flax), PyTorch, and/or TensorFlow.
| TrOCR | | | | | |
| UniSpeech | | | | | |
| UniSpeechSat | | | | | |
| Vision Encoder decoder | | | | | |
| Vision Encoder decoder | | | | | |
| VisionTextDualEncoder | | | | | |
| VisualBert | | | | | |
| ViT | | | | | |
Expand Down
4 changes: 4 additions & 0 deletions docs/source/model_doc/auto.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its

[[autodoc]] TFAutoModelForQuestionAnswering

## TFAutoModelForVision2Seq

[[autodoc]] TFAutoModelForVision2Seq

## FlaxAutoModel

[[autodoc]] FlaxAutoModel
Expand Down
6 changes: 6 additions & 0 deletions docs/source/model_doc/vision-encoder-decoder.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ An example of how to use a [`VisionEncoderDecoderModel`] for inference can be se
- forward
- from_encoder_decoder_pretrained

## TFVisionEncoderDecoderModel

[[autodoc]] TFVisionEncoderDecoderModel
- call
- from_encoder_decoder_pretrained

## FlaxVisionEncoderDecoderModel

[[autodoc]] FlaxVisionEncoderDecoderModel
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,7 @@
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel",
Expand All @@ -1500,6 +1501,7 @@
"TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelWithLMHead",
]
)
Expand Down Expand Up @@ -1838,6 +1840,7 @@
"TFTransfoXLPreTrainedModel",
]
)
_import_structure["models.vision_encoder_decoder"].extend(["TFVisionEncoderDecoderModel"])
_import_structure["models.vit"].extend(
[
"TFViTForImageClassification",
Expand Down Expand Up @@ -3354,6 +3357,7 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
Expand All @@ -3367,6 +3371,7 @@
TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelWithLMHead,
)
from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
Expand Down Expand Up @@ -3636,6 +3641,7 @@
TFTransfoXLModel,
TFTransfoXLPreTrainedModel,
)
from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
33 changes: 23 additions & 10 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -628,14 +629,18 @@ def generate(
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"

# This block corresponds to the following line in `generation_utils`:
# "input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))"
# with the following differences:
# 1. In PT, `generate()`'s `model_kwargs` can accept `encoder_outputs`, but not the case in TF.
# 2. There is no shape checking in PT.
# In both PT/TF, if `input_ids` is `None`, we try to create it as it is for a text model.
if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = tf.fill((batch_size, 1), bos_token_id)
else:
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."

# not allow to duplicate outputs when greedy decoding
if do_sample is False:
Expand Down Expand Up @@ -691,21 +696,29 @@ def generate(
# get encoder and store encoder outputs
encoder = self.get_encoder()

encoder_outputs = encoder(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict_in_generate,
)
encoder_kwargs = {
"attention_mask": attention_mask,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict_in_generate,
}

# vision models don't use `attention_mask`.
signature = dict(inspect.signature(encoder.call).parameters)
if "attention_mask" not in signature:
encoder_kwargs.pop("attention_mask")

encoder_outputs = encoder(input_ids, **encoder_kwargs)
if return_dict_in_generate:
if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states

# The condition `len(shape_list(input_ids)) == 2` is to make this block treats only text inputs.
# (vision inputs might occur when the model is an encoder-decoder model)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
if len(shape_list(input_ids)) == 2 and (num_return_sequences > 1 or num_beams > 1):
input_ids_len = shape_list(input_ids)[-1]
input_ids = tf.broadcast_to(
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel",
Expand All @@ -100,6 +101,7 @@
"TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelWithLMHead",
]

Expand Down Expand Up @@ -197,6 +199,7 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
Expand All @@ -210,6 +213,7 @@
TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelWithLMHead,
)

Expand Down
15 changes: 14 additions & 1 deletion src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@
]
)

TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
]
)

TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
Expand All @@ -182,7 +188,6 @@
]
)


TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
Expand Down Expand Up @@ -327,6 +332,7 @@
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
Expand Down Expand Up @@ -387,6 +393,13 @@ class TFAutoModelForImageClassification(_BaseAutoModelClass):
AutoModelForImageClassification = auto_class_update(TFAutoModelForImageClassification, head_doc="image classification")


class TFAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING


TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")


class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class TFEncoderDecoderModel(TFPreTrainedModel):
r"""
[`TFEncoderDecoder`] is a generic model class that will be instantiated as a transformer architecture with one of
the base model classes of the library as encoder and another one as decoder when created with the
:meth*~transformers.TFAutoModel.from_pretrained* class method for the encoder and
:meth*~transformers.TFAutoModelForCausalLM.from_pretrained* class method for the decoder.
[`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
of the base model classes of the library as encoder and another one as decoder when created with the
[`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class
method for the decoder.
"""
config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder"
Expand Down Expand Up @@ -233,13 +233,6 @@ def dummy_inputs(self):
# Add `decoder_input_ids` because `self.decoder` requires it.
input_ids = tf.constant(DUMMY_INPUTS)
dummy = {"input_ids": input_ids, "decoder_input_ids": input_ids}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = input_ids.shape
shape = (batch_size, seq_len) + (self.config.hidden_size,)
h = tf.random.uniform(shape=shape)
dummy["encoder_hidden_states"] = h

return dummy

def get_encoder(self):
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/vision_encoder_decoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_flax_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available


_import_structure = {
Expand All @@ -28,6 +28,9 @@
if is_torch_available():
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]

if is_tf_available():
_import_structure["modeling_tf_vision_encoder_decoder"] = ["TFVisionEncoderDecoderModel"]

if is_flax_available():
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]

Expand All @@ -37,6 +40,9 @@
if is_torch_available():
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel

if is_tf_available():
from .modeling_tf_vision_encoder_decoder import TFVisionEncoderDecoderModel

if is_flax_available():
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel

Expand Down
Loading

0 comments on commit b67fd79

Please sign in to comment.