Skip to content

fix(models, testing): Fix Llama4 vision rotary meta tensor initialization and MyT5 get_tokenizer signature#44581

Merged
Rocketknight1 merged 5 commits intohuggingface:mainfrom
harshaljanjani:fix/llama4-rope-meta-myt5-tokenizer
Mar 13, 2026
Merged

fix(models, testing): Fix Llama4 vision rotary meta tensor initialization and MyT5 get_tokenizer signature#44581
Rocketknight1 merged 5 commits intohuggingface:mainfrom
harshaljanjani:fix/llama4-rope-meta-myt5-tokenizer

Conversation

@harshaljanjani
Copy link
Contributor

@harshaljanjani harshaljanjani commented Mar 10, 2026

What does this PR do?

The following issues were identified and fixed in this PR:

Llama-4 Vision: freqs_ci is stored as a plain attr in Llama4VisionRotaryEmbedding. When from_pretrained initializes the model with device_map="auto", all tensors become meta tensors, but freqs_ci is not registered and never materialized to device giving an error when copying out of the meta tensor. Fixed by registering it as a buffer and adding a meta-device recompute guard in forward.
MyT5: 05c0e1d ("rm slow tokenizers") refactored TokenizerTesterMixin.get_tokenizer to accept pretrained_name as a positional argument and changed the call sites in the base class tests accordingly, but MyT5TokenizationTest.get_tokenizer was never updated to match; this should fix that. Took the canonical pattern from the other models (BARTpho, CANINE, CLVP, etc.) and used it here.

cc: @Rocketknight1 @itazap

CI Failures

Before the fixes (feel free to cross-check; these errors are reproducible):

4 image image

After the fixes (feel free to cross-check):

2 image image

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you fix any necessary existing tests?

@harshaljanjani harshaljanjani marked this pull request as ready for review March 10, 2026 19:35
@harshaljanjani
Copy link
Contributor Author

The failing test test_modeling_utils.py::InitializeMissingKeysTest::test_fsdp_non_rank0_end_to_end_no_reinit seems unrelated to this change.

@Rocketknight1
Copy link
Member

Hi @harshaljanjani, I don't like this fix unfortunately! The forward() guard is a problem because it clobbers the buffer registration by assigning an attribute.

The standard way for handling meta device initialization is to do param initialization in init_weights and not in __init__(), since __init__() is run before parameters are materialized. Doing it that way should make this work without the hacky forward() guard

@harshaljanjani
Copy link
Contributor Author

harshaljanjani commented Mar 11, 2026

Thanks for your time @Rocketknight1. If I understand your point correctly, you're referring to this pattern, among others right? I'll remove the forward meta device check and have init_weights recompute freqs_ci. Is there anything else you'd like me to look at? I'll address this right away and push after making sure all the bells and whistles are tight :)
My repro script (as I wrote in the PR description previously) bypassed init_weights, I've updated the PR description with a more accurate reproduction 😓

Comment on lines +491 to +492
elif isinstance(module, Llama4VisionRotaryEmbedding):
module.freqs_ci = module._compute_freqs_ci(module.config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I should have been clearer. The problem is that in the __init__ you register a buffer, but this line clobbers the freqs_ci attribute with a totally new tensor., which makes the __init__() line pointless. What you probably want to do is module.freqs_ci.copy_(module._compute_freqs_ci(module.config)), which will preserve the tensor object and simply initialize the right values for it, which is how weight init is supposed to work!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh gotchya, thanks for taking the time! Hopefully this should be better 🤗

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: llama4, myt5

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, LGTM now!

@Rocketknight1 Rocketknight1 enabled auto-merge March 13, 2026 14:05
@Rocketknight1 Rocketknight1 added this pull request to the merge queue Mar 13, 2026
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Merged via the queue into huggingface:main with commit b762631 Mar 13, 2026
28 checks passed
@harshaljanjani harshaljanjani deleted the fix/llama4-rope-meta-myt5-tokenizer branch March 13, 2026 14:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants