Skip to content

Fix dtype mismatch in in modeling_llava_next#42979

Open
Godkunn wants to merge 1 commit intohuggingface:mainfrom
Godkunn:main
Open

Fix dtype mismatch in in modeling_llava_next#42979
Godkunn wants to merge 1 commit intohuggingface:mainfrom
Godkunn:main

Conversation

@Godkunn
Copy link
Copy Markdown

@Godkunn Godkunn commented Dec 21, 2025

What does this PR do?

Fixes #42968

This PR fixes a RuntimeError when running LlavaNextForConditionalGeneration with mixed precision (e.g., BFloat16 vs Float16).

The Fix:
I added a cast .to(self.lm_head.weight.dtype) to hidden_states before passing it to lm_head. This ensures the input tensor matches the linear layer's weight type.

Verification:
Verified locally on T4 GPU using the reproduction script from the issue. The forward pass now completes without crashing.

Who can review?

@zucchini-nlp
image
image

Ensure logits are computed with the correct dtype.
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: llava_next

@qgallouedec
Copy link
Copy Markdown
Member

qgallouedec commented Dec 21, 2025

I don't think this is the right fix for this issue. It could actually come from the loading logic:

>>> from transformers import AutoModelForImageTextToText
>>> model = AutoModelForImageTextToText.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", dtype="auto")
>>> model.language_model.dtype, model.vision_tower.dtype
(torch.bfloat16, torch.float16)

but we expect (torch.float16, torch.float16)

@qgallouedec
Copy link
Copy Markdown
Member

@qgallouedec
Copy link
Copy Markdown
Member

Yep, this works:

import torch
from transformers import AutoModelForImageTextToText

model = AutoModelForImageTextToText.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", dtype="auto", revision="refs/pr/46")
input_ids = torch.randint(0, model.config.text_config.vocab_size, (2, 8))
model(input_ids=input_ids)

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Godkunn though this is a general issue with how dtypes are dispatched when loading. It should not be fixed per each model class, but rather in base modeling class. I will make a fix, we have discussed internally and decided to deprecate out different dtypes per backbone

@Godkunn
Copy link
Copy Markdown
Author

Godkunn commented Dec 22, 2025

Thanks @zucchini-nlp and @qgallouedec for the deep dive!

That makes total sense regarding the dispatching logic in the base modeling class. Since you've identified the root cause and will be implementing a central fix (and deprecating the mixed backbone dtypes), I'll close this PR so it doesn't clutter the queue.

Glad I could help bring attention to the issue! Happy to help verify the fix once it's live if needed...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LlavaNextForConditionalGeneration forward pass broken

3 participants