Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tl.dot() has a too large precision error #2843

Closed
DouHappy opened this issue Dec 25, 2023 · 2 comments
Closed

tl.dot() has a too large precision error #2843

DouHappy opened this issue Dec 25, 2023 · 2 comments

Comments

@DouHappy
Copy link

import torch
import triton
import triton.language as tl

@triton.jit
def test_kernal(X, Y, O,
                stride_x_m, stride_x_k,
                stride_y_k, stride_y_n,
                stride_o_m, stride_o_n,
                m:tl.constexpr, k:tl.constexpr, n:tl.constexpr):
    offset_x = tl.arange(0, m)[:, None] * stride_x_m + tl.arange(0, k)[None, :] * stride_x_k
    x = tl.load(X + offset_x)
    offset_y = tl.arange(0, k)[:, None] * stride_y_k + tl.arange(0, n)[None, :] * stride_y_n
    y = tl.load(Y + offset_y)
    o = tl.dot(x, y)
    offset_o = tl.arange(0, m)[:, None] * stride_o_m + tl.arange(0, n)[None, :] * stride_o_n
    tl.store(O + offset_o, o)

def test(m, k, n):
    torch.manual_seed(0)
    x = torch.randn(size=[m, k], dtype=torch.float64, device='cuda')
    y = torch.randn(size=[k, n], dtype=torch.float64, device='cuda')
    o = torch.zeros(size=[m, n], dtype=torch.float32, device='cuda')
    test_kernal[(1,)](x.to(torch.float32), y.to(torch.float32), o,
                      x.stride(0), x.stride(1),
                      y.stride(0), y.stride(1),
                      o.stride(0), o.stride(1),
                      m, k, n)
    result = torch.mm(x, y)
    max_differ = torch.max(o - result)
    max_value_xy = torch.max(x + y)
    print(f"max differ: {max_differ}\nwhile max value in x + y:{max_value_xy}")

test(64, 64, 64)

The code above print

max differ: 0.02577356549080534
while max value in x + y:4.6755498368741915

Is there some bugs in tl.dot() or my test code? This precision error shouldn't be unacceptable.

@DouHappy
Copy link
Author

import torch
import triton
import triton.language as tl

@triton.jit
def test_kernal(X, Y, O,
                stride_x_m, stride_x_k,
                stride_y_k, stride_y_n,
                stride_o_m, stride_o_n,
                m:tl.constexpr, k:tl.constexpr, n:tl.constexpr):
    offset_x = tl.arange(0, m)[:, None] * stride_x_m + tl.arange(0, k)[None, :] * stride_x_k
    x = tl.load(X + offset_x)
    offset_y = tl.arange(0, k)[:, None] * stride_y_k + tl.arange(0, n)[None, :] * stride_y_n
    y = tl.load(Y + offset_y)
    # o = tl.sum(x[:, :, None] * y[None, :, :], 1)
    # max differ: 0.01586423528981129

    o = tl.dot(x, y)
    # max differ: 0.01586423528981129

    # o = tl.sum(x[:, None, :] * tl.trans(y)[None, :, :], 2)
    # max differ: 2.9020518859113054e-06
    
    offset_o = tl.arange(0, m)[:, None] * stride_o_m + tl.arange(0, n)[None, :] * stride_o_n
    tl.store(O + offset_o, o)

def test(m, k, n):
    torch.manual_seed(0)
    x = torch.randn(size=[m, k], dtype=torch.float64, device='cuda')
    y = torch.randn(size=[k, n], dtype=torch.float64, device='cuda')
    o = torch.zeros(size=[m, n], dtype=torch.float32, device='cuda')
    test_kernal[(1,)](x.to(torch.float32), y.to(torch.float32), o,
                      x.stride(0), x.stride(1),
                      y.stride(0), y.stride(1),
                      o.stride(0), o.stride(1),
                      m, k, n)
    result = torch.mm(x, y)
    max_differ = torch.max(o - result)
    max_value_xy = torch.max(torch.max(x), torch.max(y))
    print(f"max differ: {max_differ}\nwhile max value in x and y:{max_value_xy}")

test(32, 32, 64)

Weird output. I tried three different way to calculate dot, They have very different output! This make me very confused. Could someone explain this weird output?

@DouHappy
Copy link
Author

set allow_tf32=False in tl.dot(), Then Output is normal. Sorry for my ignorant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant