-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[NFC][MLIR] Refactor NVVM_CpAsyncBulkGlobalToSharedClusterOp's lowering #162611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NFC][MLIR] Refactor NVVM_CpAsyncBulkGlobalToSharedClusterOp's lowering #162611
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Dharuni R Acharya (DharuniRAcharya) ChangesThis patch refactors Full diff: https://github.com/llvm/llvm-project/pull/162611.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f8e3167b42c35..bfbf8edef58f8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3256,35 +3256,15 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
attr-dict `:` type($dstMem) `,` type($srcMem)
}];
+ let extraClassDeclaration = [{
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }];
string llvmBuilder = [{
- // Arguments to the intrinsic:
- // dst, mbar, src, size
- // multicast_mask, cache_hint,
- // flag for multicast_mask,
- // flag for cache_hint
- llvm::SmallVector<llvm::Value *> translatedOperands;
- translatedOperands.push_back($dstMem);
- translatedOperands.push_back($mbar);
- translatedOperands.push_back($srcMem);
- translatedOperands.push_back($size);
-
- // Multicast, if available
- llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
- auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0);
- bool isMulticast = op.getMulticastMask() ? true : false;
- translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused);
-
- // Cachehint, if available
- auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
- bool isCacheHint = op.getL2CacheHint() ? true : false;
- translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
-
- // Flag arguments for multicast and cachehint
- translatedOperands.push_back(builder.getInt1(isMulticast));
- translatedOperands.push_back(builder.getInt1(isCacheHint));
-
- createIntrinsicCall(builder,
- llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands);
+ auto [id, args] = NVVM::CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ createIntrinsicCall(builder, id, args);
}];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e8f8824d47de0..7d2e1b0983e11 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1555,6 +1555,41 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
+mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkGlobalToSharedClusterOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args: dst, mbar, src, size.
+ args.push_back(mt.lookupValue(thisOp.getDstMem()));
+ args.push_back(mt.lookupValue(thisOp.getMbar()));
+ args.push_back(mt.lookupValue(thisOp.getSrcMem()));
+ args.push_back(mt.lookupValue(thisOp.getSize()));
+
+ // Multicast mask, if available.
+ mlir::Value multicastMask = thisOp.getMulticastMask();
+ const bool hasMulticastMask = static_cast<bool>(multicastMask);
+ llvm::Value *i16Unused =
+ llvm::ConstantInt::get(builder.getInt16Ty(), 0);
+ args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
+
+ // Cache hint, if available.
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(builder.getInt64Ty(), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+
+ // Flag arguments for multicast and cachehint.
+ args.push_back(builder.getInt1(hasMulticastMask));
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ llvm::Intrinsic::ID id =
+ llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
+
+ return {id, std::move(args)};
+}
+
mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This patch refactors NVVM_CpAsyncBulkGlobalToSharedClusterOp's lowering by implementing getIntrinsicIDAndArgs for this op. Signed-off-by: Dharuni R Acharya <dharunira@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks for the clean-up!
9627b5c
to
1d7d83b
Compare
This patch moves the lowering code of the
NVVM_CpAsyncBulkGlobalToSharedClusterOp
fromtd
file to NVVMDialect.cpp file. This makes it consistent with the lowering of other TMA Ops in NVVM Dialect.