-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VisionTextDualEncoder #13511
Merged
patil-suraj
merged 51 commits into
huggingface:master
from
patil-suraj:vision-text-clip
Nov 30, 2021
+2,643
−0
Merged
VisionTextDualEncoder #13511
Changes from all commits
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
b6c40ca
init vision_text_dual_encoder
patil-suraj b08ea2f
fix merge
patil-suraj 0949605
remove extra heads
patil-suraj 15f7ca1
fix tests
patil-suraj af649c2
remove VISION_TEXT_DUAL_ENCODER_PRETRAINED_CONFIG_ARCHIVE_MAP
patil-suraj 4a65a44
remove archive map
patil-suraj 01a05c1
fix imports
patil-suraj dafe6e4
fix more imports
patil-suraj a8bf9a2
fix init
patil-suraj 7ae6b92
delete tokenizers
patil-suraj 26716a8
fix imports
patil-suraj 5a75ca3
clean
patil-suraj fd06544
support clip's vision model
patil-suraj 271863b
handle None config
patil-suraj f9d217e
begin tests
patil-suraj 1731a80
more test and few fixes
patil-suraj e150b5d
warn about newly init weights
patil-suraj bbb4d47
more tests
patil-suraj 9297052
add loss to model
patil-suraj e90f5e7
remove extra classes from doc
patil-suraj d5d4de2
add processor
patil-suraj fb7d63d
doc and small fixes
patil-suraj 8b95eca
add start docstr
patil-suraj 4fc6be0
update flax model
patil-suraj 3179dbc
flax tests
patil-suraj 7184ad0
more flax tests
patil-suraj e19b9cf
doc
patil-suraj 38ada68
quality
patil-suraj 9af7274
doc and quality
patil-suraj be1fb39
fix doc
patil-suraj 4ad23ac
doc
patil-suraj 45f15b7
remove comments
patil-suraj 636c8bf
update warning
patil-suraj 1f7a875
quality
patil-suraj 9a6234b
fix docs
patil-suraj 7c72bb6
Apply suggestions from code review
patil-suraj 38736d1
replace asserts, fix imports
patil-suraj a14042f
update imports
patil-suraj 58f9514
fix import
patil-suraj c33911d
address some review comments
patil-suraj f7dddaa
fix check
patil-suraj 6a3e943
reduce tolerance
patil-suraj 9dba37d
fix test
patil-suraj 8e00565
add flax integration test
patil-suraj 09eca4e
Apply suggestions from code review
patil-suraj 2a6e420
address Sylvain's comments
patil-suraj ac84a85
fix style
patil-suraj 0ee6603
add pt_flax_equivalence test in PT tests
patil-suraj 7a8eb2f
add pt integration test
patil-suraj 89743d4
update test
patil-suraj e6dd8ec
use pre-trained checkpoint in examples
patil-suraj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
.. | ||
Copyright 2021 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
VisionTextDualEncoder | ||
----------------------------------------------------------------------------------------------------------------------- | ||
|
||
Overview | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
The :class:`~transformers.VisionTextDualEncoderModel` can be used to initialize a vision-text dual encoder model with | ||
any pretrained vision autoencoding model as the vision encoder (*e.g.* :doc:`ViT <vit>`, :doc:`BEiT <beit>`, :doc:`DeiT | ||
<deit>`) and any pretrained text autoencoding model as the text encoder (*e.g.* :doc:`RoBERTa <roberta>`, :doc:`BERT | ||
<bert>`). Two projection layers are added on top of both the vision and text encoder to project the output embeddings | ||
to a shared latent space. The projection layers are randomly initialized so the model should be fine-tuned on a | ||
downstream task. This model can be used to align the vision-text embeddings using CLIP like contrastive image-text | ||
training and then can be used for zero-shot vision tasks such image-classification or retrieval. | ||
|
||
In `LiT: Zero-Shot Transfer with Locked-image Text Tuning <https://arxiv.org/abs/2111.07991>`__ it is shown how | ||
leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment on | ||
new zero-shot vision tasks such as image classification or retrieval. | ||
|
||
VisionTextDualEncoderConfig | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.VisionTextDualEncoderConfig | ||
:members: | ||
|
||
|
||
VisionTextDualEncoderProcessor | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.VisionTextDualEncoderProcessor | ||
:members: | ||
|
||
|
||
VisionTextDualEncoderModel | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.VisionTextDualEncoderModel | ||
:members: forward | ||
|
||
|
||
FlaxVisionTextDualEncoderModel | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.FlaxVisionTextDualEncoderModel | ||
:members: __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
52 changes: 52 additions & 0 deletions
52
src/transformers/models/vision_text_dual_encoder/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# flake8: noqa | ||
# There's no way to ignore "F401 '...' imported but unused" warnings in this | ||
# module, but to preserve other warnings. So, don't check this module at all. | ||
|
||
# Copyright 2021 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
# rely on isort to merge the imports | ||
from ...file_utils import _LazyModule, is_flax_available, is_torch_available | ||
|
||
|
||
_import_structure = { | ||
"configuration_vision_text_dual_encoder": ["VisionTextDualEncoderConfig"], | ||
"processing_vision_text_dual_encoder": ["VisionTextDualEncoderProcessor"], | ||
patil-suraj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
|
||
if is_torch_available(): | ||
_import_structure["modeling_vision_text_dual_encoder"] = ["VisionTextDualEncoderModel"] | ||
|
||
|
||
if is_flax_available(): | ||
_import_structure["modeling_flax_vision_text_dual_encoder"] = ["FlaxVisionTextDualEncoderModel"] | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig | ||
from .processing_visiotn_text_dual_encoder import VisionTextDualEncoderProcessor | ||
|
||
if is_torch_available(): | ||
from .modeling_vision_text_dual_encoder import VisionTextDualEncoderModel | ||
|
||
if is_flax_available(): | ||
from .modeling_vision_text_dual_encoder import FlaxVisionTextDualEncoderModel | ||
|
||
|
||
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) |
129 changes: 129 additions & 0 deletions
129
src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# coding=utf-8 | ||
# Copyright The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" VisionTextDualEncoder model configuration """ | ||
|
||
import copy | ||
|
||
from ...configuration_utils import PretrainedConfig | ||
from ...utils import logging | ||
from ..auto.configuration_auto import AutoConfig | ||
from ..clip.configuration_clip import CLIPVisionConfig | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class VisionTextDualEncoderConfig(PretrainedConfig): | ||
r""" | ||
:class:`~transformers.VisionTextDualEncoderConfig` is the configuration class to store the configuration of a | ||
:class:`~transformers.VisionTextDualEncoderModel`. It is used to instantiate | ||
:class:`~transformers.VisionTextDualEncoderModel` model according to the specified arguments, defining the text | ||
model and vision model configs. | ||
|
||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model | ||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. | ||
|
||
Args: | ||
text_config_dict (:obj:`dict`): | ||
Dictionary of configuration options that defines text model config. | ||
vision_config_dict (:obj:`dict`): | ||
Dictionary of configuration options that defines vison model config. | ||
projection_dim (:obj:`int`, `optional`, defaults to 512): | ||
Dimentionality of text and vision projection layers. | ||
logit_scale_init_value (:obj:`float`, `optional`, defaults to 2.6592): | ||
The inital value of the `logit_scale` paramter. Default is used as per the original CLIP implementation. | ||
kwargs (`optional`): | ||
Dictionary of keyword arguments. | ||
|
||
Examples:: | ||
|
||
>>> from transformers import ViTConfig, BertConfig, VisionTextDualEncoderConfig, VisionTextDualEncoderModel | ||
|
||
>>> # Initializing a BERT and ViT configuration | ||
>>> config_vision = ViTConfig() | ||
>>> config_text = BertConfig() | ||
|
||
>>> config = VisionTextDualEncoderConfig.from_vision_text_configs(config_vision, config_text, projection_dim=512) | ||
|
||
>>> # Initializing a BERT and ViT model | ||
>>> model = VisionTextDualEncoderModel(config=config) | ||
|
||
>>> # Accessing the model configuration | ||
>>> config_vision = model.config.vision_config | ||
>>> config_text = model.config.text_config | ||
|
||
>>> # Saving the model, including its configuration | ||
>>> model.save_pretrained('my-model') | ||
|
||
>>> # loading model and config from pretrained folder | ||
>>> vision_text_config = VisionTextDualEncoderConfig.from_pretrained('vit-bert') | ||
>>> model = VisionTextDualEncoderModel.from_pretrained('vit-bert', config=vision_text_config) | ||
""" | ||
|
||
model_type = "vision-text-dual-encoder" | ||
is_composition = True | ||
|
||
def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
if "vision_config" not in kwargs: | ||
raise ValueError("`vision_config` can not be `None`.") | ||
|
||
if "text_config" not in kwargs: | ||
raise ValueError("`text_config` can not be `None`.") | ||
|
||
vision_config = kwargs.pop("vision_config") | ||
text_config = kwargs.pop("text_config") | ||
|
||
vision_model_type = vision_config.pop("model_type") | ||
text_model_type = text_config.pop("model_type") | ||
|
||
if vision_model_type == "clip": | ||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config | ||
elif vision_model_type == "clip_vision_model": | ||
self.vision_config = CLIPVisionConfig(**vision_config) | ||
else: | ||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config) | ||
|
||
self.text_config = AutoConfig.for_model(text_model_type, **text_config) | ||
|
||
self.projection_dim = projection_dim | ||
self.logit_scale_init_value = logit_scale_init_value | ||
|
||
@classmethod | ||
def from_vision_text_configs(cls, vision_config: PretrainedConfig, text_config: PretrainedConfig, **kwargs): | ||
r""" | ||
Instantiate a :class:`VisionTextDualEncoderConfig` (or a derived class) from text model configuration and | ||
vision model configuration. | ||
|
||
Returns: | ||
:class:`VisionTextDualEncoderConfig`: An instance of a configuration object | ||
""" | ||
|
||
return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs) | ||
|
||
def to_dict(self): | ||
""" | ||
Serializes this instance to a Python dictionary. Override the default | ||
:meth:`~transformers.PretrainedConfig.to_dict`. | ||
|
||
Returns: | ||
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, | ||
""" | ||
output = copy.deepcopy(self.__dict__) | ||
output["vision_config"] = self.vision_config.to_dict() | ||
output["text_config"] = self.text_config.to_dict() | ||
output["model_type"] = self.__class__.model_type | ||
return output |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also update the README no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe @sgugger was against this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we removed generic classes like the Vision Encoder Decoder, especially since they did not have a research article coming with it. This one has an article as an example, so you can add it if you really want it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there is no official research implementation and pre-trained checkpoint I would also prefer to not add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me