diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index b3395b7e0a24e..a96d65d3fcacd 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1052,31 +1052,35 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity" let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; } -def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">, - Results<(outs I1:$res)>, - Arguments<(ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, - I64:$state)> { +def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait"> { let summary = "MBarrier Non-Blocking Test Wait Operation"; let description = [{ - The `nvvm.mbarrier.test.wait` operation performs a non-blocking test for the + The `nvvm.mbarrier.test.wait` operation performs a non-blocking test for the completion of a specific phase of an *mbarrier object*. It uses the default - `.acquire.cta` semantics. This acquire pattern establishes memory ordering for - operations occurring in program order after this wait instruction by making - operations from other threads in the CTA visible to subsequent operations in the current - thread. When this wait completes, it synchronizes with the corresponding release - pattern from the `mbarrier.arrive` operation, establishing memory ordering within + `.acquire.cta` semantics. This acquire pattern establishes memory ordering for + operations occurring in program order after this wait instruction by making + operations from other threads in the CTA visible to subsequent operations in the current + thread. When this wait completes, it synchronizes with the corresponding release + pattern from the `mbarrier.arrive` operation, establishing memory ordering within the CTA. - This operation tests whether the mbarrier phase specified by the state operand - has completed. It is a non-blocking instruction that immediately returns the + This operation tests whether the mbarrier phase specified by the state operand + has completed. It is a non-blocking instruction that immediately returns the completion status without suspending the executing thread. The operation takes the following operands: - - `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic + - `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic addressing, but the address must still be in the shared memory space. - - `state`: An opaque value returned by a previous `mbarrier.arrive` - operation on the same *mbarrier object* during the current or immediately - preceding phase. + - `stateOrPhase`: This argument represents a `state` when it is a 64-bit value + and represents a `phase` when it is a 32-bit value. The `state` is an opaque + value returned by a previous `mbarrier.arrive` operation on the same + *mbarrier object* during the current or immediately preceding phase. + The `phase` is an integer specifying the phase parity (0 or 1). + Even phases have parity 0, odd phases have parity 1. + - `scope`: This specifies the set of threads that directly observe the memory + synchronizing effect of the `mbarrier.test.wait` operation. + - `relaxed`: When set to true, the `arrive` operation has relaxed memory semantics + and does not provide any ordering or visibility guarantees. The operation returns a boolean value indicating whether the specified phase has completed: @@ -1103,7 +1107,15 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">, [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait) }]; - let assemblyFormat = "$addr `,` $state attr-dict `:` type(operands) `->` type($res)"; + let results = (outs I1:$res); + let arguments = (ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, + AnyTypeOf<[I64, I32]>:$stateOrPhase, + DefaultValuedAttr:$scope, + DefaultValuedAttr:$relaxed); + + let assemblyFormat = "$addr `,` $stateOrPhase attr-dict `:` type(operands) `->` type($res)"; + let hasVerifier = 1; let extraClassDeclaration = [{ static mlir::NVVM::IDArgPair diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 413125245aca8..ada4223ac12de 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -252,10 +252,10 @@ LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() { static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal = nullptr) { - bool isSharedCluster = isPtrInSharedClusterSpace(addr); if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER) return op->emitError("mbarrier scope must be either CTA or Cluster"); + bool isSharedCluster = isPtrInSharedClusterSpace(addr); bool hasRetValue = static_cast(retVal); if (isSharedCluster && hasRetValue) return op->emitError( @@ -310,6 +310,10 @@ LogicalResult MBarrierCompleteTxOp::verify() { return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); } +LogicalResult MBarrierTestWaitOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + LogicalResult ConvertFloatToTF32Op::verify() { using RndMode = NVVM::FPRoundingMode; switch (getRnd()) { @@ -2718,16 +2722,34 @@ mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs( mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); - bool isShared = isPtrInSharedCTASpace(thisOp.getAddr()); - llvm::Intrinsic::ID id = isShared - ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared - : llvm::Intrinsic::nvvm_mbarrier_test_wait; - // Fill the Intrinsic Args - llvm::SmallVector args; - args.push_back(mt.lookupValue(thisOp.getAddr())); - args.push_back(mt.lookupValue(thisOp.getState())); + bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + // bit-0: isPhaseParity + // bit-1: Scope + size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0); - return {id, std::move(args)}; + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the Intrinsic Args + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + return {id, {mbar, input}}; } mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs( diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index cd7bd37da5763..6f67a50c1a946 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -464,19 +464,6 @@ llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) { llvm.return } -llvm.func private @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 { - // CHECK: nvvm.mbarrier.test.wait %{{.*}} - %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1 - llvm.return %isComplete : i1 -} - -llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) { - %count = nvvm.read.ptx.sreg.ntid.x : i32 - // CHECK: nvvm.mbarrier.test.wait %{{.*}} - %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1 - llvm.return -} - // CHECK-LABEL: @wgmma_fence_aligned func.func @wgmma_fence_aligned() { // CHECK: nvvm.wgmma.fence.aligned diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir index ae9c7f29bc7a5..9c1d1cc0cdc31 100644 --- a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir @@ -54,23 +54,3 @@ llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) { nvvm.mbarrier.inval %barrier : !llvm.ptr<3> llvm.return } - -llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 { - // CHECK-LABEL: define i1 @mbarrier_test_wait(ptr %0, i64 %1) { - // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait(ptr %0, i64 %1) - // CHECK-NEXT: ret i1 %3 - // CHECK-NEXT: } - %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1 - llvm.return %isComplete : i1 -} - -llvm.func @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) { - // CHECK-LABEL: define void @mbarrier_test_wait_shared(ptr addrspace(3) %0, i64 %1) { - // CHECK-NEXT: %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() - // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.shared(ptr addrspace(3) %0, i64 %1) - // CHECK-NEXT: ret void - // CHECK-NEXT: } - %count = nvvm.read.ptx.sreg.ntid.x : i32 - %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1 - llvm.return -} diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir index 2bb90943d4ce1..d8cb9853f3374 100644 --- a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir @@ -112,3 +112,11 @@ llvm.func @mbarrier_arr_drop_expect_tx_cluster(%barrier: !llvm.ptr<7>, %tx_count llvm.return } +// ----- + +llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr<3>, %phase: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} + diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir new file mode 100644 index 0000000000000..21ab72eeab167 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_test_wait_state(%barrier: !llvm.ptr, %state : i64) { + // CHECK-LABEL: define void @mbarrier_test_wait_state(ptr %0, i64 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr, i64 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope} : !llvm.ptr, i64 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr, i64 -> i1 + llvm.return +} + +llvm.func @mbarrier_test_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) { + // CHECK-LABEL: define void @mbarrier_test_wait_shared_state(ptr addrspace(3) %0, i64 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr<3>, i64 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope} : !llvm.ptr<3>, i64 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr<3>, i64 -> i1 + llvm.return +} + +llvm.func @mbarrier_test_wait_phase(%barrier: !llvm.ptr, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_test_wait_phase(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr, i32 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope} : !llvm.ptr, i32 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_test_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_test_wait_shared_phase(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1 + %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope} : !llvm.ptr<3>, i32 -> i1 + + %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1 + %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr<3>, i32 -> i1 + llvm.return +}