-
Notifications
You must be signed in to change notification settings - Fork 25.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨 Add Blip2ForImageTextRetrieval #29261
base: main
Are you sure you want to change the base?
🚨 Add Blip2ForImageTextRetrieval #29261
Conversation
cc @NielsRogge and @younesbelkada if one of you want to review on @jpizarrom makes the CIs go green! |
Hi, what could I do to makes the CIs go green! shall I just merge to upstream/main, or rebase to it? |
@jpizarrom It's preferable for you to rebase onto main. To see how to make the CIs green, you'll need to click on |
0e82065
to
9aa9a15
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! Overall looks great, just a few small comments
Once they're addressed we can move the checkpoints to be under the salesforce org
@classmethod | ||
def from_vision_qformer_configs( | ||
cls, | ||
vision_config: Blip2VisionConfig, | ||
qformer_config: Blip2QFormerConfig, | ||
**kwargs, | ||
): | ||
r""" | ||
Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision and Q-Former model configurations. | ||
|
||
Returns: | ||
[`Blip2Config`]: An instance of a configuration object | ||
""" | ||
|
||
return cls( | ||
vision_config=vision_config.to_dict(), | ||
qformer_config=qformer_config.to_dict(), | ||
**kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's necessary to add a separate method here. We can just make text_config
optional in from_vision_qformer_text_config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from_vision_qformer_configs
was removed
if self.device != torch.device("cpu"): | ||
with torch.cuda.amp.autocast(dtype=torch.float16): | ||
vision_outputs = self.vision_model( | ||
pixel_values=pixel_values, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
else: | ||
vision_outputs = self.vision_model( | ||
pixel_values=pixel_values, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Autocasting and typing should be handled outside of the model definition
if self.device != torch.device("cpu"): | |
with torch.cuda.amp.autocast(dtype=torch.float16): | |
vision_outputs = self.vision_model( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
else: | |
vision_outputs = self.vision_model( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
vision_outputs = self.vision_model( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was done because in the original model the autocast was applied only to the vision layers, don't know yet how to do this in a different way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @amyeroberts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was removed, as discussed in #29261 (comment)
if config.use_qformer_text_input: | ||
self.embeddings = Blip2TextEmbeddings(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using this config argument to conditionally call and create this layer, I'd suggest instead call self.embeddings
if input_ids
is not None
if config.use_qformer_text_input: | |
self.embeddings = Blip2TextEmbeddings(config) | |
self.embeddings = Blip2TextEmbeddings(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when this layer is created_always_, I got this type of errors, don't know how to fix them.
Some Blip2 models do not use this bert based embeddings, they use opt or flan-t5 to create the query_embeds. Maybe I could try to refactor the code to move the Blip2TextEmbeddings
outside of Blip2QFormerModel
and pass always query_embeds. what do you think?
FAILED tests/models/blip_2/test_modeling_blip_2.py::Blip2ForConditionalGenerationDecoderOnlyTest::test_training_gradient_checkpointing - AssertionError: False is not true : qformer.embeddings.word_embeddings.weight in Blip2ForConditionalGeneration has no gradient!
FAILED tests/models/blip_2/test_modeling_blip_2.py::Blip2ModelTest::test_training_gradient_checkpointing - AssertionError: False is not true : qformer.embeddings.word_embeddings.weight in Blip2ForConditionalGeneration has no gradient!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a refactor, embeddings were removed from Blip2QFormerModel
, and place them into Blip2ForImageTextRetrieval
and Blip2TextModelWithProjection
, but to do so i needed to add query_length
param to Blip2QFormerModel.forward
.
# past_key_values_length | ||
past_key_values_length = ( | ||
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 | ||
) | ||
|
||
query_length = query_embeds.shape[1] if query_embeds is not None else 0 | ||
|
||
embedding_output = self.layernorm(query_embeds) | ||
if self.config.use_qformer_text_input: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.config.use_qformer_text_input: | |
if input_ids is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is outdated, because embeddings were removed from Blip2QFormerModel
# TODO: maybe have a cleaner way to cast the input (from `Blip2Processor` side?) | ||
expected_dtype = self.dtype | ||
if encoder_hidden_states is not None and encoder_hidden_states.dtype != expected_dtype: | ||
encoder_hidden_states = encoder_hidden_states.to(expected_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this even necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should not be necessary indeed given that modeling code is by default in torch.float32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was removed
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
05327aa
to
da0cc83
Compare
) | ||
|
||
if self.device != torch.device("cpu"): | ||
with torch.cuda.amp.autocast(dtype=torch.float16): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell we don't add torch.cuda.amp.autocast
code to modeling files, they are just in float32 by default. This was discussed on the original BLIP-2 model addition PR from what I remember. It's up to users to call something like torch.cuda.amp.autocast
themselves if they wish to load the model in a different precision than the default one (cc @younesbelkada).
Hence in the conversion script I casted both the original weights and my BLIP-2 implementation to float32 in order to verify the conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, so this means that i need to remove maybe_autocast
from https://github.com/NielsRogge/LAVIS/blob/blip2_float32/lavis/models/blip2_models/blip2_image_text_matching.py#L57-L58, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was removed, a PR on your fork was opened to also remove the autocast from the ITM model NielsRogge/LAVIS#1
@@ -84,6 +84,99 @@ def to_tuple(self) -> Tuple[Any]: | |||
) | |||
|
|||
|
|||
@dataclass | |||
class Blip2ImageTextMatchingModelOutput(ModelOutput): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if feasible, but it'd be nice to match the output class of CLIP, which is also an image-text matching model. It consists of the following keys:
- loss
- logits_per_image (this I assume is the itm_score)
- logits_per_text (this I assume is the itm_score transposed)
- and some other keys which are CLIP-specific.
Making sure that Blip2ForImageTextRetrieval matches this would allow it to be added to the zero-shot image classification pipeline, which relies on this output key:
"logits": outputs.logits_per_image, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise we will have a hard time adding BLIP-2 support to the zero-shot image classification pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @NielsRogge, i updated the output to match CLIP output, but this PR is not being updated with my latest commits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your work! Would request some changes however in order to be able to make BLIP-2 compatible with the zero-shot image classification pipeline.
input_ids: Optional[torch.FloatTensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
query_embeds: Optional[torch.FloatTensor] = None, | ||
past_key_values_length: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
past_key_values_length: int = 0, |
past_key_values are not used I assume
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was removed. thanks
@jpizarrom once the CI is green I can assign a core maintainer for a final approval |
I believe the CI errors are not related to this branch, i see modeling_mra and other non related error logs, don't know how to make CI green, maybe rebase to a more recent commit of main branch? |
Could you rebase on main and push? ( |
…t_retrieval_model
Thanks, |
Ok, pinging @ydshieh here |
…t_retrieval_model
Hi, i merged main again, there are some errors in |
Those could be ignored. But we could probably get them fixed in another PR soon. |
…t_retrieval_model
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this - exciting to have this feature finally added!
Mostly a few nits. Main comment is about identifying the cause of the change in integration tests values and possibly rectifying given it indicates a degradation in performance.
@@ -347,6 +361,6 @@ def from_vision_qformer_text_configs( | |||
return cls( | |||
vision_config=vision_config.to_dict(), | |||
qformer_config=qformer_config.to_dict(), | |||
text_config=text_config.to_dict(), | |||
text_config=text_config.to_dict() if text_config is not None else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making this optional is a bit funny given the name of the method. We should at least update the docstring to indicate that language model config is optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring were updated
assert unexpected_keys == ["qformer.embeddings.position_ids"] | ||
|
||
if "itm" in model_name: | ||
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this filtering necessary here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are some fields that were excluded from the original model.
Qformer.cls.predictions.bias, Qformer.cls.predictions.transform.dense.weight, Qformer.cls.predictions.transform.dense.bias, Qformer.cls.predictions.transform.LayerNorm.weight, Qformer.cls.predictions.transform.LayerNorm.bias, Qformer.cls.predictions.decoder.weight
[2, 15610, 1597, 2977, 6, 13011, 1594, 43052, 50118], | ||
) | ||
self.assertEqual(generated_text, "it's not a city, it's a beach") | ||
self.assertEqual(generated_text, "san diego, california") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmmm - this doesn't seem right (the picture is indeed of a beach, not a city).
Could you try the following to see if you're able to recover the previous generations:
- Try without the additional generation kwargs
- Try without the added scaling included in the modeling file c.f. https://github.com/huggingface/transformers/pull/29261/files#r1501123761
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I will change the test, I was only trying to get a similar result to test_inference_t5
since in that one the answer shows san diego. it was not related to the scaling change.
Hi, thanks for the feedback, I am making the suggested changes. I don't know what values from the integration tests you are referring to that indicate performance degradation, could you give more context? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work adding this!
Only thing left to do is to update the checkpoint references to point to ones under the salesforce org
Shall I do it? can i publish a model under salesforce org? |
@@ -79,6 +82,12 @@ def create_rename_keys(config): | |||
# QFormer | |||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight")) | |||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias")) | |||
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I just found that i got an error converting blip2-opt-2.7b
model
KeyError: 'Qformer.bert.embeddings.word_embeddings.weight'
I'm going to have to make this rename keys for the itm models only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it'd be great to keep backwards compatibility for the existing checkpoints, and also making sure that users don't get warnings when loading the existing checkpoints (like unexpected keys in checkpoint
etc.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I fixed the keys issues on the convert script
- then the logit comparison on convert script for
blip2-opt-2.7b
were failing, I needed to revert back a change that i did, now the scale is done after the dot product between "query" and "key" - slow tests on test_modeling_blip_2.py are passing.
RUN_SLOW=1 python -m pytest tests/models/blip_2/test_modeling_blip_2.py
- but the tests on CI are failing, it look for a reason different than this PR
The convert script for blip2-opt-2.7b
is failing in the generation step, it looks like it is not related to the changes of this PR, because i got the same error on the main branch.
ValueError: Input length of input_ids is 0, but `max_length` is set to -14. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @zucchini-nlp who fixed this error, we might need to update the conversion script of BLIP-2 for that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check that model's generation config has a high enough max_length
(if it doesn't have max_length
the default is 20)?
Right now BLIP will count all tokens (image and text) towards max_length
, so we can either add higher max_length
or max_new_tokens
in model's generation config so that one run generate directly with model.generate(**inputs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added max_length=50
to the generate
call of the lavis model on the convert script, now conversion of blip2-opt-2.7b
works
I'll make sure the checkpoints get transferred |
Hi @jpizarrom could you push the checkpoints to your HF profile so that I can transfer them to the Salesforce org? |
Hi, I just ran the conversion script, so the checkpoints are updated in. https://huggingface.co/jpizarrom/blip2-itm-vit-g |
What does this PR do?
Add
Blip2ForImageTextRetrieval
,Blip2TextModelWithProjection
,Blip2VisionModelWithProjection
models to be able to get Image Text Matching scores, and extract text,image,multimodal features.Fixes part of #25300 #25245
This is continuation of #25612, I tried to apply most of the feedback received in that PR.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @amyeroberts