Fix dtype mismatch in in modeling_llava_next#42979
Fix dtype mismatch in in modeling_llava_next#42979Godkunn wants to merge 1 commit intohuggingface:mainfrom
Conversation
Ensure logits are computed with the correct dtype.
|
[For maintainers] Suggested jobs to run (before merge) run-slow: llava_next |
|
I don't think this is the right fix for this issue. It could actually come from the loading logic: >>> from transformers import AutoModelForImageTextToText
>>> model = AutoModelForImageTextToText.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", dtype="auto")
>>> model.language_model.dtype, model.vision_tower.dtype
(torch.bfloat16, torch.float16)but we expect |
|
this is probably the issue https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf/blob/main/config.json#L43 |
|
Yep, this works: import torch
from transformers import AutoModelForImageTextToText
model = AutoModelForImageTextToText.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", dtype="auto", revision="refs/pr/46")
input_ids = torch.randint(0, model.config.text_config.vocab_size, (2, 8))
model(input_ids=input_ids) |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Thanks @Godkunn though this is a general issue with how dtypes are dispatched when loading. It should not be fixed per each model class, but rather in base modeling class. I will make a fix, we have discussed internally and decided to deprecate out different dtypes per backbone
|
Thanks @zucchini-nlp and @qgallouedec for the deep dive! That makes total sense regarding the dispatching logic in the base modeling class. Since you've identified the root cause and will be implementing a central fix (and deprecating the mixed backbone dtypes), I'll close this PR so it doesn't clutter the queue. Glad I could help bring attention to the issue! Happy to help verify the fix once it's live if needed... |
What does this PR do?
Fixes #42968
This PR fixes a
RuntimeErrorwhen runningLlavaNextForConditionalGenerationwith mixed precision (e.g., BFloat16 vs Float16).The Fix:
I added a cast
.to(self.lm_head.weight.dtype)tohidden_statesbefore passing it tolm_head. This ensures the input tensor matches the linear layer's weight type.Verification:
Verified locally on T4 GPU using the reproduction script from the issue. The forward pass now completes without crashing.
Who can review?
@zucchini-nlp

