Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference/opt] Fused KVCahce Memcopy #5374

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,9 @@ def forward(
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
Expand Down
5 changes: 3 additions & 2 deletions colossalai/inference/modeling/models/padding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,9 @@ def forward(
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(
key_states, value_states, k_cache, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
Expand Down
69 changes: 53 additions & 16 deletions colossalai/kernel/triton/kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,26 @@
# Triton 2.1.0
@triton.jit
def _copy_to_kvcache_seqlen1_kernel(
KV, # K or V
KVCache, # KCache or VCache
K, # K
V, # V
KCache, # KCache
VCache, # VCache
BLOCK_TABLES,
context_lengths,
stride_kt,
stride_kh,
stride_kd,
stride_cacheb,
stride_cacheh,
stride_cachebs,
stride_cached,
stride_vt,
stride_vh,
stride_vd,
stride_cachekb,
stride_cachekh,
stride_cachekbs,
stride_cachekd,
stride_cachevb,
stride_cachevh,
stride_cachevbs,
stride_cachevd,
stride_bts,
stride_btb,
block_size,
Expand All @@ -32,37 +41,57 @@ def _copy_to_kvcache_seqlen1_kernel(
offsets_in_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv)

k = tl.load(K + offsets_kv)
v = tl.load(V + offsets_kv)

offsets_kvcache = (
block_id * stride_cacheb
+ cur_kv_head_idx * stride_cacheh
+ offsets_in_last_block * stride_cachebs
+ offsets_dmodel * stride_cached
block_id * stride_cachekb
+ cur_kv_head_idx * stride_cachekh
+ offsets_in_last_block * stride_cachekbs
+ offsets_dmodel * stride_cachekd
)
tl.store(KVCache + offsets_kvcache, kv)
offsets_kvcache = (
block_id * stride_cachevb
+ cur_kv_head_idx * stride_cachevh
+ offsets_in_last_block * stride_cachevbs
+ offsets_dmodel * stride_cachevd
)

tl.store(KCache + offsets_kvcache, k)
tl.store(VCache + offsets_kvcache, v)
return


def copy_kv_to_blocked_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
kv_lengths: torch.Tensor,
block_tables: torch.Tensor,
):
"""
Copy keys or values to the blocked key/value cache during decoding stage.

Args:
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys during decoding with seq len 1.
v (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Values during decoding with seq len 1.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key cache.
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache.
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
"""
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."

k = k.squeeze(1) if k.dim() == 4 else k
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"

assert v.size(-1) == v_cache.size(-1), "Incompatible head dim"
assert v.dtype == v_cache.dtype, "Expected consistent dtype for tensor and cache."
v = v.squeeze(1) if v.dim() == 4 else v
assert v.dim() == 3, f"Incompatible v dim {v.dim()}"

bsz, num_kv_heads, head_dim = k.shape

assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
Expand All @@ -75,20 +104,28 @@ def copy_kv_to_blocked_cache(
block_size = k_cache.size(-2)

num_warps = 8 if head_dim > 128 else 4

grid = (bsz, num_kv_heads)
_copy_to_kvcache_seqlen1_kernel[grid](
k,
v,
k_cache,
v_cache,
block_tables,
kv_lengths,
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
block_size,
Expand Down
28 changes: 17 additions & 11 deletions tests/test_infer/test_ops/triton/test_kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,19 @@ def prepare_data(
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)

k_cache, _, block_tables = generate_caches_and_block_tables_v2(
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=dtype, device=device
)
block_tables = block_tables.to(device=device)

new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
new_v = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
# kv seq len = past kv seq len + seq len (1 during decoding stage)
kv_seq_lengths = past_kv_seq_lengths + 1

return new_k, k_cache, kv_seq_lengths, block_tables
return new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables


@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
Expand All @@ -80,7 +81,7 @@ def test_copy_kv_to_caches(
dtype = torch.float16
device = get_current_device()

new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
Expand All @@ -93,16 +94,20 @@ def test_copy_kv_to_caches(
)
# k_cache_torch = k_cache.clone().detach()
# copy_to_cache(new_k, k_cache_torch, lengths=kv_seq_lengths, block_tables=block_tables, type="decoding")
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)
copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)

past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
target = k_cache[target_block_ids, :, offsets_in_block, :]
source = new_k.squeeze()

assert target.shape == source.shape
assert torch.equal(target, source)
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
k_source = new_k.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
v_source = new_v.squeeze()

assert k_target.shape == k_source.shape
assert torch.equal(k_target, k_source)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
# assert target_torch.shape == source.shape
# assert torch.equal(target_torch, source)
Expand Down Expand Up @@ -143,7 +148,7 @@ def benchmark_kvcache_copy(

assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"

new_k, k_cache, context_lengths, block_tables = prepare_data(
new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
HEAD_DIM,
Expand All @@ -156,10 +161,11 @@ def benchmark_kvcache_copy(
)

quantiles = [0.5, 0.2, 0.8]
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
if provider == "torch_copy_func":
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
if provider == "triton_copy_func":
fn = lambda: copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)

ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
return ms, min_ms, max_ms
Expand Down
Loading