-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[flang][cuda] Extract element count computation into helper function #168937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Zhen Wang (wangzpgi) ChangesThis patch extracts the common logic for computing array element counts from shape operands into a reusable helper function in CUFCommon. Full diff: https://github.com/llvm/llvm-project/pull/168937.diff 3 Files Affected:
diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h
index 6e2442745f9a0..98d01958846f7 100644
--- a/flang/include/flang/Optimizer/Builder/CUFCommon.h
+++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h
@@ -39,6 +39,10 @@ int computeElementByteSize(mlir::Location loc, mlir::Type type,
fir::KindMapping &kindMap,
bool emitErrorOnFailure = true);
+mlir::Value computeElementCount(mlir::PatternRewriter &rewriter,
+ mlir::Location loc, mlir::Value shapeOperand,
+ mlir::Type seqType, mlir::Type targetType);
+
} // namespace cuf
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp
index 461deb8e4cb55..2266f4d47a0cf 100644
--- a/flang/lib/Optimizer/Builder/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp
@@ -114,3 +114,44 @@ int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
mlir::emitError(loc, "unsupported type");
return 0;
}
+
+mlir::Value cuf::computeElementCount(mlir::PatternRewriter &rewriter,
+ mlir::Location loc,
+ mlir::Value shapeOperand,
+ mlir::Type seqType,
+ mlir::Type targetType) {
+ if (shapeOperand) {
+ // Dynamic extent - extract from shape operand
+ llvm::SmallVector<mlir::Value> extents;
+ if (auto shapeOp =
+ mlir::dyn_cast<fir::ShapeOp>(shapeOperand.getDefiningOp())) {
+ extents = shapeOp.getExtents();
+ } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
+ shapeOperand.getDefiningOp())) {
+ for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
+ if (i.index() & 1)
+ extents.push_back(i.value());
+ }
+
+ if (extents.empty())
+ return mlir::Value();
+
+ // Compute total element count by multiplying all dimensions
+ mlir::Value count =
+ fir::ConvertOp::create(rewriter, loc, targetType, extents[0]);
+ for (unsigned i = 1; i < extents.size(); ++i) {
+ auto operand =
+ fir::ConvertOp::create(rewriter, loc, targetType, extents[i]);
+ count = mlir::arith::MulIOp::create(rewriter, loc, count, operand);
+ }
+ return count;
+ } else {
+ // Static extent - use constant array size
+ if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(seqType)) {
+ mlir::IntegerAttr attr =
+ rewriter.getIntegerAttr(targetType, seqTy.getConstantArraySize());
+ return mlir::arith::ConstantOp::create(rewriter, loc, targetType, attr);
+ }
+ }
+ return mlir::Value();
+}
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 5b1b0a2f6feab..caf9b7b8b38f2 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -651,31 +651,8 @@ struct CUFDataTransferOpConversion
}
mlir::Type i64Ty = builder.getI64Type();
- mlir::Value nbElement;
- if (op.getShape()) {
- llvm::SmallVector<mlir::Value> extents;
- if (auto shapeOp =
- mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) {
- extents = shapeOp.getExtents();
- } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
- op.getShape().getDefiningOp())) {
- for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
- if (i.index() & 1)
- extents.push_back(i.value());
- }
-
- nbElement = fir::ConvertOp::create(rewriter, loc, i64Ty, extents[0]);
- for (unsigned i = 1; i < extents.size(); ++i) {
- auto operand =
- fir::ConvertOp::create(rewriter, loc, i64Ty, extents[i]);
- nbElement =
- mlir::arith::MulIOp::create(rewriter, loc, nbElement, operand);
- }
- } else {
- if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
- nbElement = builder.createIntegerConstant(
- loc, i64Ty, seqTy.getConstantArraySize());
- }
+ mlir::Value nbElement =
+ cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
unsigned width = 0;
if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
mlir::Type structTy =
|
clementval
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
🐧 Linux x64 Test Results
|
This patch extracts the common logic for computing array element counts from shape operands into a reusable helper function in CUFCommon.