Skip to content

Commit

Permalink
add support for grouped query attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 20, 2024
1 parent 8faa05e commit 6138cce
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 15 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ I believe this is being used for the 1-10 million tokens for the latest Gemini.

In addition, the repository also contains the logic for <a href="https://arxiv.org/abs/2311.09431">Striped Attention</a>, a follow up paper that permutes the sequence for better workload balancing for autoregressive transformers.

It also contains support for <a href="https://arxiv.org/abs/2305.13245">grouped query attention</a>, popularized by Llama series of attention models. This will further save on communication costs during the ring reduce.

## Appreciation

- <a href="https://a16z.com/supporting-the-open-source-ai-community/">A16Z Open Source AI Grant Program</a> for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
Expand Down Expand Up @@ -157,4 +159,15 @@ $ python assert.py --use-cuda --causal --striped-ring-attn
}
```

```bibtex
@article{Ainslie2023GQATG,
title = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
author = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr'on and Sumit K. Sanghai},
journal = {ArXiv},
year = {2023},
volume = {abs/2305.13245},
url = {https://api.semanticscholar.org/CorpusID:258833177}
}
```

*<a href="http://www.incompleteideas.net/IncIdeas/BitterLesson.html">The Bitter Lesson</a>* - Richard Sutton
12 changes: 12 additions & 0 deletions assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def start(
causal,
striped_ring_attn,
dim,
heads,
num_grouped_query_heads,
dim_head,
use_cuda,
compare_regular_attn
Expand All @@ -52,6 +54,8 @@ def start(
dim = dim,
causal = causal,
depth = 2,
heads = heads,
num_grouped_query_heads = num_grouped_query_heads,
dim_head = dim_head,
ring_attn = True,
striped_ring_attn = striped_ring_attn,
Expand All @@ -64,6 +68,8 @@ def start(
dim = dim,
causal = causal,
depth = 2,
heads = heads,
num_grouped_query_heads = num_grouped_query_heads,
dim_head = dim_head,
ring_attn = False,
ring_seq_size = ring_seq_size,
Expand Down Expand Up @@ -143,6 +149,8 @@ def start(
@click.option('--num-buckets', default = 2, help = 'number of buckets per machine (each sharded sequence is further windowed for flash attention to achieve even greater context lengths)')
@click.option('--seq-len', default = 31, help = 'sequence length to test')
@click.option('--model-dim', default = 8, help = 'model dimensions for testing')
@click.option('--heads', default = 8, help = 'number of query attention heads')
@click.option('--num-grouped-query-heads', default = 2, help = 'number of query attention head groups')
@click.option('--dim-head', default = 16, help = 'attention head dimension')
@click.option('--compare-regular-attn', is_flag = True, help = 'compare ring to regular attention')
def test(
Expand All @@ -156,6 +164,8 @@ def test(
num_buckets: int,
seq_len: int,
model_dim: int,
heads: int,
num_grouped_query_heads: int,
dim_head: int,
compare_regular_attn: bool
):
Expand All @@ -173,6 +183,8 @@ def test(
causal,
striped_ring_attn,
model_dim,
heads,
num_grouped_query_heads,
dim_head,
use_cuda,
compare_regular_attn
Expand Down
12 changes: 12 additions & 0 deletions assert_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def start(
causal,
striped_ring_attn,
dim,
heads,
num_grouped_query_heads,
dim_head,
use_cuda,
compare_regular_attn
Expand All @@ -51,6 +53,8 @@ def start(
dim = dim,
causal = causal,
dim_head = dim_head,
heads = heads,
num_grouped_query_heads = num_grouped_query_heads,
ring_attn = True,
striped_ring_attn = striped_ring_attn,
ring_seq_size = ring_seq_size,
Expand All @@ -63,6 +67,8 @@ def start(
dim = dim,
causal = causal,
dim_head = dim_head,
heads = heads,
num_grouped_query_heads = num_grouped_query_heads,
ring_attn = False,
ring_seq_size = ring_seq_size,
bucket_size = bucket_size,
Expand Down Expand Up @@ -145,6 +151,8 @@ def start(
@click.option('--num-buckets', default = 2, help = 'number of buckets per machine (each sharded sequence is further windowed for flash attention to achieve even greater context lengths)')
@click.option('--seq-len', default = 31, help = 'sequence length to test')
@click.option('--model-dim', default = 8, help = 'model dimensions for testing')
@click.option('--heads', default = 8, help = 'number of query attention heads')
@click.option('--num-grouped-query-heads', default = 2, help = 'number of query attention head groups')
@click.option('--dim-head', default = 16, help = 'model dimensions for testing')
@click.option('--compare-regular-attn', is_flag = True, help = 'compare ring to regular attention')
def test(
Expand All @@ -158,6 +166,8 @@ def test(
num_buckets: int,
seq_len: int,
model_dim: int,
heads: int,
num_grouped_query_heads: int,
dim_head: int,
compare_regular_attn: bool
):
Expand All @@ -175,6 +185,8 @@ def test(
causal,
striped_ring_attn,
model_dim,
heads,
num_grouped_query_heads,
dim_head,
use_cuda,
compare_regular_attn
Expand Down
27 changes: 25 additions & 2 deletions ring_attention_pytorch/ring_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def default_attention(

mask_value = -torch.finfo(q.dtype).max

# account for grouped query attention

heads, kv_heads = q.shape[-2], k.shape[-2]
assert divisible_by(heads, kv_heads)
q_head_groups = heads // kv_heads

k, v = map(lambda t: repeat(t, '... h d -> ... (g h) d', g = q_head_groups), (k, v))

# similarity

sim = einsum('b i h d, b j h d -> b h i j', q, k)
Expand Down Expand Up @@ -264,6 +272,7 @@ def __init__(
*,
dim_head: int = 64,
heads: int = 8,
num_grouped_query_heads: int = 1,
causal: bool = False,
eps: float = 1e-10,
bucket_size: int = 512,
Expand All @@ -287,6 +296,14 @@ def __init__(

self.eps = eps
self.heads = heads
self.dim_head = dim_head

assert divisible_by(heads, num_grouped_query_heads), f'number of query heads ({heads}) must be divisible by the groups ({num_grouped_query_heads})'

kv_heads = heads // num_grouped_query_heads
self.num_grouped_query_heads = num_grouped_query_heads
self.qkv_head_breakdown = (heads, kv_heads, kv_heads)

self.scale = dim_head ** -0.5
self.causal = causal

Expand Down Expand Up @@ -321,10 +338,13 @@ def __init__(
# projections

dim_inner = dim_head * heads
dim_kv_inner = dim_head * kv_heads

self.to_qkv_split = (dim_inner, dim_kv_inner, dim_kv_inner)

self.to_qkv = nn.Sequential(
RMSNorm(dim) if prenorm else nn.Identity(),
nn.Linear(dim, dim_inner * 3, bias = False)
nn.Linear(dim, dim_inner + (dim_kv_inner * 2), bias = False)
)

self.to_out = nn.Linear(dim_inner, dim, bias = False)
Expand Down Expand Up @@ -369,7 +389,8 @@ def forward(
device = x.device

qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b n h d', qkv = 3, h = self.heads)

q, k, v = rearrange(qkv, 'b n (h d) -> b n h d', d = self.dim_head).split(self.qkv_head_breakdown, dim = -2)

# rotary relative positions

Expand Down Expand Up @@ -460,6 +481,7 @@ def __init__(
dim_head: int = 64,
heads: int = 8,
ff_mult: int = 4,
num_grouped_query_heads: int = 1, # grouped query attention - kv heads = (heads // num_grouped_query_heads)
bucket_size: int = 512,
ring_attn: bool = False,
striped_ring_attn: bool = False,
Expand Down Expand Up @@ -516,6 +538,7 @@ def __init__(
causal = causal,
dim_head = dim_head,
heads = heads,
num_grouped_query_heads = num_grouped_query_heads,
bucket_size = bucket_size,
ring_attn = ring_attn,
ring_seq_size = ring_seq_size,
Expand Down
29 changes: 24 additions & 5 deletions ring_attention_pytorch/ring_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import nn, einsum, Tensor
from torch.autograd.function import Function

from einops import rearrange
from einops import rearrange, repeat, reduce

from ring_attention_pytorch.ring import (
ring_pass,
Expand Down Expand Up @@ -69,11 +69,15 @@ def forward(
):
ring_size = default(ring_size, get_world_size())

cross_attn = q.shape[-2] != k.shape[-2]
cross_attn = q.shape[-3] != k.shape[-3]
ring_reduce_col &= not cross_attn
striped_ring_attn &= not cross_attn

assert k.shape[-1] == v.shape[-1]
assert k.shape[-2:] == v.shape[-2:]
q_heads, kv_heads = q.shape[-2], k.shape[-2]

assert divisible_by(q_heads, kv_heads)
q_head_groups = q_heads // kv_heads

per_machine_seq_size = k.shape[1]

Expand Down Expand Up @@ -124,6 +128,10 @@ def forward(

k, v = kv

# account for grouped query attention

k, v = map(lambda t: repeat(t, '... h d -> ... (g h) d', g = q_head_groups), (k, v))

col_splits = zip(
k.split(bucket_size, dim = -3),
v.split(bucket_size, dim = -3),
Expand Down Expand Up @@ -206,7 +214,8 @@ def forward(
max_ring_passes,
num_lookback_buckets,
striped_ring_attn,
ring_size
ring_size,
q_head_groups
)

ctx.save_for_backward(q, orig_k, orig_v, o, lse)
Expand All @@ -227,7 +236,8 @@ def backward(ctx, do):
max_ring_passes,
num_lookback_buckets,
striped_ring_attn,
ring_size
ring_size,
q_head_groups
) = ctx.args

q, k, v, o, lse = ctx.saved_tensors
Expand Down Expand Up @@ -268,6 +278,10 @@ def backward(ctx, do):

k, v, dk, dv = kv_and_dkv

# account for grouped query attention

k, v = map(lambda t: repeat(t, '... h d -> ... (g h) d', g = q_head_groups), (k, v))

col_splits = zip(
k.split(bucket_size, dim = 1),
v.split(bucket_size, dim = 1),
Expand Down Expand Up @@ -323,6 +337,11 @@ def backward(ctx, do):
dq_chunk = einsum('b h i j, b j h d -> b i h d', ds, kc)
dk_chunk = einsum('b h i j, b i h d -> b j h d', ds, qc)

# account for grouped query attention

dk_chunk = reduce(dk_chunk, '... (g h) d -> ... h d', g = q_head_groups, reduction = 'sum')
dv_chunk = reduce(dv_chunk, '... (g h) d -> ... h d', g = q_head_groups, reduction = 'sum')

dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
Expand Down
32 changes: 25 additions & 7 deletions ring_attention_pytorch/ring_flash_attention_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from beartype import beartype

from einops import rearrange
from einops import rearrange, repeat, reduce

# helpers

Expand Down Expand Up @@ -55,6 +55,12 @@ def forward(
):
from ring_attention_pytorch.triton_flash_attn import flash_attn_forward

assert k.shape[-2:] == v.shape[-2:]
q_heads, kv_heads = q.shape[-2], k.shape[-2]

assert divisible_by(q_heads, kv_heads)
q_head_groups = q_heads // kv_heads

assert all([t.is_cuda for t in (q, k, v)]), 'inputs must be all on cuda'

dtype = q.dtype
Expand Down Expand Up @@ -127,6 +133,10 @@ def forward(
for (ring_rank, (is_first, is_last)), ((kv, mask), (receive_kv, receive_mask)) in ring_pass_fn(kv, mask, receive_buffers = (receive_kv, receive_mask), max_iters = max_ring_passes, ring_size = ring_size):
k, v = kv

# account for grouped query attention

k, v = map(lambda t: repeat(t, '... h d -> ... (g h) d', g = q_head_groups), (k, v))

# translate key padding mask to bias

bias = None
Expand Down Expand Up @@ -180,6 +190,7 @@ def forward(
num_lookback_buckets,
striped_ring_attn,
ring_size,
q_head_groups,
dtype
)

Expand All @@ -206,6 +217,7 @@ def backward(ctx, do):
num_lookback_buckets,
striped_ring_attn,
ring_size,
q_head_groups,
dtype
) = ctx.args

Expand All @@ -232,7 +244,6 @@ def backward(ctx, do):
dv = torch.zeros_like(v, device = device)

# k and v will have 16 bits, and dk and dv can also be accumulated safely with the same type, i think
# view everything as float32 for ring passing

assert k.dtype == v.dtype
kv_dtype = k.dtype
Expand All @@ -252,6 +263,10 @@ def backward(ctx, do):

k, v, dk, dv = kv_and_dkv

# account for grouped query attention

k, v = map(lambda t: repeat(t, '... h d -> ... (g h) d', g = q_head_groups), (k, v))

# translate key padding mask to bias

bias = None
Expand Down Expand Up @@ -300,12 +315,15 @@ def backward(ctx, do):
causal_mask_diagonal = causal_mask_diagonal,
softmax_scale = softmax_scale
)
else:
ring_dq, ring_dk, ring_dv = 0., 0., 0.

dq.add_(ring_dq)
dk.add_(ring_dk)
dv.add_(ring_dv)
# account for grouped query attention

ring_dk = reduce(ring_dk, '... (g h) d -> ... h d', g = q_head_groups, reduction = 'sum')
ring_dv = reduce(ring_dv, '... (g h) d -> ... h d', g = q_head_groups, reduction = 'sum')

dq.add_(ring_dq)
dk.add_(ring_dk)
dv.add_(ring_dv)

if not ring_reduce_col:
continue
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.21',
version = '0.4.1',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6138cce

Please sign in to comment.