Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce tl.assume or use assert expression in non-debug builds to guide optimization? #4331

Closed
0x804d8000 opened this issue Jul 16, 2024 · 2 comments

Comments

@0x804d8000
Copy link

In certain cases compiler cannot infer some situation cannot happen, e.g.:

@triton.jit
def kernel(N: tl.constexpr, BLOCK_N: tl.constexpr):
  current_size = N - tl.program_id(0) * BLOCK_N
  if current_size >= 128:
    return impl_specialized_for_block_128(...)
  else:
    return impl_slower_but_safe_for_smaller_block(...)

kernel[(1024 // 128, )](1024, 128)

Note that both N and BLOCK_N is marked as constexpr so the compiler will be able to do constant propagation and dead code elimination as necessary.

Here the compiler won't be able to eliminate the if without external knowledge. This is because given a (wrongly) large grid, current_size could get negative, and thus the else branch would be executed.

I think we can either:

  • Use assert expression as a hint in release build, and add a assert current_size > 0 before if. This way the compiler should be able to figure out else branch won't be needed. The caveat here is that if the assertion itself is wrong, previously such wrong assertion wouldn't hurt, while now it can lead to undefined behavior.
  • Introduce tl.assume. Its usage would be largely the same as assert. However, this might result in people duplicating the same conditional expression for both assert and tl.assume all over the place, which would be a violation of DRY principle.
Ghost-LZW added a commit to Ghost-LZW/triton that referenced this issue Jul 26, 2024
Introduce `tl.assume` to hint compiler to do dead code elimination

eg:
```python
@triton.jit
def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr):
  current_size = N - tl.program_id(0) * BLOCK_N
  tl.assume(current_size >= BLOCK_N)
  if current_size >= 128:
    tl.store(out_ptr + tl.program_id(0), current_size)
  else:
    tl.store(out_ptr + tl.program_id(0), current_size + 101024)
```

If BLOCK_N is greater equal to 128, the `else` branch will never happen.

Fix issue triton-lang#4331
@Ghost-LZW Ghost-LZW mentioned this issue Jul 26, 2024
7 tasks
Ghost-LZW added a commit to Ghost-LZW/triton that referenced this issue Jul 26, 2024
Introduce `tl.assume` to hint compiler to do dead code elimination

eg:
```python
@triton.jit
def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr):
  current_size = N - tl.program_id(0) * BLOCK_N
  tl.assume(current_size >= BLOCK_N)
  if current_size >= 128:
    tl.store(out_ptr + tl.program_id(0), current_size)
  else:
    tl.store(out_ptr + tl.program_id(0), current_size + 101024)
```

If BLOCK_N is greater equal to 128, the `else` branch will never happen.

Fix issue triton-lang#4331
Ghost-LZW added a commit to Ghost-LZW/triton that referenced this issue Jul 30, 2024
Introduce `tl.assume` to hint compiler to do dead code elimination

eg:
```python
@triton.jit
def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr):
  current_size = N - tl.program_id(0) * BLOCK_N
  tl.assume(current_size >= BLOCK_N)
  if current_size >= 128:
    tl.store(out_ptr + tl.program_id(0), current_size)
  else:
    tl.store(out_ptr + tl.program_id(0), current_size + 101024)
```

If BLOCK_N is greater equal to 128, the `else` branch will never happen.

Fix issue triton-lang#4331
ThomasRaoux pushed a commit that referenced this issue Jul 30, 2024
Introduce `tl.assume` hint compiler to do dead code elimination

eg:
```python
@triton.jit
def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr):
  current_size = N - tl.program_id(0) * BLOCK_N
  tl.assume(current_size >= BLOCK_N)
  if current_size >= 128:
    tl.store(out_ptr + tl.program_id(0), current_size)
  else:
    tl.store(out_ptr + tl.program_id(0), current_size + 101024)
```

If `BLOCK_N` is greater equal to 128, the `else` branch will never
happen.

Fix issue #4331

The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
@Ghost-LZW
Copy link
Contributor

tl.assume has been introduced. #4396

@0x804d8000
Copy link
Author

Thanks, closing the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants