Describe the bug
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
)
Reproduction
when we calculate the hidden states inside the auraflow attention class, we are not passing the attention mask into the sdpa function
this leads to non-zero attention scores on the padded positions of the input. when then training on long sequence lengths, the model is unnecessarily perturbed by the change. loss can be as high as 2.0! it is about as bad as reparameterising the model.
so that's an issue for another day, but we should at least make the attention mask optional in the transformer __call__ method that then passes it through to the attention class, similar to how deepfloyd handles them as an input to the unet __call__ method.
Logs
No response
System Info
diffusers git
Who can help?
@sayakpaul @yiyixuxu (and @DN6 since your tag is related to SD3)
Describe the bug
Reproduction
when we calculate the hidden states inside the auraflow attention class, we are not passing the attention mask into the sdpa function
this leads to non-zero attention scores on the padded positions of the input. when then training on long sequence lengths, the model is unnecessarily perturbed by the change. loss can be as high as 2.0! it is about as bad as reparameterising the model.
so that's an issue for another day, but we should at least make the attention mask optional in the transformer
__call__method that then passes it through to the attention class, similar to how deepfloyd handles them as an input to the unet__call__method.Logs
No response
System Info
diffusers git
Who can help?
@sayakpaul @yiyixuxu (and @DN6 since your tag is related to SD3)