Skip to content

[CTRL] Support attn_implementation="sdpa" dispatch#46073

Open
YangKai0616 wants to merge 7 commits into
huggingface:mainfrom
YangKai0616:sdpa-ctrl
Open

[CTRL] Support attn_implementation="sdpa" dispatch#46073
YangKai0616 wants to merge 7 commits into
huggingface:mainfrom
YangKai0616:sdpa-ctrl

Conversation

@YangKai0616
Copy link
Copy Markdown
Contributor

@YangKai0616 YangKai0616 commented May 19, 2026

What does this PR do?

The CTRL model fails during the initialization phase when using from_pretrained(..., attn_implementation="sdpa"). This PR enables standard attn_implementation="sdpa" dispatch to CTRL .

Hi @ArthurZucker, pls help review, thx!

@YangKai0616 YangKai0616 changed the title [CTRL] Support attn_implementation dispatch [CTRL] Support attn_implementation="sdpa" dispatch May 20, 2026
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: ctrl

@YangKai0616
Copy link
Copy Markdown
Contributor Author

Hi @vasqu , please help review this PR, thank you!

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, would like to align a few smaller things here and there but shouldn't be too much. nice work!

self.is_causal = True

self.depth = int(d_model_size / self.num_heads)
self.depth = int(self.d_model_size / self.num_heads)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this model is super old, would rather align a bit more with modern terminology - ig this should be head_dim

Comment on lines 115 to 118
q = self.split_into_heads(q, batch_size)
k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo we should remove this function and just follow what llama did (with the same reshapes etc)

):
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention(
attn_output = self.multi_head_attention(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attn_output = self.multi_head_attention(
attn_output, _ = self.multi_head_attention(

would rather do this then the ...[0]

_supports_sdpa = True
_can_record_outputs = {
"hidden_states": EncoderLayer,
"attentions": OutputRecorder(MultiHeadAttention, index=1),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"attentions": OutputRecorder(MultiHeadAttention, index=1),
"attentions": MultiHeadAttention,

pretty sure we default to index 1 on attention so no need but not 100% sure

class CTRLPreTrainedModel(PreTrainedModel):
config: CTRLConfig
base_model_prefix = "transformer"
_supports_sdpa = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attention looks fairly simple, could we also support all the other flags for this - flex, flash, attention backend (needs interface + kwargs passing from top to bottom)?

def set_input_embeddings(self, new_embeddings):
self.w = new_embeddings

@capture_outputs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing merge config with defaults

def scaled_dot_product_attention(q, k, v, mask, attention_mask=None):
# calculate attention
matmul_qk = torch.matmul(q, k.permute(0, 1, 3, 2))
def eager_attention_forward(module, query, key, value, attention_mask, scaling=None, dropout=0.0, **kwargs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, we could use a copied from here?

Comment on lines 265 to 286
@@ -289,26 +285,6 @@
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not on you but imo we can refactor this as well, see e.g. llama where we create the input embeds earlier and get the shape from there then most of this can be reduced by quite a bit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants