diff --git a/aten/src/ATen/native/attention.cpp b/aten/src/ATen/native/attention.cpp index 38db814ceba..26dca7ed4ed 100644 --- a/aten/src/ATen/native/attention.cpp +++ b/aten/src/ATen/native/attention.cpp @@ -42,7 +42,7 @@ std::tuple transform_bias_rescale_qkv( const scalar_t sqrt_dim_per_head = std::sqrt(static_cast(dim_per_head)); int64_t grain_size = - std::min(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1); + std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1); parallel_for( 0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) { for (auto i : c10::irange(begin, end)) {