Skip to content

Commit

Permalink
fix issue of baddbmm when out has nan value for beta=0 (#96086)
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaobingSuper authored and cyyever committed Mar 12, 2023
1 parent 6857f3e commit 3f500bc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
3 changes: 2 additions & 1 deletion aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,8 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T
r += s2[k] * m1[k][j];
}
} else {
r *= beta;
// For beta == 0, the r's value will be ignored, especially for nan value.
r = beta == scalar_t(0) ? scalar_t(0) : beta * r;
for (const auto k : c10::irange(ks)) {
r += alpha * s2[k] * m1[k][j];
}
Expand Down
15 changes: 15 additions & 0 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5565,6 +5565,21 @@ def test_addmm_baddbmm_overflow(self, device, dtype):
self.assertTrue((out == 10000.).all())
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig

@dtypes(torch.float)
def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
for shape in [[3, 2, 2], [2, 20, 20]]:
mat1, mat2 = [torch.randn(shape, dtype=dtype, device=device) for _ in range(2)]
inputs = [torch.randn(shape, dtype=dtype, device=device),
torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
outs = [None, torch.randn(shape, dtype=dtype, device=device),
torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
options = itertools.product(inputs, outs)
for input, out in options:
y_ref = torch.bmm(mat1, mat2)
y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out)
self.assertEqual(y_ref, y)


@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyCUDA
def test_matmul_45724(self, device):
Expand Down

0 comments on commit 3f500bc

Please sign in to comment.