Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouped-Query Attention #384

Open
19h opened this issue Jul 19, 2023 · 3 comments
Open

Grouped-Query Attention #384

19h opened this issue Jul 19, 2023 · 3 comments
Labels
question General questions about using Llama2 research-paper Issues and questions relating to the published architecture or methodology

Comments

@19h
Copy link

19h commented Jul 19, 2023

Hello Meta GenAI team (cc @ruanslv),

With regards to the 70B model, I'm currently looking into the implementation of the GQA architecture -- specifically after noticing the 8192 x 1024 layer shapes, I was trying to identify the conditional GQA parts in your reference implementation but couldn't pin it down.

Given that there are some conditions that smell suspiciously GQA-related, could you please elaborate on the parts of the implementation that enable this architecture specifically for the 34B / 70B models?

Thanks

@19h
Copy link
Author

19h commented Jul 19, 2023

Out of impatience I asked Claude 2 about the differences between Implementation A (LLaMA 1) and Implementation B (LLaMA 2):

  • Increased model size (dim, n_layers, n_heads, etc). This will increase the model capacity.
  • Added a n_kv_heads argument to allow having separate key/value heads from query heads. This can improve attention computation efficiency.
  • Added repeat_kv function to repeat k/v projections to match number of query heads if n_kv_heads < n_heads. This enables the flexibility of having fewer key/value projections.
  • Added a ffn_dim_multiplier argument to optionally increase feedforward hidden layer size. This further increases model capacity.
  • Changed the feedforward output to not just predict the last token, but the full sequence. This allows the model to be used for tasks like masked language modeling.
  • Switched from float32 to float16 operations in some places. This can improve efficiency on hardware with tensor cores.
  • Various minor code style/organization changes.

Then I explained the concept of GQA and asked it for the parts enabling GQA:

  • The key difference between Implementation A and B that enables Grouped Query Attention is having separate n_kv_heads and n_heads arguments.

  • In Implementation B, n_kv_heads allows having fewer key/value projections than query projections. This enables queries to be grouped into fewer key/value heads, which is the core idea behind GQA.

  • The repeat_kv function is used to repeat the keys/values to match the number of query heads. This allows compatibility with a standard multi-head attention model trained with full n_heads.

  • During inference, n_kv_heads can be set to an intermediate value between 1 (MQA) and n_heads (full MHA) to trade off quality and speed.

  • In summary, the separate n_kv_heads argument and repeat_kv function provide the flexibility needed to easily switch to GQA after training a full MHA model. The other changes in Implementation B enable training a high quality MHA model in the first place.

Is this a proper assessment by Claude?

@byildiz
Copy link

byildiz commented Jul 19, 2023

Hi,

I think this image is a good summary of GQA:

Screenshot 2023-07-19 at 04 39 22

As far as I understand GQA reduces cache sizes for keys and values by n_heads / n_kv_heads times. Because they have reduced sizes, they are later needed to be back to actual size. This is done by simply repeating. This repetition process is done in repeat_kv function at:

https://github.com/facebookresearch/llama/blob/4d92db8a1db6c7f663252bf3477d2c4b8bad2385/llama/model.py#L77

As an example:

>>> x = torch.rand(1, 2, 3, 4)
>>> x
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.2516, 0.6651, 0.1699, 0.0092]],

         [[0.9057, 0.8071, 0.6634, 0.5770],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.3958, 0.9162, 0.7325, 0.9555]]]])
>>> n_rep = 2
>>> bs, slen, n_kv_heads, head_dim = x.shape
>>> x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim)
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
          [0.1269, 0.8517, 0.4630, 0.1814],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.2516, 0.6651, 0.1699, 0.0092],
          [0.2516, 0.6651, 0.1699, 0.0092]],

         [[0.9057, 0.8071, 0.6634, 0.5770],
          [0.9057, 0.8071, 0.6634, 0.5770],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.3958, 0.9162, 0.7325, 0.9555],
          [0.3958, 0.9162, 0.7325, 0.9555]]]])

The only major change that I notice is this repetition. I hope this helps you.

@missflash
Copy link

Hi,

I think this image is a good summary of GQA:

Screenshot 2023-07-19 at 04 39 22 As far as I understand GQA reduces cache sizes for keys and values by `n_heads / n_kv_heads` times. Because they have reduced sizes, they are later needed to be back to actual size. This is done by simply repeating. This repetition process is done in `repeat_kv` function at:

https://github.com/facebookresearch/llama/blob/4d92db8a1db6c7f663252bf3477d2c4b8bad2385/llama/model.py#L77

As an example:

>>> x = torch.rand(1, 2, 3, 4)
>>> x
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.2516, 0.6651, 0.1699, 0.0092]],

         [[0.9057, 0.8071, 0.6634, 0.5770],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.3958, 0.9162, 0.7325, 0.9555]]]])
>>> n_rep = 2
>>> bs, slen, n_kv_heads, head_dim = x.shape
>>> x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim)
tensor([[[[0.1269, 0.8517, 0.4630, 0.1814],
          [0.1269, 0.8517, 0.4630, 0.1814],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.3441, 0.1733, 0.3397, 0.5518],
          [0.2516, 0.6651, 0.1699, 0.0092],
          [0.2516, 0.6651, 0.1699, 0.0092]],

         [[0.9057, 0.8071, 0.6634, 0.5770],
          [0.9057, 0.8071, 0.6634, 0.5770],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.1865, 0.2643, 0.8765, 0.8715],
          [0.3958, 0.9162, 0.7325, 0.9555],
          [0.3958, 0.9162, 0.7325, 0.9555]]]])

The only major change that I notice is this repetition. I hope this helps you.

Thanks for the great explanation! :)

@macarran macarran added question General questions about using Llama2 research-paper Issues and questions relating to the published architecture or methodology labels Sep 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question General questions about using Llama2 research-paper Issues and questions relating to the published architecture or methodology
Projects
None yet
Development

No branches or pull requests

4 participants