diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index f8e3167b42c35..bfbf8edef58f8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -3256,35 +3256,15 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp : attr-dict `:` type($dstMem) `,` type($srcMem) }]; + let extraClassDeclaration = [{ + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]; string llvmBuilder = [{ - // Arguments to the intrinsic: - // dst, mbar, src, size - // multicast_mask, cache_hint, - // flag for multicast_mask, - // flag for cache_hint - llvm::SmallVector translatedOperands; - translatedOperands.push_back($dstMem); - translatedOperands.push_back($mbar); - translatedOperands.push_back($srcMem); - translatedOperands.push_back($size); - - // Multicast, if available - llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext(); - auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0); - bool isMulticast = op.getMulticastMask() ? true : false; - translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused); - - // Cachehint, if available - auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0); - bool isCacheHint = op.getL2CacheHint() ? true : false; - translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused); - - // Flag arguments for multicast and cachehint - translatedOperands.push_back(builder.getInt1(isMulticast)); - translatedOperands.push_back(builder.getInt1(isCacheHint)); - - createIntrinsicCall(builder, - llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands); + auto [id, args] = NVVM::CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, id, args); }]; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e8f8824d47de0..81bf99476d092 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1555,6 +1555,39 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs( return {id, std::move(args)}; } +mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + llvm::SmallVector args; + + // Fill the Intrinsic Args: dst, mbar, src, size. + args.push_back(mt.lookupValue(thisOp.getDstMem())); + args.push_back(mt.lookupValue(thisOp.getMbar())); + args.push_back(mt.lookupValue(thisOp.getSrcMem())); + args.push_back(mt.lookupValue(thisOp.getSize())); + + // Multicast mask, if available. + mlir::Value multicastMask = thisOp.getMulticastMask(); + const bool hasMulticastMask = static_cast(multicastMask); + llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); + args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); + + // Cache hint, if available. + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast(cacheHint); + llvm::Value *i64Unused = llvm::ConstantInt::get(builder.getInt64Ty(), 0); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); + + // Flag arguments for multicast and cachehint. + args.push_back(builder.getInt1(hasMulticastMask)); + args.push_back(builder.getInt1(hasCacheHint)); + + llvm::Intrinsic::ID id = + llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; + + return {id, std::move(args)}; +} + mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op);