Skip to content

[gemma4] infer from config instead of hardcoding#45606

Merged
vasqu merged 9 commits intomainfrom
remove-hardcoded-rel-pos-gemma4
Apr 27, 2026
Merged

[gemma4] infer from config instead of hardcoding#45606
vasqu merged 9 commits intomainfrom
remove-hardcoded-rel-pos-gemma4

Conversation

@eustlb
Copy link
Copy Markdown
Contributor

@eustlb eustlb commented Apr 23, 2026

What does this PR do?

As per title. Fix #45468

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@mathceo
Copy link
Copy Markdown
Contributor

mathceo commented Apr 23, 2026

@eustlb thanks for the clarification and the fix - context_size // 2 + 1 makes sense here given the blocked attention path and _rel_shift.

One thing that may still be useful is a small regression test for nondefault audio configs, so this does not silently stay tied to the default length again.

Something along these lines in tests/models/gemma4/test_modeling_gemma4.py:

def test_audio_rel_pos_encoding_uses_context_size_from_config(self):
    from transformers.models.gemma4.configuration_gemma4 import Gemma4AudioConfig
    from transformers.models.gemma4.modeling_gemma4 import Gemma4AudioRelPositionalEncoding

    config = Gemma4AudioConfig(
        hidden_size=32,
        attention_chunk_size=6,
        attention_context_left=5,
        attention_context_right=1,
        use_clipped_linears=False,
    )

    module = Gemma4AudioRelPositionalEncoding(config)
    hidden_states = torch.zeros(1, 3, config.hidden_size)

    pos = module(hidden_states)

    context_size = config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
    expected_len = context_size // 2 + 1

    self.assertEqual(pos.shape, (1, expected_len, config.hidden_size))

    position_ids = torch.arange(context_size // 2, -1, -1, device=hidden_states.device)[..., None]
    scaled_time = position_ids * module.inv_timescales.to(device=hidden_states.device)
    expected = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1).to(hidden_states.dtype)

    torch.testing.assert_close(pos, expected)

Happy to open a follow-up PR for the test as well if that is preferred.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Like @mathceo also mentioned, let's add a small regression test just in case 🫡

@mathceo
Copy link
Copy Markdown
Contributor

mathceo commented Apr 23, 2026

I already implemented the regression test in #45607. Since #45606 is not my branch, I can't push directly there from my side. Feel free to cherry-pick the test commit from #45607
@eustlb

@eustlb eustlb force-pushed the remove-hardcoded-rel-pos-gemma4 branch from 05e6e94 to cc87b12 Compare April 24, 2026 09:22
@eustlb
Copy link
Copy Markdown
Contributor Author

eustlb commented Apr 24, 2026

thanks @mathceo! added you as a co-author here

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 27, 2026

run-slow: gemma4

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma4

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/gemma4"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 32ada79d workflow commit (merge commit)
PR d404da97 branch commit (from PR)
main bbb51c83 base commit (on main)

Model CI Report

1 new failed tests from this PR 😭

  • gemma4:
    tests/models/gemma4/test_modeling_gemma4.py::Gemma4IntegrationTest::test_export_text_only (❌ ⟹ ❌)

@vasqu vasqu added this pull request to the merge queue Apr 27, 2026
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 27, 2026

The different failure is a different PID in the test itself lol, merging

Merged via the queue into main with commit 5d24d8c Apr 27, 2026
22 of 23 checks passed
@vasqu vasqu deleted the remove-hardcoded-rel-pos-gemma4 branch April 27, 2026 13:05
ArthurZucker pushed a commit that referenced this pull request Apr 28, 2026
* infer from config instead of hardcoding

* Update test_modeling_gemma4.py

* Update modeling_gemma4.py

* Update modeling_gemma4.py

* Update modeling_gemma4.py

* make style

* add small docstring for reference

---------

Co-authored-by: omar zoloev <ozoloevwork@gmail.com>
Co-authored-by: vasqu <antonprogamer@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Gemma-4 Gemma4AudioRelPositionalEncoding

4 participants