diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp index 3f4ec4f3bccc8..eb6434b274663 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp @@ -812,54 +812,59 @@ class ReductionElementalConversion : public mlir::OpRewritePattern { // inlined elemental. // %e = hlfir.elemental %shape ({ ... }) // %m = hlfir.minloc %array mask %e -class MinMaxlocElementalConversion - : public mlir::OpRewritePattern { +template +class MinMaxlocElementalConversion : public mlir::OpRewritePattern { public: - using mlir::OpRewritePattern::OpRewritePattern; + using mlir::OpRewritePattern::OpRewritePattern; mlir::LogicalResult - matchAndRewrite(hlfir::MinlocOp minloc, - mlir::PatternRewriter &rewriter) const override { - if (!minloc.getMask() || minloc.getDim() || minloc.getBack()) - return rewriter.notifyMatchFailure(minloc, "Did not find valid minloc"); + matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override { + if (!mloc.getMask() || mloc.getDim() || mloc.getBack()) + return rewriter.notifyMatchFailure(mloc, + "Did not find valid minloc/maxloc"); - auto elemental = minloc.getMask().getDefiningOp(); + constexpr bool isMax = std::is_same_v; + + auto elemental = + mloc.getMask().template getDefiningOp(); if (!elemental || hlfir::elementalOpMustProduceTemp(elemental)) - return rewriter.notifyMatchFailure(minloc, "Did not find elemental"); + return rewriter.notifyMatchFailure(mloc, "Did not find elemental"); - mlir::Value array = minloc.getArray(); + mlir::Value array = mloc.getArray(); - unsigned rank = mlir::cast(minloc.getType()).getShape()[0]; + unsigned rank = mlir::cast(mloc.getType()).getShape()[0]; mlir::Type arrayType = array.getType(); if (!arrayType.isa()) return rewriter.notifyMatchFailure( - minloc, "Currently requires a boxed type input"); + mloc, "Currently requires a boxed type input"); mlir::Type elementType = hlfir::getFortranElementType(arrayType); if (!fir::isa_trivial(elementType)) return rewriter.notifyMatchFailure( - minloc, "Character arrays are currently not handled"); + mloc, "Character arrays are currently not handled"); - mlir::Location loc = minloc.getLoc(); - fir::FirOpBuilder builder{rewriter, minloc.getOperation()}; + mlir::Location loc = mloc.getLoc(); + fir::FirOpBuilder builder{rewriter, mloc.getOperation()}; mlir::Value resultArr = builder.createTemporary( loc, fir::SequenceType::get( - rank, hlfir::getFortranElementType(minloc.getType()))); + rank, hlfir::getFortranElementType(mloc.getType()))); - auto init = [](fir::FirOpBuilder builder, mlir::Location loc, - mlir::Type elementType) { + auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc, + mlir::Type elementType) { if (auto ty = elementType.dyn_cast()) { const llvm::fltSemantics &sem = ty.getFloatSemantics(); return builder.createRealConstant( loc, elementType, - llvm::APFloat::getLargest(sem, /*Negative=*/false)); + llvm::APFloat::getLargest(sem, /*Negative=*/!isMax)); } unsigned bits = elementType.getIntOrFloatBitWidth(); - int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); - return builder.createIntegerConstant(loc, elementType, maxInt); + int64_t limitInt = + isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue() + : llvm::APInt::getSignedMaxValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, elementType, limitInt); }; auto genBodyOp = - [&rank, &resultArr, &elemental]( + [&rank, &resultArr, &elemental, isMax]( fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType, mlir::Value array, mlir::Value flagRef, mlir::Value reduction, @@ -899,10 +904,16 @@ class MinMaxlocElementalConversion mlir::Value cmp; if (elementType.isa()) { cmp = builder.create( - loc, mlir::arith::CmpFPredicate::OLT, elem, reduction); + loc, + isMax ? mlir::arith::CmpFPredicate::OGT + : mlir::arith::CmpFPredicate::OLT, + elem, reduction); } else if (elementType.isa()) { cmp = builder.create( - loc, mlir::arith::CmpIPredicate::slt, elem, reduction); + loc, + isMax ? mlir::arith::CmpIPredicate::sgt + : mlir::arith::CmpIPredicate::slt, + elem, reduction); } else { llvm_unreachable("unsupported type"); } @@ -975,15 +986,15 @@ class MinMaxlocElementalConversion // AsExpr for the temporary resultArr. llvm::SmallVector destroys; llvm::SmallVector assigns; - for (auto user : minloc->getUsers()) { + for (auto user : mloc->getUsers()) { if (auto destroy = mlir::dyn_cast(user)) destroys.push_back(destroy); else if (auto assign = mlir::dyn_cast(user)) assigns.push_back(assign); } - // Check if the minloc was the only user of the elemental (apart from a - // destroy), and remove it if so. + // Check if the minloc/maxloc was the only user of the elemental (apart from + // a destroy), and remove it if so. mlir::Operation::user_range elemUsers = elemental->getUsers(); hlfir::DestroyOp elemDestroy; if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) { @@ -996,7 +1007,7 @@ class MinMaxlocElementalConversion rewriter.eraseOp(d); for (auto a : assigns) a.setOperand(0, resultArr); - rewriter.replaceOp(minloc, asExpr); + rewriter.replaceOp(mloc, asExpr); if (elemDestroy) { rewriter.eraseOp(elemDestroy); rewriter.eraseOp(elemental); @@ -1030,7 +1041,8 @@ class OptimizedBufferizationPass patterns.insert>(context); patterns.insert>(context); patterns.insert>(context); - patterns.insert(context); + patterns.insert>(context); + patterns.insert>(context); if (mlir::failed(mlir::applyPatternsAndFoldGreedily( func, std::move(patterns), config))) { diff --git a/flang/test/HLFIR/maxloc-elemental.fir b/flang/test/HLFIR/maxloc-elemental.fir new file mode 100644 index 0000000000000..67cd9ee4bb75a --- /dev/null +++ b/flang/test/HLFIR/maxloc-elemental.fir @@ -0,0 +1,140 @@ +// RUN: fir-opt %s -opt-bufferization | FileCheck %s + +func.func @_QPtest(%arg0: !fir.box> {fir.bindc_name = "array"}, %arg1: !fir.ref {fir.bindc_name = "val"}, %arg2: !fir.box> {fir.bindc_name = "m"}) { + %c0 = arith.constant 0 : index + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %1:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %2:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %3 = fir.load %2#0 : !fir.ref + %4:3 = fir.box_dims %0#0, %c0 : (!fir.box>, index) -> (index, index, index) + %5 = fir.shape %4#1 : (index) -> !fir.shape<1> + %6 = hlfir.elemental %5 unordered : (!fir.shape<1>) -> !hlfir.expr> { + ^bb0(%arg3: index): + %8 = hlfir.designate %0#0 (%arg3) : (!fir.box>, index) -> !fir.ref + %9 = fir.load %8 : !fir.ref + %10 = arith.cmpi sge, %9, %3 : i32 + %11 = fir.convert %10 : (i1) -> !fir.logical<4> + hlfir.yield_element %11 : !fir.logical<4> + } + %7 = hlfir.maxloc %0#0 mask %6 {fastmath = #arith.fastmath} : (!fir.box>, !hlfir.expr>) -> !hlfir.expr<1xi32> + hlfir.assign %7 to %1#0 : !hlfir.expr<1xi32>, !fir.box> + hlfir.destroy %7 : !hlfir.expr<1xi32> + hlfir.destroy %6 : !hlfir.expr> + return +} +// CHECK-LABEL: func.func @_QPtest(%arg0: !fir.box> {fir.bindc_name = "array"}, %arg1: !fir.ref {fir.bindc_name = "val"}, %arg2: !fir.box> {fir.bindc_name = "m"}) { +// CHECK-NEXT: %c-2147483648_i32 = arith.constant -2147483648 : i32 +// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %[[V0:.*]] = fir.alloca i32 +// CHECK-NEXT: %[[RES:.*]] = fir.alloca !fir.array<1xi32> +// CHECK-NEXT: %[[V1:.*]]:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box>) -> (!fir.box>, !fir.box>) +// CHECK-NEXT: %[[V2:.*]]:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box>) -> (!fir.box>, !fir.box>) +// CHECK-NEXT: %[[V3:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref) -> (!fir.ref, !fir.ref) +// CHECK-NEXT: %[[V4:.*]] = fir.load %[[V3]]#0 : !fir.ref +// CHECK-NEXT: %[[V8:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref>, index) -> !fir.ref +// CHECK-NEXT: fir.store %c0_i32 to %[[V8]] : !fir.ref +// CHECK-NEXT: fir.store %c0_i32 to %[[V0]] : !fir.ref +// CHECK-NEXT: %[[V9:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box>, index) -> (index, index, index) +// CHECK-NEXT: %[[V10:.*]] = arith.subi %[[V9]]#1, %c1 : index +// CHECK-NEXT: %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10]] step %c1 iter_args(%arg4 = %c-2147483648_i32) -> (i32) { +// CHECK-NEXT: %[[V14:.*]] = arith.addi %arg3, %c1 : index +// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V1]]#0 (%[[V14]]) : (!fir.box>, index) -> !fir.ref +// CHECK-NEXT: %[[V16:.*]] = fir.load %[[V15]] : !fir.ref +// CHECK-NEXT: %[[V17:.*]] = arith.cmpi sge, %[[V16]], %[[V4]] : i32 +// CHECK-NEXT: %[[V18:.*]] = fir.if %[[V17]] -> (i32) { +// CHECK-NEXT: fir.store %c1_i32 to %[[V0]] : !fir.ref +// CHECK-NEXT: %[[DIMS:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box>, index) -> (index, index, index) +// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index +// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index +// CHECK-NEXT: %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box>, index) -> !fir.ref +// CHECK-NEXT: %[[V20:.*]] = fir.load %[[V19]] : !fir.ref +// CHECK-NEXT: %[[V21:.*]] = arith.cmpi sgt, %[[V20]], %arg4 : i32 +// CHECK-NEXT: %[[V22:.*]] = fir.if %[[V21]] -> (i32) { +// CHECK-NEXT: %[[V23:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref>, index) -> !fir.ref +// CHECK-NEXT: %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32 +// CHECK-NEXT: fir.store %[[V24]] to %[[V23]] : !fir.ref +// CHECK-NEXT: fir.result %[[V20]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: fir.result %arg4 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: fir.result %[[V22]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: fir.result %arg4 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: fir.result %[[V18]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: %[[V12:.*]] = fir.load %[[V0]] : !fir.ref +// CHECK-NEXT: %[[V13:.*]] = arith.cmpi eq, %[[V12]], %c1_i32 : i32 +// CHECK-NEXT: fir.if %[[V13]] { +// CHECK-NEXT: %[[V14:.*]] = arith.cmpi eq, %[[V11]], %c-2147483648_i32 : i32 +// CHECK-NEXT: fir.if %[[V14]] { +// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref>, index) -> !fir.ref +// CHECK-NEXT: fir.store %c1_i32 to %[[V15]] : !fir.ref +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: %[[BD:.*]]:3 = fir.box_dims %[[V2]]#0, %c0 : (!fir.box>, index) -> (index, index, index) +// CHECK-NEXT: fir.do_loop %arg3 = %c1 to %[[BD]]#1 step %c1 unordered { +// CHECK-NEXT: %[[V13:.*]] = hlfir.designate %[[RES]] (%arg3) : (!fir.ref>, index) -> !fir.ref +// CHECK-NEXT: %[[V14:.*]] = fir.load %[[V13]] : !fir.ref +// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V2]]#0 (%arg3) : (!fir.box>, index) -> !fir.ref +// CHECK-NEXT: hlfir.assign %[[V14]] to %[[V15]] : i32, !fir.ref +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + + + +func.func @_QPtest_float(%arg0: !fir.box> {fir.bindc_name = "array"}, %arg1: !fir.ref {fir.bindc_name = "val"}, %arg2: !fir.box> {fir.bindc_name = "m"}) { + %c0 = arith.constant 0 : index + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFtestEarray"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %1:2 = hlfir.declare %arg2 {uniq_name = "_QFtestEm"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %2:2 = hlfir.declare %arg1 {uniq_name = "_QFtestEval"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %3 = fir.load %2#0 : !fir.ref + %4:3 = fir.box_dims %0#0, %c0 : (!fir.box>, index) -> (index, index, index) + %5 = fir.shape %4#1 : (index) -> !fir.shape<1> + %6 = hlfir.elemental %5 unordered : (!fir.shape<1>) -> !hlfir.expr> { + ^bb0(%arg3: index): + %8 = hlfir.designate %0#0 (%arg3) : (!fir.box>, index) -> !fir.ref + %9 = fir.load %8 : !fir.ref + %10 = arith.cmpf oge, %9, %3 : f32 + %11 = fir.convert %10 : (i1) -> !fir.logical<4> + hlfir.yield_element %11 : !fir.logical<4> + } + %7 = hlfir.maxloc %0#0 mask %6 {fastmath = #arith.fastmath} : (!fir.box>, !hlfir.expr>) -> !hlfir.expr<1xi32> + hlfir.assign %7 to %1#0 : !hlfir.expr<1xi32>, !fir.box> + hlfir.destroy %7 : !hlfir.expr<1xi32> + hlfir.destroy %6 : !hlfir.expr> + return +} +// CHECK-LABEL: _QPtest_float +// CHECK: %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10:.*]] step %c1 iter_args(%arg4 = %cst) -> (f32) { +// CHECK-NEXT: %[[V14:.*]] = arith.addi %arg3, %c1 : index +// CHECK-NEXT: %[[V15:.*]] = hlfir.designate %[[V1:.*]]#0 (%[[V14]]) : (!fir.box>, index) -> !fir.ref +// CHECK-NEXT: %[[V16:.*]] = fir.load %[[V15]] : !fir.ref +// CHECK-NEXT: %[[V17:.*]] = arith.cmpf oge, %[[V16]], %[[V4:.*]] : f32 +// CHECK-NEXT: %[[V18:.*]] = fir.if %[[V17]] -> (f32) { +// CHECK-NEXT: fir.store %c1_i32 to %[[V0:.*]] : !fir.ref +// CHECK-NEXT: %[[DIMS:.*]]:3 = fir.box_dims %2#0, %c0 : (!fir.box>, index) -> (index, index, index) +// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index +// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index +// CHECK-NEXT: %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box>, index) -> !fir.ref +// CHECK-NEXT: %[[V20:.*]] = fir.load %[[V19]] : !fir.ref +// CHECK-NEXT: %[[V21:.*]] = arith.cmpf ogt, %[[V20]], %arg4 fastmath : f32 +// CHECK-NEXT: %[[V22:.*]] = fir.if %[[V21]] -> (f32) { +// CHECK-NEXT: %[[V23:.*]] = hlfir.designate %{{.}} (%c1) : (!fir.ref>, index) -> !fir.ref +// CHECK-NEXT: %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32 +// CHECK-NEXT: fir.store %[[V24]] to %[[V23]] : !fir.ref +// CHECK-NEXT: fir.result %[[V20]] : f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: fir.result %arg4 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: fir.result %[[V22]] : f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: fir.result %arg4 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: fir.result %[[V18]] : f32 +// CHECK-NEXT: } +