Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouped Query Attention + Refactor Attn #492

Merged
merged 31 commits into from
Aug 11, 2023

Conversation

sashaDoubov
Copy link
Contributor

@sashaDoubov sashaDoubov commented Jul 27, 2023

Adding grouped query attention to LLM-foundry, and making the GQA class a superclass of MQA and MHA attention, as it is a generalization of those two variants of attention.
Things to note:

  • we currently use repeat_interleave to make the grouped query tensor the same dimensions as multi-head attention, which does allocate new memory, compared to using expand. This can be updated in the future, but is the safer bet for now, given that we previously saw edge-cases with using expand vs repeat for particular head_dim settings causing NaNs
  • Part of this change is also changing how we initialize the QKV matrix, which used to initialize (n_heads * head_dim, d_model), but now is changed to initialize each n_head (head_dim, d_model) matrix separately

@sashaDoubov sashaDoubov changed the title Generalized attn Grouped Attention + Refactor Attn Jul 27, 2023
@sashaDoubov sashaDoubov changed the title Grouped Attention + Refactor Attn Grouped Query Attention + Refactor Attn Jul 27, 2023
@sashaDoubov sashaDoubov marked this pull request as ready for review August 2, 2023 17:52
Copy link
Contributor

@vchiley vchiley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GeneralizedAttention is a valid name, but should be call it GroupedQueryAttention to not confuse ppl?

or alternatively create

class GroupedQueryAttention(GeneralizedAttention):
    def __init__(..., groups=G, ...):
        super().__init__(..., kv_n_heads=G, ...)

or something like that (following convention of other impl)

sashaDoubov and others added 2 commits August 3, 2023 18:16
Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com>
@sashaDoubov
Copy link
Contributor Author

GeneralizedAttention is a valid name, but should be call it GroupedQueryAttention to not confuse ppl?

or alternatively create

class GroupedQueryAttention(GeneralizedAttention):
    def __init__(..., groups=G, ...):
        super().__init__(..., kv_n_heads=G, ...)

or something like that (following convention of other impl)

I've renamed it to GroupedQueryAttention, as per discussion offline

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please include a comparison run for a normal mha model before and after this pr? I'd like to make sure we don't have a perf regression (or at least know the quantity of it)

llmfoundry/models/mpt/configuration_mpt.py Show resolved Hide resolved
llmfoundry/models/layers/attention.py Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/mpt/modeling_mpt.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/blocks.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/blocks.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few nits, but lgtm pending the performance comparison

llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/attention.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/blocks.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/blocks.py Outdated Show resolved Hide resolved
sashaDoubov and others added 7 commits August 10, 2023 16:17
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
@sashaDoubov
Copy link
Contributor Author

Could you please include a comparison run for a normal mha model before and after this pr? I'd like to make sure we don't have a perf regression (or at least know the quantity of it)

Here are some perf numbers, which look good outside of the loss spikes. However, I did see loss spikes with both "before and after" the change:

image image image

@sashaDoubov sashaDoubov merged commit d2fbc3b into mosaicml:main Aug 11, 2023
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants