-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce tl.assume
or use assert
expression in non-debug builds to guide optimization?
#4331
Comments
Ghost-LZW
added a commit
to Ghost-LZW/triton
that referenced
this issue
Jul 26, 2024
Introduce `tl.assume` to hint compiler to do dead code elimination eg: ```python @triton.jit def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): current_size = N - tl.program_id(0) * BLOCK_N tl.assume(current_size >= BLOCK_N) if current_size >= 128: tl.store(out_ptr + tl.program_id(0), current_size) else: tl.store(out_ptr + tl.program_id(0), current_size + 101024) ``` If BLOCK_N is greater equal to 128, the `else` branch will never happen. Fix issue triton-lang#4331
Ghost-LZW
added a commit
to Ghost-LZW/triton
that referenced
this issue
Jul 26, 2024
Introduce `tl.assume` to hint compiler to do dead code elimination eg: ```python @triton.jit def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): current_size = N - tl.program_id(0) * BLOCK_N tl.assume(current_size >= BLOCK_N) if current_size >= 128: tl.store(out_ptr + tl.program_id(0), current_size) else: tl.store(out_ptr + tl.program_id(0), current_size + 101024) ``` If BLOCK_N is greater equal to 128, the `else` branch will never happen. Fix issue triton-lang#4331
Ghost-LZW
added a commit
to Ghost-LZW/triton
that referenced
this issue
Jul 30, 2024
Introduce `tl.assume` to hint compiler to do dead code elimination eg: ```python @triton.jit def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): current_size = N - tl.program_id(0) * BLOCK_N tl.assume(current_size >= BLOCK_N) if current_size >= 128: tl.store(out_ptr + tl.program_id(0), current_size) else: tl.store(out_ptr + tl.program_id(0), current_size + 101024) ``` If BLOCK_N is greater equal to 128, the `else` branch will never happen. Fix issue triton-lang#4331
ThomasRaoux
pushed a commit
that referenced
this issue
Jul 30, 2024
Introduce `tl.assume` hint compiler to do dead code elimination eg: ```python @triton.jit def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): current_size = N - tl.program_id(0) * BLOCK_N tl.assume(current_size >= BLOCK_N) if current_size >= 128: tl.store(out_ptr + tl.program_id(0), current_size) else: tl.store(out_ptr + tl.program_id(0), current_size + 101024) ``` If `BLOCK_N` is greater equal to 128, the `else` branch will never happen. Fix issue #4331 The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
|
Thanks, closing the issue. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In certain cases compiler cannot infer some situation cannot happen, e.g.:
Note that both
N
andBLOCK_N
is marked asconstexpr
so the compiler will be able to do constant propagation and dead code elimination as necessary.Here the compiler won't be able to eliminate the
if
without external knowledge. This is because given a (wrongly) largegrid
,current_size
could get negative, and thus theelse
branch would be executed.I think we can either:
assert
expression as a hint in release build, and add aassert current_size > 0
beforeif
. This way the compiler should be able to figure outelse
branch won't be needed. The caveat here is that if the assertion itself is wrong, previously such wrong assertion wouldn't hurt, while now it can lead to undefined behavior.tl.assume
. Its usage would be largely the same asassert
. However, this might result in people duplicating the same conditional expression for bothassert
andtl.assume
all over the place, which would be a violation of DRY principle.The text was updated successfully, but these errors were encountered: