[RecurrentGemma] Support attn_implementation dispatch#46320
[RecurrentGemma] Support attn_implementation dispatch#46320YangKai0616 wants to merge 3 commits into
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: recurrent_gemma |
| return hidden_states | ||
|
|
||
|
|
||
| class RecurrentGemmaRecurrentDecoderLayer(GradientCheckpointingLayer): |
There was a problem hiding this comment.
I would not follow jamba in this case. Could we refuse them into one class? The only difference is the temporal block, right? I.e. the used class
|
On my local A100, the |
vasqu
left a comment
There was a problem hiding this comment.
Just one bigger comment but other than that rather small changes 🤗 I'm not sure how long we are maintaining this model so my focus here will be a bit less
| @@ -178,7 +179,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |||
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |||
|
|
|||
|
|
|||
| class RecurrentGemmaSdpaAttention(nn.Module): | |||
| def eager_attention_forward( | |||
There was a problem hiding this comment.
probably can be copied from 🤔
| self.partial_rotary_factor = config.partial_rotary_factor | ||
| self.scaling = self.head_dim**-0.5 | ||
| self.is_causal = True | ||
| self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) |
There was a problem hiding this comment.
| self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) | |
| self.rotary_dim = int(self.head_dim * config.partial_rotary_factor) |
nit
| key_rot, key_pass = ( | ||
| key_states[..., : self.rotary_ndims], | ||
| key_states[..., self.rotary_ndims :], | ||
| ) | ||
| query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) |
There was a problem hiding this comment.
Hmm we have a few partial rope, could you check if we could refactor this as well to existing functions? Not high prio but would be nice
| @@ -452,9 +487,6 @@ def _setup_cache(self, batch, device, dtype): | |||
| self.conv1d_state = torch.zeros((batch, self.hidden_size, self.conv1d_width - 1), device=device, dtype=dtype) | |||
|
|
|||
|
|
|||
| TEMPORAL_BLOCK_CLASSES = {"recurrent": RecurrentGemmaRecurrentBlock, "attention": RecurrentGemmaSdpaAttention} | |||
There was a problem hiding this comment.
This is too breaking imo and not worth the effort, let's keep this
| return hidden_states | ||
|
|
||
|
|
||
| class RecurrentGemmaRecurrentDecoderLayer(GradientCheckpointingLayer): |
There was a problem hiding this comment.
I would not follow jamba in this case. Could we refuse them into one class? The only difference is the temporal block, right? I.e. the used class
| _supports_flash_attn = False | ||
| _supports_sdpa = False # we can't compare with eager for now | ||
| _supports_flash_attn = True | ||
| _supports_sdpa = True |
What does this PR do?
As per the title.