Skip to content

Commit

Permalink
Add layout control
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed May 23, 2024
1 parent af36df3 commit 2c3ecf8
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 78 deletions.
4 changes: 4 additions & 0 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
LENGTH = "activation_length"
HEAD = "activation_heads"
D_KV = "activation_kv"
CACHE_BATCH = "cache_batch"
CACHE_SEQUENCE = "cache_sequence"
CACHE_HEADS = "cache_heads"
CACHE_KV = "cache_kv"

MODEL_MODE_AUTOREGRESSIVE = "autoregressive"
MODEL_MODE_PREFILL = "prefill"
Expand Down
8 changes: 8 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,14 @@ inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
inference_microbenchmark_log_file_path: ""

# KV Cache layout control
# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV
# Default layout: 1,2,0,3 ; CACHE_SEQUENCE, CACHE_HEADS, CACHE_BATCH, CACHE_KV
prefill_key_axis_order: "1,2,0,3"
prefill_value_axis_order: "1,2,0,3"
ar_key_axis_order: "1,2,0,3"
ar_value_axis_order: "1,2,0,3"

# Checkpoint Structured logging
enable_checkpoint_cloud_logger: False
enable_checkpoint_standard_logger: False
175 changes: 99 additions & 76 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
LENGTH = common_types.LENGTH
HEAD = common_types.HEAD
D_KV = common_types.D_KV
CACHE_BATCH = common_types.CACHE_BATCH
CACHE_SEQUENCE = common_types.CACHE_SEQUENCE
CACHE_HEADS = common_types.CACHE_HEADS
CACHE_KV = common_types.CACHE_KV
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)


Expand Down Expand Up @@ -104,6 +108,11 @@ class AttentionOp(nn.Module):
max_prefill_predict_length: int = -1
float32_logits: bool = False
flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV)
kv_cache_logical_layout: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV)
prefill_key_axis_order: tuple[int] = (1, 2, 0, 3)
prefill_value_axis_order: tuple[int] = (1, 2, 0, 3)
ar_key_axis_order: tuple[int] = (1, 2, 0, 3)
ar_value_axis_order: tuple[int] = (1, 2, 0, 3)
dropout_rate: float = 0.0
dtype: DType = jnp.float32
quant: Optional[Quant] = None
Expand Down Expand Up @@ -355,7 +364,7 @@ def wv_product(self, attn_weights: Array, value: Array) -> Array:
result = jnp.reshape(out, (b, t, n_kv * g, d))
return result

def revert_kvlen_axis(self, kv):
def revert_kvlen_axis(self, kv, cached_axis_order):
"""Revert key/value length axis.
Args:
Expand All @@ -364,9 +373,9 @@ def revert_kvlen_axis(self, kv):
Returns:
reshaped kv as [b, ..., s, n, d]
"""
return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (1, 2, 0, 3))
return jax.numpy.moveaxis(kv, (0, 1, 2, 3), cached_axis_order)

def move_kvlen_axis(self, kv):
def move_kvlen_axis(self, kv, cached_axis_order):
"""Move key/value length axis to the end.
Args:
Expand All @@ -375,9 +384,14 @@ def move_kvlen_axis(self, kv):
Returns:
reshaped kv as [b, ..., n, d, s]
"""
return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (2, 0, 1, 3))
axis_order_to_index_mapping = {a:i for i, a in enumerate(cached_axis_order)}
axis_destination = tuple([i for a, i in sorted(axis_order_to_index_mapping.items())])
return jax.numpy.moveaxis(kv, (0, 1, 2, 3), axis_destination)

def cached_kv_shape(self, kv_shape):
def cached_kv_layout(self, kv_layout, cached_axis_order):
return tuple([kv_layout[i] for i in cached_axis_order])

def cached_kv_shape(self, kv_shape, cached_axis_order):
"""Cached KV shape.
The key and value have dimension [batch, length, num_heads, head_dim], but
Expand All @@ -389,55 +403,60 @@ def cached_kv_shape(self, kv_shape):
Returns:
Swapped kv_shape as [b, ..., n, d, s] for cache.
"""
return (kv_shape[1], kv_shape[2], kv_shape[0], kv_shape[3])
return tuple([kv_shape[i] for i in cached_axis_order])

def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache):
dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16

kv_cache_layout = (
"cache_sequence",
"cache_heads",
"cache_batch",
"cache_kv",
)
cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size)

key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_key_axis_order)
value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_value_axis_order)

key_shape = self.cached_kv_shape(cache_logical_shape, self.prefill_key_axis_order)
value_shape = self.cached_kv_shape(cache_logical_shape, self.prefill_value_axis_order)

cached_key = self.variable(
"cache",
"cached_prefill_key",
nn.with_logical_partitioning(jnp.zeros, kv_cache_layout),
self.cached_kv_shape(cache_logical_shape),
nn.with_logical_partitioning(jnp.zeros, key_layout),
key_shape,
dtype,
)
cached_value = self.variable(
"cache",
"cached_prefill_value",
nn.with_logical_partitioning(jnp.zeros, kv_cache_layout),
self.cached_kv_shape(cache_logical_shape),
nn.with_logical_partitioning(jnp.zeros, value_layout),
value_shape,
dtype,
)
cached_segment_id = self.variable(
"cache",
"cache_prefill_segment_id",
nn.with_logical_partitioning(jnp.zeros, ("cache_batch", "cache_sequence")),
nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH, CACHE_SEQUENCE)),
(cache_logical_shape[0], self.max_prefill_predict_length),
jnp.int32,
)

if self.quantize_kvcache:

cache_logical_shape_scale = (batch, self.max_prefill_predict_length, heads, 1)

key_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.prefill_key_axis_order)
value_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.prefill_value_axis_order)

cached_key_scale_var = self.variable(
"cache",
"cached_prefill_key_scale",
nn.with_logical_partitioning(jnp.zeros, kv_cache_layout),
self.cached_kv_shape(cache_logical_shape_scale),
nn.with_logical_partitioning(jnp.zeros, key_layout),
key_shape_scale,
jnp.bfloat16,
)
cached_value_scale_var = self.variable(
"cache",
"cached_prefill_value_scale",
nn.with_logical_partitioning(jnp.zeros, kv_cache_layout),
self.cached_kv_shape(cache_logical_shape_scale),
nn.with_logical_partitioning(jnp.zeros, value_layout),
value_shape_scale,
jnp.bfloat16,
)
else:
Expand All @@ -451,71 +470,67 @@ def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache):
def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache):
dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16
cache_length = self.max_target_length - self.max_prefill_predict_length
kv_cache_layout = (
"cache_sequence",
"cache_heads",
"cache_batch",
"cache_kv",
)

cache_logical_shape = (batch, cache_length, heads, kv_head_size)

key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_key_axis_order)
value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_value_axis_order)

key_shape = self.cached_kv_shape(cache_logical_shape, self.ar_key_axis_order)
value_shape = self.cached_kv_shape(cache_logical_shape, self.ar_value_axis_order)

# TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding
cached_key = self.variable(
"cache",
"cached_ar_key",
nn.with_logical_partitioning(jnp.zeros, kv_cache_layout),
self.cached_kv_shape(cache_logical_shape),
nn.with_logical_partitioning(jnp.zeros, key_layout),
key_shape,
dtype,
)
cached_key.value = nn.with_logical_constraint(
cached_key.value,
(
"cache_sequence",
"cache_heads",
"cache_batch",
"cache_kv",
),
key_layout,
)

cached_value = self.variable(
"cache",
"cached_ar_value",
nn.with_logical_partitioning(jnp.zeros, kv_cache_layout),
self.cached_kv_shape(cache_logical_shape),
nn.with_logical_partitioning(jnp.zeros, value_layout),
value_shape,
dtype,
)
cached_value.value = nn.with_logical_constraint(
cached_value.value,
(
"cache_sequence",
"cache_heads",
"cache_batch",
"cache_kv",
),
value_layout,
)

cached_segment_id = self.variable(
"cache",
"cache_ar_segment_id",
nn.with_logical_partitioning(jnp.zeros, ("cache_batch", "cache_sequence")),
nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH, CACHE_SEQUENCE)),
(cache_logical_shape[0], cache_length),
jnp.int32,
)

if self.quantize_kvcache:

cache_logical_shape_scale = (batch, cache_length, heads, 1)

key_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.ar_key_axis_order)
value_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.ar_value_axis_order)

cached_key_scale_var = self.variable(
"cache",
"cached_ar_key_scale",
nn.with_logical_partitioning(jnp.zeros, kv_cache_layout),
self.cached_kv_shape(cache_logical_shape_scale),
nn.with_logical_partitioning(jnp.zeros, key_layout),
key_shape_scale,
jnp.bfloat16,
)
cached_value_scale_var = self.variable(
"cache",
"cached_ar_value_scale",
nn.with_logical_partitioning(jnp.zeros, kv_cache_layout),
self.cached_kv_shape(cache_logical_shape_scale),
nn.with_logical_partitioning(jnp.zeros, value_layout),
value_shape_scale,
jnp.bfloat16,
)
else:
Expand Down Expand Up @@ -553,12 +568,15 @@ def kv_cache_prefill(
)
self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now

key_shaped_for_cache = self.move_kvlen_axis(key)
value_shaped_for_cache = self.move_kvlen_axis(value)
prefill_key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_key_axis_order)
prefill_value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_value_axis_order)

key_shaped_for_cache = self.move_kvlen_axis(key, self.prefill_key_axis_order)
value_shaped_for_cache = self.move_kvlen_axis(value, self.prefill_value_axis_order)

if self.quantize_kvcache:
key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache)
value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache)
key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache, prefill_key_layout.index(CACHE_KV))
value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache, prefill_value_layout.index(CACHE_KV))
cached_prefill_key_var[1].value = key_scale
cached_prefill_value_var[1].value = value_scale

Expand Down Expand Up @@ -596,64 +614,65 @@ def update_ar_key_value(

# In order to update the key, value caches with the current key and
# value, we move the length axis to the back
one_token_key = self.move_kvlen_axis(one_token_key)
one_token_value = self.move_kvlen_axis(one_token_value)
one_token_key_shaped_for_cache = self.move_kvlen_axis(one_token_key, self.ar_key_axis_order)
one_token_value_shaped_for_cache = self.move_kvlen_axis(one_token_value, self.ar_value_axis_order)

ar_key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_key_axis_order)
ar_value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_value_axis_order)

if self.quantize_kvcache:
one_token_key, one_token_key_scale = quantizations.quantize_kv(one_token_key)
one_token_value, one_token_value_scale = quantizations.quantize_kv(one_token_value)
one_token_key_shaped_for_cache, one_token_key_scale = quantizations.quantize_kv(one_token_key_shaped_for_cache, ar_key_layout.index(CACHE_KV))
one_token_value_shaped_for_cache, one_token_value_scale = quantizations.quantize_kv(one_token_value_shaped_for_cache, ar_value_layout.index(CACHE_KV))

one_hot_indices = one_hot_indices.astype(int)

ar_key = cached_key_var.value
ar_key = jax.lax.dynamic_update_index_in_dim(ar_key, one_token_key, jnp.squeeze(one_hot_indices), 0)
ar_key = jax.lax.dynamic_update_index_in_dim(ar_key, one_token_key_shaped_for_cache, jnp.squeeze(one_hot_indices), 0)
ar_key = nn.with_logical_constraint(
ar_key,
(
"cache_sequence",
"cache_heads",
"cache_batch",
"cache_kv",
),
ar_key_layout
)
cached_key_var.value = ar_key

ar_value = cached_value_var.value
ar_value = jax.lax.dynamic_update_index_in_dim(ar_value, one_token_value, jnp.squeeze(one_hot_indices), 0)
ar_value = jax.lax.dynamic_update_index_in_dim(ar_value, one_token_value_shaped_for_cache, jnp.squeeze(one_hot_indices), 0)
ar_value = nn.with_logical_constraint(
ar_value,
(
"cache_sequence",
"cache_heads",
"cache_batch",
"cache_kv",
),
ar_value_layout,
)
cached_value_var.value = ar_value

if self.quantize_kvcache:
ar_key_scale = jax.lax.dynamic_update_index_in_dim(
cached_key_scale_var.value, one_token_key_scale, jnp.squeeze(one_hot_indices), 0
)
ar_key_scale = nn.with_logical_constraint(
ar_key_scale,
ar_key_layout
)
ar_value_scale = jax.lax.dynamic_update_index_in_dim(
cached_value_scale_var.value, one_token_value_scale, jnp.squeeze(one_hot_indices), 0
)
ar_value_scale = nn.with_logical_constraint(
ar_value_scale,
ar_value_layout
)
cached_key_scale_var.value = ar_key_scale
cached_value_scale_var.value = ar_value_scale

ar_key = quantizations.unquantize_kv(cached_key_var.value, cached_key_scale_var.value, one_token_key.dtype)
ar_value = quantizations.unquantize_kv(cached_value_var.value, cached_value_scale_var.value, one_token_value.dtype)

# Move the keys and values back to their original shapes.
return self.revert_kvlen_axis(ar_key), self.revert_kvlen_axis(ar_value)
return self.revert_kvlen_axis(ar_key, self.ar_key_axis_order), self.revert_kvlen_axis(ar_value, self.ar_value_axis_order)

def prefill_cache_var_model_var(self, cache_var, target_dtype):
def prefill_cache_var_model_var(self, cache_var, target_dtype, cache_axis_order):
if not self.quantize_kvcache:
return self.revert_kvlen_axis(cache_var[0].value)
return self.revert_kvlen_axis(cache_var[0].value, cache_axis_order)
else:
raw_cache, quant_scale = cache_var
raw_cache_unquantized = quantizations.unquantize_kv(raw_cache.value, quant_scale.value, target_dtype)
return self.revert_kvlen_axis(raw_cache_unquantized)
return self.revert_kvlen_axis(raw_cache_unquantized, cache_axis_order)

def kv_cache_autoregressive(
self,
Expand Down Expand Up @@ -700,8 +719,8 @@ def kv_cache_autoregressive(
)

cached_prefill = (
self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype),
self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype),
self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype, self.prefill_key_axis_order),
self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype, self.prefill_value_axis_order),
cached_prefill_segment_id.value,
)
return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value)
Expand Down Expand Up @@ -991,6 +1010,10 @@ def __call__(
num_kv_heads=self.num_kv_heads,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
prefill_key_axis_order = tuple([int(i) for i in self.config.prefill_key_axis_order.split(",")]),
prefill_value_axis_order = tuple([int(i) for i in self.config.prefill_value_axis_order.split(",")]),
ar_key_axis_order = tuple([int(i) for i in self.config.ar_key_axis_order.split(",")]),
ar_value_axis_order = tuple([int(i) for i in self.config.ar_value_axis_order.split(",")]),
)

out = attention_op(query, key, value, decoder_segment_ids, model_mode)
Expand Down

0 comments on commit 2c3ecf8

Please sign in to comment.