diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index a96d65d3fcacd..308bae21d98e9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1130,6 +1130,47 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait"> { }]; } +def NVVM_MBarrierTryWaitOp : NVVM_Op<"mbarrier.try_wait"> { + let summary = "MBarrier try wait on state or phase with an optional timelimit"; + let description = [{ + The `nvvm.mbarrier.try_wait` operation checks whether the specified + *mbarrier object* at `addr` has completed the given phase. Note that + unlike the `nvvm.mbarrier.test.wait` operation, the try_wait operation + is a potentially-blocking one. If the phase is not yet complete, the + calling thread may be suspended. A suspended thread resumes execution + once the phase completes or when a system-defined timeout occurs. + Optionally, the `ticks` operand can be used to provide a custom timeout + (in nanoseconds), overriding the system-defined one. The semantics of + this operation and its operands are otherwise similar to those of the + `nvvm.mbarrier.test.wait` Op. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait) + }]; + + let results = (outs I1:$res); + let arguments = (ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, + AnyTypeOf<[I64, I32]>:$stateOrPhase, + Optional:$ticks, + DefaultValuedAttr:$scope, + DefaultValuedAttr:$relaxed); + + let assemblyFormat = "$addr `,` $stateOrPhase (`,` $ticks^)? attr-dict `:` type(operands) `->` type($res)"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MBarrierTryWaitOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; +} + //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index ada4223ac12de..cb0d70361aec9 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -314,6 +314,10 @@ LogicalResult MBarrierTestWaitOp::verify() { return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); } +LogicalResult MBarrierTryWaitOp::verify() { + return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); +} + LogicalResult ConvertFloatToTF32Op::verify() { using RndMode = NVVM::FPRoundingMode; switch (getRnd()) { @@ -2752,6 +2756,56 @@ mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs( return {id, {mbar, input}}; } +mlir::NVVM::IDArgPair MBarrierTryWaitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32); + bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER; + bool hasTicks = static_cast(thisOp.getTicks()); + // bit-0: isPhaseParity + // bit-1: Scope + // bit-2: hasTicks + size_t index = ((hasTicks ? 1 : 0) << 2) | ((isClusterScope ? 1 : 0) << 1) | + (isPhaseParity ? 1 : 0); + + // clang-format off + static constexpr llvm::Intrinsic::ID IDs[] = { + llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_scope_cluster_space_cta}; + static constexpr llvm::Intrinsic::ID relaxedIDs[] = { + llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cta_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_tl_relaxed_scope_cluster_space_cta, + llvm::Intrinsic::nvvm_mbarrier_try_wait_parity_tl_relaxed_scope_cluster_space_cta}; + // clang-format on + auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index]; + + // Tidy-up the mbarrier pointer + llvm::Value *mbar = mt.lookupValue(thisOp.getAddr()); + bool needCast = isPtrInGenericSpace(thisOp.getAddr()); + if (needCast) + mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared); + + // Fill the Intrinsic Args + llvm::SmallVector args; + args.push_back(mbar); + args.push_back(mt.lookupValue(thisOp.getStateOrPhase())); + if (hasTicks) + args.push_back(mt.lookupValue(thisOp.getTicks())); + + return {id, std::move(args)}; +} + mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir index d8cb9853f3374..4a7776d86b28e 100644 --- a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir @@ -120,3 +120,19 @@ llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr<3>, %phase: i32) { llvm.return } +// ----- + +llvm.func @mbarrier_try_wait(%barrier: !llvm.ptr<3>, %phase: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} + +// ----- + +llvm.func @mbarrier_try_wait_with_timelimit(%barrier: !llvm.ptr<3>, %phase: i32, %ticks: i32) { + // expected-error @below {{mbarrier scope must be either CTA or Cluster}} + %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope} : !llvm.ptr<3>, i32, i32 -> i1 + llvm.return +} + diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir new file mode 100644 index 0000000000000..18aaf0e451e20 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_try_wait.mlir @@ -0,0 +1,147 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @mbarrier_try_wait_state(%barrier: !llvm.ptr, %state : i64) { + // CHECK-LABEL: define void @mbarrier_try_wait_state(ptr %0, i64 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state : !llvm.ptr, i64 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state {scope = #nvvm.mem_scope} : !llvm.ptr, i64 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr, i64 -> i1 + + llvm.return +} + +llvm.func @mbarrier_try_wait_state_with_timelimit(%barrier: !llvm.ptr, %state : i64, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_state_with_timelimit(ptr %0, i64 %1, i32 %2) { + // CHECK-NEXT: %4 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cta.space.cta(ptr addrspace(3) %4, i64 %1, i32 %2) + // CHECK-NEXT: %6 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cluster.space.cta(ptr addrspace(3) %6, i64 %1, i32 %2) + // CHECK-NEXT: %8 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %9 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %8, i64 %1, i32 %2) + // CHECK-NEXT: %10 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %11 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %10, i64 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state, %ticks : !llvm.ptr, i64, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {scope = #nvvm.mem_scope} : !llvm.ptr, i64, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true} : !llvm.ptr, i64, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr, i64, i32 -> i1 + + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_state(ptr addrspace(3) %0, i64 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state : !llvm.ptr<3>, i64 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state {scope = #nvvm.mem_scope} : !llvm.ptr<3>, i64 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr<3>, i64 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_state_with_timelimit(%barrier: !llvm.ptr<3>, %state : i64, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_state_with_timelimit(ptr addrspace(3) %0, i64 %1, i32 %2) { + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %state, %ticks : !llvm.ptr<3>, i64, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {scope = #nvvm.mem_scope} : !llvm.ptr<3>, i64, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true} : !llvm.ptr<3>, i64, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %state, %ticks {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr<3>, i64, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_phase(%barrier: !llvm.ptr, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_phase(ptr %0, i32 %1) { + // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1) + // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1) + // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1) + // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase : !llvm.ptr, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope} : !llvm.ptr, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_phase_with_timelimit(%barrier: !llvm.ptr, %phase : i32, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_phase_with_timelimit(ptr %0, i32 %1, i32 %2) { + // CHECK-NEXT: %4 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cta.space.cta(ptr addrspace(3) %4, i32 %1, i32 %2) + // CHECK-NEXT: %6 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cluster.space.cta(ptr addrspace(3) %6, i32 %1, i32 %2) + // CHECK-NEXT: %8 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %9 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %8, i32 %1, i32 %2) + // CHECK-NEXT: %10 = addrspacecast ptr %0 to ptr addrspace(3) + // CHECK-NEXT: %11 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %10, i32 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks : !llvm.ptr, i32, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope} : !llvm.ptr, i32, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true} : !llvm.ptr, i32, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr, i32, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_phase(ptr addrspace(3) %0, i32 %1) { + // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase {scope = #nvvm.mem_scope} : !llvm.ptr<3>, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr<3>, i32 -> i1 + llvm.return +} + +llvm.func @mbarrier_try_wait_shared_phase_with_timelimit(%barrier: !llvm.ptr<3>, %phase : i32, %ticks : i32) { + // CHECK-LABEL: define void @mbarrier_try_wait_shared_phase_with_timelimit(ptr addrspace(3) %0, i32 %1, i32 %2) { + // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: %7 = call i1 @llvm.nvvm.mbarrier.try.wait.parity.tl.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1, i32 %2) + // CHECK-NEXT: ret void + // CHECK-NEXT: } + %0 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 -> i1 + %1 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {scope = #nvvm.mem_scope} : !llvm.ptr<3>, i32, i32 -> i1 + + %2 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true} : !llvm.ptr<3>, i32, i32 -> i1 + %3 = nvvm.mbarrier.try_wait %barrier, %phase, %ticks {relaxed = true, scope = #nvvm.mem_scope} : !llvm.ptr<3>, i32, i32 -> i1 + llvm.return +}