Skip to content

Commit

Permalink
[PyTorch] Fix MHA grain size computation (#72463)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#72463

maxing with 1 makes a lot more sense to me than minning with 1, but I have no idea what I'm doing.
ghstack-source-id: 149067332

Test Plan: CI

Reviewed By: zrphercule

Differential Revision: D33990633

fbshipit-source-id: c706148c357473c929020f5dc65cc5050611af8f
(cherry picked from commit 2adf3be11a59387bbab7fc73da236ab5fff7be9c)
  • Loading branch information
swolchok authored and cyyever committed Feb 16, 2022
1 parent af308ba commit 6507a16
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv(
const scalar_t sqrt_dim_per_head = std::sqrt(static_cast<scalar_t>(dim_per_head));

int64_t grain_size =
std::min(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
parallel_for(
0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
for (auto i : c10::irange(begin, end)) {
Expand Down

0 comments on commit 6507a16

Please sign in to comment.