New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flang] Generate inline reduction loops for elemental count intrinsics #75774
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: David Green (davemgreen) ChangesThis adds a ReductionElementalConversion transform to OptimizedBufferizationPass, taking hlfir::count(hlfir::elemental) and generating the inline loop to perform the count of true elements. This lets us generate a single loop instead of ending up as two plus a temporary. This is currently part of OptimizedBufferization, similar to #74828. I attempted to move it to LowerHLFIRIntrinsics to make it part of the existing lowering, but it hit problems with inlining elementals that contain operations that are being legalized by the same pass. Any and All should be able to share the same code with a different function/initial value. Patch is 26.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75774.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 7abfa20493c736..59a3a433d5d182 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -659,6 +659,124 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
return mlir::success();
}
+using GenBodyFn =
+ std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
+ const llvm::SmallVectorImpl<mlir::Value> &)>;
+static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Value init,
+ mlir::Value shape, GenBodyFn genBody) {
+ auto extents = hlfir::getIndexExtents(loc, builder, shape);
+ mlir::Value reduction = init;
+ mlir::IndexType idxTy = builder.getIndexType();
+ mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+
+ // Create a reduction loop nest. We use one-based indices so that they can be
+ // passed to the elemental.
+ llvm::SmallVector<mlir::Value> indices;
+ for (unsigned i = 0; i < extents.size(); ++i) {
+ auto loop =
+ builder.create<fir::DoLoopOp>(loc, oneIdx, extents[i], oneIdx, false,
+ /*finalCountValue=*/false, reduction);
+ reduction = loop.getRegionIterArgs()[0];
+ indices.push_back(loop.getInductionVar());
+ // Set insertion point to the loop body so that the next loop
+ // is inserted inside the current one.
+ builder.setInsertionPointToStart(loop.getBody());
+ }
+
+ // Generate the body
+ reduction = genBody(builder, loc, reduction, indices);
+
+ // Unwind the loop nest.
+ for (unsigned i = 0; i < extents.size(); ++i) {
+ auto result = builder.create<fir::ResultOp>(loc, reduction);
+ auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+ reduction = loop.getResult(0);
+ // Set insertion point after the loop operation that we have
+ // just processed.
+ builder.setInsertionPointAfter(loop.getOperation());
+ }
+
+ return reduction;
+}
+
+/// Given a reduction operation with an elemental mask, attempt to generate a
+/// do-loop to perform the operation inline.
+/// %e = hlfir.elemental %shape unordered
+/// %r = hlfir.count %e
+/// =>
+/// %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init) {
+/// <inline elemental>
+/// <reduce count>
+/// }
+template <typename Op>
+class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
+public:
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
+ mlir::Location loc = op.getLoc();
+ hlfir::ElementalOp elemental =
+ op.getMask().template getDefiningOp<hlfir::ElementalOp>();
+ if (!elemental || op.getDim())
+ return rewriter.notifyMatchFailure(op, "Did not find valid elemental");
+
+ fir::KindMapping kindMap =
+ fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
+ fir::FirOpBuilder builder{op, kindMap};
+
+ mlir::Value init;
+ GenBodyFn genBodyFn;
+ if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
+ init = builder.createIntegerConstant(loc, op.getType(), 0);
+ genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
+ mlir::Value reduction,
+ const llvm::SmallVectorImpl<mlir::Value> &indices)
+ -> mlir::Value {
+ // Inline the elemental and get the condition from it.
+ auto yield = inlineElementalOp(loc, builder, elemental, indices);
+ mlir::Value cond = builder.create<fir::ConvertOp>(
+ loc, builder.getI1Type(), yield.getElementValue());
+ yield->erase();
+
+ // Conditionally add one to the current value
+ mlir::Value one =
+ builder.createIntegerConstant(loc, reduction.getType(), 1);
+ mlir::Value add1 =
+ builder.create<mlir::arith::AddIOp>(loc, reduction, one);
+ return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
+ reduction);
+ };
+ } else {
+ static_assert("Expected Op to be handled");
+ return mlir::failure();
+ }
+
+ mlir::Value res = generateReductionLoop(builder, loc, init,
+ elemental.getOperand(0), genBodyFn);
+ if (res.getType() != op.getType())
+ res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
+
+ // Check if the op was the only user of the elemental (apart from a
+ // destroy), and remove it if so.
+ mlir::Operation::user_range elemUsers = elemental->getUsers();
+ hlfir::DestroyOp elemDestroy;
+ if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
+ elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
+ if (!elemDestroy)
+ elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
+ }
+
+ rewriter.replaceOp(op, res);
+ if (elemDestroy) {
+ rewriter.eraseOp(elemDestroy);
+ rewriter.eraseOp(elemental);
+ }
+ return mlir::success();
+ }
+};
+
class OptimizedBufferizationPass
: public hlfir::impl::OptimizedBufferizationBase<
OptimizedBufferizationPass> {
@@ -681,6 +799,7 @@ class OptimizedBufferizationPass
patterns.insert<ElementalAssignBufferization>(context);
patterns.insert<BroadcastAssignBufferization>(context);
patterns.insert<VariableAssignBufferization>(context);
+ patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
func, std::move(patterns), config))) {
diff --git a/flang/test/HLFIR/count-elemental.fir b/flang/test/HLFIR/count-elemental.fir
new file mode 100644
index 00000000000000..2be0eb17de1d88
--- /dev/null
+++ b/flang/test/HLFIR/count-elemental.fir
@@ -0,0 +1,314 @@
+// RUN: fir-opt %s -opt-bufferization | FileCheck %s
+
+func.func @_QFPtest(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c7 = arith.constant 7 : index
+ %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+ %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+ %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %3 = fir.alloca i32 {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+ %4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %6 = fir.load %2#0 : !fir.ref<i32>
+ %7 = fir.convert %6 : (i32) -> i64
+ %8 = fir.shape %c7 : (index) -> !fir.shape<1>
+ %9 = hlfir.designate %1#0 (%7, %c1:%c7:%c1) shape %8 : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+ %10 = fir.load %5#0 : !fir.ref<i32>
+ %11 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7x!fir.logical<4>> {
+ ^bb0(%arg3: index):
+ %14 = hlfir.designate %9 (%arg3) : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+ %15 = fir.load %14 : !fir.ref<i32>
+ %16 = arith.cmpi sge, %15, %10 : i32
+ %17 = fir.convert %16 : (i1) -> !fir.logical<4>
+ hlfir.yield_element %17 : !fir.logical<4>
+ }
+ %12 = hlfir.count %11 : (!hlfir.expr<7x!fir.logical<4>>) -> i32
+ hlfir.assign %12 to %4#0 : i32, !fir.ref<i32>
+ hlfir.destroy %11 : !hlfir.expr<7x!fir.logical<4>>
+ %13 = fir.load %4#1 : !fir.ref<i32>
+ return %13 : i32
+}
+// CHECK-LABEL: func.func @_QFPtest(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
+// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
+// CHECK-NEXT: %c1 = arith.constant 1 : index
+// CHECK-NEXT: %c4 = arith.constant 4 : index
+// CHECK-NEXT: %c7 = arith.constant 7 : index
+// CHECK-NEXT: %[[V0:.*]] = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+// CHECK-NEXT: %[[V1:.*]]:2 = hlfir.declare %arg0(%[[V0]])
+// CHECK-NEXT: %[[V2:.*]]:2 = hlfir.declare %arg1
+// CHECK-NEXT: %[[V3:.*]] = fir.alloca i32
+// CHECK-NEXT: %[[V4:.*]]:2 = hlfir.declare %[[V3]]
+// CHECK-NEXT: %[[V5:.*]]:2 = hlfir.declare %arg2
+// CHECK-NEXT: %[[V6:.*]] = fir.load %[[V2]]#0 : !fir.ref<i32>
+// CHECK-NEXT: %[[V7:.*]] = fir.convert %[[V6]] : (i32) -> i64
+// CHECK-NEXT: %[[V8:.*]] = fir.shape %c7 : (index) -> !fir.shape<1>
+// CHECK-NEXT: %[[V9:.*]] = hlfir.designate %[[V1]]#0 (%[[V7]], %c1:%c7:%c1) shape %[[V8]] : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+// CHECK-NEXT: %[[V10:.*]] = fir.load %[[V5]]#0 : !fir.ref<i32>
+// CHECK-NEXT: %[[V11:.*]] = fir.do_loop %arg3 = %c1 to %c7 step %c1 iter_args(%arg4 = %c0_i32) -> (i32) {
+// CHECK-NEXT: %[[V13:.*]] = hlfir.designate %[[V9]] (%arg3) : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT: %[[V14:.*]] = fir.load %[[V13]] : !fir.ref<i32>
+// CHECK-NEXT: %[[V15:.*]] = arith.cmpi sge, %[[V14]], %[[V10]] : i32
+// CHECK-NEXT: %[[V16:.*]] = arith.addi %arg4, %c1_i32 : i32
+// CHECK-NEXT: %[[V17:.*]] = arith.select %[[V15]], %[[V16]], %arg4 : i32
+// CHECK-NEXT: fir.result %[[V17]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: hlfir.assign %[[V11]] to %[[V4]]#0 : i32, !fir.ref<i32>
+// CHECK-NEXT: %[[V12:.*]] = fir.load %[[V4]]#1 : !fir.ref<i32>
+// CHECK-NEXT: return %[[V12]] : i32
+
+func.func @_QFPtest_kind2(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i16 {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c7 = arith.constant 7 : index
+ %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+ %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+ %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %3 = fir.alloca i16 {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+ %4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<i16>) -> (!fir.ref<i16>, !fir.ref<i16>)
+ %5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %6 = fir.load %2#0 : !fir.ref<i32>
+ %7 = fir.convert %6 : (i32) -> i64
+ %8 = fir.shape %c7 : (index) -> !fir.shape<1>
+ %9 = hlfir.designate %1#0 (%7, %c1:%c7:%c1) shape %8 : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+ %10 = fir.load %5#0 : !fir.ref<i32>
+ %11 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7x!fir.logical<4>> {
+ ^bb0(%arg3: index):
+ %14 = hlfir.designate %9 (%arg3) : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+ %15 = fir.load %14 : !fir.ref<i32>
+ %16 = arith.cmpi sge, %15, %10 : i32
+ %17 = fir.convert %16 : (i1) -> !fir.logical<4>
+ hlfir.yield_element %17 : !fir.logical<4>
+ }
+ %12 = hlfir.count %11 : (!hlfir.expr<7x!fir.logical<4>>) -> i16
+ hlfir.assign %12 to %4#0 : i16, !fir.ref<i16>
+ hlfir.destroy %11 : !hlfir.expr<7x!fir.logical<4>>
+ %13 = fir.load %4#1 : !fir.ref<i16>
+ return %13 : i16
+}
+// CHECK-LABEL: func.func @_QFPtest_kind2(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i16 {
+// CHECK-NEXT: %c1_i16 = arith.constant 1 : i16
+// CHECK-NEXT: %c0_i16 = arith.constant 0 : i16
+// CHECK-NEXT: %c1 = arith.constant 1 : index
+// CHECK-NEXT: %c4 = arith.constant 4 : index
+// CHECK-NEXT: %c7 = arith.constant 7 : index
+// CHECK-NEXT: %[[V0:.*]] = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+// CHECK-NEXT: %[[V1:.*]]:2 = hlfir.declare %arg0(%[[V0]])
+// CHECK-NEXT: %[[V2:.*]]:2 = hlfir.declare %arg1
+// CHECK-NEXT: %[[V3:.*]] = fir.alloca i16
+// CHECK-NEXT: %[[V4:.*]]:2 = hlfir.declare %[[V3]]
+// CHECK-NEXT: %[[V5:.*]]:2 = hlfir.declare %arg2
+// CHECK-NEXT: %[[V6:.*]] = fir.load %[[V2]]#0 : !fir.ref<i32>
+// CHECK-NEXT: %[[V7:.*]] = fir.convert %[[V6]] : (i32) -> i64
+// CHECK-NEXT: %[[V8:.*]] = fir.shape %c7 : (index) -> !fir.shape<1>
+// CHECK-NEXT: %[[V9:.*]] = hlfir.designate %[[V1]]#0 (%[[V7]], %c1:%c7:%c1) shape %[[V8]] : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+// CHECK-NEXT: %[[V10:.*]] = fir.load %[[V5]]#0 : !fir.ref<i32>
+// CHECK-NEXT: %[[V11:.*]] = fir.do_loop %arg3 = %c1 to %c7 step %c1 iter_args(%arg4 = %c0_i16) -> (i16) {
+// CHECK-NEXT: %[[V13:.*]] = hlfir.designate %[[V9]] (%arg3) : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT: %[[V14:.*]] = fir.load %[[V13]] : !fir.ref<i32>
+// CHECK-NEXT: %[[V15:.*]] = arith.cmpi sge, %[[V14]], %[[V10]] : i32
+// CHECK-NEXT: %[[V16:.*]] = arith.addi %arg4, %c1_i16 : i16
+// CHECK-NEXT: %[[V17:.*]] = arith.select %[[V15]], %[[V16]], %arg4 : i16
+// CHECK-NEXT: fir.result %[[V17]] : i16
+// CHECK-NEXT: }
+// CHECK-NEXT: hlfir.assign %[[V11]] to %[[V4]]#0 : i16, !fir.ref<i16>
+// CHECK-NEXT: %[[V12:.*]] = fir.load %[[V4]]#1 : !fir.ref<i16>
+// CHECK-NEXT: return %[[V12]] : i16
+
+func.func @_QFPtest_dim(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.array<7xi32> {
+ %c1_i32 = arith.constant 1 : i32
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c7 = arith.constant 7 : index
+ %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+ %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+ %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %3 = fir.alloca !fir.array<7xi32> {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+ %4 = fir.shape %c7 : (index) -> !fir.shape<1>
+ %5:2 = hlfir.declare %3(%4) {uniq_name = "_QFFtestEtest"} : (!fir.ref<!fir.array<7xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<7xi32>>, !fir.ref<!fir.array<7xi32>>)
+ %6:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %7 = hlfir.designate %1#0 (%c1:%c4:%c1, %c1:%c7:%c1) shape %0 : (!fir.ref<!fir.array<4x7xi32>>, index, index, index, index, index, index, !fir.shape<2>) -> !fir.ref<!fir.array<4x7xi32>>
+ %8 = fir.load %6#0 : !fir.ref<i32>
+ %9 = hlfir.elemental %0 unordered : (!fir.shape<2>) -> !hlfir.expr<4x7x!fir.logical<4>> {
+ ^bb0(%arg3: index, %arg4: index):
+ %12 = hlfir.designate %7 (%arg3, %arg4) : (!fir.ref<!fir.array<4x7xi32>>, index, index) -> !fir.ref<i32>
+ %13 = fir.load %12 : !fir.ref<i32>
+ %14 = arith.cmpi sge, %13, %8 : i32
+ %15 = fir.convert %14 : (i1) -> !fir.logical<4>
+ hlfir.yield_element %15 : !fir.logical<4>
+ }
+ %10 = hlfir.count %9 dim %c1_i32 : (!hlfir.expr<4x7x!fir.logical<4>>, i32) -> !hlfir.expr<7xi32>
+ hlfir.assign %10 to %5#0 : !hlfir.expr<7xi32>, !fir.ref<!fir.array<7xi32>>
+ hlfir.destroy %10 : !hlfir.expr<7xi32>
+ hlfir.destroy %9 : !hlfir.expr<4x7x!fir.logical<4>>
+ %11 = fir.load %5#1 : !fir.ref<!fir.array<7xi32>>
+ return %11 : !fir.array<7xi32>
+}
+// CHECK-LABEL: func.func @_QFPtest_dim(
+// CHECK: %{{.*}} = hlfir.count %{{.*}} dim %c1_i32
+
+
+func.func @_QFPtest_multi(%arg0: !fir.ref<!fir.array<4x7x2xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c7 = arith.constant 7 : index
+ %c2 = arith.constant 2 : index
+ %0 = fir.shape %c4, %c7, %c2 : (index, index, index) -> !fir.shape<3>
+ %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7x2xi32>>, !fir.shape<3>) -> (!fir.ref<!fir.array<4x7x2xi32>>, !fir.ref<!fir.array<4x7x2xi32>>)
+ %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %3 = fir.alloca i32 {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+ %4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %6 = hlfir.designate %1#0 (%c1:%c4:%c1, %c1:%c7:%c1, %c1:%c2:%c1) shape %0 : (!fir.ref<!fir.array<4x7x2xi32>>, index, index, index, index, index, index, index, index, index, !fir.shape<3>) -> !fir.ref<!fir.array<4x7x2xi32>>
+ %7 = fir.load %5#0 : !fir.ref<i32>
+ %8 = hlfir.elemental %0 unordered : (!fir.shape<3>) -> !hlfir.expr<4x7x2x!fir.logical<4>> {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index):
+ %11 = hlfir.designate %6 (%arg3, %arg4, %arg5) : (!fir.ref<!fir.array<4x7x2xi32>>, index, index, index) -> !fir.ref<i32>
+ %12 = fir.load %11 : !fir.ref<i32>
+ %13 = arith.cmpi sge, %12, %7 : i32
+ %14 = fir.convert %13 : (i1) -> !fir.logical<4>
+ hlfir.yield_element %14 : !fir.logical<4>
+ }
+ %9 = hlfir.count %8 : (!hlfir.expr<4x7x2x!fir.logical<4>>) -> i32
+ hlfir.assign %9 to %4#0 : i32, !fir.ref<i32>
+ hlfir.destroy %8 : !hlfir.expr<4x7x2x!fir.logical<4>>
+ %10 = fir.load %4#1 : !fir.ref<i32>
+ return %10 : i32
+}
+// CHECK-LABEL: func.func @_QFPtest_multi(%arg0: !fir.ref<!fir.array<4x7x2xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> i32 {
+// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32
+// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
+// CHECK-NEXT: %c1 = arith.constant 1 : index
+// CHECK-NEXT: %c4 = arith.constant 4 : index
+// CHECK-NEXT: %c7 = arith.constant 7 : index
+// CHECK-NEXT: %c2 = arith.constant 2 : index
+// CHECK-NEXT: %[[V0:.*]] = fir.shape %c4, %c7, %c2 : (index, index, index) -> !fir.shape<3>
+// CHECK-NEXT: %[[V1:.*]]:2 = hlfir.declare %arg0(%[[V0]]) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7x2xi32>>, !fir.shape<3>) -> (!fir.ref<!fir.array<4x7x2xi32>>, !fir.ref<!fir.array<4x7x2xi32>>)
+// CHECK-NEXT: %[[V2:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK-NEXT: %[[V3:.*]] = fir.alloca i32 {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+// CHECK-NEXT: %[[V4:.*]]:2 = hlfir.declare %[[V3]] {uniq_name = "_QFFtestEtest"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK-NEXT: %[[V5:.*]]:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK-NEXT: %[[V6:.*]] = hlfir.designate %[[V1]]#0 (%c1:%c4:%c1, %c1:%c7:%c1, %c1:%c2:%c1) shape %[[V0]] : (!fir.ref<!fir.array<4x7x2xi32>>, index, index, index, index, index, index, index, index, index, !fir.shape<3>) -> !fir.ref<!fir.array<4x7x2xi32>>
+// CHECK-NEXT: %[[V7:.*]] = fir.load %[[V5]]#0 : !fir.ref<i32>
+// CHECK-NEXT: %[[V8:.*]] = fir.do_loop %arg3 = %c1 to %c4 step %c1 iter_args(%arg4 = %c0_i32) -> (i32) {
+// CHECK-NEXT: %[[V10:.*]] = fir.do_loop %arg5 = %c1 to %c7 step %c1 iter_args(%arg6 = %arg4) -> (i32) {
+// CHECK-NEXT: %[[V11:.*]] = fir.do_loop %arg7 = %c1 to %c2 step %c1 iter_args(%arg8 = %arg6) -> (i32) {
+// CHECK-NEXT: %[[V12:.*]] = hlfir.designate %[[V6]] (%arg3, %arg5, %arg7) : (!fir.ref...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
1d0c5b2
to
f61d44f
Compare
Thanks for your work on this! I think creating operations that are legalized by the same pass should work. But you need to create the correct listeners. See
This makes sure the rest of the pass infrastructure realizes that the new operations were created so that it can then translate them as though they were there at the start of the pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
This looks good, but I agree with @jeanPerier that this would be better done with another hlfir.elemental and then however many elementals can be inlined together in the elemental inlining pass. But if there are good reasons why that would not be possible, this current implementation is good.
Hi. Could you explain more about what those elementals would look like? As far as I understood elementals acted on each element in a shape (a map step), where as these operations (count, any, all), would be reductions to a scalar. So the idea is that we already optimize any elementals together, then perform a reduction to a scalar with the code in this patch. Let me know if you think it should work differently. |
Ahh you are quite right. Apologies for the confusion. Please wait for Jean or Slava to take a look before merging. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside from the loop order, this looks good to me. My comment about rewriting reductions to hlfir.elemental is indeed only valid when there is a DIM argument, otherwise I think this needs to be done like you are doing.
flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Outdated
Show resolved
Hide resolved
f61d44f
to
943f1ea
Compare
Ah, that looks like it might work. That means I could move this to LowerHLFIRIntrinsics if that would be better. It would be similar code to this, with a name like Edit: It might involve changing OpRewritePattern to OpConversionPattern too. |
Ping. Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a basic build and ran check-flang, and all looked good. Then I ran our full internal test suited and got many failures. I'm currently doing more testing to narrow things down.
Please do not merge until you hear from me again.
Also, please let me know if you have access to the NAG test suite. If so, that will simplify our communication.
Ah I see. I have never heard of that, but might have access. I will see if I can give it a try. |
Arm has access to NAG tests. Company is here: https://nag.com/about-nag/ but their products are definitely not open source...
[https://nag.com/wp-content/uploads/2023/06/NAG-Founders.jpg]<https://nag.com/about-nag/>
About NAG - NAG<https://nag.com/about-nag/>
NAG provides industry-leading numerical software and technical services to banking and finance, energy, engineering, and more...
nag.com
I will try to out where it lives and send you a link in Slack (it's an internal Arm link, so not much point publishing it).
…--
Mats
________________________________
From: David Green ***@***.***>
Sent: 05 January 2024 19:31
To: llvm/llvm-project ***@***.***>
Cc: Mats Petersson ***@***.***>; Review requested ***@***.***>
Subject: Re: [llvm/llvm-project] [Flang] Generate inline reduction loops for elemental count intrinsics (PR #75774)
Ah I see. I have never heard of that, but might have access. I will see if I can give it a try.
—
Reply to this email directly, view it on GitHub<#75774 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/ABUDZBPAHUVJEX5H3OLOJSLYNBIHRAVCNFSM6AAAAABAZF7CHOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNZZGE3DGNRYGM>.
You are receiving this because your review was requested.Message ID: ***@***.***>< /p>
IMPORTANT NOTICE: The contents of this email and any attachments are confidential and may also be privileged. If you are not the intended recipient, please notify the sender immediately and do not disclose the contents to any other person, use it for any purpose, or store or copy the information in any medium. Thank you.
|
Hello. I've been attempting to test this in various cases of count I could think of, but not managed to hit any errors yet. It sounds like many cases are failing for you, so there must be something I am missing. Can you provide some extra info on what options you are using, whether it is a compile-time crash or a miscompile, and any details on the kind of code that is going wrong? Thanks |
My previous tests were a mismatch between the state of your fork and our internal tests. I've corrected that mismatch and re-run all of our tests, and they all look good. Note that I have little expertise in the actual content of this change, so someone with more expertise than I should approve. |
This adds a ReductionElementalConversion transform to OptimizedBufferizationPass, taking hlfir::count(hlfir::elemental) and generating the inline loop to perform the count of true elements. This lets us generate a single loop instead of ending up as two plus a temporary. This is currently part of OptimizedBufferization, similar to llvm#74828. I attempted to move it to LowerHLFIRIntrinsics to make it part of the existing lowering, but it hit problems with inlining elementals that contain operations that are being legalized by the same pass. Any and All should be able to share the same code with a different function/initial value.
943f1ea
to
a377173
Compare
Ah, OK. Thanks for checking. It's good to see it seems to be doing OK. I think this is hopefully OK to go then? It just needs @jeanPerier to take a look maybe. |
I think a blessing from either @jeanPerier or @tblah would do. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for solving my previous comment
Thanks. I gave this another test run of the llvm-test-suite and it seemed OK. Let me know if any issues arise. |
This is an extension of llvm#75774, with Any and All lowering added along side Count.
This is an extension of #75774, with Any and All lowering added alongside Count.
llvm#75774) This adds a ReductionElementalConversion transform to OptimizedBufferizationPass, taking hlfir::count(hlfir::elemental) and generating the inline loop to perform the count of true elements. This lets us generate a single loop instead of ending up as two plus a temporary. Any and All should be able to share the same code with a different function/initial value.
This is an extension of llvm#75774, with Any and All lowering added alongside Count.
This adds a ReductionElementalConversion transform to OptimizedBufferizationPass, taking hlfir::count(hlfir::elemental) and generating the inline loop to perform the count of true elements. This lets us generate a single loop instead of ending up as two plus a temporary.
This is currently part of OptimizedBufferization, similar to #74828. I attempted to move it to LowerHLFIRIntrinsics to make it part of the existing lowering, but it hit problems with inlining elementals that contain operations that are being legalized by the same pass.
Any and All should be able to share the same code with a different function/initial value.