Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit 3379361 #1646

Merged
merged 10 commits into from
Jul 17, 2024
Merged

Conversation

whitneywhtsang
Copy link
Contributor

@whitneywhtsang whitneywhtsang commented Jul 17, 2024

This PR change the Triton base from 69539b8 to 3379361 (Jul 17).
Pass rate: 98.5%

Please do not squash and merge this PR.

peterbell10 and others added 8 commits July 15, 2024 22:28
I managed to ablate this reproducer from #4311 down a lot. I believe
this is close to minimal now.

cc @Jokeren

---------

Co-authored-by: Peter Bell <peter@pop-os.Home>
1. Use the builtin `ast.increment_lineno` function to make it more
robust
2. Clean up function rewrite logic
3. Resolve global variable reference issues
4. Enable line info tests
That is only present in CUDA-12 compatible drivers, and is missing in
CUDA-11 ones

Spiritual follow up after
triton-lang/triton#2771 allows for dynamic query
of the symbol and if run on an older driver, it will return an error.
Also, fix `occupancyMaxActiveClusters` behavior when symbol is not found
(before this change it would crash with null pointer deref, now it
should return a structured exception)
**This pull request adds the use of the `device` fixture to the test to
make it not only CUDA specific. This simplifies testing of various
devices in the downstream, without having to modify the test code
itself.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.

~- [ ] I am not making a trivial change, such as fixing a typo in a
comment.~

- [ ] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
- [x] This PR does not need a test because `it only modifies the test`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
There was a function pointer lookup missing in the previous patch.
triton-lang/triton@f9f2960
nvvm.reqntid has a stronger semantic and should allow better
optimization in the finalizer.
This PR fixes the bug demonstrated
[here](https://github.com/embg/triton/blob/ed125e4a44e397e9a40e691bb7ce40c698120a1a/tma_repro.py),
which is the probable root cause of
triton-lang/triton#4332.

## The problem
NVIDIA docs
[recommend](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#using-tma-to-transfer-multi-dimensional-arrays)
that TMA descriptors should be passed through immutable device memory,
but Triton currently passes them through mutable device memory. This is
unsafe unless the TMA descriptor cache is flushed *on every SM*.

The current implementation attempts to flush the cache by launching a
[dedicated TMA flush
kernel](https://github.com/triton-lang/triton/blob/3aeb223819d632303dd2b45f4dc533d6af90dc46/python/triton/tools/experimental_descriptor.py#L34).
Unfortunately, this kernel does not run on all SMs. As a result, Triton
TMA kernels may hang or return incorrect results.

According to @ThomasRaoux, it isn't possible to guarantee a kernel will
run on every SM (as there may be another workload on a different CUDA
stream). So flushing in a separate kernel is not possible.

## Proposed solution
* Add fences to all example code via inline assembly.
* Add documentation to inform users about the fence issue.
* Remove the existing cache flush code since it is incorrect.

## Why this solution?
Since each kernel needs to issue its own fence instruction, we have
three options:
* Inline assembly
* Add a new op, like `tl._experimental_tma_acquire_fence(addr)`
* Use compiler analysis to insert the fence automatically

I believe we should not add a new op or analysis pass until both
`__grid_constant__` and on-device descriptor mutation are landed. Once
host-side descriptors switch to `__grid_constant__`, the fence will only
be needed for on-device mutation, which won't require a separate op or
analysis pass (simply add a fence while lowering the mutation op).

If I'm wrong and we do end up needing a separate op or analysis pass, it
will be trivial to clean up 6 lines of inline assembly.

## Checklist

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.

- Select one of the following.
  - [x] I have not added any `lit` tests.
@whitneywhtsang whitneywhtsang self-assigned this Jul 17, 2024
@whitneywhtsang whitneywhtsang changed the title Merge OpenAI Triton commit 69539b8 Merge OpenAI Triton commit 2946cd1 Jul 17, 2024
@whitneywhtsang whitneywhtsang marked this pull request as ready for review July 17, 2024 23:00
@whitneywhtsang whitneywhtsang merged commit 1ae1c3c into llvm-target Jul 17, 2024
4 checks passed
@whitneywhtsang whitneywhtsang deleted the whitneywhtsang/merge branch July 17, 2024 23:01
@whitneywhtsang whitneywhtsang changed the title Merge OpenAI Triton commit 2946cd1 Merge OpenAI Triton commit 3379361 Jul 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants