Skip to content

[RecurrentGemma] Support attn_implementation dispatch#46320

Open
YangKai0616 wants to merge 3 commits into
huggingface:mainfrom
YangKai0616:sdpa-RecurrentGemmaForCausalLM
Open

[RecurrentGemma] Support attn_implementation dispatch#46320
YangKai0616 wants to merge 3 commits into
huggingface:mainfrom
YangKai0616:sdpa-RecurrentGemmaForCausalLM

Conversation

@YangKai0616
Copy link
Copy Markdown
Contributor

What does this PR do?

As per the title.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 1, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: recurrent_gemma

return hidden_states


class RecurrentGemmaRecurrentDecoderLayer(GradientCheckpointingLayer):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to Jamba.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not follow jamba in this case. Could we refuse them into one class? The only difference is the temporal block, right? I.e. the used class

@YangKai0616
Copy link
Copy Markdown
Contributor Author

On my local A100, the run_slow tests output_text from this PR are identical to upstream/main (note: upstream/main has 4 failed, 1 passed. IntegrationTest issues will not be addressed in this PR for now.).

@YangKai0616
Copy link
Copy Markdown
Contributor Author

@vasqu

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one bigger comment but other than that rather small changes 🤗 I'm not sure how long we are maintaining this model so my focus here will be a bit less

@@ -178,7 +179,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class RecurrentGemmaSdpaAttention(nn.Module):
def eager_attention_forward(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably can be copied from 🤔

self.partial_rotary_factor = config.partial_rotary_factor
self.scaling = self.head_dim**-0.5
self.is_causal = True
self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)

nit

key_rot, key_pass = (
key_states[..., : self.rotary_ndims],
key_states[..., self.rotary_ndims :],
)
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm we have a few partial rope, could you check if we could refactor this as well to existing functions? Not high prio but would be nice

@@ -452,9 +487,6 @@ def _setup_cache(self, batch, device, dtype):
self.conv1d_state = torch.zeros((batch, self.hidden_size, self.conv1d_width - 1), device=device, dtype=dtype)


TEMPORAL_BLOCK_CLASSES = {"recurrent": RecurrentGemmaRecurrentBlock, "attention": RecurrentGemmaSdpaAttention}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too breaking imo and not worth the effort, let's keep this

return hidden_states


class RecurrentGemmaRecurrentDecoderLayer(GradientCheckpointingLayer):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not follow jamba in this case. Could we refuse them into one class? The only difference is the temporal block, right? I.e. the used class

_supports_flash_attn = False
_supports_sdpa = False # we can't compare with eager for now
_supports_flash_attn = True
_supports_sdpa = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flex, attn backend?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants