diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 2df9349269a69..fd4a6c5897364 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -6542,7 +6542,8 @@ IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType, mlir::Value arg1 = builder.create(loc, builder.getI1Type(), args[1]); return builder - .create(loc, resultType, args[0], arg1) + .create(loc, resultType, args[0], arg1, + mlir::NVVM::VoteSyncKind::ballot) .getResult(); } diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index a7f9038761b51..7d6d920dfb2e8 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -303,7 +303,7 @@ end subroutine ! CHECK-LABEL: func.func @_QPtestvote() ! CHECK: fir.call @llvm.nvvm.vote.all.sync ! CHECK: fir.call @llvm.nvvm.vote.any.sync -! CHECK: %{{.*}} = nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32 +! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32 ! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref>, !fir.ref>) ! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref>, !fir.ref>) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 8a54804b220a1..0a6e66919f021 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -808,15 +808,49 @@ def NVVM_ShflOp : let hasVerifier = 1; } -def NVVM_VoteBallotOp : - NVVM_Op<"vote.ballot.sync">, - Results<(outs LLVM_Type:$res)>, - Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> { +def VoteSyncKindAny : I32EnumAttrCase<"any", 0>; +def VoteSyncKindAll : I32EnumAttrCase<"all", 1>; +def VoteSyncKindBallot : I32EnumAttrCase<"ballot", 2>; +def VoteSyncKindUni : I32EnumAttrCase<"uni", 3>; + +def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind", + [VoteSyncKindAny, VoteSyncKindAll, + VoteSyncKindBallot, VoteSyncKindUni]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} + +def VoteSyncKindAttr : EnumAttr; + +def NVVM_VoteSyncOp + : NVVM_Op<"vote.sync">, + Results<(outs AnyTypeOf<[I32, I1]>:$res)>, + Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> { + let summary = "Vote across thread group"; + let description = [{ + The `vote.sync` op will cause executing thread to wait until all non-exited + threads corresponding to membermask have executed `vote.sync` with the same + qualifiers and same membermask value before resuming execution. + + The vote operation kinds are: + - `any`: True if source predicate is True for some thread in membermask. + - `all`: True if source predicate is True for all non-exited threads in + membermask. + - `uni`: True if source predicate has the same value in all non-exited + threads in membermask. + - `ballot`: In the ballot form, the destination result is a 32 bit integer. + In this form, the predicate from each thread in membermask are copied into + the corresponding bit position of the result, where the bit position + corresponds to the thread’s lane id. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync) + }]; string llvmBuilder = [{ - $res = createIntrinsicCall(builder, - llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred}); + auto intId = getVoteSyncIntrinsicId($kind); + $res = createIntrinsicCall(builder, intId, {$mask, $pred}); }]; - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "$kind $mask `,` $pred attr-dict `->` type($res)"; + let hasVerifier = 1; } def NVVM_SyncWarpOp : diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 556114f4370b3..09bff6101edd3 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -48,34 +48,6 @@ using namespace NVVM; #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" -//===----------------------------------------------------------------------===// -// Printing/parsing for NVVM ops -//===----------------------------------------------------------------------===// - -static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { - p << " " << op->getOperands(); - if (op->getNumResults() > 0) - p << " : " << op->getResultTypes(); -} - -// ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type -ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) { - MLIRContext *context = parser.getContext(); - auto int32Ty = IntegerType::get(context, 32); - auto int1Ty = IntegerType::get(context, 1); - - SmallVector ops; - Type type; - return failure(parser.parseOperandList(ops) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.addTypeToList(type, result.types) || - parser.resolveOperands(ops, {int32Ty, int1Ty}, - parser.getNameLoc(), result.operands)); -} - -void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } - //===----------------------------------------------------------------------===// // Verifier methods //===----------------------------------------------------------------------===// @@ -1160,6 +1132,19 @@ LogicalResult NVVM::MatchSyncOp::verify() { return success(); } +LogicalResult NVVM::VoteSyncOp::verify() { + if (getKind() == NVVM::VoteSyncKind::ballot) { + if (!getType().isInteger(32)) { + return emitOpError("vote.sync 'ballot' returns an i32"); + } + } else { + if (!getType().isInteger(1)) { + return emitOpError("vote.sync 'any', 'all' and 'uni' returns an i1"); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 9d14ff09ab434..beff90237562d 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -121,6 +121,21 @@ static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, } } +static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) { + switch (kind) { + case NVVM::VoteSyncKind::any: + return llvm::Intrinsic::nvvm_vote_any_sync; + case NVVM::VoteSyncKind::all: + return llvm::Intrinsic::nvvm_vote_all_sync; + case NVVM::VoteSyncKind::ballot: + return llvm::Intrinsic::nvvm_vote_ballot_sync; + case NVVM::VoteSyncKind::uni: + return llvm::Intrinsic::nvvm_vote_uni_sync; + default: + llvm_unreachable("unsupported vote kind"); + } +} + /// Return the intrinsic ID associated with ldmatrix for the given paramters. static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num) { diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index 18bf39424f0bf..d3915492c38a0 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -129,8 +129,14 @@ func.func @nvvm_shfl_pred( // CHECK-LABEL: @nvvm_vote( func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 { - // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32 - %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32 + // CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32 + %0 = nvvm.vote.sync ballot %arg0, %arg1 -> i32 + // CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1 + %1 = nvvm.vote.sync all %arg0, %arg1 -> i1 + // CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1 + %2 = nvvm.vote.sync any %arg0, %arg1 -> i1 + // CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1 + %3 = nvvm.vote.sync uni %arg0, %arg1 -> i1 llvm.return %0 : i32 } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index c3ec88db1d694..3a0713f2feee8 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -255,7 +255,13 @@ llvm.func @nvvm_shfl_pred( // CHECK-LABEL: @nvvm_vote llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 { // CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}}) - %3 = nvvm.vote.ballot.sync %0, %1 : i32 + %3 = nvvm.vote.sync ballot %0, %1 -> i32 + // CHECK: call i1 @llvm.nvvm.vote.all.sync(i32 %{{.*}}, i1 %{{.*}}) + %4 = nvvm.vote.sync all %0, %1 -> i1 + // CHECK: call i1 @llvm.nvvm.vote.any.sync(i32 %{{.*}}, i1 %{{.*}}) + %5 = nvvm.vote.sync any %0, %1 -> i1 + // CHECK: call i1 @llvm.nvvm.vote.uni.sync(i32 %{{.*}}, i1 %{{.*}}) + %6 = nvvm.vote.sync uni %0, %1 -> i1 llvm.return %3 : i32 }