diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 1cc5b74a3cb67..13a2af0efe87d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -745,7 +745,9 @@ def NVVM_MBarrierArriveNocompleteOp : NVVM_Op<"mbarrier.arrive.nocomplete">, } def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">, - Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> { + Arguments<(ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, + I32:$txcount, PtxPredicate:$predicate)> { let summary = "MBarrier Arrive with Expected Transaction Count"; let description = [{ The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation @@ -773,28 +775,12 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive) }]; let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); } - }]; -} - -def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">, - Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> { - let summary = "Shared MBarrier Arrive with Expected Transaction Count"; - let description = [{ - This Op is the same as `nvvm.mbarrier.arrive.expect_tx` except that the *mbarrier object* - should be accessed using a shared-memory pointer instead of a generic-memory pointer. - - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive) - }]; - let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); } - }]; } def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">, - Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> { + Arguments<(ins + AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, + I32:$phase, I32:$ticks)> { let summary = "MBarrier Potentially-Blocking Try Wait with Phase Parity"; let description = [{ The `nvvm.mbarrier.try_wait.parity` operation performs a potentially-blocking @@ -847,46 +833,6 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity" [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait) }]; let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - return std::string( - "{\n\t" - ".reg .pred P1; \n\t" - "LAB_WAIT: \n\t" - "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t" - "@P1 bra.uni DONE; \n\t" - "bra.uni LAB_WAIT; \n\t" - "DONE: \n\t" - "}" - ); - } - }]; -} - -def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">, - Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> { - let summary = "Shared MBarrier Potentially-Blocking Try Wait with Phase Parity"; - let description = [{ - This Op is the same as `nvvm.mbarrier.try_wait.parity` except that the *mbarrier object* - should be accessed using a shared-memory pointer instead of a generic-memory pointer. - - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait) - }]; - let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - return std::string( - "{\n\t" - ".reg .pred P1; \n\t" - "LAB_WAIT: \n\t" - "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" - "@P1 bra.uni DONE; \n\t" - "bra.uni LAB_WAIT; \n\t" - "DONE: \n\t" - "}" - ); - } - }]; } def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">, diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 9348d3c172a07..3a70f787da124 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -922,13 +922,6 @@ struct NVGPUMBarrierArriveExpectTxLowering getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Value txcount = truncToI32(b, adaptor.getTxcount()); - - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp( - op, barrier, txcount, adaptor.getPredicate()); - return success(); - } - rewriter.replaceOpWithNewOp( op, barrier, txcount, adaptor.getPredicate()); return success(); @@ -949,13 +942,6 @@ struct NVGPUMBarrierTryWaitParityLowering Value ticks = truncToI32(b, adaptor.getTicks()); Value phase = LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity()); - - if (isMbarrierShared(op.getBarriers().getType())) { - rewriter.replaceOpWithNewOp( - op, barrier, phase, ticks); - return success(); - } - rewriter.replaceOpWithNewOp(op, barrier, phase, ticks); return success(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index d43f8815be16d..e0c25ab6cdef7 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -47,6 +47,19 @@ using namespace NVVM; static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic; +//===----------------------------------------------------------------------===// +// Helper/Utility methods +//===----------------------------------------------------------------------===// + +static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { + auto ptrTy = llvm::cast(ptr.getType()); + return ptrTy.getAddressSpace() == static_cast(targetAS); +} + +static bool isPtrInSharedCTASpace(mlir::Value ptr) { + return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); +} + //===----------------------------------------------------------------------===// // Verifier methods //===----------------------------------------------------------------------===// @@ -1741,26 +1754,37 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op, //===----------------------------------------------------------------------===// std::string NVVM::MBarrierInitOp::getPtx() { - unsigned addressSpace = - llvm::cast(getAddr().getType()).getAddressSpace(); - return (addressSpace == NVVMMemorySpace::Shared) - ? std::string("mbarrier.init.shared.b64 [%0], %1;") - : std::string("mbarrier.init.b64 [%0], %1;"); + bool isShared = isPtrInSharedCTASpace(getAddr()); + return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;") + : std::string("mbarrier.init.b64 [%0], %1;"); } -//===----------------------------------------------------------------------===// -// getIntrinsicID/getIntrinsicIDAndArgs methods -//===----------------------------------------------------------------------===// - -static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) { - auto ptrTy = llvm::cast(ptr.getType()); - return ptrTy.getAddressSpace() == static_cast(targetAS); +std::string NVVM::MBarrierArriveExpectTxOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + return isShared + ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;") + : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); } -static bool isPtrInSharedCTASpace(mlir::Value ptr) { - return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared); +std::string NVVM::MBarrierTryWaitParityOp::getPtx() { + bool isShared = isPtrInSharedCTASpace(getAddr()); + llvm::StringRef space = isShared ? ".shared" : ""; + + return llvm::formatv("{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}", + space); } +//===----------------------------------------------------------------------===// +// getIntrinsicID/getIntrinsicIDAndArgs methods +//===----------------------------------------------------------------------===// + mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index dcf4ddb2dd48c..0eb44789fe31d 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -603,14 +603,14 @@ func.func @mbarrier_txcount() { %txcount = arith.constant 256 : index // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] + // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType scf.yield } else { %txcount = arith.constant 0 : index // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] + // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType scf.yield } @@ -620,7 +620,7 @@ func.func @mbarrier_txcount() { %ticks = arith.constant 10000000 : index // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] + // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]] nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType func.return @@ -649,14 +649,14 @@ func.func @mbarrier_txcount_pred() { %txcount = arith.constant 256 : index // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]] + // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]], {{.*}}, predicate = %[[P]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType %phase_c0 = arith.constant 0 : i1 %ticks = arith.constant 10000000 : index // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] + // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]] nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType func.return diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index a9356c5cb60bb..a94fcb4856db4 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -17,9 +17,9 @@ llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %cou // CHECK-LABEL: @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) { //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r" - nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b" - nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 llvm.return } @@ -44,7 +44,7 @@ llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, // CHECK-SAME: DONE: // CHECK-SAME: }", // CHECK-SAME: "r,r,r" - nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 + nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 llvm.return }