[Gemma4] Replace one-hot matmul with F.embedding in position embeddings#46176
Conversation
Replace the one-hot encoding + matrix multiply pattern in Gemma4VisionPatchEmbedder._position_embeddings with two F.embedding lookups (one per spatial axis) summed together. This is mathematically equivalent but avoids materializing a ~19 GiB intermediate one-hot tensor (int64) and its bf16 cast copy during training with large batch sizes. Fixes huggingface#46175 AI-assisted contribution (Claude Code).
| (shape ``(2, position_embedding_size, hidden_size)``). The result is the | ||
| sum of the x- and y-embeddings for each patch. | ||
| """ | ||
| clamped_positions = pixel_position_ids.clamp(min=0, max=self.position_embedding_size - 1) |
There was a problem hiding this comment.
i am not sure about clamping the upper bound, F.one_hot raises an error when the values are beyond the total number of classes
Can you check what exactly happened, and why we need the upper bound clamping?
There was a problem hiding this comment.
Good catch — removed the upper-bound clamp in the latest commit. The original code only had clamp(min=0) to handle negative padding sentinels, and F.one_hot would have raised on out-of-bounds values. I've matched that behavior: clamp(min=0) only, so valid positions remain in-range by construction and any out-of-bounds input would surface the same way it did before.
| x_emb = F.embedding(clamped_positions[..., 0], self.position_embedding_table[0]) | ||
| y_emb = F.embedding(clamped_positions[..., 1], self.position_embedding_table[1]) |
|
|
||
|
|
||
| @require_torch | ||
| class Gemma4VisionPatchEmbedderTest(unittest.TestCase): | ||
| """Unit tests for Gemma4VisionPatchEmbedder._position_embeddings.""" | ||
|
|
||
| def _make_embedder(self, position_embedding_size=64, hidden_size=32): | ||
| from transformers import Gemma4VisionConfig |
There was a problem hiding this comment.
to remove, we don't need a test imo. I will trigger slow CI to check that model isn't broken
There was a problem hiding this comment.
Done — removed the Gemma4VisionPatchEmbedderTest class entirely in the latest commit. Happy to let the slow CI integration tests cover this.
|
run-slow: gemma4 |
|
This comment contains models: ["models/gemma4"] |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
CI ResultsCommit Info
Model CI Report❌ 1 new failed tests from this PR 😭
|
- Drop max= from pixel_position_ids.clamp(): only negative values (padding sentinels) need guarding; valid positions are in-range by construction, matching the original F.one_hot behavior. - Remove Gemma4VisionPatchEmbedderTest per reviewer request; slow CI integration tests are sufficient to catch regressions.
|
run-slow: gemma4 |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma4 |
|
This comment contains models: ["models/gemma4"] |
CI ResultsCommit Info
Model CI Report❌ 1 new failed tests from this PR 😭
|
zucchini-nlp
left a comment
There was a problem hiding this comment.
Merging, thanks for iterating!
…gs (huggingface#46176) * [Gemma4] Replace one-hot matmul with F.embedding in position embeddings Replace the one-hot encoding + matrix multiply pattern in Gemma4VisionPatchEmbedder._position_embeddings with two F.embedding lookups (one per spatial axis) summed together. This is mathematically equivalent but avoids materializing a ~19 GiB intermediate one-hot tensor (int64) and its bf16 cast copy during training with large batch sizes. Fixes huggingface#46175 AI-assisted contribution (Claude Code). * Address review feedback: remove upper-bound clamp and test class - Drop max= from pixel_position_ids.clamp(): only negative values (padding sentinels) need guarding; valid positions are in-range by construction, matching the original F.one_hot behavior. - Remove Gemma4VisionPatchEmbedderTest per reviewer request; slow CI integration tests are sufficient to catch regressions. --------- Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
…gs (huggingface#46176) * [Gemma4] Replace one-hot matmul with F.embedding in position embeddings Replace the one-hot encoding + matrix multiply pattern in Gemma4VisionPatchEmbedder._position_embeddings with two F.embedding lookups (one per spatial axis) summed together. This is mathematically equivalent but avoids materializing a ~19 GiB intermediate one-hot tensor (int64) and its bf16 cast copy during training with large batch sizes. Fixes huggingface#46175 AI-assisted contribution (Claude Code). * Address review feedback: remove upper-bound clamp and test class - Drop max= from pixel_position_ids.clamp(): only negative values (padding sentinels) need guarding; valid positions are in-range by construction, matching the original F.one_hot behavior. - Remove Gemma4VisionPatchEmbedderTest per reviewer request; slow CI integration tests are sufficient to catch regressions. --------- Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
Summary
Fixes #46175
Gemma4VisionPatchEmbedder._position_embeddingsmaterializes a one-hot tensor of shape[batch, num_patches, 2, position_embedding_size]in int64, then casts to the table's dtype, and matrix-multiplies againstposition_embedding_table. For typical training configs (position_embedding_size=10240, batch=40, 2520 patches), this allocates ~19 GiB of GPU memory for what is mathematically a 2-row embedding lookup.This PR replaces the
F.one_hot+ matmul pattern with twoF.embeddingcalls (one per spatial axis), summed. The change:[0, position_embedding_size - 1](the original only clampedmin=0;F.embeddingwould crash on OOB unlikeF.one_hotwhich silently handled it)Coordination
Gemma4VisionPatchEmbedder._position_embeddingsmaterializes a ~19 GiB one-hot tensor that's mathematically a 2-row embedding lookup #46175 (comment)Gemma4VisionPatchEmbedder._position_embeddingsmaterializes a ~19 GiB one-hot tensor that's mathematically a 2-row embedding lookup #46175 (comment)gh pr list --search "46175 in:body")Test
AI-assisted contribution (Claude Code).