From 3dd6b91f6c14f76ec14b8b1a75e87c14ca0a9291 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 20 Nov 2025 11:34:49 -0800 Subject: [PATCH] Extract element count computation into helper function --- .../flang/Optimizer/Builder/CUFCommon.h | 4 ++ flang/lib/Optimizer/Builder/CUFCommon.cpp | 41 +++++++++++++++++++ .../Optimizer/Transforms/CUFOpConversion.cpp | 27 +----------- 3 files changed, 47 insertions(+), 25 deletions(-) diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h index 6e2442745f9a0..98d01958846f7 100644 --- a/flang/include/flang/Optimizer/Builder/CUFCommon.h +++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h @@ -39,6 +39,10 @@ int computeElementByteSize(mlir::Location loc, mlir::Type type, fir::KindMapping &kindMap, bool emitErrorOnFailure = true); +mlir::Value computeElementCount(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Value shapeOperand, + mlir::Type seqType, mlir::Type targetType); + } // namespace cuf #endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_ diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp index 461deb8e4cb55..2266f4d47a0cf 100644 --- a/flang/lib/Optimizer/Builder/CUFCommon.cpp +++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp @@ -114,3 +114,44 @@ int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type, mlir::emitError(loc, "unsupported type"); return 0; } + +mlir::Value cuf::computeElementCount(mlir::PatternRewriter &rewriter, + mlir::Location loc, + mlir::Value shapeOperand, + mlir::Type seqType, + mlir::Type targetType) { + if (shapeOperand) { + // Dynamic extent - extract from shape operand + llvm::SmallVector extents; + if (auto shapeOp = + mlir::dyn_cast(shapeOperand.getDefiningOp())) { + extents = shapeOp.getExtents(); + } else if (auto shapeShiftOp = mlir::dyn_cast( + shapeOperand.getDefiningOp())) { + for (auto i : llvm::enumerate(shapeShiftOp.getPairs())) + if (i.index() & 1) + extents.push_back(i.value()); + } + + if (extents.empty()) + return mlir::Value(); + + // Compute total element count by multiplying all dimensions + mlir::Value count = + fir::ConvertOp::create(rewriter, loc, targetType, extents[0]); + for (unsigned i = 1; i < extents.size(); ++i) { + auto operand = + fir::ConvertOp::create(rewriter, loc, targetType, extents[i]); + count = mlir::arith::MulIOp::create(rewriter, loc, count, operand); + } + return count; + } else { + // Static extent - use constant array size + if (auto seqTy = mlir::dyn_cast_or_null(seqType)) { + mlir::IntegerAttr attr = + rewriter.getIntegerAttr(targetType, seqTy.getConstantArraySize()); + return mlir::arith::ConstantOp::create(rewriter, loc, targetType, attr); + } + } + return mlir::Value(); +} diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 5b1b0a2f6feab..caf9b7b8b38f2 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -651,31 +651,8 @@ struct CUFDataTransferOpConversion } mlir::Type i64Ty = builder.getI64Type(); - mlir::Value nbElement; - if (op.getShape()) { - llvm::SmallVector extents; - if (auto shapeOp = - mlir::dyn_cast(op.getShape().getDefiningOp())) { - extents = shapeOp.getExtents(); - } else if (auto shapeShiftOp = mlir::dyn_cast( - op.getShape().getDefiningOp())) { - for (auto i : llvm::enumerate(shapeShiftOp.getPairs())) - if (i.index() & 1) - extents.push_back(i.value()); - } - - nbElement = fir::ConvertOp::create(rewriter, loc, i64Ty, extents[0]); - for (unsigned i = 1; i < extents.size(); ++i) { - auto operand = - fir::ConvertOp::create(rewriter, loc, i64Ty, extents[i]); - nbElement = - mlir::arith::MulIOp::create(rewriter, loc, nbElement, operand); - } - } else { - if (auto seqTy = mlir::dyn_cast_or_null(dstTy)) - nbElement = builder.createIntegerConstant( - loc, i64Ty, seqTy.getConstantArraySize()); - } + mlir::Value nbElement = + cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty); unsigned width = 0; if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) { mlir::Type structTy =