Skip to content

fix: cast to proper dtype in EmbeddingParallel#44612

Merged
3outeille merged 4 commits intohuggingface:mainfrom
michaelbenayoun:tp_embedding_cast_fix
Mar 12, 2026
Merged

fix: cast to proper dtype in EmbeddingParallel#44612
3outeille merged 4 commits intohuggingface:mainfrom
michaelbenayoun:tp_embedding_cast_fix

Conversation

@michaelbenayoun
Copy link
Member

What does this PR do?

The output function hook in EmbeddingParallel casts the mask to fp32. It breaks things for neuron devices. Suggested fix: cast to the outputs' dtype.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille 3outeille enabled auto-merge March 12, 2026 09:30
@3outeille 3outeille added this pull request to the merge queue Mar 12, 2026
Merged via the queue into huggingface:main with commit adc2f16 Mar 12, 2026
28 checks passed
@michaelbenayoun michaelbenayoun deleted the tp_embedding_cast_fix branch March 12, 2026 21:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants