Skip to content

nnx MultiHeadAttention support for different feature dimensions of inputs_q and inputs_k/v #4756

@ryan112358

Description

@ryan112358

nnx MultiHeadAttention currently requires a single "in_features" provided at construction time, but I'd like to construct the query from an input with a feature dim H1 and the key/value from an input with feature dim H2. the previous linen implementation I believe supports this due to shape inference. For nnx, can there be support for this use case? Either via an additional kwarg, or generalizing the allowed inputs for "in_features"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions