diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp index f483651a68dc1..a11aa38c771bd 100644 --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -656,7 +656,7 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder, unsigned rank, int maskRank, mlir::Type elementType, mlir::Type maskElemType, - mlir::Type resultElemTy) { + mlir::Type resultElemTy, bool isDim) { auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType) { if (auto ty = elementType.dyn_cast()) { @@ -858,16 +858,27 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder, maskElemType, resultArr, maskRank == 0); // Store newly created output array to the reference passed in - fir::SequenceType::Shape resultShape(1, rank); - mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy); - mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy); - mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy); - mlir::Type outputRefTy = builder.getRefType(outputBoxTy); - mlir::Value outputArr = builder.create( - loc, outputRefTy, funcOp.front().getArgument(0)); - - // Store nearly created array to output array - builder.create(loc, resultArr, outputArr); + if (isDim) { + mlir::Type resultBoxTy = + fir::BoxType::get(fir::HeapType::get(resultElemTy)); + mlir::Value outputArr = builder.create( + loc, builder.getRefType(resultBoxTy), funcOp.front().getArgument(0)); + mlir::Value resultArrScalar = builder.create( + loc, fir::HeapType::get(resultElemTy), resultArrInit); + mlir::Value resultBox = + builder.create(loc, resultBoxTy, resultArrScalar); + builder.create(loc, resultBox, outputArr); + } else { + fir::SequenceType::Shape resultShape(1, rank); + mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy); + mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy); + mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy); + mlir::Type outputRefTy = builder.getRefType(outputBoxTy); + mlir::Value outputArr = builder.create( + loc, outputRefTy, funcOp.front().getArgument(0)); + builder.create(loc, resultArr, outputArr); + } + builder.create(loc); } @@ -1146,11 +1157,14 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction( mlir::Operation::operand_range args = call.getArgs(); - mlir::Value back = args[6]; + mlir::SymbolRefAttr callee = call.getCalleeAttr(); + mlir::StringRef funcNameBase = callee.getLeafReference().getValue(); + bool isDim = funcNameBase.ends_with("Dim"); + mlir::Value back = args[isDim ? 7 : 6]; if (isTrueOrNotConstant(back)) return; - mlir::Value mask = args[5]; + mlir::Value mask = args[isDim ? 6 : 5]; mlir::Value maskDef = findMaskDef(mask); // maskDef is set to NULL when the defining op is not one we accept. @@ -1159,10 +1173,8 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction( if (maskDef == NULL) return; - mlir::SymbolRefAttr callee = call.getCalleeAttr(); - mlir::StringRef funcNameBase = callee.getLeafReference().getValue(); unsigned rank = getDimCount(args[1]); - if (funcNameBase.ends_with("Dim") || !(rank > 0)) + if ((isDim && rank != 1) || !(rank > 0)) return; fir::FirOpBuilder builder{getSimplificationBuilder(call, kindMap)}; @@ -1203,22 +1215,24 @@ void SimplifyIntrinsicsPass::simplifyMinMaxlocReduction( llvm::raw_string_ostream nameOS(funcName); outType.print(nameOS); + if (isDim) + nameOS << '_' << inputType; nameOS << '_' << fmfString; auto typeGenerator = [rank](fir::FirOpBuilder &builder) { return genRuntimeMinlocType(builder, rank); }; auto bodyGenerator = [rank, maskRank, inputType, logicalElemType, outType, - isMax](fir::FirOpBuilder &builder, - mlir::func::FuncOp &funcOp) { + isMax, isDim](fir::FirOpBuilder &builder, + mlir::func::FuncOp &funcOp) { genRuntimeMinMaxlocBody(builder, funcOp, isMax, rank, maskRank, inputType, - logicalElemType, outType); + logicalElemType, outType, isDim); }; mlir::func::FuncOp newFunc = getOrCreateFunction(builder, funcName, typeGenerator, bodyGenerator); builder.create(loc, newFunc, - mlir::ValueRange{args[0], args[1], args[5]}); + mlir::ValueRange{args[0], args[1], mask}); call->dropAllReferences(); call->erase(); } diff --git a/flang/test/Transforms/simplifyintrinsics.fir b/flang/test/Transforms/simplifyintrinsics.fir index ce9f2dbd3e0fb..f21776e03ded8 100644 --- a/flang/test/Transforms/simplifyintrinsics.fir +++ b/flang/test/Transforms/simplifyintrinsics.fir @@ -2098,13 +2098,13 @@ func.func @_QPtestminloc_doesntwork1d_back(%arg0: !fir.ref> { // CHECK-NOT: fir.call @_FortranAMinlocInteger4x1_i32_contract_simplified({{.*}}) fastmath : (!fir.ref>, !fir.box, !fir.box) -> () // ----- -// Check Minloc is not simplified when DIM arg is set +// Check Minloc is simplified when DIM arg is set so long as the result is scalar -func.func @_QPtestminloc_doesntwork1d_dim(%arg0: !fir.ref> {fir.bindc_name = "a"}) -> !fir.array<1xi32> { +func.func @_QPtestminloc_1d_dim(%arg0: !fir.ref> {fir.bindc_name = "a"}) -> !fir.array<1xi32> { %0 = fir.alloca !fir.box> %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index - %1 = fir.alloca !fir.array<1xi32> {bindc_name = "testminloc_doesntwork1d_dim", uniq_name = "_QFtestminloc_doesntwork1d_dimEtestminloc_doesntwork1d_dim"} + %1 = fir.alloca !fir.array<1xi32> {bindc_name = "testminloc_1d_dim", uniq_name = "_QFtestminloc_1d_dimEtestminloc_1d_dim"} %2 = fir.shape %c1 : (index) -> !fir.shape<1> %3 = fir.array_load %1(%2) : (!fir.ref>, !fir.shape<1>) -> !fir.array<1xi32> %4 = fir.shape %c10 : (index) -> !fir.shape<1> @@ -2139,11 +2139,62 @@ func.func @_QPtestminloc_doesntwork1d_dim(%arg0: !fir.ref> {f %21 = fir.load %1 : !fir.ref> return %21 : !fir.array<1xi32> } -// CHECK-LABEL: func.func @_QPtestminloc_doesntwork1d_dim( +// CHECK-LABEL: func.func @_QPtestminloc_1d_dim( // CHECK-SAME: %[[ARR:.*]]: !fir.ref> {fir.bindc_name = "a"}) -> !fir.array<1xi32> { -// CHECK-NOT: fir.call @_FortranAMinlocDimx1_i32_contract_simplified({{.*}}) fastmath : (!fir.ref>, !fir.box, !fir.box) -> () -// CHECK: fir.call @_FortranAMinlocDim({{.*}}) fastmath : (!fir.ref>, !fir.box, i32, i32, !fir.ref, i32, !fir.box, i1) -> none -// CHECK-NOT: fir.call @_FortranAMinlocDimx1_i32_contract_simplified({{.*}}) fastmath : (!fir.ref>, !fir.box, !fir.box) -> () +// CHECK: fir.call @_FortranAMinlocDimx1_i32_i32_contract_simplified({{.*}}) fastmath : (!fir.ref>, !fir.box, !fir.box) -> () + +// CHECK-LABEL: func.func private @_FortranAMinlocDimx1_i32_i32_contract_simplified(%arg0: !fir.ref>, %arg1: !fir.box, %arg2: !fir.box) attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: %[[V0:.*]] = fir.alloca i32 +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %[[V1:.*]] = fir.allocmem !fir.array<1xi32> +// CHECK-NEXT: %[[V2:.*]] = fir.shape %c1 : (index) -> !fir.shape<1> +// CHECK-NEXT: %[[V3:.*]] = fir.embox %[[V1]](%[[V2]]) : (!fir.heap>, !fir.shape<1>) -> !fir.box>> +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %[[V4:.*]] = fir.coordinate_of %[[V3]], %c0 : (!fir.box>>, index) -> !fir.ref +// CHECK-NEXT: fir.store %c0_i32 to %[[V4]] : !fir.ref +// CHECK-NEXT: %c0_0 = arith.constant 0 : index +// CHECK-NEXT: %[[V5:.*]] = fir.convert %arg1 : (!fir.box) -> !fir.box> +// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 +// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32 +// CHECK-NEXT: fir.store %c0_i32_1 to %[[V0]] : !fir.ref +// CHECK-NEXT: %c2147483647_i32 = arith.constant 2147483647 : i32 +// CHECK-NEXT: %c1_2 = arith.constant 1 : index +// CHECK-NEXT: %c0_3 = arith.constant 0 : index +// CHECK-NEXT: %[[V6:.*]]:3 = fir.box_dims %[[V5]], %c0_3 : (!fir.box>, index) -> (index, index, index) +// CHECK-NEXT: %[[V7:.*]] = arith.subi %[[V6]]#1, %c1_2 : index +// CHECK-NEXT: %[[V8:.*]] = fir.do_loop %arg3 = %c0_0 to %[[V7]] step %c1_2 iter_args(%arg4 = %c2147483647_i32) -> (i32) { +// CHECK-NEXT: %c1_i32_4 = arith.constant 1 : i32 +// CHECK-NEXT: %[[ISFIRST:.*]] = fir.load %[[FLAG_ALLOC]] : !fir.ref +// CHECK-NEXT: %[[V12:.*]] = fir.coordinate_of %[[V5]], %arg3 : (!fir.box>, index) -> !fir.ref +// CHECK-NEXT: %[[V13:.*]] = fir.load %[[V12]] : !fir.ref +// CHECK-NEXT: %[[V14:.*]] = arith.cmpi slt, %[[V13]], %arg4 : i32 +// CHECK-NEXT: %[[ISFIRSTL:.*]] = fir.convert %[[ISFIRST]] : (i32) -> i1 +// CHECK-NEXT: %true = arith.constant true +// CHECK-NEXT: %[[ISFIRSTNOT:.*]] = arith.xori %[[ISFIRSTL]], %true : i1 +// CHECK-NEXT: %[[ORCOND:.*]] = arith.ori %[[V14]], %[[ISFIRSTNOT]] : i1 +// CHECK-NEXT: %[[V15:.*]] = fir.if %[[ORCOND]] -> (i32) { +// CHECK-NEXT: fir.store %c1_i32_4 to %[[V0]] : !fir.ref +// CHECK-NEXT: %c1_i32_5 = arith.constant 1 : i32 +// CHECK-NEXT: %c0_6 = arith.constant 0 : index +// CHECK-NEXT: %[[V16:.*]] = fir.coordinate_of %[[V3]], %c0_6 : (!fir.box>>, index) -> !fir.ref +// CHECK-NEXT: %[[V17:.*]] = fir.convert %arg3 : (index) -> i32 +// CHECK-NEXT: %[[V18:.*]] = arith.addi %[[V17]], %c1_i32_5 : i32 +// CHECK-NEXT: fir.store %[[V18]] to %[[V16]] : !fir.ref +// CHECK-NEXT: fir.result %[[V13]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: fir.result %arg4 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: fir.result %[[V15]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: %[[V11:.*]] = fir.convert %arg0 : (!fir.ref>) -> !fir.ref>> +// CHECK-NEXT: %[[V12:.*]] = fir.convert %[[V1]] : (!fir.heap>) -> !fir.heap +// CHECK-NEXT: %[[V13:.*]] = fir.embox %[[V12]] : (!fir.heap) -> !fir.box> +// CHECK-NEXT: fir.store %[[V13]] to %[[V11]] : !fir.ref>> +// CHECK-NEXT: return +// CHECK-NEXT: } + + // ----- // Check Minloc is not simplified when dimension of inputArr is unknown