Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VisionTextDualEncoder #13511

Merged
merged 51 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b6c40ca
init vision_text_dual_encoder
patil-suraj Sep 10, 2021
b08ea2f
fix merge
patil-suraj Nov 16, 2021
0949605
remove extra heads
patil-suraj Nov 16, 2021
15f7ca1
fix tests
patil-suraj Nov 16, 2021
af649c2
remove VISION_TEXT_DUAL_ENCODER_PRETRAINED_CONFIG_ARCHIVE_MAP
patil-suraj Nov 16, 2021
4a65a44
remove archive map
patil-suraj Nov 16, 2021
01a05c1
fix imports
patil-suraj Nov 16, 2021
dafe6e4
fix more imports
patil-suraj Nov 16, 2021
a8bf9a2
fix init
patil-suraj Nov 16, 2021
7ae6b92
delete tokenizers
patil-suraj Nov 16, 2021
26716a8
fix imports
patil-suraj Nov 16, 2021
5a75ca3
clean
patil-suraj Nov 16, 2021
fd06544
support clip's vision model
patil-suraj Nov 16, 2021
271863b
handle None config
patil-suraj Nov 16, 2021
f9d217e
begin tests
patil-suraj Nov 16, 2021
1731a80
more test and few fixes
patil-suraj Nov 18, 2021
e150b5d
warn about newly init weights
patil-suraj Nov 18, 2021
bbb4d47
more tests
patil-suraj Nov 18, 2021
9297052
add loss to model
patil-suraj Nov 18, 2021
e90f5e7
remove extra classes from doc
patil-suraj Nov 18, 2021
d5d4de2
add processor
patil-suraj Nov 18, 2021
fb7d63d
doc and small fixes
patil-suraj Nov 18, 2021
8b95eca
add start docstr
patil-suraj Nov 18, 2021
4fc6be0
update flax model
patil-suraj Nov 18, 2021
3179dbc
flax tests
patil-suraj Nov 18, 2021
7184ad0
more flax tests
patil-suraj Nov 18, 2021
e19b9cf
doc
patil-suraj Nov 18, 2021
38ada68
quality
patil-suraj Nov 18, 2021
9af7274
doc and quality
patil-suraj Nov 18, 2021
be1fb39
fix doc
patil-suraj Nov 18, 2021
4ad23ac
doc
patil-suraj Nov 18, 2021
45f15b7
remove comments
patil-suraj Nov 18, 2021
636c8bf
update warning
patil-suraj Nov 18, 2021
1f7a875
quality
patil-suraj Nov 19, 2021
9a6234b
fix docs
patil-suraj Nov 19, 2021
7c72bb6
Apply suggestions from code review
patil-suraj Nov 19, 2021
38736d1
replace asserts, fix imports
patil-suraj Nov 19, 2021
a14042f
update imports
patil-suraj Nov 19, 2021
58f9514
fix import
patil-suraj Nov 19, 2021
c33911d
address some review comments
patil-suraj Nov 19, 2021
f7dddaa
fix check
patil-suraj Nov 19, 2021
6a3e943
reduce tolerance
patil-suraj Nov 19, 2021
9dba37d
fix test
patil-suraj Nov 19, 2021
8e00565
add flax integration test
patil-suraj Nov 19, 2021
09eca4e
Apply suggestions from code review
patil-suraj Nov 30, 2021
2a6e420
address Sylvain's comments
patil-suraj Nov 30, 2021
ac84a85
fix style
patil-suraj Nov 30, 2021
0ee6603
add pt_flax_equivalence test in PT tests
patil-suraj Nov 30, 2021
7a8eb2f
add pt integration test
patil-suraj Nov 30, 2021
89743d4
update test
patil-suraj Nov 30, 2021
e6dd8ec
use pre-trained checkpoint in examples
patil-suraj Nov 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Vision Encoder decoder | ❌ | ❌ | ✅ | ❌ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also update the README no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe @sgugger was against this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we removed generic classes like the Vision Encoder Decoder, especially since they did not have a research article coming with it. This one has an article as an example, so you can add it if you really want it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is no official research implementation and pre-trained checkpoint I would also prefer to not add it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me

+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
Expand Down Expand Up @@ -686,6 +688,7 @@ Flax), PyTorch, and/or TensorFlow.
model_doc/unispeech
model_doc/unispeech_sat
model_doc/visionencoderdecoder
model_doc/vision_text_dual_encoder
model_doc/vit
model_doc/visual_bert
model_doc/wav2vec2
Expand Down
56 changes: 56 additions & 0 deletions docs/source/model_doc/vision_text_dual_encoder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
..
Copyright 2021 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

VisionTextDualEncoder
-----------------------------------------------------------------------------------------------------------------------

Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The :class:`~transformers.VisionTextDualEncoderModel` can be used to initialize a vision-text dual encoder model with
any pretrained vision autoencoding model as the vision encoder (*e.g.* :doc:`ViT <vit>`, :doc:`BEiT <beit>`, :doc:`DeiT
<deit>`) and any pretrained text autoencoding model as the text encoder (*e.g.* :doc:`RoBERTa <roberta>`, :doc:`BERT
<bert>`). Two projection layers are added on top of both the vision and text encoder to project the output embeddings
to a shared latent space. The projection layers are randomly initialized so the model should be fine-tuned on a
downstream task. This model can be used to align the vision-text embeddings using CLIP like contrastive image-text
training and then can be used for zero-shot vision tasks such image-classification or retrieval.

In `LiT: Zero-Shot Transfer with Locked-image Text Tuning <https://arxiv.org/abs/2111.07991>`__ it is shown how
leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment on
new zero-shot vision tasks such as image classification or retrieval.

VisionTextDualEncoderConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.VisionTextDualEncoderConfig
:members:


VisionTextDualEncoderProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.VisionTextDualEncoderProcessor
:members:


VisionTextDualEncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.VisionTextDualEncoderModel
:members: forward


FlaxVisionTextDualEncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxVisionTextDualEncoderModel
:members: __call__
7 changes: 7 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@
"UniSpeechSatConfig",
],
"models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
"models.vision_text_dual_encoder": ["VisionTextDualEncoderConfig", "VisionTextDualEncoderProcessor"],
"models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
"models.wav2vec2": [
Expand Down Expand Up @@ -1307,6 +1308,7 @@
]
)
_import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"])
_import_structure["models.vision_text_dual_encoder"].extend(["VisionTextDualEncoderModel"])
_import_structure["models.visual_bert"].extend(
[
"VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1896,6 +1898,7 @@
)

# Flax models structure

_import_structure["models.bart"].extend(
[
"FlaxBartForConditionalGeneration",
Expand Down Expand Up @@ -2013,6 +2016,7 @@
)
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
_import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
Expand Down Expand Up @@ -2253,6 +2257,7 @@
from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig
from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig
from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
from .models.vision_text_dual_encoder import VisionTextDualEncoderConfig, VisionTextDualEncoderProcessor
from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from .models.wav2vec2 import (
Expand Down Expand Up @@ -3096,6 +3101,7 @@
UniSpeechSatPreTrainedModel,
)
from .models.vision_encoder_decoder import VisionEncoderDecoderModel
from .models.vision_text_dual_encoder import VisionTextDualEncoderModel
from .models.visual_bert import (
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
VisualBertForMultipleChoice,
Expand Down Expand Up @@ -3676,6 +3682,7 @@
)
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
from .models.wav2vec2 import (
FlaxWav2Vec2ForCTC,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
unispeech,
unispeech_sat,
vision_encoder_decoder,
vision_text_dual_encoder,
visual_bert,
vit,
wav2vec2,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
("trocr", "TrOCRConfig"),
("fnet", "FNetConfig"),
("segformer", "SegformerConfig"),
("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
("gptj", "GPTJConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("beit", "BeitConfig"),
Expand Down Expand Up @@ -192,6 +193,7 @@
("trocr", "TrOCR"),
("fnet", "FNet"),
("segformer", "SegFormer"),
("vision-text-dual-encoder", "VisionTextDualEncoder"),
("gptj", "GPT-J"),
("beit", "BEiT"),
("rembert", "RemBERT"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
("qdqbert", "QDQBertModel"),
("fnet", "FNetModel"),
("segformer", "SegformerModel"),
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
("gptj", "GPTJModel"),
("layoutlmv2", "LayoutLMv2Model"),
("beit", "BeitModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
[
# Base model mapping
("pegasus", "FlaxPegasusModel"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("distilbert", "FlaxDistilBertModel"),
("albert", "FlaxAlbertModel"),
("roberta", "FlaxRobertaModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
("speech_to_text_2", "Speech2Text2Processor"),
("trocr", "TrOCRProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
]
)

Expand Down
52 changes: 52 additions & 0 deletions src/transformers/models/vision_text_dual_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

# rely on isort to merge the imports
from ...file_utils import _LazyModule, is_flax_available, is_torch_available


_import_structure = {
"configuration_vision_text_dual_encoder": ["VisionTextDualEncoderConfig"],
"processing_vision_text_dual_encoder": ["VisionTextDualEncoderProcessor"],
patil-suraj marked this conversation as resolved.
Show resolved Hide resolved
}


if is_torch_available():
_import_structure["modeling_vision_text_dual_encoder"] = ["VisionTextDualEncoderModel"]


if is_flax_available():
_import_structure["modeling_flax_vision_text_dual_encoder"] = ["FlaxVisionTextDualEncoderModel"]


if TYPE_CHECKING:
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
from .processing_visiotn_text_dual_encoder import VisionTextDualEncoderProcessor

if is_torch_available():
from .modeling_vision_text_dual_encoder import VisionTextDualEncoderModel

if is_flax_available():
from .modeling_vision_text_dual_encoder import FlaxVisionTextDualEncoderModel


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# coding=utf-8
# Copyright The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" VisionTextDualEncoder model configuration """

import copy

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..clip.configuration_clip import CLIPVisionConfig


logger = logging.get_logger(__name__)


class VisionTextDualEncoderConfig(PretrainedConfig):
r"""
:class:`~transformers.VisionTextDualEncoderConfig` is the configuration class to store the configuration of a
:class:`~transformers.VisionTextDualEncoderModel`. It is used to instantiate
:class:`~transformers.VisionTextDualEncoderModel` model according to the specified arguments, defining the text
model and vision model configs.

Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.

Args:
text_config_dict (:obj:`dict`):
Dictionary of configuration options that defines text model config.
vision_config_dict (:obj:`dict`):
Dictionary of configuration options that defines vison model config.
projection_dim (:obj:`int`, `optional`, defaults to 512):
Dimentionality of text and vision projection layers.
logit_scale_init_value (:obj:`float`, `optional`, defaults to 2.6592):
The inital value of the `logit_scale` paramter. Default is used as per the original CLIP implementation.
kwargs (`optional`):
Dictionary of keyword arguments.

Examples::

>>> from transformers import ViTConfig, BertConfig, VisionTextDualEncoderConfig, VisionTextDualEncoderModel

>>> # Initializing a BERT and ViT configuration
>>> config_vision = ViTConfig()
>>> config_text = BertConfig()

>>> config = VisionTextDualEncoderConfig.from_vision_text_configs(config_vision, config_text, projection_dim=512)

>>> # Initializing a BERT and ViT model
>>> model = VisionTextDualEncoderModel(config=config)

>>> # Accessing the model configuration
>>> config_vision = model.config.vision_config
>>> config_text = model.config.text_config

>>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model')

>>> # loading model and config from pretrained folder
>>> vision_text_config = VisionTextDualEncoderConfig.from_pretrained('vit-bert')
>>> model = VisionTextDualEncoderModel.from_pretrained('vit-bert', config=vision_text_config)
"""

model_type = "vision-text-dual-encoder"
is_composition = True

def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs):
super().__init__(**kwargs)

if "vision_config" not in kwargs:
raise ValueError("`vision_config` can not be `None`.")

if "text_config" not in kwargs:
raise ValueError("`text_config` can not be `None`.")

vision_config = kwargs.pop("vision_config")
text_config = kwargs.pop("text_config")

vision_model_type = vision_config.pop("model_type")
text_model_type = text_config.pop("model_type")

if vision_model_type == "clip":
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
elif vision_model_type == "clip_vision_model":
self.vision_config = CLIPVisionConfig(**vision_config)
else:
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)

self.text_config = AutoConfig.for_model(text_model_type, **text_config)

self.projection_dim = projection_dim
self.logit_scale_init_value = logit_scale_init_value

@classmethod
def from_vision_text_configs(cls, vision_config: PretrainedConfig, text_config: PretrainedConfig, **kwargs):
r"""
Instantiate a :class:`VisionTextDualEncoderConfig` (or a derived class) from text model configuration and
vision model configuration.

Returns:
:class:`VisionTextDualEncoderConfig`: An instance of a configuration object
"""

return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs)

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default
:meth:`~transformers.PretrainedConfig.to_dict`.

Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
Loading