From 0cdd160cb65bd249ce810dfba34f2c71c4039d0b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 14 Feb 2022 18:12:50 -0800 Subject: [PATCH] [PyTorch] Fix MHA grain size computation (#72463) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72463 maxing with 1 makes a lot more sense to me than minning with 1, but I have no idea what I'm doing. ghstack-source-id: 149067332 Test Plan: CI Reviewed By: zrphercule Differential Revision: D33990633 fbshipit-source-id: c706148c357473c929020f5dc65cc5050611af8f (cherry picked from commit 2adf3be11a59387bbab7fc73da236ab5fff7be9c) --- aten/src/ATen/native/attention.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/attention.cpp b/aten/src/ATen/native/attention.cpp index 38db814ceba..26dca7ed4ed 100644 --- a/aten/src/ATen/native/attention.cpp +++ b/aten/src/ATen/native/attention.cpp @@ -42,7 +42,7 @@ std::tuple transform_bias_rescale_qkv( const scalar_t sqrt_dim_per_head = std::sqrt(static_cast(dim_per_head)); int64_t grain_size = - std::min(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1); + std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1); parallel_for( 0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) { for (auto i : c10::irange(begin, end)) {