-
Notifications
You must be signed in to change notification settings - Fork 74
Merge OpenAI commit dbc85fc
#5210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Failing to verify this can cause all sorts of weird crashes in the frontend due to violating invariants when calling into dialect (C++) code.
A simple transpose construction of the offsets tensor, followed by a
trans() operation, results in conservative contiguity analysis in
AxisInfo. The expected behavior is supposed to similar to a contiguous
offsets tensor construction (without transpose op).
```
@triton.jit
def transpose_read_kernel(
X_ptr,
stride_xa,
stride_xb,
):
offsets = (tl.arange(0, 64)[:, None] * stride_xb + tl.arange(0, 64)[None, :] * stride_xa)
offsets = tl.trans(offsets, (1, 0))
# remark: %11 = tt.trans %10 {order = array<i32: 1, 0>} : tensor<64x64xi32> -> tensor<64x64xi32> =>
# contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
# ideal remark:
# contiguity = [1, 64], divisibility = [2, 16], constancy = [1, 1], constant_value = <none>
tl.async_load(X_ptr + offsets, buffer)
if __name__ == "__main__":
x = torch.randn(
(128, 128),
device="cuda",
dtype=torch.float16,
)
transpose_read_kernel[(1,)](
x,
x.stride(0),
x.stride(1),
)
```
`TransOp` did not have an `AxisInfo` visitor, which was causing it to
fall back to pessimistic defaults that don't properly propagate
contiguity information. This PR adds a new visitor that handles
transpose operations in the `AxisInfo` lattice.
<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!). To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.
- [x] I have written a PR description following these
[rules](https://cbea.ms/git-commit/#why-not-how).
- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.
- Select one of the following.
- [x] I have added tests.
- `/test` for `lit` tests
- `/unittest` for C++ tests
- `/python/test` for end-to-end tests
- [ ] This PR does not need a test because `FILL THIS IN`.
- Select one of the following.
- [ ] I have not added any `lit` tests.
- [x] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
and using the instructions it generates is not minimal.)
Co-authored-by: pka <pka@devgpu009.pnb3.facebook.com>
Adds a condition on `TRITON_BUILD_UT` before including Proton tests. When `TRITON_BUILD_UT` is `OFF` without adding this condition, a build failure occurs. This is because the Proton tests cmake calls the `add_triton_ut` function, however this function is not declared when `TRITON_BUILD_UT` is `OFF`. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because it's a build bug fix. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
* split_k was not working with batched matmul, I'll just disable it for now. * Epilogue handling was missing in _reduce_grouped.
Previous pc sampling and cupti uses different format, now that they use the same format as "file_name:line_number@function_name"
This commit switches to use a basic heuristic for improving support of preshuffled scale tensors--we try a few common scale tensor schemes and see which one gives the largest vectorization when global load.
Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
dbc85fc3f285394f24273e200d95b4142541e809dbc85fc
whitneywhtsang
approved these changes
Sep 29, 2025
Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR change the Triton base from 1b27b93 to dbc85fc (Sep 23).
Pass rate: 96.32%->96.23%
Please do not squash and merge this PR.