[CTRL] Support attn_implementation="sdpa" dispatch#46073
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: ctrl |
|
Hi @vasqu , please help review this PR, thank you! |
vasqu
left a comment
There was a problem hiding this comment.
Overall looks good, would like to align a few smaller things here and there but shouldn't be too much. nice work!
| self.is_causal = True | ||
|
|
||
| self.depth = int(d_model_size / self.num_heads) | ||
| self.depth = int(self.d_model_size / self.num_heads) |
There was a problem hiding this comment.
I think this model is super old, would rather align a bit more with modern terminology - ig this should be head_dim
| q = self.split_into_heads(q, batch_size) | ||
| k = self.split_into_heads(k, batch_size) | ||
| v = self.split_into_heads(v, batch_size) | ||
|
|
There was a problem hiding this comment.
Imo we should remove this function and just follow what llama did (with the same reshapes etc)
| ): | ||
| normed = self.layernorm1(x) | ||
| attn_outputs = self.multi_head_attention( | ||
| attn_output = self.multi_head_attention( |
There was a problem hiding this comment.
| attn_output = self.multi_head_attention( | |
| attn_output, _ = self.multi_head_attention( |
would rather do this then the ...[0]
| _supports_sdpa = True | ||
| _can_record_outputs = { | ||
| "hidden_states": EncoderLayer, | ||
| "attentions": OutputRecorder(MultiHeadAttention, index=1), |
There was a problem hiding this comment.
| "attentions": OutputRecorder(MultiHeadAttention, index=1), | |
| "attentions": MultiHeadAttention, |
pretty sure we default to index 1 on attention so no need but not 100% sure
| class CTRLPreTrainedModel(PreTrainedModel): | ||
| config: CTRLConfig | ||
| base_model_prefix = "transformer" | ||
| _supports_sdpa = True |
There was a problem hiding this comment.
The attention looks fairly simple, could we also support all the other flags for this - flex, flash, attention backend (needs interface + kwargs passing from top to bottom)?
| def set_input_embeddings(self, new_embeddings): | ||
| self.w = new_embeddings | ||
|
|
||
| @capture_outputs |
There was a problem hiding this comment.
missing merge config with defaults
| def scaled_dot_product_attention(q, k, v, mask, attention_mask=None): | ||
| # calculate attention | ||
| matmul_qk = torch.matmul(q, k.permute(0, 1, 3, 2)) | ||
| def eager_attention_forward(module, query, key, value, attention_mask, scaling=None, dropout=0.0, **kwargs): |
There was a problem hiding this comment.
Imo, we could use a copied from here?
| @@ -289,26 +285,6 @@ | |||
| position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) | |||
| position_ids = position_ids.unsqueeze(0) | |||
There was a problem hiding this comment.
Not on you but imo we can refactor this as well, see e.g. llama where we create the input embeds earlier and get the shape from there then most of this can be reduced by quite a bit
What does this PR do?
The
CTRLmodel fails during the initialization phase when usingfrom_pretrained(..., attn_implementation="sdpa"). This PR enables standardattn_implementation="sdpa"dispatch toCTRL.Hi @ArthurZucker, pls help review, thx!