Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Using both reduction and atomic operations along with autotune makes incorrect results #4217

Closed
Kitsunetic opened this issue Jun 26, 2024 · 2 comments

Comments

@Kitsunetic
Copy link

Kitsunetic commented Jun 26, 2024

Environment

  • Triton version: 2.3.1
  • PyTorch version: 2.3.0
  • CUDA version: 12.1

Issue Description

When using both reduction operations and atomic operations triton.autotune, the output is incorrect upon encountering a new input shape.

Reproduction Code

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 256}),
        triton.Config({"BLOCK_SIZE": 128}),
        triton.Config({"BLOCK_SIZE": 64}),
        triton.Config({"BLOCK_SIZE": 32}),
    ],
    key=["N"],
)
@triton.jit
def vector_sum_kernel(x_ptr, y_ptr, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offs_n = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + offs_n, mask=offs_n < N, other=0.0)

    y = tl.sum(x)
    tl.atomic_add(y_ptr, y)


def vector_sum(x):
    assert x.ndim == 1

    N = x.size(0)
    y = x.new_zeros(1)
    grid = lambda meta: ((triton.cdiv(N, meta["BLOCK_SIZE"]),))
    vector_sum_kernel[grid](x, y, N)
    return y

# Test cases
x = th.rand(120, device="cuda")

print(vector_sum(x))  # tensor([79561.5703], device='cuda:6'), incorrect
print(vector_sum(x))  # tensor([62.5476], device='cuda:6'), correct
print(vector_sum(x))  # tensor([62.5476], device='cuda:6'), correct
print(vector_sum(x))  # tensor([62.5476], device='cuda:6'), correct

Conclusion

The first call to vector_sum(x) produces an incorrect result of tensor([79561.5703], device='cuda:6').
Subsequent calls to vector_sum(x) produce correct results.

The issue occurs with other reduction functions such as tl.max and atomic functions like tl.atomic_max.
However, using only one of them does not raise the issue.

@Jokeren
Copy link
Contributor

Jokeren commented Jun 26, 2024

The conclusion might be wrong. I think it might be a problem in the user code instead of the autotuner. Try reset_to_zero=["y_ptr"].

@Kitsunetic
Copy link
Author

Oh, I think I didn't look carefully the document. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants