[gemma4] infer from config instead of hardcoding#45606
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@eustlb thanks for the clarification and the fix - context_size // 2 + 1 makes sense here given the blocked attention path and _rel_shift. One thing that may still be useful is a small regression test for nondefault audio configs, so this does not silently stay tied to the default length again. Something along these lines in tests/models/gemma4/test_modeling_gemma4.py: def test_audio_rel_pos_encoding_uses_context_size_from_config(self):
from transformers.models.gemma4.configuration_gemma4 import Gemma4AudioConfig
from transformers.models.gemma4.modeling_gemma4 import Gemma4AudioRelPositionalEncoding
config = Gemma4AudioConfig(
hidden_size=32,
attention_chunk_size=6,
attention_context_left=5,
attention_context_right=1,
use_clipped_linears=False,
)
module = Gemma4AudioRelPositionalEncoding(config)
hidden_states = torch.zeros(1, 3, config.hidden_size)
pos = module(hidden_states)
context_size = config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
expected_len = context_size // 2 + 1
self.assertEqual(pos.shape, (1, expected_len, config.hidden_size))
position_ids = torch.arange(context_size // 2, -1, -1, device=hidden_states.device)[..., None]
scaled_time = position_ids * module.inv_timescales.to(device=hidden_states.device)
expected = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1).to(hidden_states.dtype)
torch.testing.assert_close(pos, expected)Happy to open a follow-up PR for the test as well if that is preferred. |
…s-gemma4 Co-Authored-By: Omar Zoloev <ozoloevwork@gmail.com>
05e6e94 to
cc87b12
Compare
|
thanks @mathceo! added you as a co-author here |
|
run-slow: gemma4 |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gemma4 |
|
This comment contains models: ["models/gemma4"] |
CI ResultsCommit Info
Model CI Report❌ 1 new failed tests from this PR 😭
|
|
The different failure is a different PID in the test itself lol, merging |
* infer from config instead of hardcoding * Update test_modeling_gemma4.py * Update modeling_gemma4.py * Update modeling_gemma4.py * Update modeling_gemma4.py * make style * add small docstring for reference --------- Co-authored-by: omar zoloev <ozoloevwork@gmail.com> Co-authored-by: vasqu <antonprogamer@gmail.com>
What does this PR do?
As per title. Fix #45468