Skip to content

Commit

Permalink
Optim cinn block_reduce (PaddlePaddle#58196)
Browse files Browse the repository at this point in the history
* optim cinn block_reduce

* fix bugs

* simplify code

* replace tree reduce with butterfly reduce when active mask is 0x1f

* fix bugs

* fix sync bugs

* remove shared memory when blockdim less than 32
  • Loading branch information
Courtesy-Xs authored and jiahy0825 committed Oct 26, 2023
1 parent 37c061c commit a89c6e5
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,11 @@ __device__ inline bool cinn_any(const bool left, const bool right) { return left
tmp_val = __shfl_sync(mask, tmp_val, 0, 32); \
return tmp_val; \
} else { \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 16, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 8, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 4, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 2, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 1, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 16, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 8, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 4, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 2, 32)); \
tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 1, 32)); \
return tmp_val; \
} \
}
Expand Down Expand Up @@ -530,25 +530,22 @@ __device__ inline float cinn_warp_reduce_avg_fp32(const float *buf, int offset,

#define CINN_BLOCK_REDUCE_INTERNAL_IMPL(TYPE, value, init_value, cinn_warp_shuffle_internal) \
int warp_id = threadIdx.x / 32; \
__shared__ TYPE tmp[32]; \
if (warp_id == 0) { \
tmp[threadIdx.x] = init_value; \
} \
TYPE tmp_val = cinn_warp_shuffle_internal(value); \
if (blockDim.x <= 32) { \
return tmp_val; \
} \
__shared__ TYPE tmp[32]; \
if (warp_id == 0) { \
tmp[threadIdx.x] = init_value; \
} \
__syncthreads(); \
if (threadIdx.x % 32 == 0) { \
if ((threadIdx.x & 31) == 0) { \
tmp[warp_id] = tmp_val; \
} \
__syncthreads(); \
if (warp_id == 0) { \
tmp_val = tmp[threadIdx.x]; \
tmp_val = cinn_warp_shuffle_internal(tmp_val); \
if (threadIdx.x == 0) { \
tmp[0] = tmp_val; \
} \
tmp[threadIdx.x] = cinn_warp_shuffle_internal(tmp_val); \
} \
__syncthreads(); \
return tmp[0];
Expand Down

0 comments on commit a89c6e5

Please sign in to comment.