Skip to content

Support out_dim argument for Attention block #7877

@tigerlittle1

Description

@tigerlittle1

Is your feature request related to a problem? Please describe.
When i feed the out_dim argument in __init__ in Attention block it will raise the shape error, because the query_dim != out_dim. In this case, the following code try to keep the given channel of hidden_states.

hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

But it should change the channel as the output of hidden_states = attn.to_out[0](hidden_states).

Describe the solution you'd like.
I suggest the change of code base :

hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

to hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, -1, height, width), then it will respect the channel of hidden_states.
Maybe I will make a PR later.

Describe alternatives you've considered.
None.

Additional context.
None.

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleIssues that haven't received updates

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions