Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix _merge_input_ids_with_image_features for llava model #28333

Merged
merged 9 commits into from
Jan 10, 2024

Conversation

VictorSanh
Copy link
Member

@VictorSanh VictorSanh commented Jan 3, 2024

Bug detected by @Sakshi-Bhargava

The method LlavaForConditionalGeneration._merge_input_ids_with_image_features takes care of merging the input_embeds with the hidden states obtained from the vision encoder. The merge output is fed to the language model part of the model.

However, labels was omitted from the merge, and when trying to compute a loss, the shapes of the logits and the labels are not compatible.

This fix ensures that labels is also properly merged.

Dummy reproduction case (still respect the model hidden sizes):

import torch
from transformers import LlavaForConditionalGeneration
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-13b-hf")

pixel_values = torch.randn(
    (2, 3, 336, 336),
    dtype=torch.float
)
input_ids = torch.tensor(
    [
        [32001, 32001, 1, 15043,  7084, 32000, 29871,    13, 7900],
        [1, 15043,  7084, 29901, 29871, 32000, 29871,    13, 7900]
    ], dtype=torch.long
)
attention_mask = torch.tensor(
    [
        [0, 0, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1]
    ], dtype=torch.long
)

output = model(
    pixel_values=pixel_values,
    input_ids=input_ids,
    attention_mask=attention_mask,
    labels=input_ids,
)

will yield the following error without the fix

    output = model(
  File "/victor/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/victor/code/transformers/src/transformers/models/llava/modeling_llava.py", line 486, in forward
    shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
IndexError: The shape of the mask [2, 583] at index 1 does not match the shape of the indexed tensor [2, 8] at index 1

cc @gullalc @younesbelkada @amyeroberts

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the support for training 😉
Let's add a test as well

src/transformers/models/llava/modeling_llava.py Outdated Show resolved Hide resolved
src/transformers/models/llava/modeling_llava.py Outdated Show resolved Hide resolved
VictorSanh and others added 3 commits January 4, 2024 11:50
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@VictorSanh
Copy link
Member Author

Adressed the comments and moved the dummy test case into proper tests.
let me know if you would like something more involved test wise!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks, let's also make sure that loss.backward() works (we usually have an automatic test for this here

def test_training(self):

tests/models/llava/test_modeling_llava.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! I left a single comment about the test

tests/models/llava/test_modeling_llava.py Outdated Show resolved Hide resolved
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work @VictorSanh ! Thanks a lot for the fix!

@ArthurZucker
Copy link
Collaborator

Good to go ! Feel free to merge @VictorSanh 🤗

@VictorSanh
Copy link
Member Author

I am not cool enough to have merge access. The time where i am merging stuff whenever i wanted on hf transformers is well passed haha

@VictorSanh
Copy link
Member Author

so either you @ArthurZucker or @younesbelkada need to merge lol 😅
but perhaps i can be promoted to core maintainer with that PR @LysandreJik ?

@younesbelkada younesbelkada merged commit 0f2f0c6 into huggingface:main Jan 10, 2024
18 checks passed
@ArthurZucker
Copy link
Collaborator

Ooops 🤣

@LysandreJik
Copy link
Member

Considering it @VictorSanh!

staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
…ce#28333)

* fix `_merge_input_ids_with_image_features` for llava model

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* adress comments

* style and tests

* ooops

* test the backward too

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update tests/models/vipllava/test_modeling_vipllava.py

* style and quality

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
MadElf1337 pushed a commit to MadElf1337/transformers that referenced this pull request Jan 15, 2024
…ce#28333)

* fix `_merge_input_ids_with_image_features` for llava model

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* adress comments

* style and tests

* ooops

* test the backward too

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update tests/models/vipllava/test_modeling_vipllava.py

* style and quality

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
wgifford pushed a commit to wgifford/transformers that referenced this pull request Jan 21, 2024
…ce#28333)

* fix `_merge_input_ids_with_image_features` for llava model

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* adress comments

* style and tests

* ooops

* test the backward too

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update tests/models/vipllava/test_modeling_vipllava.py

* style and quality

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
AjayP13 pushed a commit to AjayP13/transformers that referenced this pull request Jan 22, 2024
…ce#28333)

* fix `_merge_input_ids_with_image_features` for llava model

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* adress comments

* style and tests

* ooops

* test the backward too

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update tests/models/vipllava/test_modeling_vipllava.py

* style and quality

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
@alexandrosXe
Copy link

Hi,
I am still getting the following error when I'm trying to finetune the model for ConditionalGeneration in the forward call:

final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
IndexError: index 8 is out of bounds for dimension 1 with size 8

The same code works fine if I just change the model to another VLLM like InstructBlip.

Thank you and kind regards,
Alexandros Xenos

@VictorSanh
Copy link
Member Author

@alexandrosXe do you have a reproduction case we can start debugging from?

@alexandrosXe
Copy link

@alexandrosXe do you have a reproduction case we can start debugging from?
@VictorSanh Thank you for replying so fast!
This code can reproduce my error:

from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
answer = "The image has a stop sign in the corner of the road"


inputs = processor(text=prompt, images=image, return_tensors="pt")

labels = processor.tokenizer(answer, return_tensors="pt")
label_ids = labels["input_ids"]
label_mask = labels["attention_mask"].bool()
label_ids = label_ids.masked_fill(~label_mask, -100) #We dont count the loss on the padded tokens
loss = model(**inputs, labels = label_ids).loss
print("loss: ", loss)

@ArthurZucker
Copy link
Collaborator

Can you also make sure you are using the latest version of transformers?

@alexandrosXe
Copy link

Can you also make sure you are using the latest version of transformers?

@ArthurZucker I am using the transformers 4.37.2 version.

@ArthurZucker
Copy link
Collaborator

Yes, it seems oss = model(**inputs, labels = inputs["input_ids"]) works well however. Loss was made to be of the same size as the input ids:

        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

from the doc it should be length sequence length

@alexandrosXe
Copy link

@ArthurZucker Thank you, it was my fault using wrong labels. Now everything works fine!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants