Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiHeadAttention is missing a position_bias call argument, making Alibi or T5-style position embedding impossible to implement #18423

Closed
martin-gorner opened this issue Aug 2, 2023 · 10 comments
Assignees
Labels

Comments

@martin-gorner
Copy link
Contributor

martin-gorner commented Aug 2, 2023

Alibi or T5 relative position embeddings modify the attention computation instead of being simply added to token embeddings.

The T5 implementation of MultiHeadAttention has a position_bias argument that allows this.

The Keras MultiHeadAttention seems to be missing this argument.

Without this, I don't think that implementing T5-style or Alibi relative position embeddings is possible.

@fchollet
Copy link
Member

fchollet commented Aug 2, 2023

@mattdangerw what do you think?

@mattdangerw
Copy link
Member

Could add! Honestly, this is one of many extensions/variations to MHA we could consider...

  1. Cached attention, which is needed for any efficient decoding.
  2. Attention score biases, which are needed for alibi and t5 style attention.
  3. Rotary embeddings, which are applied to the key and query after the dense projections. (overall, seems a bit more popular than alibi or t5 bias)
  4. Multi query attention, grouped query attention. (I think this is supplanting multi-head as best practice)
  5. Probably more I am missing.

Of all of these, 1) is probably the most important, because without a caching solution, you wouldn't want to use the layer for anything generative in practice. And most of these new techniques are becoming popular strictly for generative models. But might be worth taking a step back...

Probably the first question we should answer is whether we want a robust attention layer offering to live in keras-core on keras-nlp.

After that, some design choises...

  • A single MHA layer with a lot of options?
  • A few different layers? And if so, where do we split? We face a cartesian product of features you could combine.
  • Subclassing as a preferred solution?

@fchollet
Copy link
Member

My strategy would be:

  1. Use subclassing for any one-off or new use case and make the layer subclass part of the model codebase.
  2. If a certain pattern of subclassing occurs often across multiple models, then we need to add it as a built-in feature for better UX.
  3. If it can be added to the layer in a "flat" way (meaning that it has no/low coupling with existing logic and does not add combinatorial complexity), we can add it as a layer argument on the MHA layer.
  4. If it interacts with the rest of the layer logic in a way that would make the MHA layer very complex, roll out a new layer (in KerasNLP since it will be NLP specific) that subclasses MHA or maybe even rewrite it from scratch -- where the feature name is in the name of the layer.

@mattdangerw
Copy link
Member

mattdangerw commented Aug 15, 2023

Talked with Francois. I think we could make the following changes...

Add the following call arguments:

  • Add cache=None or key_cache=None, value_cache=None depending on if we want a single tensor or two.
  • Add cache_index=None or cache_update_index=None (in KerasNLP we use the latter), which controls where the newly computed key/values will update the cache.
  • Add attention_bias=None. Not as popular as RoPE, but well defined and simple to add without making our layer spaghetti.

MultiQueryAttention/GroupQueryAttention I am not sure about. Either we add some init arguments or make subclass. Essentially we just need to allow controlling the query and key/value head count separately.

For RoPE, we can recommend a subclass for now. There is some variation with how it's configure/applied, so it would be awkward to add arguments for it. I think this would do it...

    def _compute_attention(self, query, key, value, **kwargs):
        query = rotary_embedding(query)
        key = rotary_embedding(key)
        return super()._compute_attention(query, key, value, **kwargs)

@fchollet
Copy link
Member

MultiQueryAttention/GroupQueryAttention I am not sure about. Either we add some init arguments or make subclass. Essentially we just need to allow controlling the query and key/value head count separately.

Ideally we just need to check that our current factoring of MHA enables to implement this via a subclass without too much hassle.

@mattdangerw
Copy link
Member

Ideally we just need to check that our current factoring of MHA enables to implement this via a subclass without too much hassle.

I just did a multi-query attn layer for a model conversion, and I would say it's a hassle right now sadly. I ended up abandoning a subclass and just rewriting the layer. But should be doable to improve!

@martin-gorner
Copy link
Contributor Author

Thank you for bringing a broader perspective to this conversation!

@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
@sachinprasadhs sachinprasadhs self-assigned this Apr 11, 2024
@sachinprasadhs sachinprasadhs added the type:feature The user is asking for a new feature. label Apr 11, 2024
@sachinprasadhs
Copy link
Collaborator

With the addition of GroupedQueryAttention layer if the concern of this issue is addressed, feel free to close the issue. Thanks!

Copy link

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Apr 26, 2024
Copy link

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants