-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][NVVM] Update mbarrier Ops to use AnyTypeOf[] (3/3) #167567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Update mbarrier Ops to use AnyTypeOf[] (3/3) #167567
Conversation
|
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Durgadoss R (durga4github) ChangesThis is a follow-up of PR #165558 and #165993. This patch updates the remaining two Ops to use the AnyTypeOf[] Full diff: https://github.com/llvm/llvm-project/pull/167567.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 1cc5b74a3cb67..13a2af0efe87d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -745,7 +745,9 @@ def NVVM_MBarrierArriveNocompleteOp : NVVM_Op<"mbarrier.arrive.nocomplete">,
}
def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
- Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> {
+ Arguments<(ins
+ AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
+ I32:$txcount, PtxPredicate:$predicate)> {
let summary = "MBarrier Arrive with Expected Transaction Count";
let description = [{
The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation
@@ -773,28 +775,12 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
}];
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
- }];
-}
-
-def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
- Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
- let summary = "Shared MBarrier Arrive with Expected Transaction Count";
- let description = [{
- This Op is the same as `nvvm.mbarrier.arrive.expect_tx` except that the *mbarrier object*
- should be accessed using a shared-memory pointer instead of a generic-memory pointer.
-
- [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
- }];
- let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
- }];
}
def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
- Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
+ Arguments<(ins
+ AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
+ I32:$phase, I32:$ticks)> {
let summary = "MBarrier Potentially-Blocking Try Wait with Phase Parity";
let description = [{
The `nvvm.mbarrier.try_wait.parity` operation performs a potentially-blocking
@@ -847,46 +833,6 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
}];
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- return std::string(
- "{\n\t"
- ".reg .pred P1; \n\t"
- "LAB_WAIT: \n\t"
- "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
- "@P1 bra.uni DONE; \n\t"
- "bra.uni LAB_WAIT; \n\t"
- "DONE: \n\t"
- "}"
- );
- }
- }];
-}
-
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
- Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
- let summary = "Shared MBarrier Potentially-Blocking Try Wait with Phase Parity";
- let description = [{
- This Op is the same as `nvvm.mbarrier.try_wait.parity` except that the *mbarrier object*
- should be accessed using a shared-memory pointer instead of a generic-memory pointer.
-
- [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
- }];
- let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- return std::string(
- "{\n\t"
- ".reg .pred P1; \n\t"
- "LAB_WAIT: \n\t"
- "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
- "@P1 bra.uni DONE; \n\t"
- "bra.uni LAB_WAIT; \n\t"
- "DONE: \n\t"
- "}"
- );
- }
- }];
}
def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9348d3c172a07..3a70f787da124 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -922,13 +922,6 @@ struct NVGPUMBarrierArriveExpectTxLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(b, adaptor.getTxcount());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
- op, barrier, txcount, adaptor.getPredicate());
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
op, barrier, txcount, adaptor.getPredicate());
return success();
@@ -949,13 +942,6 @@ struct NVGPUMBarrierTryWaitParityLowering
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
- op, barrier, phase, ticks);
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
phase, ticks);
return success();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d43f8815be16d..d3bbda175093d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -47,6 +47,19 @@ using namespace NVVM;
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
+//===----------------------------------------------------------------------===//
+// Helper/Utility methods
+//===----------------------------------------------------------------------===//
+
+static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
+ auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
+ return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
+}
+
+static bool isPtrInSharedCTASpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+}
+
//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
@@ -1741,26 +1754,38 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
//===----------------------------------------------------------------------===//
std::string NVVM::MBarrierInitOp::getPtx() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
- return (addressSpace == NVVMMemorySpace::Shared)
- ? std::string("mbarrier.init.shared.b64 [%0], %1;")
- : std::string("mbarrier.init.b64 [%0], %1;");
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
+ : std::string("mbarrier.init.b64 [%0], %1;");
}
-//===----------------------------------------------------------------------===//
-// getIntrinsicID/getIntrinsicIDAndArgs methods
-//===----------------------------------------------------------------------===//
-
-static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
- auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
- return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
+std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared
+ ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
+ : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
}
-static bool isPtrInSharedCTASpace(mlir::Value ptr) {
- return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ std::string space = isShared ? ".shared" : "";
+
+ return "{\n\t"
+ ".reg .pred P1; \n\t"
+ "LAB_WAIT: \n\t"
+ "mbarrier.try_wait.parity" +
+ space +
+ ".b64 P1, [%0], %1, %2; \n\t"
+ "@P1 bra.uni DONE; \n\t"
+ "bra.uni LAB_WAIT; \n\t"
+ "DONE: \n\t"
+ "}";
}
+//===----------------------------------------------------------------------===//
+// getIntrinsicID/getIntrinsicIDAndArgs methods
+//===----------------------------------------------------------------------===//
+
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInitOp>(op);
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dcf4ddb2dd48c..0eb44789fe31d 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -603,14 +603,14 @@ func.func @mbarrier_txcount() {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
} else {
%txcount = arith.constant 0 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
}
@@ -620,7 +620,7 @@ func.func @mbarrier_txcount() {
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
func.return
@@ -649,14 +649,14 @@ func.func @mbarrier_txcount_pred() {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
func.return
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9356c5cb60bb..a94fcb4856db4 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -17,9 +17,9 @@ llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %cou
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
- nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
- nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
llvm.return
}
@@ -44,7 +44,7 @@ llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32,
// CHECK-SAME: DONE:
// CHECK-SAME: }",
// CHECK-SAME: "r,r,r"
- nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
+ nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
llvm.return
}
|
|
@llvm/pr-subscribers-mlir-gpu Author: Durgadoss R (durga4github) ChangesThis is a follow-up of PR #165558 and #165993. This patch updates the remaining two Ops to use the AnyTypeOf[] Full diff: https://github.com/llvm/llvm-project/pull/167567.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 1cc5b74a3cb67..13a2af0efe87d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -745,7 +745,9 @@ def NVVM_MBarrierArriveNocompleteOp : NVVM_Op<"mbarrier.arrive.nocomplete">,
}
def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
- Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> {
+ Arguments<(ins
+ AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
+ I32:$txcount, PtxPredicate:$predicate)> {
let summary = "MBarrier Arrive with Expected Transaction Count";
let description = [{
The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation
@@ -773,28 +775,12 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
}];
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
- }];
-}
-
-def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
- Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
- let summary = "Shared MBarrier Arrive with Expected Transaction Count";
- let description = [{
- This Op is the same as `nvvm.mbarrier.arrive.expect_tx` except that the *mbarrier object*
- should be accessed using a shared-memory pointer instead of a generic-memory pointer.
-
- [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
- }];
- let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
- }];
}
def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
- Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
+ Arguments<(ins
+ AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
+ I32:$phase, I32:$ticks)> {
let summary = "MBarrier Potentially-Blocking Try Wait with Phase Parity";
let description = [{
The `nvvm.mbarrier.try_wait.parity` operation performs a potentially-blocking
@@ -847,46 +833,6 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
}];
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- return std::string(
- "{\n\t"
- ".reg .pred P1; \n\t"
- "LAB_WAIT: \n\t"
- "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
- "@P1 bra.uni DONE; \n\t"
- "bra.uni LAB_WAIT; \n\t"
- "DONE: \n\t"
- "}"
- );
- }
- }];
-}
-
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
- Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
- let summary = "Shared MBarrier Potentially-Blocking Try Wait with Phase Parity";
- let description = [{
- This Op is the same as `nvvm.mbarrier.try_wait.parity` except that the *mbarrier object*
- should be accessed using a shared-memory pointer instead of a generic-memory pointer.
-
- [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
- }];
- let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- return std::string(
- "{\n\t"
- ".reg .pred P1; \n\t"
- "LAB_WAIT: \n\t"
- "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
- "@P1 bra.uni DONE; \n\t"
- "bra.uni LAB_WAIT; \n\t"
- "DONE: \n\t"
- "}"
- );
- }
- }];
}
def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9348d3c172a07..3a70f787da124 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -922,13 +922,6 @@ struct NVGPUMBarrierArriveExpectTxLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(b, adaptor.getTxcount());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
- op, barrier, txcount, adaptor.getPredicate());
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
op, barrier, txcount, adaptor.getPredicate());
return success();
@@ -949,13 +942,6 @@ struct NVGPUMBarrierTryWaitParityLowering
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
- op, barrier, phase, ticks);
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
phase, ticks);
return success();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d43f8815be16d..d3bbda175093d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -47,6 +47,19 @@ using namespace NVVM;
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
+//===----------------------------------------------------------------------===//
+// Helper/Utility methods
+//===----------------------------------------------------------------------===//
+
+static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
+ auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
+ return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
+}
+
+static bool isPtrInSharedCTASpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+}
+
//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
@@ -1741,26 +1754,38 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
//===----------------------------------------------------------------------===//
std::string NVVM::MBarrierInitOp::getPtx() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
- return (addressSpace == NVVMMemorySpace::Shared)
- ? std::string("mbarrier.init.shared.b64 [%0], %1;")
- : std::string("mbarrier.init.b64 [%0], %1;");
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
+ : std::string("mbarrier.init.b64 [%0], %1;");
}
-//===----------------------------------------------------------------------===//
-// getIntrinsicID/getIntrinsicIDAndArgs methods
-//===----------------------------------------------------------------------===//
-
-static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
- auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
- return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
+std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared
+ ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
+ : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
}
-static bool isPtrInSharedCTASpace(mlir::Value ptr) {
- return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ std::string space = isShared ? ".shared" : "";
+
+ return "{\n\t"
+ ".reg .pred P1; \n\t"
+ "LAB_WAIT: \n\t"
+ "mbarrier.try_wait.parity" +
+ space +
+ ".b64 P1, [%0], %1, %2; \n\t"
+ "@P1 bra.uni DONE; \n\t"
+ "bra.uni LAB_WAIT; \n\t"
+ "DONE: \n\t"
+ "}";
}
+//===----------------------------------------------------------------------===//
+// getIntrinsicID/getIntrinsicIDAndArgs methods
+//===----------------------------------------------------------------------===//
+
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInitOp>(op);
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dcf4ddb2dd48c..0eb44789fe31d 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -603,14 +603,14 @@ func.func @mbarrier_txcount() {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
} else {
%txcount = arith.constant 0 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
}
@@ -620,7 +620,7 @@ func.func @mbarrier_txcount() {
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
func.return
@@ -649,14 +649,14 @@ func.func @mbarrier_txcount_pred() {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
func.return
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9356c5cb60bb..a94fcb4856db4 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -17,9 +17,9 @@ llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %cou
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
- nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
- nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
llvm.return
}
@@ -44,7 +44,7 @@ llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32,
// CHECK-SAME: DONE:
// CHECK-SAME: }",
// CHECK-SAME: "r,r,r"
- nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
+ nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
llvm.return
}
|
This patch updates the remaining two Ops to use the AnyTypeOf[] construct, completing the migration for the mbarrier family of Ops. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
1d721e5 to
8efa918
Compare
This is a follow-up of PR llvm#165558 and llvm#165993. This patch updates the remaining two Ops to use the AnyTypeOf[] construct, completing the migration for the mbarrier family of Ops. ``` mbarrier.arrive.expect_tx mbarrier.try_wait.parity ``` Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
This is a follow-up of PR #165558 and #165993.
This patch updates the remaining two Ops to use the AnyTypeOf[]
construct, completing the migration for the mbarrier family of Ops.