Skip to content

Commit

Permalink
Add Pallas attention kernel for GPU serving.
Browse files Browse the repository at this point in the history
Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 607404565
  • Loading branch information
2 people authored and jax authors committed Feb 15, 2024
1 parent 3708336 commit 9fcf9e5
Show file tree
Hide file tree
Showing 4 changed files with 550 additions and 0 deletions.
13 changes: 13 additions & 0 deletions jax/BUILD
Expand Up @@ -605,6 +605,19 @@ pytype_strict_library(
],
)

pytype_strict_library(
name = "pallas_gpu_ops",
srcs = glob(["experimental/pallas/ops/gpu/**/*.py"]),
visibility = [
":pallas_gpu_users",
],
deps = [
":jax",
":pallas",
":pallas_gpu",
] + py_deps("numpy"),
)

pytype_strict_library(
name = "pallas_tpu_ops",
srcs = glob(["experimental/pallas/ops/tpu/**/*.py"]),
Expand Down
345 changes: 345 additions & 0 deletions jax/experimental/pallas/ops/gpu/decode_attention.py
@@ -0,0 +1,345 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module containing decode attention."""
from __future__ import annotations

import functools
from typing import Any

import jax
from jax import lax
from jax.experimental import pallas as pl
import jax.numpy as jnp


def attn_forward_kernel(
q_ref, # [num_heads, head_dim]
k_ref, # [k_seq_len, head_dim]
v_ref, # [k_seq_len, head_dim]
o_ref: Any, # [num_heads, head_dim]
*residual_refs: Any, # Residual outputs: [num_heads,], [num_heads,]
sm_scale: float,
block_k: int,
):
block_h, head_dim = q_ref.shape
k_seq_len, _ = k_ref.shape
start_q = pl.program_id(0)

# o is the buffer where we accumulate the output on sram.
# m_i and l_i (see FlashAttention2 paper) are updated during the k,v loop.
m_i = jnp.zeros(block_h, dtype=jnp.float32) - float("inf")
l_i = jnp.zeros(block_h, dtype=jnp.float32)
o = jnp.zeros((block_h, head_dim), dtype=jnp.float32)

# Load q: it will stay in L1 throughout. Indices form a matrix because we
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
# q tile has shape [block_h, head_dim].
curr_q_slice = pl.dslice(start_q * block_h, block_h)
q = pl.load(q_ref, (curr_q_slice, pl.dslice(None)))

def _dot(a, b):
# if a.shape[0] == 1:
# # Use matrix vector product
# return (a.T * b).sum(axis=0, keepdims=True)
return pl.dot(a, b)

# Loop over blocks of kv to process entire kv seq_len.
# Grid loops over q blocks over num_heads.
def body(start_k, carry):
o_prev, m_prev, l_prev = carry
curr_k_slice = pl.dslice(start_k * block_k, block_k)

k = pl.load(k_ref, (curr_k_slice, slice(None)))
qk = _dot(q, k.T) # [block_h, block_k]
if sm_scale != 1.0:
qk *= sm_scale # [block_h, block_k]

m_curr = qk.max(axis=-1)
m_next = jnp.maximum(m_prev, m_curr)
correction = jnp.exp(m_prev - m_next)
l_prev_corr = correction * l_prev
s_curr = jnp.exp(
qk - m_next[:, None]
) # Use m_next instead of m_curr to avoid a correction on l_curr
l_curr = s_curr.sum(axis=-1)
l_next = l_prev_corr + l_curr
v = pl.load(v_ref, (curr_k_slice, slice(None)))
o_curr = _dot(s_curr.astype(v.dtype), v)

# flash2 unscaled_o
o_next = correction[:, None] * o_prev + o_curr
return o_next, m_next, l_next

upper_bound = pl.cdiv(k_seq_len, block_k) # type: ignore
# o is left unscaled; it will be scaled in the final reduction step
o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i))

if residual_refs:
l_ref, m_ref = residual_refs
pl.store(l_ref, (curr_q_slice,), l_i)
pl.store(m_ref, (curr_q_slice,), m_i)
# Write output to dram.
o = o.astype(o_ref.dtype)
pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o)


def attn_unbatched(
q, # [num_heads, head_dim]
k, # [k_seq_len, head_dim]
v, # [k_seq_len, head_dim]
sm_scale: float,
block_h: int,
block_k: int,
k_splits: int,
num_warps: int | None,
num_stages: int,
grid: tuple[int, ...] | None,
interpret: bool,
debug: bool,
):
num_heads, head_dim = q.shape
k_seq_len, _ = k.shape
# Pad num query heads to 16 if needed, and slice output at the end.
original_num_heads = None
if num_heads < 16:
q = jnp.pad(q, ((0, 16 - num_heads), (0, 0)))
original_num_heads = num_heads
num_heads = q.shape[0]
block_h = min(block_h, num_heads)
head_splits = pl.cdiv(num_heads, block_h)
grid_ = grid
if grid_ is None:
grid_ = (head_splits, k_splits)

assert (
k_seq_len % k_splits == 0
), f"{k_seq_len=} must be divisible by {k_splits=}"
k = k.reshape(k_splits, k_seq_len // k_splits, head_dim)
v = v.reshape(k_splits, k_seq_len // k_splits, head_dim)
k_seq_len = k_seq_len // k_splits
assert min(num_heads, head_dim, k_seq_len) >= 16, "Minimum pl.dot size is 16"
block_k = min(block_k, k_seq_len)
num_warps_ = num_warps
if num_warps_ is None:
num_warps_ = 4
kernel = functools.partial(
attn_forward_kernel,
sm_scale=sm_scale,
block_k=block_k,
)

o, l, m = pl.pallas_call(
kernel,
grid=grid_,
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (block_h, head_dim)),
pl.BlockSpec(lambda i, j: (j, 0, 0), (None, k_seq_len, head_dim)),
pl.BlockSpec(lambda i, j: (j, 0, 0), (None, k_seq_len, head_dim)),
],
out_specs=[
pl.BlockSpec(lambda i, j: (j, i, 0), (None, block_h, head_dim)), # o
pl.BlockSpec(
lambda i, j: (j, i),
(
None,
block_h,
),
), # l
pl.BlockSpec(
lambda i, j: (j, i),
(
None,
block_h,
),
), # m
],
num_warps=num_warps_,
num_stages=num_stages,
out_shape=[
jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o
jax.ShapeDtypeStruct(
shape=(k_splits, num_heads), dtype=jnp.float32
), # l
jax.ShapeDtypeStruct(
shape=(k_splits, num_heads), dtype=jnp.float32
), # m
],
debug=debug,
interpret=interpret,
name="mha_forward",
)(q, k, v)

# final round of flash
m_next = m.max(axis=0)
correction = jnp.exp(m - m_next[None])
o = o * correction[:, :, None]
l_next = (l * correction).sum(axis=0)
o = o.sum(axis=0) / l_next[:, None]

if original_num_heads is not None:
o = o[:original_num_heads, :]
return o


@functools.partial(
jax.jit,
static_argnames=[
"sm_scale",
"block_h",
"block_k",
"k_splits",
"num_warps",
"num_stages",
"grid",
"interpret",
"debug",
],
)
def mqa(
q, # [batch_size, num_heads, head_dim]
k, # [batch_size, k_seq_len, head_dim]
v, # [batch_size, k_seq_len, head_dim]
sm_scale: float = 1.0,
block_h: int = 16,
block_k: int = 256,
k_splits: int = 16,
num_warps: int | None = None,
num_stages: int = 2,
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
):
inner = functools.partial(
attn_unbatched,
sm_scale=sm_scale,
block_h=block_h,
block_k=block_k,
k_splits=k_splits,
num_warps=num_warps,
num_stages=num_stages,
grid=grid,
interpret=interpret,
debug=debug,
)
return jax.vmap(inner)(q, k, v)


@functools.partial(
jax.jit,
static_argnames=[
"sm_scale",
"block_h",
"block_k",
"k_splits",
"num_warps",
"num_stages",
"grid",
"interpret",
"debug",
],
)
def gqa(
q, # [batch_size, num_q_heads, head_dim]
k, # [batch_size, k_seq_len, num_kv_heads, head_dim]
v, # [batch_size, k_seq_len, num_kv_heads, head_dim]
sm_scale: float = 1.0,
block_h: int = 16,
block_k: int = 256,
k_splits: int = 16,
num_warps: int | None = None,
num_stages: int = 2,
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
):
batch_size, q_heads, head_dim = q.shape
kv_heads = k.shape[2]
assert kv_heads == v.shape[2]
assert q_heads % kv_heads == 0
q_heads_per_kv_head = q_heads // kv_heads
q_reshaped = q.reshape(batch_size, kv_heads, q_heads_per_kv_head, head_dim)
k_transposed = jnp.swapaxes(
k, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
v_transposed = jnp.swapaxes(
v, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
inner = functools.partial(
attn_unbatched,
sm_scale=sm_scale,
block_h=block_h,
block_k=block_k,
k_splits=k_splits,
num_warps=num_warps,
num_stages=num_stages,
grid=grid,
interpret=interpret,
debug=debug,
)
with_kv_heads = jax.vmap(inner)
o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed)
return o.reshape(batch_size, q_heads, head_dim)


@functools.partial(jax.jit, static_argnames=["sm_scale"])
def mqa_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, head_dim]
v, # [bs, k_seq_len, head_dim]
sm_scale=1.0,
):
logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32)
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
return jnp.einsum("bns,bsd->bnd", weights, v)


@functools.partial(jax.jit, static_argnames=["sm_scale"])
def mha_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, num_k_heads, head_dim]
v, # [bs, k_seq_len, num_v_heads, head_dim]
sm_scale=1.0,
):
assert q.shape[1] == k.shape[2]
logits = jnp.einsum("bnd,bsnd->bns", q, k).astype(jnp.float32)
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
return jnp.einsum("bns,bsnd->bnd", weights, v)


@functools.partial(jax.jit, static_argnames=["sm_scale"])
def gqa_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, num_k_heads, head_dim]
v, # [bs, k_seq_len, num_v_heads, head_dim]
sm_scale=1.0,
):
bs, num_q_heads, head_dim = q.shape
num_kv_heads = k.shape[2]
assert num_q_heads % num_kv_heads == 0
q_reshaped = q.reshape(
bs, num_kv_heads, num_q_heads // num_kv_heads, head_dim
)
k_transposed = jnp.swapaxes(
k, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
v_transposed = jnp.swapaxes(
v, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype(
jnp.float32
)
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
o = jnp.einsum("bkgs,bksd->bkgd", weights, v_transposed)
return o.reshape(bs, num_q_heads, head_dim)

0 comments on commit 9fcf9e5

Please sign in to comment.