Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨 Add Blip2ForImageTextRetrieval #29261

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

jpizarrom
Copy link
Contributor

@jpizarrom jpizarrom commented Feb 23, 2024

What does this PR do?

Add Blip2ForImageTextRetrieval, Blip2TextModelWithProjection, Blip2VisionModelWithProjection models to be able to get Image Text Matching scores, and extract text,image,multimodal features.

Fixes part of #25300 #25245

This is continuation of #25612, I tried to apply most of the feedback received in that PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @amyeroberts

@jpizarrom jpizarrom marked this pull request as draft February 23, 2024 19:41
@jpizarrom jpizarrom marked this pull request as ready for review February 23, 2024 20:14
@jpizarrom jpizarrom changed the title WIP Add Blip2ForImageTextRetrieval Add Blip2ForImageTextRetrieval Feb 23, 2024
@ArthurZucker
Copy link
Collaborator

cc @NielsRogge and @younesbelkada if one of you want to review on @jpizarrom makes the CIs go green!

@jpizarrom jpizarrom changed the title Add Blip2ForImageTextRetrieval 🚨 Add Blip2ForImageTextRetrieval Mar 2, 2024
@jpizarrom
Copy link
Contributor Author

cc @NielsRogge and @younesbelkada if one of you want to review on @jpizarrom makes the CIs go green!

Hi, what could I do to makes the CIs go green! shall I just merge to upstream/main, or rebase to it?

@amyeroberts
Copy link
Collaborator

@jpizarrom It's preferable for you to rebase onto main. To see how to make the CIs green, you'll need to click on details and look at the output error logs from circleci. I'd suggest doing this after rebasing so see which errors are coming from this branch.

@jpizarrom jpizarrom force-pushed the add_blip2_image_text_retrieval_model branch from 0e82065 to 9aa9a15 Compare March 22, 2024 15:16
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! Overall looks great, just a few small comments

Once they're addressed we can move the checkpoints to be under the salesforce org

Comment on lines 372 to 389
@classmethod
def from_vision_qformer_configs(
cls,
vision_config: Blip2VisionConfig,
qformer_config: Blip2QFormerConfig,
**kwargs,
):
r"""
Instantiate a [`Blip2Config`] (or a derived class) from a BLIP-2 vision and Q-Former model configurations.

Returns:
[`Blip2Config`]: An instance of a configuration object
"""

return cls(
vision_config=vision_config.to_dict(),
qformer_config=qformer_config.to_dict(),
**kwargs,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's necessary to add a separate method here. We can just make text_config optional in from_vision_qformer_text_config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from_vision_qformer_configs was removed

src/transformers/models/blip_2/modeling_blip_2.py Outdated Show resolved Hide resolved
Comment on lines 2365 to 2387
if self.device != torch.device("cpu"):
with torch.cuda.amp.autocast(dtype=torch.float16):
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
else:
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Autocasting and typing should be handled outside of the model definition

Suggested change
if self.device != torch.device("cpu"):
with torch.cuda.amp.autocast(dtype=torch.float16):
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
else:
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was done because in the original model the autocast was applied only to the vision layers, don't know yet how to do this in a different way.

https://github.com/salesforce/LAVIS/blob/ac8fc98c93c02e2dfb727e24a361c4c309c8dbbc/lavis/models/blip2_models/blip2_qformer.py#L423-L424

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was removed, as discussed in #29261 (comment)

tests/models/blip_2/test_modeling_blip_2.py Outdated Show resolved Hide resolved
tests/models/blip_2/test_modeling_blip_2.py Outdated Show resolved Hide resolved
src/transformers/models/blip_2/modeling_blip_2.py Outdated Show resolved Hide resolved
Comment on lines 1199 to 1198
if config.use_qformer_text_input:
self.embeddings = Blip2TextEmbeddings(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using this config argument to conditionally call and create this layer, I'd suggest instead call self.embeddings if input_ids is not None

Suggested change
if config.use_qformer_text_input:
self.embeddings = Blip2TextEmbeddings(config)
self.embeddings = Blip2TextEmbeddings(config)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when this layer is created_always_, I got this type of errors, don't know how to fix them.
Some Blip2 models do not use this bert based embeddings, they use opt or flan-t5 to create the query_embeds. Maybe I could try to refactor the code to move the Blip2TextEmbeddings outside of Blip2QFormerModel and pass always query_embeds. what do you think?

FAILED tests/models/blip_2/test_modeling_blip_2.py::Blip2ForConditionalGenerationDecoderOnlyTest::test_training_gradient_checkpointing - AssertionError: False is not true : qformer.embeddings.word_embeddings.weight in Blip2ForConditionalGeneration has no gradient!
FAILED tests/models/blip_2/test_modeling_blip_2.py::Blip2ModelTest::test_training_gradient_checkpointing - AssertionError: False is not true : qformer.embeddings.word_embeddings.weight in Blip2ForConditionalGeneration has no gradient!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a refactor, embeddings were removed from Blip2QFormerModel, and place them into Blip2ForImageTextRetrieval and Blip2TextModelWithProjection, but to do so i needed to add query_length param to Blip2QFormerModel.forward.

# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
)

query_length = query_embeds.shape[1] if query_embeds is not None else 0

embedding_output = self.layernorm(query_embeds)
if self.config.use_qformer_text_input:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.config.use_qformer_text_input:
if input_ids is not None:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is outdated, because embeddings were removed from Blip2QFormerModel

Comment on lines 1373 to 1376
# TODO: maybe have a cleaner way to cast the input (from `Blip2Processor` side?)
expected_dtype = self.dtype
if encoder_hidden_states is not None and encoder_hidden_states.dtype != expected_dtype:
encoder_hidden_states = encoder_hidden_states.to(expected_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this even necessary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be necessary indeed given that modeling code is by default in torch.float32

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was removed

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@jpizarrom jpizarrom changed the title 🚨 Add Blip2ForImageTextRetrieval WIP 🚨 Add Blip2ForImageTextRetrieval May 1, 2024
@jpizarrom jpizarrom force-pushed the add_blip2_image_text_retrieval_model branch from 05327aa to da0cc83 Compare May 1, 2024 06:57
)

if self.device != torch.device("cpu"):
with torch.cuda.amp.autocast(dtype=torch.float16):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell we don't add torch.cuda.amp.autocast code to modeling files, they are just in float32 by default. This was discussed on the original BLIP-2 model addition PR from what I remember. It's up to users to call something like torch.cuda.amp.autocast themselves if they wish to load the model in a different precision than the default one (cc @younesbelkada).

Hence in the conversion script I casted both the original weights and my BLIP-2 implementation to float32 in order to verify the conversion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's right

Copy link
Contributor Author

@jpizarrom jpizarrom May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was removed, a PR on your fork was opened to also remove the autocast from the ITM model NielsRogge/LAVIS#1

@@ -84,6 +84,99 @@ def to_tuple(self) -> Tuple[Any]:
)


@dataclass
class Blip2ImageTextMatchingModelOutput(ModelOutput):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if feasible, but it'd be nice to match the output class of CLIP, which is also an image-text matching model. It consists of the following keys:

  • loss
  • logits_per_image (this I assume is the itm_score)
  • logits_per_text (this I assume is the itm_score transposed)
  • and some other keys which are CLIP-specific.

Making sure that Blip2ForImageTextRetrieval matches this would allow it to be added to the zero-shot image classification pipeline, which relies on this output key:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise we will have a hard time adding BLIP-2 support to the zero-shot image classification pipeline.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @NielsRogge, i updated the output to match CLIP output, but this PR is not being updated with my latest commits

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work! Would request some changes however in order to be able to make BLIP-2 compatible with the zero-shot image classification pipeline.

input_ids: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
query_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
past_key_values_length: int = 0,

past_key_values are not used I assume

Copy link
Contributor Author

@jpizarrom jpizarrom May 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was removed. thanks

@NielsRogge
Copy link
Contributor

@jpizarrom once the CI is green I can assign a core maintainer for a final approval

@jpizarrom
Copy link
Contributor Author

jpizarrom commented May 25, 2024

@jpizarrom once the CI is green I can assign a core maintainer for a final approval

I believe the CI errors are not related to this branch, i see modeling_mra and other non related error logs, don't know how to make CI green, maybe rebase to a more recent commit of main branch?

@NielsRogge
Copy link
Contributor

Could you rebase on main and push? (git fetch upstream followed by git merge upstream/main, assuming your upstream is set)

@jpizarrom
Copy link
Contributor Author

Could you rebase on main and push? (git fetch upstream followed by git merge upstream/main, assuming your upstream is set)

Thanks,
i merged with upstream, but the CI is still not green, other errors occurs, i believe not related to this branch.
RUN_SLOW=1 python -m pytest tests/models/blip_2/test_modeling_blip_2.py pass locally

@NielsRogge
Copy link
Contributor

Ok, pinging @ydshieh here

@jpizarrom
Copy link
Contributor Author

Ok, pinging @ydshieh here

Hi, i merged main again, there are some errors in tests/utils/test_offline.py

@ydshieh
Copy link
Collaborator

ydshieh commented May 27, 2024

Ok, pinging @ydshieh here

Hi, i merged main again, there are some errors in tests/utils/test_offline.py

Those could be ignored. But we could probably get them fixed in another PR soon.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this - exciting to have this feature finally added!

Mostly a few nits. Main comment is about identifying the cause of the change in integration tests values and possibly rectifying given it indicates a degradation in performance.

@@ -347,6 +361,6 @@ def from_vision_qformer_text_configs(
return cls(
vision_config=vision_config.to_dict(),
qformer_config=qformer_config.to_dict(),
text_config=text_config.to_dict(),
text_config=text_config.to_dict() if text_config is not None else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this optional is a bit funny given the name of the method. We should at least update the docstring to indicate that language model config is optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring were updated

assert unexpected_keys == ["qformer.embeddings.position_ids"]

if "itm" in model_name:
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this filtering necessary here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are some fields that were excluded from the original model.

Qformer.cls.predictions.bias, Qformer.cls.predictions.transform.dense.weight, Qformer.cls.predictions.transform.dense.bias, Qformer.cls.predictions.transform.LayerNorm.weight, Qformer.cls.predictions.transform.LayerNorm.bias, Qformer.cls.predictions.decoder.weight

tests/models/blip_2/test_modeling_blip_2.py Outdated Show resolved Hide resolved
tests/models/blip_2/test_modeling_blip_2.py Outdated Show resolved Hide resolved
tests/models/blip_2/test_modeling_blip_2.py Outdated Show resolved Hide resolved
Comment on lines 1431 to 1433
[2, 15610, 1597, 2977, 6, 13011, 1594, 43052, 50118],
)
self.assertEqual(generated_text, "it's not a city, it's a beach")
self.assertEqual(generated_text, "san diego, california")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm - this doesn't seem right (the picture is indeed of a beach, not a city).

Could you try the following to see if you're able to recover the previous generations:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will change the test, I was only trying to get a similar result to test_inference_t5 since in that one the answer shows san diego. it was not related to the scaling change.

tests/models/blip_2/test_modeling_blip_2.py Show resolved Hide resolved
tests/models/blip_2/test_modeling_blip_2.py Show resolved Hide resolved
tests/models/blip_2/test_modeling_blip_2.py Show resolved Hide resolved
@jpizarrom
Copy link
Contributor Author

Thanks for working on this - exciting to have this feature finally added!

Mostly a few nits. Main comment is about identifying the cause of the change in integration tests values and possibly rectifying given it indicates a degradation in performance.

Hi, thanks for the feedback, I am making the suggested changes.

I don't know what values from the integration tests you are referring to that indicate performance degradation, could you give more context?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work adding this!

Only thing left to do is to update the checkpoint references to point to ones under the salesforce org

@jpizarrom
Copy link
Contributor Author

Thanks for all the work adding this!

Only thing left to do is to update the checkpoint references to point to ones under the salesforce org

Shall I do it? can i publish a model under salesforce org?

@@ -79,6 +82,12 @@ def create_rename_keys(config):
# QFormer
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
Copy link
Contributor Author

@jpizarrom jpizarrom Jun 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I just found that i got an error converting blip2-opt-2.7b model
KeyError: 'Qformer.bert.embeddings.word_embeddings.weight'

I'm going to have to make this rename keys for the itm models only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it'd be great to keep backwards compatibility for the existing checkpoints, and also making sure that users don't get warnings when loading the existing checkpoints (like unexpected keys in checkpoint etc.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I fixed the keys issues on the convert script
  • then the logit comparison on convert script for blip2-opt-2.7b were failing, I needed to revert back a change that i did, now the scale is done after the dot product between "query" and "key"
  • slow tests on test_modeling_blip_2.py are passing. RUN_SLOW=1 python -m pytest tests/models/blip_2/test_modeling_blip_2.py
  • but the tests on CI are failing, it look for a reason different than this PR

The convert script for blip2-opt-2.7b is failing in the generation step, it looks like it is not related to the changes of this PR, because i got the same error on the main branch.
ValueError: Input length of input_ids is 0, but `max_length` is set to -14. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zucchini-nlp who fixed this error, we might need to update the conversion script of BLIP-2 for that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check that model's generation config has a high enough max_length (if it doesn't have max_length the default is 20)?

Right now BLIP will count all tokens (image and text) towards max_length, so we can either add higher max_length or max_new_tokens in model's generation config so that one run generate directly with model.generate(**inputs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added max_length=50 to the generate call of the lavis model on the convert script, now conversion of blip2-opt-2.7b works

@NielsRogge
Copy link
Contributor

NielsRogge commented Jun 16, 2024

Shall I do it? can i publish a model under salesforce org?

I'll make sure the checkpoints get transferred

@NielsRogge
Copy link
Contributor

Hi @jpizarrom could you push the checkpoints to your HF profile so that I can transfer them to the Salesforce org?

@jpizarrom
Copy link
Contributor Author

Hi @jpizarrom could you push the checkpoints to your HF profile so that I can transfer them to the Salesforce org?

Hi, I just ran the conversion script, so the checkpoints are updated in.

https://huggingface.co/jpizarrom/blip2-itm-vit-g
https://huggingface.co/jpizarrom/blip2-itm-vit-g-coco

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants