Skip to content

Commit

Permalink
sync ort-genai changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jslhcl committed Jun 19, 2024
1 parent 4bfcfad commit 05c0ba2
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 14 deletions.
12 changes: 8 additions & 4 deletions operators/cuda/attention_lib/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ void set_params_fprop(Flash_fwd_params& params,
bool is_bf16,
bool kv_bsnh = true,
int window_size_left = -1,
int window_size_right = -1) {
int window_size_right = -1,
bool paged_KV = false,
int page_block_size = -1) {
// Set the pointers and strides.
params.q_ptr = q;
params.k_ptr = k;
Expand Down Expand Up @@ -64,8 +66,8 @@ void set_params_fprop(Flash_fwd_params& params,

if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
params.k_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0)
params.v_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0)
params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
} else {
params.q_batch_stride = 0;
Expand Down Expand Up @@ -401,7 +403,9 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
is_bf16,
past_bsnh,
local_window_size,
is_causal ? 0 : -1);
is_causal ? 0 : -1,
block_table != nullptr,
page_block_size);
params.dprops = &dprops;

if (k_new != nullptr && v_new != nullptr) {
Expand Down
1 change: 1 addition & 0 deletions operators/cuda/paged_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ struct PagedAttention {
int seqlen_knew = 1; // TODO(leca): Decoding case, the sequence of k will always be 1?
int max_num_blocks_per_seq = block_tables.Shape()[1];
int seqlen_k = max_num_blocks_per_seq * block_size;
parameters.causal = false; // flash code: if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
size_t workSpaceSize = cuda::GetAttentionWorkspaceSize(sizeof(T), parameters.batch_size, parameters.num_heads, parameters.head_size, parameters.v_head_size,
seqlen_knew, nullptr, true/*data.use_flash_attention*/, false/*data.use_memory_efficient_attention*/, true);
UniquePtrWithDeletor<T> workspace_unique = GetScratchBuffer<T>(allocator_->Alloc(allocator_.get(), workSpaceSize), allocator_.get()); // for softmax_lse
Expand Down
1 change: 0 additions & 1 deletion operators/cuda/paged_attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#pragma once
#include "ortx_common.h"
#include "gsl/narrow"
#include <cuda.h>
#include <cublas_v2.h>

Expand Down
76 changes: 67 additions & 9 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _create_pagedattention_test_model(batch_size, total_seqlen, hidden_size, slo
block_tables = helper.make_tensor_value_info(
'block_tables', onnx_proto.TensorProto.INT32, [batch_size, block_cnt_needed_by_longest_seq])
slot_mappings = helper.make_tensor_value_info(
'slot_mappings', onnx_proto.TensorProto.INT32, [total_seqlen])
'slot_mappings', onnx_proto.TensorProto.INT32, [None])
context_lens = helper.make_tensor_value_info(
'context_lens', onnx_proto.TensorProto.INT32, [batch_size])
is_prompt = helper.make_tensor_value_info(
Expand Down Expand Up @@ -340,14 +340,14 @@ def test_cuda_paged_attention3(self):
out_np = out.reshape(381, 512).numpy()
assert np.allclose(y_np, out_np, rtol=1e-3, atol=1e-3, equal_nan=True)

def test_cuda_paged_attention_prompt_decoding(self):
def test_cuda_paged_attention_prompt_decoding():
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model = self._create_pagedattention_test_model(3, 381, 512, 16, 32, 8)
so.register_custom_ops_library('/home/leca/code/onnxruntime-genai/test/custom_ops/build/libgenai_custom_ops_test.so')
onnx_model = _create_pagedattention_test_model(3, 381, 512, 16, 32, 8)
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
so,
providers=['CUDAExecutionProvider'])

query = np.random.randn(381,512).astype(np.float16) # 381 is the token num of all the sequences (127, 127, 127)
key = np.random.randn(381,512).astype(np.float16)
value = np.random.randn(381,512).astype(np.float16)
Expand All @@ -360,14 +360,14 @@ def test_cuda_paged_attention_prompt_decoding(self):
slot_mappings = np.concatenate((slot1, slot2, slot3))
context_lens = np.array([127, 127, 127]).astype(np.int32)
is_prompt = np.array([1]).astype(np.int32)

key_cache_ort = _ort.OrtValue.ortvalue_from_numpy(key_cache, "cuda")
value_cache_ort = _ort.OrtValue.ortvalue_from_numpy(value_cache, "cuda")
block_tables_ort = _ort.OrtValue.ortvalue_from_numpy(block_tables, "cuda")
slot_mappings_ort = _ort.OrtValue.ortvalue_from_numpy(slot_mappings, "cuda")
context_lens_ort = _ort.OrtValue.ortvalue_from_numpy(context_lens)
is_prompt_ort = _ort.OrtValue.ortvalue_from_numpy(is_prompt)

# prompt case
io_binding = sess.io_binding()
io_binding.bind_cpu_input("query", query)
Expand All @@ -381,18 +381,76 @@ def test_cuda_paged_attention_prompt_decoding(self):
io_binding.bind_ortvalue_input("is_prompt", is_prompt_ort)
io_binding.bind_output("attn_out")
sess.run_with_iobinding(io_binding)

# decoding case
query2 = np.random.randn(3, 512).astype(np.float16)
key2 = np.random.randn(3, 512).astype(np.float16)
value2 = np.random.randn(3, 512).astype(np.float16)
slot = np.array([127, 255, 383]).astype(np.int32)
io_binding.bind_cpu_input("query", query2)
io_binding.bind_cpu_input("key", key2)
io_binding.bind_cpu_input("value", value2)
io_binding.bind_cpu_input("slot_mappings", slot)
context_lens_ort.update_inplace(np.array([1,1,1]).astype(np.int32))
is_prompt_ort.update_inplace(np.array([0]).astype(np.int32))
sess.run_with_iobinding(io_binding)


def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
pdb.set_trace()
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
k_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
v_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
block_table = rearrange(
torch.randperm(num_blocks, dtype=torch.int32, device=device),
"(b nblocks) -> b nblocks",
b=batch_size,
)
k_cache = rearrange(
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks

def test_cuda_paged_attention_decoding():
so = _ort.SessionOptions()
so.register_custom_ops_library('/home/leca/code/onnxruntime-genai/test/custom_ops/build/libgenai_custom_ops_test.so')
onnx_model = _create_pagedattention_test_model(batch_size=2, total_seqlen=0, hidden_size=96, slot_cnt_per_block=256,
block_cnt_per_layer=6, block_cnt_needed_by_longest_seq=3, num_heads=6, num_kv_heads=6, head_size=16)
sess = _ort.InferenceSession(onnx_model.SerializeToString(),
so,
providers=['CUDAExecutionProvider'])

query = np.random.randn(2,96).astype(np.float16)
key = np.random.randn(2,96).astype(np.float16)
value = np.random.randn(2,96).astype(np.float16)
key_cache = np.zeros([6,24576]).astype(np.float16) # 24576 = 256x6x16
value_cache = np.zeros([6,24576]).astype(np.float16)
block_tables = np.array([[0,1,2],[3,4,5]]).astype(np.int32)
slot_mappings = np.array([250, 500]).astype(np.int32)
context_lens = np.array([1, 1]).astype(np.int32)
is_prompt = np.array([0]).astype(np.int32)
y = sess.run(None, {'query':query, 'key':key, 'value':value, 'key_cache':key_cache, 'value_cache':value_cache, 'block_tables':block_tables, 'slot_mappings':slot_mappings, 'context_lens':context_lens, 'is_prompt':is_prompt})
# q_pt = torch.from_numpy(query.reshape(3, 127, 32, 16))
# k_pt = torch.from_numpy(key.reshape(3, 127, 32, 16))
# v_pt = torch.from_numpy(value.reshape(3, 127, 32, 16))
# out, attention = attention_ref(q_pt, k_pt, v_pt, causal=True, window_size=[-1, 0])
# y_np = np.array(y).reshape(381, 512)
# out_np = out.reshape(381, 512).numpy()
# #assert np.allclose(y_np, out_np, rtol=1e-3, atol=1e-3, equal_nan=True)
# print(np.allclose(y_np, out_np, rtol=1e-3, atol=1e-3, equal_nan=True))
# print(y_np)
# print(out_np)

if __name__ == "__main__":
unittest.main()

0 comments on commit 05c0ba2

Please sign in to comment.