Skip to content

auraflow: attention mask is not passed into the sdpa function #8886

@bghira

Description

@bghira

Describe the bug

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
        )

Reproduction

when we calculate the hidden states inside the auraflow attention class, we are not passing the attention mask into the sdpa function

this leads to non-zero attention scores on the padded positions of the input. when then training on long sequence lengths, the model is unnecessarily perturbed by the change. loss can be as high as 2.0! it is about as bad as reparameterising the model.

so that's an issue for another day, but we should at least make the attention mask optional in the transformer __call__ method that then passes it through to the attention class, similar to how deepfloyd handles them as an input to the unet __call__ method.

Logs

No response

System Info

diffusers git

Who can help?

@sayakpaul @yiyixuxu (and @DN6 since your tag is related to SD3)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions