Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 6 additions & 60 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,9 @@ def NVVM_MBarrierArriveNocompleteOp : NVVM_Op<"mbarrier.arrive.nocomplete">,
}

def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> {
Arguments<(ins
AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
I32:$txcount, PtxPredicate:$predicate)> {
let summary = "MBarrier Arrive with Expected Transaction Count";
let description = [{
The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation
Expand Down Expand Up @@ -773,28 +775,12 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
}];
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
}];
}

def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
let summary = "Shared MBarrier Arrive with Expected Transaction Count";
let description = [{
This Op is the same as `nvvm.mbarrier.arrive.expect_tx` except that the *mbarrier object*
should be accessed using a shared-memory pointer instead of a generic-memory pointer.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
}];
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
}];
}

def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
Arguments<(ins
AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
I32:$phase, I32:$ticks)> {
let summary = "MBarrier Potentially-Blocking Try Wait with Phase Parity";
let description = [{
The `nvvm.mbarrier.try_wait.parity` operation performs a potentially-blocking
Expand Down Expand Up @@ -847,46 +833,6 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
}];
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
return std::string(
"{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra.uni DONE; \n\t"
"bra.uni LAB_WAIT; \n\t"
"DONE: \n\t"
"}"
);
}
}];
}

def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
let summary = "Shared MBarrier Potentially-Blocking Try Wait with Phase Parity";
let description = [{
This Op is the same as `nvvm.mbarrier.try_wait.parity` except that the *mbarrier object*
should be accessed using a shared-memory pointer instead of a generic-memory pointer.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
}];
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
return std::string(
"{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra.uni DONE; \n\t"
"bra.uni LAB_WAIT; \n\t"
"DONE: \n\t"
"}"
);
}
}];
}

def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
Expand Down
14 changes: 0 additions & 14 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -922,13 +922,6 @@ struct NVGPUMBarrierArriveExpectTxLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(b, adaptor.getTxcount());

if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
op, barrier, txcount, adaptor.getPredicate());
return success();
}

rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
op, barrier, txcount, adaptor.getPredicate());
return success();
Expand All @@ -949,13 +942,6 @@ struct NVGPUMBarrierTryWaitParityLowering
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());

if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
op, barrier, phase, ticks);
return success();
}

rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
phase, ticks);
return success();
Expand Down
52 changes: 38 additions & 14 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ using namespace NVVM;

static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;

//===----------------------------------------------------------------------===//
// Helper/Utility methods
//===----------------------------------------------------------------------===//

static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
}

static bool isPtrInSharedCTASpace(mlir::Value ptr) {
return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
}

//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1741,26 +1754,37 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
//===----------------------------------------------------------------------===//

std::string NVVM::MBarrierInitOp::getPtx() {
unsigned addressSpace =
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
return (addressSpace == NVVMMemorySpace::Shared)
? std::string("mbarrier.init.shared.b64 [%0], %1;")
: std::string("mbarrier.init.b64 [%0], %1;");
bool isShared = isPtrInSharedCTASpace(getAddr());
return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
: std::string("mbarrier.init.b64 [%0], %1;");
}

//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//

static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
bool isShared = isPtrInSharedCTASpace(getAddr());
return isShared
? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
: std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
}

static bool isPtrInSharedCTASpace(mlir::Value ptr) {
return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
bool isShared = isPtrInSharedCTASpace(getAddr());
llvm::StringRef space = isShared ? ".shared" : "";

return llvm::formatv("{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity{0}.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra.uni DONE; \n\t"
"bra.uni LAB_WAIT; \n\t"
"DONE: \n\t"
"}",
space);
}

//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//

mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInitOp>(op);
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -603,14 +603,14 @@ func.func @mbarrier_txcount() {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
// CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
} else {
%txcount = arith.constant 0 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
// CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
}
Expand All @@ -620,7 +620,7 @@ func.func @mbarrier_txcount() {
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
// CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType

func.return
Expand Down Expand Up @@ -649,14 +649,14 @@ func.func @mbarrier_txcount_pred() {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
// CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType

%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
// CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType

func.return
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %cou
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
llvm.return
}

Expand All @@ -44,7 +44,7 @@ llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32,
// CHECK-SAME: DONE:
// CHECK-SAME: }",
// CHECK-SAME: "r,r,r"
nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
llvm.return
}

Expand Down
Loading