Skip to content

Conversation

rcogill
Copy link
Contributor

@rcogill rcogill commented Sep 3, 2025

What does this PR do?

This PR addresses two issues arising from an in-place operation where audio and text embeddings are merged in the Voxtral model. This fixes Issue #40488 and an unreported issue resulting from using device_map="auto" with the Voxtral model.

More detail about the two issues and their resolution:

  • In the issue reported in #40488, the forward method of VoxtralForConditionalGeneration fails when using LoRA. The underlying issue is that when the text embedding layer is frozen, the inputs_embeds tensor extracted from the embedding layer is a leaf tensor. When inputs_embeds is a leaf tensor that requires gradients, values of this tensor cannot be reassigned. To address this, inputs_embeds will be now be cloned to enable tracking of gradients when it is a leaf tensor that requires gradients.
  • The second unreported issue arises when using device_map="auto" with Voxtral. When using device_map="auto", audio and text layers might be distributed across different devices. When this is the case, inputs_embeds and audio_embeds might be on different devices. In this case, attempting to assign values of audio_embeds to inputs_embeds will fail since both tensors are expected to be on the same device. To address this, audio_embeds is moved to the same device as inputs_embeds before updating values of inputs_embeds.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case. -> Voxtral model fails with LoRA due to in-place operation error #40488
  • Did you make sure to update the documentation with your changes? -> No documentation changes necessary.
  • Did you write any new necessary tests? -> No new features added.

Who can review?

@eustlb

Comment on lines 245 to 250
# Enable gradient tracking when inputs_embeds is a leaf tensor
if inputs_embeds.is_leaf and inputs_embeds.requires_grad:
inputs_embeds = inputs_embeds.clone()
# replace text-audio token placeholders with audio embeddings
audio_token_mask = input_ids == self.config.audio_token_id
inputs_embeds[audio_token_mask] = audio_embeds
inputs_embeds[audio_token_mask] = audio_embeds.to(inputs_embeds.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of using a clone tbh, can we use a masked scatter here instead, e.g. see

inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be out-of-place op then, so I would hope it resolves the inplace operation issue

(probably still needs the device movement)

Copy link
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rcogill for working on this and @vasqu for first review! Quick benchmarks show non-inplace masked_scatter is faster anyway. Can you add the suggested change and good to merge then.

Comment on lines 245 to 250
# Enable gradient tracking when inputs_embeds is a leaf tensor
if inputs_embeds.is_leaf and inputs_embeds.requires_grad:
inputs_embeds = inputs_embeds.clone()
# replace text-audio token placeholders with audio embeddings
audio_token_mask = input_ids == self.config.audio_token_id
inputs_embeds[audio_token_mask] = audio_embeds
inputs_embeds[audio_token_mask] = audio_embeds.to(inputs_embeds.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree here with @vasqu, let's not forget also to move the mask to the correct device and make it's broadcastable!
Can you please change to:

# replace text-audio token placeholders with audio embeddings
audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
inputs_embeds = inputs_embeds.masked_scatter(audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushed directly to shortcut and merge! Thanks again, @vasqu and @rcogill

@eustlb eustlb enabled auto-merge (squash) September 4, 2025 14:37
Copy link
Contributor

github-actions bot commented Sep 4, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: voxtral

@vasqu
Copy link
Contributor

vasqu commented Sep 4, 2025

Adding a for patch labels here @eustlb? And glad to help ❤️

@eustlb eustlb added for patch Tag issues / labels that should be included in the next patch and removed for patch Tag issues / labels that should be included in the next patch labels Sep 4, 2025
@eustlb
Copy link
Contributor

eustlb commented Sep 4, 2025

Actually no, for-patch label is for what has been broken in last release, which is not the case here ;)

@eustlb eustlb merged commit 4cbca0d into huggingface:main Sep 4, 2025
17 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@rcogill
Copy link
Contributor Author

rcogill commented Sep 4, 2025

@eustlb and @vasqu , thank you both!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Voxtral model fails with LoRA due to in-place operation error
4 participants