New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tl.dot() has a too large precision error #2843
Comments
import torch
import triton
import triton.language as tl
@triton.jit
def test_kernal(X, Y, O,
stride_x_m, stride_x_k,
stride_y_k, stride_y_n,
stride_o_m, stride_o_n,
m:tl.constexpr, k:tl.constexpr, n:tl.constexpr):
offset_x = tl.arange(0, m)[:, None] * stride_x_m + tl.arange(0, k)[None, :] * stride_x_k
x = tl.load(X + offset_x)
offset_y = tl.arange(0, k)[:, None] * stride_y_k + tl.arange(0, n)[None, :] * stride_y_n
y = tl.load(Y + offset_y)
# o = tl.sum(x[:, :, None] * y[None, :, :], 1)
# max differ: 0.01586423528981129
o = tl.dot(x, y)
# max differ: 0.01586423528981129
# o = tl.sum(x[:, None, :] * tl.trans(y)[None, :, :], 2)
# max differ: 2.9020518859113054e-06
offset_o = tl.arange(0, m)[:, None] * stride_o_m + tl.arange(0, n)[None, :] * stride_o_n
tl.store(O + offset_o, o)
def test(m, k, n):
torch.manual_seed(0)
x = torch.randn(size=[m, k], dtype=torch.float64, device='cuda')
y = torch.randn(size=[k, n], dtype=torch.float64, device='cuda')
o = torch.zeros(size=[m, n], dtype=torch.float32, device='cuda')
test_kernal[(1,)](x.to(torch.float32), y.to(torch.float32), o,
x.stride(0), x.stride(1),
y.stride(0), y.stride(1),
o.stride(0), o.stride(1),
m, k, n)
result = torch.mm(x, y)
max_differ = torch.max(o - result)
max_value_xy = torch.max(torch.max(x), torch.max(y))
print(f"max differ: {max_differ}\nwhile max value in x and y:{max_value_xy}")
test(32, 32, 64) Weird output. I tried three different way to calculate dot, They have very different output! This make me very confused. Could someone explain this weird output? |
set allow_tf32=False in tl.dot(), Then Output is normal. Sorry for my ignorant. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The code above print
Is there some bugs in tl.dot() or my test code? This precision error shouldn't be unacceptable.
The text was updated successfully, but these errors were encountered: