-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grouped Query Attention + Refactor Attn #492
Conversation
db9d47e
to
9f792bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GeneralizedAttention
is a valid name, but should be call it GroupedQueryAttention
to not confuse ppl?
or alternatively create
class GroupedQueryAttention(GeneralizedAttention):
def __init__(..., groups=G, ...):
super().__init__(..., kv_n_heads=G, ...)
or something like that (following convention of other impl)
Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com>
I've renamed it to GroupedQueryAttention, as per discussion offline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please include a comparison run for a normal mha model before and after this pr? I'd like to make sure we don't have a perf regression (or at least know the quantity of it)
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
few nits, but lgtm pending the performance comparison
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Adding grouped query attention to LLM-foundry, and making the
GQA
class a superclass of MQA and MHA attention, as it is a generalization of those two variants of attention.Things to note:
repeat_interleave
to make the grouped query tensor the same dimensions as multi-head attention, which does allocate new memory, compared to usingexpand
. This can be updated in the future, but is the safer bet for now, given that we previously saw edge-cases with usingexpand
vsrepeat
for particularhead_dim
settings causing NaNs