Skip to content

[Gemma4] Replace one-hot matmul with F.embedding in position embeddings#46176

Merged
zucchini-nlp merged 3 commits into
huggingface:mainfrom
Sriniketh24:fix/gemma4-position-embeddings-memory
May 28, 2026
Merged

[Gemma4] Replace one-hot matmul with F.embedding in position embeddings#46176
zucchini-nlp merged 3 commits into
huggingface:mainfrom
Sriniketh24:fix/gemma4-position-embeddings-memory

Conversation

@Sriniketh24
Copy link
Copy Markdown
Contributor

Summary

Fixes #46175

Gemma4VisionPatchEmbedder._position_embeddings materializes a one-hot tensor of shape [batch, num_patches, 2, position_embedding_size] in int64, then casts to the table's dtype, and matrix-multiplies against position_embedding_table. For typical training configs (position_embedding_size=10240, batch=40, 2520 patches), this allocates ~19 GiB of GPU memory for what is mathematically a 2-row embedding lookup.

This PR replaces the F.one_hot + matmul pattern with two F.embedding calls (one per spatial axis), summed. The change:

  • Eliminates the 15.38 GiB int64 one-hot tensor and its 3.85 GiB bf16 cast copy
  • Preserves numerical equivalence (embedding lookup is the same operation)
  • Clamps position IDs to [0, position_embedding_size - 1] (the original only clamped min=0; F.embedding would crash on OOB unlike F.one_hot which silently handled it)

Coordination

Test

python -m pytest tests/models/gemma4/test_modeling_gemma4.py::Gemma4VisionPatchEmbedderTest -xvs
PASSED test_no_one_hot_intermediate
PASSED test_padding_zeroed
PASSED test_negative_positions_clamped
PASSED test_oob_positions_clamped
4 passed in 18.11s

AI-assisted contribution (Claude Code).

Replace the one-hot encoding + matrix multiply pattern in
Gemma4VisionPatchEmbedder._position_embeddings with two F.embedding
lookups (one per spatial axis) summed together. This is mathematically
equivalent but avoids materializing a ~19 GiB intermediate one-hot
tensor (int64) and its bf16 cast copy during training with large
batch sizes.

Fixes huggingface#46175

AI-assisted contribution (Claude Code).
(shape ``(2, position_embedding_size, hidden_size)``). The result is the
sum of the x- and y-embeddings for each patch.
"""
clamped_positions = pixel_position_ids.clamp(min=0, max=self.position_embedding_size - 1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am not sure about clamping the upper bound, F.one_hot raises an error when the values are beyond the total number of classes

Can you check what exactly happened, and why we need the upper bound clamping?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — removed the upper-bound clamp in the latest commit. The original code only had clamp(min=0) to handle negative padding sentinels, and F.one_hot would have raised on out-of-bounds values. I've matched that behavior: clamp(min=0) only, so valid positions remain in-range by construction and any out-of-bounds input would surface the same way it did before.

Comment on lines +598 to +599
x_emb = F.embedding(clamped_positions[..., 0], self.position_embedding_table[0])
y_emb = F.embedding(clamped_positions[..., 1], self.position_embedding_table[1])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, lgtm!

Comment on lines +866 to +873


@require_torch
class Gemma4VisionPatchEmbedderTest(unittest.TestCase):
"""Unit tests for Gemma4VisionPatchEmbedder._position_embeddings."""

def _make_embedder(self, position_embedding_size=64, hidden_size=32):
from transformers import Gemma4VisionConfig
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove, we don't need a test imo. I will trigger slow CI to check that model isn't broken

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed the Gemma4VisionPatchEmbedderTest class entirely in the latest commit. Happy to let the slow CI integration tests cover this.

@zucchini-nlp
Copy link
Copy Markdown
Member

run-slow: gemma4

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/gemma4"]
quantizations: []

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN fbdd74d7 workflow commit (merge commit)
PR af5be20a branch commit (from PR)
main 10555512 base commit (on main)

Model CI Report

1 new failed tests from this PR 😭

  • gemma4:
    tests/models/gemma4/test_modeling_gemma4.py::Gemma4IntegrationTest::test_export_text_only (❌ ⟹ ❌)

Sriniketh24 and others added 2 commits May 26, 2026 11:25
- Drop max= from pixel_position_ids.clamp(): only negative values
  (padding sentinels) need guarding; valid positions are in-range by
  construction, matching the original F.one_hot behavior.
- Remove Gemma4VisionPatchEmbedderTest per reviewer request; slow CI
  integration tests are sufficient to catch regressions.
@zucchini-nlp
Copy link
Copy Markdown
Member

run-slow: gemma4

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma4

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/gemma4"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN b15cc21c workflow commit (merge commit)
PR 8696bd9b branch commit (from PR)
main f39b5c8b base commit (on main)

Model CI Report

1 new failed tests from this PR 😭

  • gemma4:
    tests/models/gemma4/test_modeling_gemma4.py::Gemma4IntegrationTest::test_export_text_only (❌ ⟹ ❌)

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging, thanks for iterating!

@zucchini-nlp zucchini-nlp enabled auto-merge May 28, 2026 11:37
@zucchini-nlp zucchini-nlp added this pull request to the merge queue May 28, 2026
Merged via the queue into huggingface:main with commit bc8f70a May 28, 2026
23 of 24 checks passed
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
…gs (huggingface#46176)

* [Gemma4] Replace one-hot matmul with F.embedding in position embeddings

Replace the one-hot encoding + matrix multiply pattern in
Gemma4VisionPatchEmbedder._position_embeddings with two F.embedding
lookups (one per spatial axis) summed together. This is mathematically
equivalent but avoids materializing a ~19 GiB intermediate one-hot
tensor (int64) and its bf16 cast copy during training with large
batch sizes.

Fixes huggingface#46175

AI-assisted contribution (Claude Code).

* Address review feedback: remove upper-bound clamp and test class

- Drop max= from pixel_position_ids.clamp(): only negative values
  (padding sentinels) need guarding; valid positions are in-range by
  construction, matching the original F.one_hot behavior.
- Remove Gemma4VisionPatchEmbedderTest per reviewer request; slow CI
  integration tests are sufficient to catch regressions.

---------

Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
…gs (huggingface#46176)

* [Gemma4] Replace one-hot matmul with F.embedding in position embeddings

Replace the one-hot encoding + matrix multiply pattern in
Gemma4VisionPatchEmbedder._position_embeddings with two F.embedding
lookups (one per spatial axis) summed together. This is mathematically
equivalent but avoids materializing a ~19 GiB intermediate one-hot
tensor (int64) and its bf16 cast copy during training with large
batch sizes.

Fixes huggingface#46175

AI-assisted contribution (Claude Code).

* Address review feedback: remove upper-bound clamp and test class

- Drop max= from pixel_position_ids.clamp(): only negative values
  (padding sentinels) need guarding; valid positions are in-range by
  construction, matching the original F.one_hot behavior.
- Remove Gemma4VisionPatchEmbedderTest per reviewer request; slow CI
  integration tests are sufficient to catch regressions.

---------

Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Gemma4] Gemma4VisionPatchEmbedder._position_embeddings materializes a ~19 GiB one-hot tensor that's mathematically a 2-row embedding lookup

3 participants