diff --git a/csrc/src/block_info.h b/csrc/src/block_info.h index 64ea678..f9b70e8 100644 --- a/csrc/src/block_info.h +++ b/csrc/src/block_info.h @@ -36,14 +36,14 @@ struct BlockInfo { } template - __forceinline__ __device__ index_t zoh_offset(const index_t batch_stride, const int bidb + __forceinline__ __device__ index_t zoh_offset(const index_t batch_stride, const int row_stride, const int bidb ) const { - return bidb * batch_stride; + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; } template - __forceinline__ __device__ index_t active_mask_offset(const index_t batch_stride, const int bidb) const { - return bidb * batch_stride; + __forceinline__ __device__ index_t active_mask_offset(const index_t batch_stride, int row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; } const int sum_s_q;