-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][NVVM] Fix the lowering of mbarrier.test.wait #166555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Fix the lowering of mbarrier.test.wait #166555
Conversation
|
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Durgadoss R (durga4github) ChangesPR #165993 accidentally broke the lowering of the This patch fixes the issue and adds tests to verify the lowering to intrinsics Additionally, the Full diff: https://github.com/llvm/llvm-project/pull/166555.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 10f0cc254ea97..80bc0e5986e51 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -949,7 +949,7 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
}];
string llvmBuilder = [{
- auto [id, args] = NVVM::MBarrierArriveNocompleteOp::getIntrinsicIDAndArgs(
+ auto [id, args] = NVVM::MBarrierTestWaitOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
$res = createIntrinsicCall(builder, id, args);
}];
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbarriers.mlir b/mlir/test/Target/LLVMIR/nvvm/mbarriers.mlir
new file mode 100644
index 0000000000000..0064ae97eebba
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbarriers.mlir
@@ -0,0 +1,117 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
+ // CHECK-LABEL: define void @cp_async_mbarrier_arrive(ptr addrspace(3) %0, ptr %1) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %1)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %1)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %0)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
+ nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
+ nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3>
+ nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3>
+ llvm.return
+}
+
+llvm.func @mbarrier_init_generic(%barrier: !llvm.ptr) {
+ // CHECK-LABEL: define void @mbarrier_init_generic(ptr %0) {
+ // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.init(ptr %0, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ nvvm.mbarrier.init %barrier, %count : !llvm.ptr, i32
+ llvm.return
+}
+
+
+llvm.func @mbarrier_init_shared(%barrier: !llvm.ptr<3>) {
+ // CHECK-LABEL: define void @mbarrier_init_shared(ptr addrspace(3) %0) {
+ // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.init.shared(ptr addrspace(3) %0, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ nvvm.mbarrier.init %barrier, %count : !llvm.ptr<3>, i32
+ llvm.return
+}
+
+llvm.func @mbarrier_inval_generic(%barrier: !llvm.ptr) {
+ // CHECK-LABEL: define void @mbarrier_inval_generic(ptr %0) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval(ptr %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.inval %barrier : !llvm.ptr
+ llvm.return
+}
+
+llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) {
+ // CHECK-LABEL: define void @mbarrier_inval_shared(ptr addrspace(3) %0) {
+ // CHECK-NEXT: call void @llvm.nvvm.mbarrier.inval.shared(ptr addrspace(3) %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.mbarrier.inval %barrier : !llvm.ptr<3>
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive(%barrier: !llvm.ptr) {
+ // CHECK-LABEL: define void @mbarrier_arrive(ptr %0) {
+ // CHECK-NEXT: %2 = call i64 @llvm.nvvm.mbarrier.arrive(ptr %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) {
+ // CHECK-LABEL: define void @mbarrier_arrive_shared(ptr addrspace(3) %0) {
+ // CHECK-NEXT: %2 = call i64 @llvm.nvvm.mbarrier.arrive.shared(ptr addrspace(3) %0)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.arrive %barrier : !llvm.ptr<3> -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
+ // CHECK-LABEL: define void @mbarrier_arrive_nocomplete(ptr %0) {
+ // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete(ptr %0, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr, i32 -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
+ // CHECK-LABEL: define void @mbarrier_arrive_nocomplete_shared(ptr addrspace(3) %0) {
+ // CHECK-NEXT: %2 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.noComplete.shared(ptr addrspace(3) %0, i32 %2)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ %0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
+
+llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 {
+ // CHECK-LABEL: define i1 @mbarrier_test_wait(ptr %0, i64 %1) {
+ // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait(ptr %0, i64 %1)
+ // CHECK-NEXT: ret i1 %3
+ // CHECK-NEXT: }
+ %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1
+ llvm.return %isComplete : i1
+}
+
+llvm.func @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) {
+ // CHECK-LABEL: define void @mbarrier_test_wait_shared(ptr addrspace(3) %0, i64 %1) {
+ // CHECK-NEXT: %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.shared(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %count = nvvm.read.ptx.sreg.ntid.x : i32
+ %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 3fc09f371a347..1ec55408e97a5 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -531,19 +531,6 @@ llvm.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32)
llvm.return
}
-// CHECK-LABEL: @cp_async_mbarrier_arrive
-llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
- // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %{{.*}})
- nvvm.cp.async.mbarrier.arrive %bar_gen : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc(ptr %{{.*}})
- nvvm.cp.async.mbarrier.arrive %bar_gen {noinc = true} : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.shared(ptr addrspace(3) %{{.*}})
- nvvm.cp.async.mbarrier.arrive %bar_shared : !llvm.ptr<3>
- // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared(ptr addrspace(3) %{{.*}})
- nvvm.cp.async.mbarrier.arrive %bar_shared {noinc = true} : !llvm.ptr<3>
- llvm.return
-}
-
// CHECK-LABEL: @llvm_nvvm_setmaxregister
llvm.func @llvm_nvvm_setmaxregister() {
// CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
|
PR llvm#165993 broke the lowering of the `test.wait` Op. This patch fixes the issue and adds tests to verify the lowering to intrinsics for all mbarrier Ops, ensuring similar regressions are caught in the future. Additionally, the `cp-async-mbarrier` test is moved to the `mbarriers.mlir` test file to keep all related tests together. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
b56669c to
7b07840
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix! LGTM!
PR #165993 accidentally broke the lowering of the
test.waitOp.This patch fixes the issue and adds tests to verify the lowering to intrinsics
for all mbarrier Ops, ensuring similar regressions are caught in the future.
Additionally, the
cp-async-mbarriertest is moved to thembarriers.mlirtest file to keep all related tests together.