Skip to content

Commit

Permalink
[Inference/opt] Fused KVCahce Memcopy (#5374)
Browse files Browse the repository at this point in the history
* fused kv memcopy

* add TODO in test_kvcache_copy.py
  • Loading branch information
isky-cd authored Feb 7, 2024
1 parent 58740b5 commit 6fb4bcb
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 31 deletions.
5 changes: 3 additions & 2 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,9 @@ def forward(
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
Expand Down
5 changes: 3 additions & 2 deletions colossalai/inference/modeling/models/padding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,9 @@ def forward(
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
Expand Down
69 changes: 53 additions & 16 deletions colossalai/kernel/triton/kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,26 @@
# Triton 2.1.0
@triton.jit
def _copy_to_kvcache_seqlen1_kernel(
KV, # K or V
KVCache, # KCache or VCache
K, # K
V, # V
KCache, # KCache
VCache, # VCache
BLOCK_TABLES,
context_lengths,
stride_kt,
stride_kh,
stride_kd,
stride_cacheb,
stride_cacheh,
stride_cachebs,
stride_cached,
stride_vt,
stride_vh,
stride_vd,
stride_cachekb,
stride_cachekh,
stride_cachekbs,
stride_cachekd,
stride_cachevb,
stride_cachevh,
stride_cachevbs,
stride_cachevd,
stride_bts,
stride_btb,
block_size,
Expand All @@ -32,37 +41,57 @@ def _copy_to_kvcache_seqlen1_kernel(
offsets_in_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv)

k = tl.load(K + offsets_kv)
v = tl.load(V + offsets_kv)

offsets_kvcache = (
block_id * stride_cacheb
+ cur_kv_head_idx * stride_cacheh
+ offsets_in_last_block * stride_cachebs
+ offsets_dmodel * stride_cached
block_id * stride_cachekb
+ cur_kv_head_idx * stride_cachekh
+ offsets_in_last_block * stride_cachekbs
+ offsets_dmodel * stride_cachekd
)
tl.store(KVCache + offsets_kvcache, kv)
offsets_kvcache = (
block_id * stride_cachevb
+ cur_kv_head_idx * stride_cachevh
+ offsets_in_last_block * stride_cachevbs
+ offsets_dmodel * stride_cachevd
)

tl.store(KCache + offsets_kvcache, k)
tl.store(VCache + offsets_kvcache, v)
return


def copy_kv_to_blocked_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
kv_lengths: torch.Tensor,
block_tables: torch.Tensor,
):
"""
Copy keys or values to the blocked key/value cache during decoding stage.
Args:
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys during decoding with seq len 1.
v (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Values during decoding with seq len 1.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key cache.
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache.
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
"""
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."

k = k.squeeze(1) if k.dim() == 4 else k
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"

assert v.size(-1) == v_cache.size(-1), "Incompatible head dim"
assert v.dtype == v_cache.dtype, "Expected consistent dtype for tensor and cache."
v = v.squeeze(1) if v.dim() == 4 else v
assert v.dim() == 3, f"Incompatible v dim {v.dim()}"

bsz, num_kv_heads, head_dim = k.shape

assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
Expand All @@ -75,20 +104,28 @@ def copy_kv_to_blocked_cache(
block_size = k_cache.size(-2)

num_warps = 8 if head_dim > 128 else 4

grid = (bsz, num_kv_heads)
_copy_to_kvcache_seqlen1_kernel[grid](
k,
v,
k_cache,
v_cache,
block_tables,
kv_lengths,
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
block_size,
Expand Down
28 changes: 17 additions & 11 deletions tests/test_infer/test_ops/triton/test_kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,19 @@ def prepare_data(
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)

k_cache, _, block_tables = generate_caches_and_block_tables_v2(
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
block_tables = block_tables.to(device=device)

new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
# kv seq len = past kv seq len + seq len (1 during decoding stage)
kv_seq_lengths = past_kv_seq_lengths + 1

return new_k, k_cache, kv_seq_lengths, block_tables
return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables


@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
Expand All @@ -80,7 +81,7 @@ def test_copy_kv_to_caches(
dtype = torch.float16
device = get_current_device()

new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
Expand All @@ -93,16 +94,20 @@ def test_copy_kv_to_caches(
)
# k_cache_torch = k_cache.clone().detach()
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)

past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
target = k_cache[target_block_ids, :, offsets_in_block, :]
source = new_k.squeeze()

assert target.shape == source.shape
assert torch.equal(target, source)
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
k_source = new_k.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
v_source = new_v.squeeze()

assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
# assert target_torch.shape == source.shape
# assert torch.equal(target_torch, source)
Expand Down Expand Up @@ -143,7 +148,7 @@ def benchmark_kvcache_copy(

assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"

new_k, k_cache, context_lengths, block_tables = prepare_data(
new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
Expand All @@ -156,10 +161,11 @@ def benchmark_kvcache_copy(
)

quantiles = [0.5, 0.2, 0.8]
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
if provider == "torch_copy_func":
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
if provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)

ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
return ms, min_ms, max_ms
Expand Down

0 comments on commit 6fb4bcb

Please sign in to comment.