Skip to content

Commit

Permalink
[Paddle-Inference]support GQA in variable_length_memory_efficient_att…
Browse files Browse the repository at this point in the history
…ention (PaddlePaddle#58836)

[Paddle-Inference]support GQA in variable_length_memory_efficient_attention (PaddlePaddle#58836)
  • Loading branch information
zhoutianzi666 committed Nov 9, 2023
1 parent 18e39bd commit 63b8380
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 13 deletions.
16 changes: 12 additions & 4 deletions paddle/phi/infermeta/multiary.cc
Expand Up @@ -2930,11 +2930,19 @@ void VariableLengthMemoryEfficientAttentionInferMeta(
phi::errors::InvalidArgument(
"The batch size of Query, Key, Value should be equal."));

PADDLE_ENFORCE_EQ((key_num_head == value_num_head),
true,
phi::errors::InvalidArgument(
"The head number of Key, Value should be equal."));

PADDLE_ENFORCE_EQ(
((query_num_head == key_num_head) && (key_num_head == value_num_head)),
true,
phi::errors::InvalidArgument(
"The head number of Query, Key, Value should be equal."));
query_num_head % key_num_head,
0,
errors::InvalidArgument(
"The num_head of query must be divisible by the num_head of key, but "
"recived num_head of query is %d, and the num_head of key is %d",
query_num_head,
key_num_head));

PADDLE_ENFORCE_EQ(query_head_size == key_head_size,
true,
Expand Down
Expand Up @@ -164,6 +164,7 @@ struct FMHAGrouped {
int problem_count;
int threadblock_count;
int num_heads;
int kv_num_heads;

ElementQ *ptr_Q;
ElementK *ptr_K;
Expand Down Expand Up @@ -205,6 +206,7 @@ struct FMHAGrouped {
: problem_count(0),
threadblock_count(0),
num_heads(0),
kv_num_heads(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
Expand Down Expand Up @@ -234,6 +236,7 @@ struct FMHAGrouped {
int problem_count,
int threadblock_count,
int num_heads,
int kv_num_heads,
ElementQ *ptr_Q,
ElementK *ptr_K,
ElementM *ptr_M,
Expand All @@ -259,6 +262,7 @@ struct FMHAGrouped {
problem_count(problem_count),
threadblock_count(threadblock_count),
num_heads(num_heads),
kv_num_heads(kv_num_heads),
ptr_Q(ptr_Q),
ptr_K(ptr_K),
ptr_M(ptr_M),
Expand Down Expand Up @@ -307,6 +311,7 @@ struct FMHAGrouped {
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
int num_heads;
int kv_num_heads;

ElementQ *ptr_Q;
ElementK *ptr_K;
Expand Down Expand Up @@ -369,6 +374,7 @@ struct FMHAGrouped {
tile_count),
threadblock_count(args.threadblock_count),
num_heads(args.num_heads),
kv_num_heads(args.kv_num_heads),
ptr_Q(args.ptr_Q),
ptr_K(args.ptr_K),
ptr_P(args.ptr_P),
Expand Down Expand Up @@ -403,6 +409,7 @@ struct FMHAGrouped {
tile_count);
threadblock_count = args.threadblock_count;
num_heads = args.num_heads;
kv_num_heads = args.kv_num_heads;
ptr_Q = args.ptr_Q;
ptr_K = args.ptr_K;
ptr_P = args.ptr_P;
Expand Down Expand Up @@ -580,6 +587,8 @@ struct FMHAGrouped {

const int32_t problem_idx = problem_visitor.problem_index();
const int32_t batch_idx = problem_idx / params.num_heads;
// how many query head share a kv head?
const int32_t qhead_per_kv_head = params.num_heads / params.kv_num_heads;

if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = ElementAccumulator(0);
Expand Down Expand Up @@ -639,7 +648,8 @@ struct FMHAGrouped {
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv)},
params.ptr_V + problem_idx * params.kElementV +
params.ptr_V +
(problem_idx / qhead_per_kv_head) * params.kElementV +
iter_key_start * params.ldv,
{problem_size_1_k, problem_size_1_n},
thread_id(),
Expand Down Expand Up @@ -679,7 +689,8 @@ struct FMHAGrouped {
typename MM0::IteratorB iterator_B(
typename MM0::IteratorB::Params(
typename MM0::MmaCore::LayoutB(params.ldk)),
params.ptr_K + problem_idx * params.kElementK +
params.ptr_K +
(problem_idx / qhead_per_kv_head) * params.kElementK +
iter_key_start * params.ldk,
{problem_size_0_k, problem_size_0_n},
thread_id(),
Expand Down Expand Up @@ -834,7 +845,8 @@ struct FMHAGrouped {

typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv)},
params.ptr_V + problem_idx * params.kElementV +
params.ptr_V +
(problem_idx / qhead_per_kv_head) * params.kElementV +
iter_key_start * params.ldv,
{problem_size_1_k, problem_size_1_n},
thread_id(),
Expand Down
Expand Up @@ -179,6 +179,7 @@ def parse_args():
problem_count,
threadblock_count,
params.num_heads,
params.kv_num_heads,
const_cast<scalar_t*>(reinterpret_cast<const scalar_t*>(params.query_ptr)),
const_cast<scalar_t*>(reinterpret_cast<const scalar_t*>(params.key_ptr)),
params.mask_ptr
Expand Down Expand Up @@ -465,6 +466,7 @@ def write_main_header():
// Dimensions/strides
int32_t num_batches;
int32_t num_heads;
int32_t kv_num_heads;
int32_t query_seq_len;
int32_t key_value_seq_len;
int32_t head_size;
Expand Down
Expand Up @@ -37,6 +37,7 @@ void MultiHeadAttentionVariableForwardKernel(

params.num_batches = query.dims()[0];
params.num_heads = query.dims()[1];
params.kv_num_heads = key.dims()[1];
params.query_seq_len = query.dims()[2];
params.head_size = query.dims()[3];
params.key_value_seq_len = key.dims()[2];
Expand Down
Expand Up @@ -61,6 +61,20 @@ def create_attn_mask(


def naive_attention_impl(query, key, value, mask, scale):
batch = query.shape[0]
heads = query.shape[1]
seq_len = query.shape[2]
head_dim = query.shape[3]
kv_head = key.shape[1]

key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1])
key = key.reshape([batch, heads, seq_len, head_dim])

value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1])
value = value.reshape([batch, heads, seq_len, head_dim])

qk_res = paddle.matmul(query, key, transpose_y=True)
attention = qk_res * scale
attention = attention + mask
Expand All @@ -79,6 +93,7 @@ def setUp(self):
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.num_head = 8
self.kv_num_head = 2
self.seq_len = 64
self.dim_head = 16
self.seq_lens = paddle.to_tensor(
Expand All @@ -94,6 +109,12 @@ def setUp(self):
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float32'
self.attention_mask = create_attn_mask(
self.dtype,
Expand All @@ -112,11 +133,11 @@ def test_all(self):
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
key = np.random.random(self.shape)
key = np.random.random(self.shape_kv)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
value = np.random.random(self.shape)
value = np.random.random(self.shape_kv)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
Expand Down Expand Up @@ -148,6 +169,7 @@ def setUp(self):
self.place = paddle.CUDAPlace(0)
self.batch_size = 3
self.num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.dim_head = 32
self.seq_lens = paddle.to_tensor(
Expand All @@ -163,6 +185,12 @@ def setUp(self):
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float16'
self.attention_mask = create_attn_mask(
self.dtype,
Expand All @@ -181,6 +209,7 @@ def setUp(self):
self.place = paddle.CUDAPlace(0)
self.batch_size = 2
self.num_head = 8
self.kv_num_head = 2
self.seq_len = 32
self.dim_head = 128
self.seq_lens = paddle.to_tensor(
Expand All @@ -196,6 +225,12 @@ def setUp(self):
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'bfloat16'
self.attention_mask = create_attn_mask(
self.dtype,
Expand All @@ -218,6 +253,7 @@ def setUp(self):
self.place = paddle.CUDAPlace(0)
self.batch_size = 3
self.num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.dim_head = 32
self.seq_lens = paddle.to_tensor(
Expand All @@ -233,6 +269,12 @@ def setUp(self):
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float16'
self.attention_mask = create_attn_mask(
self.dtype,
Expand All @@ -243,8 +285,8 @@ def setUp(self):
* self.batch_size,
).numpy()
self.q = np.random.random(self.shape).astype(self.dtype)
self.k = np.random.random(self.shape).astype(self.dtype)
self.v = np.random.random(self.shape).astype(self.dtype)
self.k = np.random.random(self.shape_kv).astype(self.dtype)
self.v = np.random.random(self.shape_kv).astype(self.dtype)
self.scale = 1.0 / np.sqrt(self.shape[-1])

self.ref_out = naive_attention_impl(
Expand All @@ -263,10 +305,10 @@ def test_all(self):
name="query", shape=self.shape, dtype=self.dtype
)
k = paddle.static.data(
name="key", shape=self.shape, dtype=self.dtype
name="key", shape=self.shape_kv, dtype=self.dtype
)
v = paddle.static.data(
name="value", shape=self.shape, dtype=self.dtype
name="value", shape=self.shape_kv, dtype=self.dtype
)
mask = paddle.static.data(
name="mask",
Expand Down

0 comments on commit 63b8380

Please sign in to comment.