Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Jan 19, 2024
1 parent 271caa2 commit 551858a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 49 deletions.
65 changes: 41 additions & 24 deletions include/flashinfer/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,25 +91,35 @@ class BatchDecodeHandler {
auto work_estimation_func =
BatchDecodeWithPagedKVCacheWorkEstimation<page_storage, DTypeIn, DTypeOut, IdType>;

SWITCH_DEVICE_PTR(indptr, indptr_h, batch_size + 1, stream_, {
SWITCH_DEVICE_PTR(last_page_len, last_page_len_h, batch_size, stream_, {
FLASHINFER_CUDA_CALL(work_estimation_func(
tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr_h,
num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_));
batch_size_after_partition_ = new_batch_size;
if (tmp_size > 0) {
FLASHINFER_CUDA_CALL(cudaMallocAsync(&float_buffer_, tmp_size, stream_));
FLASHINFER_CUDA_CALL(cudaMallocAsync(
&int_buffer_,
sizeof(IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2), stream_));
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, page_size, indptr_h, last_page_len_h,
GetNewIndPtr<IdType>(), GetNewLastPageLen<IdType>(), GetChunkIndPtr<IdType>(),
GetBatchIdxMap<IdType>(), GetChunkStartPos<IdType>(),
GetSeqLengthsBeforePartition<IdType>(), stream_));
}
});
});
std::vector<IdType> indptr_h(batch_size + 1), last_page_len_h(batch_size);
if (is_device_ptr(indptr)) {
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(indptr_h.data(), indptr,
sizeof(IdType) * (batch_size + 1),
cudaMemcpyDeviceToHost, stream_));
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(last_page_len_h.data(), last_page_len,
sizeof(IdType) * batch_size, cudaMemcpyDeviceToHost,
stream_));
FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream_));
} else {
indptr_h.assign(indptr, indptr + batch_size + 1);
last_page_len_h.assign(last_page_len, last_page_len + batch_size);
}

FLASHINFER_CUDA_CALL(work_estimation_func(
tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size,
indptr_h.data(), num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_));
batch_size_after_partition_ = new_batch_size;
if (tmp_size > 0) {
FLASHINFER_CUDA_CALL(cudaMallocAsync(&float_buffer_, tmp_size, stream_));
FLASHINFER_CUDA_CALL(cudaMallocAsync(
&int_buffer_, sizeof(IdType) * (5 * new_batch_size + batch_size_before_partition_ + 2),
stream_));
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, page_size, indptr_h.data(), last_page_len_h.data(),
GetNewIndPtr<IdType>(), GetNewLastPageLen<IdType>(), GetChunkIndPtr<IdType>(),
GetBatchIdxMap<IdType>(), GetChunkStartPos<IdType>(),
GetSeqLengthsBeforePartition<IdType>(), stream_));
}

forward_started_ = true;
return cudaSuccess;
Expand Down Expand Up @@ -183,11 +193,18 @@ class BatchPrefillHandler {
abort();
}
uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
std::vector<IdType> request_indices_h, tile_indices_h;
SWITCH_DEVICE_PTR(qo_indptr, qo_indptr_h, batch_size + 1, stream_, {
std::tie(num_frags_x_, num_qo_tiles_, request_indices_h, tile_indices_h) =
split_qo_indptr(qo_indptr_h, batch_size, gqa_group_size, stream_);
});
std::vector<IdType> qo_indptr_h(batch_size + 1), request_indices_h, tile_indices_h;
if (is_device_ptr(qo_indptr)) {
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(qo_indptr_h.data(), qo_indptr,
sizeof(IdType) * (batch_size + 1),
cudaMemcpyDeviceToHost, stream_));
FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream_));
} else {
qo_indptr_h.assign(qo_indptr, qo_indptr + batch_size + 1);
}

std::tie(num_frags_x_, num_qo_tiles_, request_indices_h, tile_indices_h) =
split_qo_indptr(qo_indptr_h.data(), batch_size, gqa_group_size, stream_);
FLASHINFER_CUDA_CALL(cudaMalloc(&request_indices_, sizeof(IdType) * request_indices_h.size()));
FLASHINFER_CUDA_CALL(cudaMalloc(&tile_indices_, sizeof(IdType) * tile_indices_h.size()));
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, request_indices_h.data(),
Expand Down
36 changes: 26 additions & 10 deletions include/flashinfer/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1742,12 +1742,19 @@ cudaError_t BatchPrefillWithRaggedKVCache(
const uint32_t group_size = num_qo_heads / num_kv_heads;

uint32_t num_frags_x, num_qo_tiles;
std::vector<IdType> request_indices_h, tile_indices_h;
std::vector<IdType> qo_indptr_h(batch_size + 1), request_indices_h, tile_indices_h;

SWITCH_DEVICE_PTR(qo_indptr, qo_indptr_h, batch_size + 1, stream, {
std::tie(num_frags_x, num_qo_tiles, request_indices_h, tile_indices_h) =
split_qo_indptr(qo_indptr_h, batch_size, group_size, stream);
});
if (is_device_ptr(qo_indptr)) {
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(qo_indptr_h.data(), qo_indptr,
sizeof(IdType) * (batch_size + 1), cudaMemcpyDeviceToHost,
stream));
FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream));
} else {
qo_indptr_h.assign(qo_indptr, qo_indptr + batch_size + 1);
}

std::tie(num_frags_x, num_qo_tiles, request_indices_h, tile_indices_h) =
split_qo_indptr(qo_indptr_h.data(), batch_size, group_size, stream);
IdType* request_indices_d;
IdType* tile_indices_d;

Expand Down Expand Up @@ -1925,11 +1932,20 @@ cudaError_t BatchPrefillWithPagedKVCache(
const uint32_t group_size = num_qo_heads / num_kv_heads;

uint32_t num_frags_x, num_qo_tiles;
std::vector<IdType> request_indices_h, tile_indices_h;
SWITCH_DEVICE_PTR(qo_indptr, qo_indptr_h, batch_size + 1, stream, {
std::tie(num_frags_x, num_qo_tiles, request_indices_h, tile_indices_h) =
split_qo_indptr(qo_indptr_h, batch_size, group_size, stream);
});
std::vector<IdType> qo_indptr_h(batch_size + 1), request_indices_h, tile_indices_h;

if (is_device_ptr(qo_indptr)) {
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(qo_indptr_h.data(), qo_indptr,
sizeof(IdType) * (batch_size + 1), cudaMemcpyDeviceToHost,
stream));
FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream));
} else {
qo_indptr_h.assign(qo_indptr, qo_indptr + batch_size + 1);
}

std::tie(num_frags_x, num_qo_tiles, request_indices_h, tile_indices_h) =
split_qo_indptr(qo_indptr_h.data(), batch_size, group_size, stream);

IdType* request_indices_d;
IdType* tile_indices_d;

Expand Down
15 changes: 0 additions & 15 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,6 @@
} \
}

#define SWITCH_DEVICE_PTR(maybe_device_ptr, host_ptr, length, stream, ...) \
using ptr_t = decltype(maybe_device_ptr); \
using elem_t = std::remove_pointer_t<ptr_t>; \
if (is_device_ptr(maybe_device_ptr)) { \
std::vector<elem_t> host_vec(length); \
ptr_t host_ptr = host_vec.data(); \
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(host_ptr, maybe_device_ptr, sizeof(elem_t) * length, \
cudaMemcpyDeviceToHost, stream)); \
FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream)); \
__VA_ARGS__ \
} else { \
ptr_t host_ptr = maybe_device_ptr; \
__VA_ARGS__ \
}

#define SWITCH_SPLIT_QO_INDPTR(split_qo_indptr, SPLIT_QO_INDPTR, ...) \
if (split_qo_indptr) { \
constexpr bool SPLIT_QO_INDPTR = true; \
Expand Down

0 comments on commit 551858a

Please sign in to comment.