Skip to content

Commit

Permalink
Add quantization support for PagedAttention TPU Pallas kernel.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634914369
  • Loading branch information
ashishenoyp authored and jax authors committed May 17, 2024
1 parent 2d6d408 commit 1043e24
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
"""PagedAttention TPU kernel."""

import functools
from typing import Optional
from typing import Optional, Union

import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils
import jax.numpy as jnp
import numpy as np

Expand All @@ -34,43 +35,78 @@ class MultiPageAsyncCopyDescriptor:
def __init__(
self,
pages_hbm_ref,
scales_pages_hbm_ref,
vmem_buffer,
scales_vmem_buffer,
sem,
page_indices,
page_indices_start_offset,
num_pages_to_load,
head_index,
):
self._vmem_buffer = vmem_buffer
self._scales_vmem_buffer = scales_vmem_buffer
self._num_pages_to_load = num_pages_to_load
if head_index is not None:
self._pages_hbm_ref = pages_hbm_ref.at[head_index]
if scales_pages_hbm_ref is not None:
self._scales_pages_hbm_ref = scales_pages_hbm_ref.at[head_index]
else:
self._scales_pages_hbm_ref = None
else:
self._pages_hbm_ref = pages_hbm_ref
self._scales_pages_hbm_ref = scales_pages_hbm_ref
self._sem = sem
self._page_indices = page_indices
self._page_indices_start_offset = page_indices_start_offset
self._async_copies = [
self._make_async_copy(i) for i in range(self._num_pages_to_load)
]
if (
self._scales_pages_hbm_ref is not None
and self._scales_vmem_buffer is not None
):
self._async_copies += [
self._make_scales_async_copy(i)
for i in range(self._num_pages_to_load)
]

def _make_async_copy(self, i):
page_index = self._page_indices[self._page_indices_start_offset + i]
return pltpu.make_async_copy(
self._pages_hbm_ref.at[page_index], self._vmem_buffer.at[i], self._sem
)

def _make_scales_async_copy(self, i):
page_index = self._page_indices[self._page_indices_start_offset + i]
return pltpu.make_async_copy(
self._scales_pages_hbm_ref.at[page_index], # pytype: disable=attribute-error
self._scales_vmem_buffer.at[i], # pytype: disable=attribute-error
self._sem,
)

def start(self):
"""Starts the async copies."""
for async_copy in self._async_copies:
async_copy.start()

def _maybe_dequantize(self, x, x_scale, dtype=jnp.bfloat16):
if x_scale is None:
return x.astype(dtype)
return quantization_utils.from_int8(x, x_scale, dtype=dtype)

def wait_and_get_loaded(self) -> jax.Array:
"""Wait async copies and gets the loaded buffer as a jax.Array."""
for async_copy in self._async_copies:
async_copy.wait()
head_dim = self._vmem_buffer.shape[-1]
return self._vmem_buffer[...].astype(jnp.float32).reshape(-1, head_dim)
jax_array = self._vmem_buffer[...].astype(jnp.float32)
if self._scales_vmem_buffer is not None:
scales_jax_array = self._scales_vmem_buffer[...].astype(jnp.float32)
else:
scales_jax_array = None
jax_array = self._maybe_dequantize(jax_array, scales_jax_array)
return jax_array.reshape(-1, head_dim)


def paged_flash_attention_kernel(
Expand All @@ -80,12 +116,16 @@ def paged_flash_attention_kernel(
step_ref,
q_ref,
k_pages_hbm_ref,
k_scales_pages_hbm_ref,
v_pages_hbm_ref,
v_scales_pages_hbm_ref,
o_ref,
m_ref,
l_ref,
k_vmem_buffer,
k_scales_vmem_buffer,
v_vmem_buffer,
v_scales_vmem_buffer,
sem,
*,
batch_size: int,
Expand Down Expand Up @@ -153,7 +193,11 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index):
pages_to_load = pages_per_compute_block
async_copy_k = MultiPageAsyncCopyDescriptor(
k_pages_hbm_ref,
k_scales_pages_hbm_ref,
k_vmem_buffer.at[buffer_index],
k_scales_vmem_buffer.at[buffer_index]
if k_scales_vmem_buffer is not None
else None,
sem,
page_indices_ref,
page_offset,
Expand All @@ -162,7 +206,11 @@ def create_kv_async_copy_descriptors(b, h, i, buffer_index):
)
async_copy_v = MultiPageAsyncCopyDescriptor(
v_pages_hbm_ref,
v_scales_pages_hbm_ref,
v_vmem_buffer.at[buffer_index],
v_scales_vmem_buffer.at[buffer_index]
if v_scales_vmem_buffer is not None
else None,
sem,
page_indices_ref,
page_offset,
Expand Down Expand Up @@ -240,12 +288,16 @@ def paged_flash_attention_kernel_inline_seq_dim(
step_ref,
q_ref,
k_pages_hbm_ref,
k_scales_pages_hbm_ref,
v_pages_hbm_ref,
v_scales_pages_hbm_ref,
o_ref,
m_ref,
l_ref,
k_vmem_buffer,
k_scales_vmem_buffer,
v_vmem_buffer,
v_scales_vmem_buffer,
sem,
*,
batch_size: int,
Expand All @@ -270,12 +322,16 @@ def body(i, _):
step_ref,
q_ref,
k_pages_hbm_ref,
k_scales_pages_hbm_ref,
v_pages_hbm_ref,
v_scales_pages_hbm_ref,
o_ref,
m_ref,
l_ref,
k_vmem_buffer,
k_scales_vmem_buffer,
v_vmem_buffer,
v_scales_vmem_buffer,
sem,
batch_size=batch_size,
pages_per_compute_block=pages_per_compute_block,
Expand Down Expand Up @@ -308,8 +364,8 @@ def body(i, _):
)
def paged_attention(
q: jax.Array,
k_pages: jax.Array,
v_pages: jax.Array,
k_pages: Union[jax.Array, quantization_utils.QuantizedTensor],
v_pages: Union[jax.Array, quantization_utils.QuantizedTensor],
lengths: jax.Array,
page_indices: jax.Array,
*,
Expand All @@ -333,7 +389,7 @@ def paged_attention(
pages_per_compute_block: how many pages to be processed in one flash
attention block in the pallas kernel.
megacore_mode: if set, enable megacore to parallelize the computation. Must
be one of ['kv_head', 'batch', None]. Ceveat: set this only if megacore is
be one of ['kv_head', 'batch', None]. Caveat: set this only if megacore is
enabled, otherwise the kernel may hang. If you are not sure, leave it to
None.
* None: disable megacore parallelism.
Expand All @@ -347,14 +403,31 @@ def paged_attention(
Returns:
The output of attention([batch_size, num_heads, head_dim]).
"""
if isinstance(k_pages, quantization_utils.QuantizedTensor):
k_pages, k_scales_pages = k_pages.weight, k_pages.scales # type: ignore[union-attr]
assert isinstance(k_scales_pages, jax.Array) # For typing.
k_scales_pages = jnp.broadcast_to(
k_scales_pages, (*k_scales_pages.shape[:-1], k_pages.shape[-1]) # type: ignore[union-attr]
)
else:
k_scales_pages = None
if isinstance(v_pages, quantization_utils.QuantizedTensor):
v_pages, v_scales_pages = v_pages.weight, v_pages.scales # type: ignore[union-attr]
assert isinstance(v_scales_pages, jax.Array) # For typing.
v_scales_pages = jnp.broadcast_to(
v_scales_pages, (*v_scales_pages.shape[:-1], v_pages.shape[-1]) # type: ignore[union-attr]
)
else:
v_scales_pages = None

batch_size, num_heads, head_dim = q.shape
num_kv_heads, _, page_size, head_dim_k = k_pages.shape
num_kv_heads, _, page_size, head_dim_k = k_pages.shape # type: ignore[union-attr]
batch_size_paged_indices, pages_per_sequence = page_indices.shape

if k_pages.shape != v_pages.shape:
if k_pages.shape != v_pages.shape: # type: ignore[union-attr]
raise ValueError(
f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and"
f" {v_pages.shape}"
f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" # type: ignore[union-attr]
f" {v_pages.shape}" # pytype: disable=attribute-error
)
if num_heads % num_kv_heads != 0:
raise ValueError(
Expand Down Expand Up @@ -456,6 +529,85 @@ def paged_attention(
) # type: ignore
dimension_sematics = ("parallel", "arbitrary", "arbitrary", "arbitrary") # type: ignore

if k_scales_pages is not None and v_scales_pages is not None:
in_specs = [
q_block_spec,
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
]
scratch_shapes = (
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
k_pages.dtype, # type: ignore[union-attr]
), # k_pages buffer
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
k_scales_pages.dtype, # pytype: disable=attribute-error
), # k_scales_pages buffer
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
v_pages.dtype, # type: ignore[union-attr]
), # v_pages buffer
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
v_scales_pages.dtype, # pytype: disable=attribute-error
), # v_scales_pages buffer
pltpu.SemaphoreType.DMA,
)
else:
in_specs = [
q_block_spec,
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
None, # type: ignore[list-item]
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
None, # type: ignore[list-item]
]
scratch_shapes = (
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
k_pages.dtype, # type: ignore[union-attr]
), # k_pages buffer
None,
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
v_pages.dtype, # type: ignore[union-attr]
), # v_pages buffer
None,
pltpu.SemaphoreType.DMA,
)

out, _, _ = pl.pallas_call(
functools.partial(
kernel,
Expand All @@ -466,39 +618,17 @@ def paged_attention(
megacore_mode=megacore_mode,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
# There are 4 scalars prefetched per kernel call: `lengths_ref`,
# `page_indices_ref`, `buffer_index_ref`, `step_ref`
num_scalar_prefetch=4,
in_specs=[
q_block_spec,
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
],
in_specs=in_specs,
out_specs=[
q_block_spec,
q_block_spec,
q_block_spec,
],
grid=grid,
scratch_shapes=(
pltpu.VMEM(
(
2,
pages_per_compute_block,
page_size,
head_dim,
),
k_pages.dtype,
), # k buffer
pltpu.VMEM(
(
2,
pages_per_compute_block,
page_size,
head_dim,
),
v_pages.dtype,
), # v buffer
pltpu.SemaphoreType.DMA,
),
scratch_shapes=scratch_shapes,
),
compiler_params=dict(mosaic=dict(dimension_semantics=dimension_sematics)),
out_shape=[
Expand All @@ -513,6 +643,8 @@ def paged_attention(
jnp.zeros((1,), jnp.int32), # step
q.astype(q_dtype_for_kernel_launch),
k_pages,
k_scales_pages,
v_pages,
v_scales_pages,
)
return out.reshape(batch_size, num_heads, head_dim).astype(q.dtype)

0 comments on commit 1043e24

Please sign in to comment.