From 7c57195c49c7847bd2fcab2a91dcdea178b60d42 Mon Sep 17 00:00:00 2001 From: Jacob Crawley Date: Tue, 2 May 2023 10:15:54 +0000 Subject: [PATCH] [flang][hlfir] lower hlfir.product into fir runtime call The shared code for lowering the sum and product operations in flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp have been moved into a new class HlfirReductionIntrinsicConverion. Depends on: D148719 Differential Revision: https://reviews.llvm.org/D149644 --- .../HLFIR/Transforms/LowerHLFIRIntrinsics.cpp | 51 ++++-- flang/test/HLFIR/product-lowering.fir | 171 ++++++++++++++++++ 2 files changed, 206 insertions(+), 16 deletions(-) create mode 100644 flang/test/HLFIR/product-lowering.fir diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp index 95192a3e79d49..fbf7670413df2 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -163,39 +163,56 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern { } }; -struct SumOpConversion : public HlfirIntrinsicConversion { - using HlfirIntrinsicConversion::HlfirIntrinsicConversion; +template +class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion { + using HlfirIntrinsicConversion::HlfirIntrinsicConversion; + using IntrinsicArgument = + typename HlfirIntrinsicConversion::IntrinsicArgument; +public: mlir::LogicalResult - matchAndRewrite(hlfir::SumOp sum, + matchAndRewrite(OP operation, mlir::PatternRewriter &rewriter) const override { + std::string opName; + if constexpr (std::is_same_v) { + opName = "sum"; + } else if constexpr (std::is_same_v) { + opName = "product"; + } else { + return mlir::failure(); + } fir::KindMapping kindMapping{rewriter.getContext()}; fir::FirOpBuilder builder{rewriter, kindMapping}; - const mlir::Location &loc = sum->getLoc(); + const mlir::Location &loc = operation->getLoc(); mlir::Type i32 = builder.getI32Type(); mlir::Type logicalType = fir::LogicalType::get( builder.getContext(), builder.getKindMap().defaultLogicalKind()); - llvm::SmallVector inArgs; - inArgs.push_back({sum.getArray(), sum.getArray().getType()}); - inArgs.push_back({sum.getDim(), i32}); - inArgs.push_back({sum.getMask(), logicalType}); + inArgs.push_back({operation.getArray(), operation.getArray().getType()}); + inArgs.push_back({operation.getDim(), i32}); + inArgs.push_back({operation.getMask(), logicalType}); - auto *argLowering = fir::getIntrinsicArgumentLowering("sum"); + auto *argLowering = fir::getIntrinsicArgumentLowering(opName); llvm::SmallVector args = - lowerArguments(sum, inArgs, rewriter, argLowering); + this->lowerArguments(operation, inArgs, rewriter, argLowering); - mlir::Type scalarResultType = hlfir::getFortranElementType(sum.getType()); + mlir::Type scalarResultType = + hlfir::getFortranElementType(operation.getType()); auto [resultExv, mustBeFreed] = - fir::genIntrinsicCall(builder, loc, "sum", scalarResultType, args); + fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args); - processReturnValue(sum, resultExv, mustBeFreed, builder, rewriter); + this->processReturnValue(operation, resultExv, mustBeFreed, builder, + rewriter); return mlir::success(); } }; +using SumOpConversion = HlfirReductionIntrinsicConversion; + +using ProductOpConversion = HlfirReductionIntrinsicConversion; + struct MatmulOpConversion : public HlfirIntrinsicConversion { using HlfirIntrinsicConversion::HlfirIntrinsicConversion; @@ -304,14 +321,16 @@ class LowerHLFIRIntrinsics mlir::ModuleOp module = this->getOperation(); mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns.insert(context); + patterns + .insert( + context); mlir::ConversionTarget target(*context); target.addLegalDialect(); target.addIllegalOp(); + hlfir::ProductOp, hlfir::TransposeOp>(); target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); if (mlir::failed( diff --git a/flang/test/HLFIR/product-lowering.fir b/flang/test/HLFIR/product-lowering.fir new file mode 100644 index 0000000000000..e22869ffc4c7d --- /dev/null +++ b/flang/test/HLFIR/product-lowering.fir @@ -0,0 +1,171 @@ +// Test hlfir.product operation lowering to fir runtime call +// RUN: fir-opt %s -lower-hlfir-intrinsics | FileCheck %s + +// one argument product +func.func @_QPproduct1(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.ref {fir.bindc_name = "s"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFsum1Ea"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %1:2 = hlfir.declare %arg1 {uniq_name = "_QFsum1Es"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %2 = hlfir.product %0#0 {fastmath = #arith.fastmath} : (!fir.box>) -> !hlfir.expr + hlfir.assign %2 to %1#0 : !hlfir.expr, !fir.ref + hlfir.destroy %2 : !hlfir.expr + return +} + +// CHECK-LABEL: func.func @_QPproduct1( +// CHECK: %[[ARG0:.*]]: !fir.box> +// CHECK: %[[ARG1:.*]]: !fir.ref +// CHECK-DAG: %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[RES:.*]]:2 = hlfir.declare %[[ARG1]] +// CHECK-DAG: %[[MASK:.*]] = fir.absent !fir.box +// CHECK-DAG: %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]#1 : (!fir.box>) -> !fir.box +// CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK]] : (!fir.box) -> !fir.box +// CHECK: %[[RET:.*]] = fir.call @_FortranAProductInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) : (!fir.box, !fir.ref, i32, i32, !fir.box) -> i32 +// CHECK-NEXT: hlfir.assign %[[RET]] to %[[RES]]#0 : i32, !fir.ref +// CHECK-NEXT: return +// CHECK-NEXT: } + +// product with DIM argument by-ref +func.func @_QPproduct2(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.box> {fir.bindc_name = "s"}, %arg2: !fir.ref {fir.bindc_name = "d"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFproduct2Ea"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %1:2 = hlfir.declare %arg2 {uniq_name = "_QFproduct2Ed"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %2:2 = hlfir.declare %arg1 {uniq_name = "_QFproduct2Es"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %3 = fir.load %1#0 : !fir.ref + %4 = hlfir.product %0#0 dim %3 {fastmath = #arith.fastmath} : (!fir.box>, index) -> !hlfir.expr + hlfir.assign %4 to %2#0 : !hlfir.expr, !fir.box> + hlfir.destroy %4 : !hlfir.expr + return +} + +// CHECK-LABEL: func.func @_QPproduct2( +// CHECK: %[[ARG0:.*]]: !fir.box> +// CHECK: %[[ARG1:.*]]: !fir.box> +// CHECK: %[[ARG2:.*]]: !fir.ref +// CHECK-DAG: %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[DIM_VAR:.*]]:2 = hlfir.declare %[[ARG2]] +// CHECK-DAG: %[[RES:.*]]:2 = hlfir.declare %[[ARG1]] + +// CHECK-DAG: %[[RET_BOX:.*]] = fir.alloca !fir.box>> +// CHECK-DAG: %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[RET_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1> +// CHECK-DAG: %[[RET_EMBOX:.*]] = fir.embox %[[RET_ADDR]](%[[RET_SHAPE]]) +// CHECK-DAG: fir.store %[[RET_EMBOX]] to %[[RET_BOX]] + +// CHECK-DAG: %[[MASK:.*]] = fir.absent !fir.box +// CHECK-DAG: %[[DIM_IDX:.*]] = fir.load %[[DIM_VAR]]#0 : !fir.ref +// CHECK-DAG: %[[DIM:.*]] = fir.convert %[[DIM_IDX]] : (index) -> i32 + +// CHECK-DAG: %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] +// CHECK-DAG: %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]] +// CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK]] + +// CHECK: %[[NONE:.*]] = fir.call @_FortranAProductDim(%[[RET_ARG]], %[[ARRAY_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[MASK_ARG]]) : (!fir.ref>, !fir.box, i32, !fir.ref, i32, !fir.box) -> none +// CHECK: %[[RET:.*]] = fir.load %[[RET_BOX]] +// CHECK: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[RET]] +// CHECK-NEXT: %[[ADDR:.*]] = fir.box_addr %[[RET]] +// CHECK-NEXT: %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1 +// CHECK-NEXT: %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"} +// CHECK: %[[TRUE:.*]] = arith.constant true +// CHECK: %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box>, i1) -> !hlfir.expr +// CHECK: hlfir.assign %[[EXPR]] to %[[RES]]#0 +// CHECK: hlfir.destroy %[[EXPR]] +// CHECK-NEXT: return +// CHECK-NEXT: } + +// product with scalar mask +func.func @_QPproduct3(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.ref {fir.bindc_name = "s"}, %arg2: !fir.ref> {fir.bindc_name = "m"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFproduct3Ea"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %1:2 = hlfir.declare %arg2 {uniq_name = "_QFproduct3Em"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) + %2:2 = hlfir.declare %arg1 {uniq_name = "_QFproduct3Es"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %3 = hlfir.product %0#0 mask %1#0 {fastmath = #arith.fastmath} : (!fir.box>, !fir.ref>) -> !hlfir.expr + hlfir.assign %3 to %2#0 : !hlfir.expr, !fir.ref + hlfir.destroy %3 : !hlfir.expr + return +} + +// CHECK-LABEL: func.func @_QPproduct3( +// CHECK: %[[ARG0:.*]]: !fir.box> +// CHECK: %[[ARG1:.*]]: !fir.ref +// CHECK: %[[ARG2:.*]]: !fir.ref> +// CHECK-DAG: %[[ARRAY:.*]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[RES:.*]]:2 = hlfir.declare %[[ARG1]] +// CHECK-DAG: %[[MASK:.*]]:2 = hlfir.declare %[[ARG2]] +// CHECK-DAG: %[[MASK_BOX:.*]] = fir.embox %[[MASK]]#1 : (!fir.ref>) -> !fir.box> +// CHECK-DAG: %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]#1 : (!fir.box>) -> !fir.box +// CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK_BOX]] : (!fir.box>) -> !fir.box +// CHECK: %[[RET:.*]] = fir.call @_FortranAProductInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) : (!fir.box, !fir.ref, i32, i32, !fir.box) -> i32 +// CHECK-NEXT: hlfir.assign %[[RET]] to %[[RES]]#0 : i32, !fir.ref +// CHECK-NEXT: return +// CHECK-NEXT: } + +// product with array mask +func.func @_QPproduct4(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.ref {fir.bindc_name = "s"}, %arg2: !fir.box>> {fir.bindc_name = "m"}) { + %0:2 = hlfir.declare %arg0 {uniq_name = "_QFproduct4Ea"} : (!fir.box>) -> (!fir.box>, !fir.box>) + %1:2 = hlfir.declare %arg2 {uniq_name = "_QFproduct4Em"} : (!fir.box>>) -> (!fir.box>>, !fir.box>>) + %2:2 = hlfir.declare %arg1 {uniq_name = "_QFproduct4Es"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %3 = hlfir.product %0#0 mask %1#0 {fastmath = #arith.fastmath} : (!fir.box>, !fir.box>>) -> !hlfir.expr + hlfir.assign %3 to %2#0 : !hlfir.expr, !fir.ref + hlfir.destroy %3 : !hlfir.expr + return +} + +// CHECK-LABEL: func.func @_QPproduct4( +// CHECK: %[[ARG0:.*]]: !fir.box +// CHECK: %[[ARG1:.*]]: !fir.ref +// CHECK: %[[ARG2:.*]]: !fir.box>> +// CHECK-DAG: %[[ARRAY]]:2 = hlfir.declare %[[ARG0]] +// CHECK-DAG: %[[RES]]:2 = hlfir.declare %[[ARG1]] +// CHECK-DAG: %[[MASK]]:2 = hlfir.declare %[[ARG2]] +// CHECK-DAG: %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY]]#1 : (!fir.box>) -> !fir.box +// CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK]]#1 : (!fir.box>>) -> !fir.box +// CHECK: %[[RET:.*]] = fir.call @_FortranAProductInteger4(%[[ARRAY_ARG]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[INT:.*]], %[[MASK_ARG]]) : (!fir.box, !fir.ref, i32, i32, !fir.box) -> i32 +// CHECK-NEXT: hlfir.assign %[[RET]] to %[[RES]]#0 : i32, !fir.ref +// CHECK-NEXT: return +// CHECK-NEXT: } + + +// product with all 3 arguments +func.func @_QPproduct5(%arg0: !fir.ref> {fir.bindc_name = "s"}) { + %0 = fir.address_of(@_QFproduct5Ea) : !fir.ref> + %c2 = arith.constant 2 : index + %c2_0 = arith.constant 2 : index + %1 = fir.shape %c2, %c2_0 : (index, index) -> !fir.shape<2> + %2:2 = hlfir.declare %0(%1) {uniq_name = "_QFproduct5Ea"} : (!fir.ref>, !fir.shape<2>) -> (!fir.ref>, !fir.ref>) + %c2_1 = arith.constant 2 : index + %3 = fir.shape %c2_1 : (index) -> !fir.shape<1> + %4:2 = hlfir.declare %arg0(%3) {uniq_name = "_QFproduct5Es"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) + %c1_i32 = arith.constant 1 : i32 + %true = arith.constant true + %5 = hlfir.product %2#0 dim %c1_i32 mask %true {fastmath = #arith.fastmath} : (!fir.ref>, i32, i1) -> !hlfir.expr<2xi32> + hlfir.assign %5 to %4#0 : !hlfir.expr<2xi32>, !fir.ref> + hlfir.destroy %5 : !hlfir.expr<2xi32> + return +} + +// CHECK-LABEL: func.func @_QPproduct5( +// CHECK: %[[ARG0:.*]]: !fir.ref> +// CHECK-DAG: %[[RET_BOX:.*]] = fir.alloca !fir.box>> +// CHECK-DAG: %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[RET_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1> +// CHECK-DAG: %[[RET_EMBOX:.*]] = fir.embox %[[RET_ADDR]](%[[RET_SHAPE]]) +// CHECK-DAG: fir.store %[[RET_EMBOX]] to %[[RET_BOX]] + +// CHECK-DAG: %[[RES_VAR:.*]] = hlfir.declare %[[ARG0]](%[[RES_SHAPE:.*]]) + +// CHECK-DAG: %[[MASK_ALLOC:.*]] = fir.alloca !fir.logical<4> +// CHECK-DAG: %[[TRUE:.*]] = arith.constant true +// CHECK-DAG: %[[MASK_VAL:.*]] = fir.convert %[[TRUE]] : (i1) -> !fir.logical<4> +// CHECK-DAG: fir.store %[[MASK_VAL]] to %[[MASK_ALLOC]] : !fir.ref> +// CHECK-DAG: %[[MASK_BOX:.*]] = fir.embox %[[MASK_ALLOC]] + +// CHECK-DAG: %[[ARRAY_ADDR:.*]] = fir.address_of +// CHECK-DAG: %[[ARRAY_VAR:.*]]:2 = hlfir.declare %[[ARRAY_ADDR]](%[[ARRAY_SHAPE:.*]]) +// CHECK-DAG: %[[ARRAY_BOX:.*]] = fir.embox %[[ARRAY_VAR]]#1(%[[ARRAY_SHAPE:.*]]) + +// CHECK-DAG: %[[DIM:.*]] = arith.constant 1 : i32 + +// CHECK-DAG: %[[RET_ARG:.*]] = fir.convert %[[RET_BOX]] +// CHECK-DAG: %[[ARRAY_ARG:.*]] = fir.convert %[[ARRAY_BOX]] : (!fir.box>) -> !fir.box +// CHECK-DAG: %[[MASK_ARG:.*]] = fir.convert %[[MASK_BOX]] : (!fir.box>) -> !fir.box +// CHECK: %[[NONE:.*]] = fir.call @_FortranAProductDim(%[[RET_ARG]], %[[ARRAY_ARG]], %[[DIM]], %[[LOC_STR:.*]], %[[LOC_N:.*]], %[[MASK_ARG]]) : (!fir.ref>, !fir.box, i32, !fir.ref, i32, !fir.box) -> none