Skip to content

Conversation

@durga4github
Copy link
Contributor

This patch extends mbarrier.test_wait to support scope,
semantics, and phase-parity, completing the updates for
this Op up to Blackwell. Corresponding lit tests are added
to verify the lowering.

@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Durgadoss R (durga4github)

Changes

This patch extends mbarrier.test_wait to support scope,
semantics, and phase-parity, completing the updates for
this Op up to Blackwell. Corresponding lit tests are added
to verify the lowering.


Full diff: https://github.com/llvm/llvm-project/pull/169698.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+20-8)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+32-10)
  • (modified) mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir (-20)
  • (modified) mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir (+8)
  • (added) mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir (+73)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 57cb9b139111d..cfdf75242bd88 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -980,10 +980,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
   let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
 }
 
-def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
-  Results<(outs I1:$res)>,
-  Arguments<(ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
-                 I64:$state)> {
+def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait"> {
   let summary = "MBarrier Non-Blocking Test Wait Operation";
   let description = [{
     The `nvvm.mbarrier.test.wait` operation performs a non-blocking test for the 
@@ -1002,9 +999,16 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
     The operation takes the following operands:
     - `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic 
       addressing, but the address must still be in the shared memory space.
-    - `state`: An opaque value returned by a previous `mbarrier.arrive` 
-      operation on the same *mbarrier object* during the current or immediately 
-      preceding phase.
+    - `stateOrPhase`: This argument represents a `state` when it is a 64-bit value
+      and represents a `phase` when it is a 32-bit value. The `state` is an opaque
+      value returned by a previous `mbarrier.arrive` operation on the same
+      *mbarrier object* during the current or immediately preceding phase.
+      The `phase` is an integer specifying the phase parity (0 or 1).
+      Even phases have parity 0, odd phases have parity 1.
+    - `scope`: This specifies the set of threads that directly observe the memory
+      synchronizing effect of the `mbarrier.test.wait` operation.
+    - `relaxed`: When set to true, the `arrive` operation has relaxed memory semantics
+      and does not provide any ordering or visibility guarantees.
 
     The operation returns a boolean value indicating whether the specified phase 
     has completed:
@@ -1031,7 +1035,15 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
   }];
 
-  let assemblyFormat = "$addr `,` $state attr-dict `:` type(operands) `->` type($res)";
+  let results = (outs I1:$res);
+  let arguments = (ins
+    AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
+    AnyTypeOf<[I64, I32]>:$stateOrPhase,
+    DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope,
+    DefaultValuedAttr<BoolAttr, "false">:$relaxed);
+
+  let assemblyFormat = "$addr `,` $stateOrPhase attr-dict `:` type(operands) `->` type($res)";
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     static mlir::NVVM::IDArgPair
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e9949547aaea4..bf5d7aab7f19a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -252,10 +252,10 @@ LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
 static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
                                                 NVVM::MemScopeKind scope,
                                                 Value retVal = nullptr) {
-  bool isSharedCluster = isPtrInSharedClusterSpace(addr);
   if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
     return op->emitError("mbarrier scope must be either CTA or Cluster");
 
+  bool isSharedCluster = isPtrInSharedClusterSpace(addr);
   bool hasRetValue = static_cast<bool>(retVal);
   if (isSharedCluster && hasRetValue)
     return op->emitError(
@@ -282,6 +282,10 @@ LogicalResult MBarrierCompleteTxOp::verify() {
   return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
 }
 
+LogicalResult MBarrierTestWaitOp::verify() {
+  return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
 LogicalResult ConvertFloatToTF32Op::verify() {
   using RndMode = NVVM::FPRoundingMode;
   switch (getRnd()) {
@@ -2084,16 +2088,34 @@ mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
 mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
     Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
   auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
-  bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
-  llvm::Intrinsic::ID id = isShared
-                               ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared
-                               : llvm::Intrinsic::nvvm_mbarrier_test_wait;
-  // Fill the Intrinsic Args
-  llvm::SmallVector<llvm::Value *> args;
-  args.push_back(mt.lookupValue(thisOp.getAddr()));
-  args.push_back(mt.lookupValue(thisOp.getState()));
+  bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
+  bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+  // bit-0: isPhaseParity
+  // bit-1: Scope
+  size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
 
-  return {id, std::move(args)};
+  static constexpr llvm::Intrinsic::ID IDs[] = {
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
+  static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
+      llvm::Intrinsic::
+          nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
+      llvm::Intrinsic::
+          nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
+  auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+  // Tidy-up the Intrinsic Args
+  llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+  llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
+  bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+  if (needCast)
+    mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+  return {id, {mbar, input}};
 }
 
 mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
index ae9c7f29bc7a5..9c1d1cc0cdc31 100644
--- a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
@@ -54,23 +54,3 @@ llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) {
   nvvm.mbarrier.inval %barrier : !llvm.ptr<3>
   llvm.return
 }
-
-llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 {
-  // CHECK-LABEL: define i1 @mbarrier_test_wait(ptr %0, i64 %1) {
-  // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait(ptr %0, i64 %1)
-  // CHECK-NEXT: ret i1 %3
-  // CHECK-NEXT: }
-  %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1
-  llvm.return %isComplete : i1
-}
-
-llvm.func @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) {
-  // CHECK-LABEL: define void @mbarrier_test_wait_shared(ptr addrspace(3) %0, i64 %1) {
-  // CHECK-NEXT: %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
-  // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.shared(ptr addrspace(3) %0, i64 %1)
-  // CHECK-NEXT: ret void
-  // CHECK-NEXT: }
-  %count = nvvm.read.ptx.sreg.ntid.x : i32
-  %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1
-  llvm.return
-}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
index 4ad76248b7e25..1c0e23516ffe1 100644
--- a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
@@ -47,3 +47,11 @@ llvm.func @mbarrier_complete_tx_scope(%barrier: !llvm.ptr<3>, %tx_count: i32) {
   nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<sys>} : !llvm.ptr<3>, i32
   llvm.return
 }
+
+// -----
+
+llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr<3>, %phase: i32) {
+  // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+  %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i1
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir
new file mode 100644
index 0000000000000..21ab72eeab167
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_test_wait_state(%barrier: !llvm.ptr, %state : i64) {
+  // CHECK-LABEL: define void @mbarrier_test_wait_state(ptr %0, i64 %1) {
+  // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1)
+  // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1)
+  // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1)
+  // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr, i64 -> i1
+  %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
+
+  %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1
+  %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
+  llvm.return
+}
+
+llvm.func @mbarrier_test_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) {
+  // CHECK-LABEL: define void @mbarrier_test_wait_shared_state(ptr addrspace(3) %0, i64 %1) {
+  // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
+  // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
+  // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
+  // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr<3>, i64 -> i1
+  %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
+
+  %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1
+  %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
+  llvm.return
+}
+
+llvm.func @mbarrier_test_wait_phase(%barrier: !llvm.ptr, %phase : i32) {
+  // CHECK-LABEL: define void @mbarrier_test_wait_phase(ptr %0, i32 %1) {
+  // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1)
+  // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1)
+  // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1)
+  // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr, i32 -> i1
+  %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
+
+  %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1
+  %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
+  llvm.return
+}
+
+llvm.func @mbarrier_test_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) {
+  // CHECK-LABEL: define void @mbarrier_test_wait_shared_phase(ptr addrspace(3) %0, i32 %1) {
+  // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+  // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+  // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+  // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1
+  %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
+
+  %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1
+  %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
+  llvm.return
+}

@durga4github
Copy link
Contributor Author

@grypp , I am inclined to change the name to "test_wait" instead of "test.wait", to align with the naming of other Ops in this family. Please let me know what do you think.

@durga4github durga4github force-pushed the durgadossr/mlir_blk_mbar_2 branch from ccce483 to f0ed106 Compare November 28, 2025 09:59
@durga4github durga4github force-pushed the durgadossr/mlir_blk_mbar_2 branch from f0ed106 to 7d39667 Compare November 28, 2025 10:02
@github-actions
Copy link

github-actions bot commented Nov 28, 2025

🐧 Linux x64 Test Results

  • 7159 tests passed
  • 595 tests skipped

@durga4github durga4github force-pushed the durgadossr/mlir_blk_mbar_2 branch from 7d39667 to a15c3ac Compare November 28, 2025 10:28
Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. But let's not change the name of the OP

@durga4github durga4github force-pushed the durgadossr/mlir_blk_mbar_2 branch from a15c3ac to 56c2a36 Compare December 1, 2025 13:28
This patch extends the mbarrier.test.wait Op to
support scope, semantics and phase-parity. This
completes updates to the test_wait up-to Blackwell.
lit tests are added to verify the lowering.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
@durga4github durga4github force-pushed the durgadossr/mlir_blk_mbar_2 branch from 56c2a36 to eb6f233 Compare December 1, 2025 17:20
@durga4github durga4github changed the title [MLIR][NVVM] Update mbarrier.test_wait Op [MLIR][NVVM] Update mbarrier.test.wait Op Dec 1, 2025
@durga4github durga4github merged commit 8c7c940 into llvm:main Dec 2, 2025
10 checks passed
@durga4github durga4github deleted the durgadossr/mlir_blk_mbar_2 branch December 2, 2025 07:02
kcloudy0717 pushed a commit to kcloudy0717/llvm-project that referenced this pull request Dec 4, 2025
This patch extends mbarrier.test_wait to support scope,
semantics, and phase-parity, completing the updates for
this Op up to Blackwell. Corresponding lit tests are added
to verify the lowering.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
honeygoyal pushed a commit to honeygoyal/llvm-project that referenced this pull request Dec 9, 2025
This patch extends mbarrier.test_wait to support scope,
semantics, and phase-parity, completing the updates for
this Op up to Blackwell. Corresponding lit tests are added
to verify the lowering.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants