Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][HLFIR] Use GreedyPatternRewriter in LowerHLFIRIntrinsics #83438

Merged
merged 2 commits into from
Mar 1, 2024

Conversation

tblah
Copy link
Contributor

@tblah tblah commented Feb 29, 2024

In #83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp should not be using rewrite patterns with the dialect conversion driver.

The intention of this pass is to lower HLFIR intrinsic operations into FIR so it conceptually fits dialect conversion. However, dialect conversion is much stricter about changing types when replacing operations. This pass sometimes looses track of array bounds, resulting in replacements with operations with different but compatible types (expressions of the same rank and element types but with or without compile time known array bounds). This is difficult to accommodate with the dialect conversion driver and so I have changed to use the greedy pattern rewriter.

There is a lot of test churn because the greedy pattern rewriter also performs canonicalization.

In llvm#83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp
should not be using rewrite patterns with the dialect conversion driver.

The intention of this pass is to lower HLFIR intrinsic operations into
FIR so it conceptually fits dialect conversion. However, dialect
conversion is much stricter about changing types when replacing
operations. This pass sometimes looses track of array bounds, resulting
in replacements with operations with different but compatible types
(expressions of the same rank and element types but with or without
compile time known array bounds). This is difficult to accommodate with
the dialect conversion driver and so I have changed to use the greedy
pattern rewriter.

There is a lot of test churn because the greedy pattern rewriter
also performs canonicalization.
@tblah tblah changed the title [flang][HLFIR] Use GreedyPatternRewriter in LowrHLFIRIntrinsics [flang][HLFIR] Use GreedyPatternRewriter in LowerHLFIRIntrinsics Feb 29, 2024
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 29, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 29, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Tom Eccles (tblah)

Changes

In #83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp should not be using rewrite patterns with the dialect conversion driver.

The intention of this pass is to lower HLFIR intrinsic operations into FIR so it conceptually fits dialect conversion. However, dialect conversion is much stricter about changing types when replacing operations. This pass sometimes looses track of array bounds, resulting in replacements with operations with different but compatible types (expressions of the same rank and element types but with or without compile time known array bounds). This is difficult to accommodate with the dialect conversion driver and so I have changed to use the greedy pattern rewriter.

There is a lot of test churn because the greedy pattern rewriter also performs canonicalization.


Patch is 99.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83438.diff

17 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp (+16-23)
  • (modified) flang/test/HLFIR/all-lowering.fir (+4-4)
  • (modified) flang/test/HLFIR/any-lowering.fir (+3-3)
  • (modified) flang/test/HLFIR/count-lowering-default-int-kinds.fir (+8-8)
  • (modified) flang/test/HLFIR/count-lowering.fir (+3-3)
  • (modified) flang/test/HLFIR/dot_product-lowering.fir (-1)
  • (modified) flang/test/HLFIR/extents-of-shape-of.f90 (+5-6)
  • (modified) flang/test/HLFIR/matmul-lowering.fir (+1-1)
  • (modified) flang/test/HLFIR/maxloc-lowering.fir (+64-73)
  • (modified) flang/test/HLFIR/maxval-lowering.fir (+2-2)
  • (modified) flang/test/HLFIR/minloc-lowering.fir (+64-73)
  • (modified) flang/test/HLFIR/minval-lowering.fir (+3-3)
  • (modified) flang/test/HLFIR/mul_transpose.f90 (+11-16)
  • (modified) flang/test/HLFIR/product-lowering.fir (+2-2)
  • (modified) flang/test/HLFIR/sum-lowering.fir (+2-2)
  • (modified) flang/test/HLFIR/transpose-lowering.fir (+1-1)
  • (modified) flang/test/Lower/convert.f90 (+1-1)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 377cc44392028f..0142fb0cfb0bb0 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -18,12 +18,12 @@
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include <mlir/IR/MLIRContext.h>
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <optional>
 
 namespace hlfir {
@@ -176,14 +176,7 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
           rewriter.eraseOp(use);
       }
     }
-    // TODO: This entire pass should be a greedy pattern rewrite or a manual
-    // IR traversal. A dialect conversion cannot be used here because
-    // `replaceAllUsesWith` is not supported. Similarly, `replaceOp` is not
-    // suitable because "op->getResult(0)" and "base" can have different types.
-    // In such a case, the dialect conversion will attempt to convert the type,
-    // but no type converter is specified in this pass. Also note that all
-    // patterns in this pass are actually rewrite patterns.
-    op->getResult(0).replaceAllUsesWith(base);
+
     rewriter.replaceOp(op, base);
   }
 };
@@ -491,19 +484,19 @@ class LowerHLFIRIntrinsics
                 ProductOpConversion, TransposeOpConversion, CountOpConversion,
                 DotProductOpConversion, MaxvalOpConversion, MinvalOpConversion,
                 MinlocOpConversion, MaxlocOpConversion>(context);
-    mlir::ConversionTarget target(*context);
-    target.addLegalDialect<mlir::BuiltinDialect, mlir::arith::ArithDialect,
-                           mlir::func::FuncDialect, fir::FIROpsDialect,
-                           hlfir::hlfirDialect>();
-    target.addIllegalOp<hlfir::MatmulOp, hlfir::MatmulTransposeOp, hlfir::SumOp,
-                        hlfir::ProductOp, hlfir::TransposeOp, hlfir::AnyOp,
-                        hlfir::AllOp, hlfir::DotProductOp, hlfir::CountOp,
-                        hlfir::MaxvalOp, hlfir::MinvalOp, hlfir::MinlocOp,
-                        hlfir::MaxlocOp>();
-    target.markUnknownOpDynamicallyLegal(
-        [](mlir::Operation *) { return true; });
-    if (mlir::failed(
-            mlir::applyFullConversion(module, target, std::move(patterns)))) {
+
+    // While conceptually this pass is performing dialect conversion, we use
+    // pattern rewrites here instead of dialect conversion because this pass
+    // looses array bounds from some of the expressions e.g.
+    // !hlfir.expr<2xi32> -> !hlfir.expr<?xi32>
+    // MLIR thinks this is a different type so dialect conversion fails.
+    // Pattern rewriting only requires that the resulting IR is still valid
+    mlir::GreedyRewriteConfig config;
+    // Prevent the pattern driver from merging blocks
+    config.enableRegionSimplification = false;
+
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            module, std::move(patterns), config))) {
       mlir::emitError(mlir::UnknownLoc::get(context),
                       "failure in HLFIR intrinsic lowering");
       signalPassFailure();
diff --git a/flang/test/HLFIR/all-lowering.fir b/flang/test/HLFIR/all-lowering.fir
index dfd1ace947d68d..e83378eacf9c9f 100644
--- a/flang/test/HLFIR/all-lowering.fir
+++ b/flang/test/HLFIR/all-lowering.fir
@@ -34,6 +34,7 @@ func.func @_QPall2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_n
 // CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>
 // CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>
 // CHECK:           %[[ARG2:.*]]: !fir.ref<i32>
+// CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
 // CHECK-DAG:     %[[MASK:.*]]:2 = hlfir.declare %[[ARG0]]
 // CHECK-DAG:     %[[DIM_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
 // CHECK-DAG:     %[[RES:.*]]:2 = hlfir.declare %[[ARG1]]
@@ -55,7 +56,6 @@ func.func @_QPall2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_n
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
 // CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
 // CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
-// CHECK:         %[[TRUE:.*]] = arith.constant true
 // CHECK:         %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i1) -> !hlfir.expr<?x!fir.logical<4>>
 // CHECK:         hlfir.assign %[[EXPR]] to %[[RES]]#0
 // CHECK:         hlfir.destroy %[[EXPR]]
@@ -79,6 +79,7 @@ func.func @_QPall3(%arg0: !fir.ref<!fir.array<2x!fir.logical<4>>> {fir.bindc_nam
 }
 // CHECK-LABEL:  func.func @_QPall3(
 // CHECK:           %[[ARG0:.*]]: !fir.ref<!fir.array<2x!fir.logical<4>>>
+// CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
 // CHECK-DAG:     %[[RET_BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
 // CHECK-DAG:     %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap<!fir.array<?x!fir.logical<4>>>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
@@ -101,7 +102,6 @@ func.func @_QPall3(%arg0: !fir.ref<!fir.array<2x!fir.logical<4>>> {fir.bindc_nam
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
 // CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
 // CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
-// CHECK:         %[[TRUE:.*]] = arith.constant true
 // CHECK:         %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i1) -> !hlfir.expr<?x!fir.logical<4>>
 // CHECK:         hlfir.assign %[[EXPR]] to %[[RES]]
 // CHECK:         hlfir.destroy %[[EXPR]]
@@ -125,6 +125,7 @@ func.func @_QPall4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_n
 // CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>
 // CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>
 // CHECK:           %[[ARG2:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>
+// CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
 // CHECK-DAG:     %[[MASK:.*]]:2 = hlfir.declare %[[ARG0]]
 // CHECK-DAG:     %[[DIM_ARG:.*]]:2 = hlfir.declare %[[ARG2]]
 // CHECK-DAG:     %[[RES:.*]]:2 = hlfir.declare %[[ARG1]]
@@ -149,9 +150,8 @@ func.func @_QPall4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_n
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
 // CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
 // CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
-// CHECK:         %[[TRUE:.*]] = arith.constant true
 // CHECK:         %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i1) -> !hlfir.expr<?x!fir.logical<4>>
 // CHECK:         hlfir.assign %[[EXPR]] to %[[RES]]
 // CHECK:         hlfir.destroy %[[EXPR]]
 // CHECK-NEXT:    return
-// CHECK-NEXT:  }
\ No newline at end of file
+// CHECK-NEXT:  }
diff --git a/flang/test/HLFIR/any-lowering.fir b/flang/test/HLFIR/any-lowering.fir
index ef8b8950293190..039146727d3f56 100644
--- a/flang/test/HLFIR/any-lowering.fir
+++ b/flang/test/HLFIR/any-lowering.fir
@@ -36,6 +36,7 @@ func.func @_QPany2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_n
 // CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>
 // CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>
 // CHECK:           %[[ARG2:.*]]: !fir.ref<i32>
+// CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
 // CHECK-DAG:     %[[MASK:.*]]:2 = hlfir.declare %[[ARG0]]
 // CHECK-DAG:     %[[DIM_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
 // CHECK-DAG:     %[[RES:.*]]:2 = hlfir.declare %[[ARG1]]
@@ -57,7 +58,6 @@ func.func @_QPany2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_n
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
 // CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
 // CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
-// CHECK:         %[[TRUE:.*]] = arith.constant true
 // CHECK:         %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i1) -> !hlfir.expr<?x!fir.logical<4>>
 // CHECK:         hlfir.assign %[[EXPR]] to %[[RES]]#0
 // CHECK:         hlfir.destroy %[[EXPR]]
@@ -82,6 +82,7 @@ func.func @_QPany3(%arg0: !fir.ref<!fir.array<2x!fir.logical<4>>> {fir.bindc_nam
 }
 // CHECK-LABEL:  func.func @_QPany3(
 // CHECK:           %[[ARG0:.*]]: !fir.ref<!fir.array<2x!fir.logical<4>>>
+// CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
 // CHECK-DAG:     %[[RET_BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.logical<4>>>>
 // CHECK-DAG:     %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap<!fir.array<?x!fir.logical<4>>>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
@@ -104,7 +105,6 @@ func.func @_QPany3(%arg0: !fir.ref<!fir.array<2x!fir.logical<4>>> {fir.bindc_nam
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
 // CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
 // CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
-// CHECK:         %[[TRUE:.*]] = arith.constant true
 // CHECK:         %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i1) -> !hlfir.expr<?x!fir.logical<4>>
 // CHECK:         hlfir.assign %[[EXPR]] to %[[RES]]
 // CHECK:         hlfir.destroy %[[EXPR]]
@@ -129,6 +129,7 @@ func.func @_QPany4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_n
 // CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>
 // CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>
 // CHECK:           %[[ARG2:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>
+// CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
 // CHECK-DAG:     %[[MASK:.*]]:2 = hlfir.declare %[[ARG0]]
 // CHECK-DAG:     %[[DIM_ARG:.*]]:2 = hlfir.declare %[[ARG2]]
 // CHECK-DAG:     %[[RES:.*]]:2 = hlfir.declare %[[ARG1]]
@@ -153,7 +154,6 @@ func.func @_QPany4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_n
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
 // CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
 // CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
-// CHECK:         %[[TRUE:.*]] = arith.constant true
 // CHECK:         %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i1) -> !hlfir.expr<?x!fir.logical<4>>
 // CHECK:         hlfir.assign %[[EXPR]] to %[[RES]]
 // CHECK:         hlfir.destroy %[[EXPR]]
diff --git a/flang/test/HLFIR/count-lowering-default-int-kinds.fir b/flang/test/HLFIR/count-lowering-default-int-kinds.fir
index ea66c435e6a8a7..68bc7fdbaad876 100644
--- a/flang/test/HLFIR/count-lowering-default-int-kinds.fir
+++ b/flang/test/HLFIR/count-lowering-default-int-kinds.fir
@@ -2,9 +2,9 @@
 // RUN: fir-opt %s -lower-hlfir-intrinsics | FileCheck %s
 
 module attributes {fir.defaultkind = "a1c4d8i8l4r4", fir.kindmap = ""} {
-  func.func @test_i8(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
+  func.func @test_i8(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) -> !hlfir.expr<?xi64> {
     %4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi64>
-    return
+    return %4 : !hlfir.expr<?xi64>
   }
 }
 // CHECK-LABEL: func.func @test_i8
@@ -12,9 +12,9 @@ module attributes {fir.defaultkind = "a1c4d8i8l4r4", fir.kindmap = ""} {
 // CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
 
 module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = ""} {
-  func.func @test_i4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
+  func.func @test_i4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) -> !hlfir.expr<?xi32> {
     %4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi32>
-    return
+    return %4 : !hlfir.expr<?xi32>
   }
 }
 // CHECK-LABEL: func.func @test_i4
@@ -22,9 +22,9 @@ module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = ""} {
 // CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
 
 module attributes {fir.defaultkind = "a1c4d8i2l4r4", fir.kindmap = ""} {
-  func.func @test_i2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
+  func.func @test_i2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) -> !hlfir.expr<?xi16> {
     %4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi16>
-    return
+    return %4 : !hlfir.expr<?xi16>
   }
 }
 // CHECK-LABEL: func.func @test_i2
@@ -32,9 +32,9 @@ module attributes {fir.defaultkind = "a1c4d8i2l4r4", fir.kindmap = ""} {
 // CHECK: fir.call @_FortranACountDim(%{{.*}}, %{{.*}}, %{{.*}}, %[[KIND]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, i32, !fir.ref<i8>, i32) -> none
 
 module attributes {fir.defaultkind = "a1c4d8i1l4r4", fir.kindmap = ""} {
-  func.func @test_i1(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) {
+  func.func @test_i1(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc_name = "x"}, %arg1: i64) -> !hlfir.expr<?xi8> {
     %4 = hlfir.count %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i64) -> !hlfir.expr<?xi8>
-    return
+    return %4 : !hlfir.expr<?xi8>
   }
 }
 // CHECK-LABEL: func.func @test_i1
diff --git a/flang/test/HLFIR/count-lowering.fir b/flang/test/HLFIR/count-lowering.fir
index da0f250dceef35..c3309724981a3f 100644
--- a/flang/test/HLFIR/count-lowering.fir
+++ b/flang/test/HLFIR/count-lowering.fir
@@ -34,6 +34,7 @@ func.func @_QPcount2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc
 // CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>
 // CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?xi32>
 // CHECK:           %[[ARG2:.*]]: !fir.ref<i32>
+// CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
 // CHECK-DAG:     %[[KIND:.*]] = arith.constant 4 : i32
 // CHECK-DAG:     %[[MASK:.*]]:2 = hlfir.declare %[[ARG0]]
 // CHECK-DAG:     %[[DIM_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
@@ -56,7 +57,6 @@ func.func @_QPcount2(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
 // CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
 // CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
-// CHECK:         %[[TRUE:.*]] = arith.constant true
 // CHECK:         %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box<!fir.array<?xi32>>, i1) -> !hlfir.expr<?xi32>
 // CHECK:         hlfir.assign %[[EXPR]] to %[[RES]]#0
 // CHECK:         hlfir.destroy %[[EXPR]]
@@ -80,6 +80,7 @@ func.func @_QPcount3(%arg0: !fir.ref<!fir.array<2xi32>> {fir.bindc_name = "s"})
 }
 // CHECK-LABEL:  func.func @_QPcount3(
 // CHECK:           %[[ARG0:.*]]: !fir.ref<!fir.array<2xi32>>
+// CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
 // CHECK-DAG:     %[[RET_BOX:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xi32>>>
 // CHECK-DAG:     %[[KIND:.*]] = arith.constant 4 : i32
 // CHECK-DAG:     %[[RET_ADDR:.*]] = fir.zero_bits !fir.heap<!fir.array<?xi32>>
@@ -104,7 +105,6 @@ func.func @_QPcount3(%arg0: !fir.ref<!fir.array<2xi32>> {fir.bindc_name = "s"})
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
 // CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
 // CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
-// CHECK:         %[[TRUE:.*]] = arith.constant true
 // CHECK:         %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box<!fir.array<?xi32>>, i1) -> !hlfir.expr<?xi32>
 // CHECK:         hlfir.assign %[[EXPR]] to %[[RES]]
 // CHECK:         hlfir.destroy %[[EXPR]]
@@ -133,6 +133,7 @@ func.func @_QPcount4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc
 // CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>
 // CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<?xi32>
 // CHECK:           %[[ARG2:.*]]: !fir.ref<i32>
+// CHECK-DAG:     %[[TRUE:.*]] = arith.constant true
 // CHECK-DAG:     %[[MASK:.*]]:2 = hlfir.declare %[[ARG0]]
 // CHECK-DAG:     %[[DIM_VAR:.*]]:2 = hlfir.declare %[[ARG2]]
 // CHECK-DAG:     %[[RES:.*]]:2 = hlfir.declare %[[ARG1]]
@@ -155,7 +156,6 @@ func.func @_QPcount4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>> {fir.bindc
 // CHECK-NEXT:    %[[ADDR:.*]] = fir.box_addr %[[RET]]
 // CHECK-NEXT:    %[[SHIFT:.*]] = fir.shape_shift %[[BOX_DIMS]]#0, %[[BOX_DIMS]]#1
 // CHECK-NEXT:    %[[TMP:.*]]:2 = hlfir.declare %[[ADDR]](%[[SHIFT]]) {uniq_name = ".tmp.intrinsic_result"}
-// CHECK:         %[[TRUE:.*]] = arith.constant true
 // CHECK:         %[[EXPR:.*]] = hlfir.as_expr %[[TMP]]#0 move %[[TRUE]] : (!fir.box<!fir.array<?xi64>>, i1) -> !hlfir.expr<?xi64>
 // CHECK-NEXT:    %[[OUT_SHAPE:.*]] = hlfir.shape_of %[[EXPR]]
 // CHECK-NEXT:    %[[OUT:.*]] = hlfir.elemental %[[OUT_SHAPE]] : (!fir.shape<1>) -> !hlfir.expr<?xi32>
diff --git a/flang/test/HLFIR/dot_product-lowering.fir b/flang/test/HLFIR/dot_product-lowering.fir
index e4f91eabfc0991..64d65665433f15 100644
--- a/flang/test/HLFIR/dot_product-lowering.fir
+++ b/flang/test/HLFIR/dot_product-lowering.fir
@@ -96,7 +96,6 @@ func.func @_QPdot_product4(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>> {fir.b
 // CHECK:           %[[VAL_2:.*]] = fir.alloca !fir.logical<4>
 // CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdot_product2Elhs"} : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>)
 // CHECK:           %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFdot_product2Erhs"} : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> (!fir.box<!fir.array<?x!fir.logical<4>>>, !fir.box<!fir.array<?x!fir.logical<4>>>)
-// CHECK:           %[[VAL_5:.*]] = fir.absent !fir.box<!fir.logical<4>>
 // CHECK:           %[[VAL_9:.*]] = fir.convert %[[VAL_3]]#1 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
 // CHECK:           %[[VAL_10:.*]] = fir.convert %[[VAL_4]]#1 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.box<none>
 // CHECK:           %[[VAL_12:.*]] = fir.call @_FortranADotProductLogical(%[[VAL_9]], %[[VAL_10]], %{{.*}}, %{{.*}}) fastmath<contract> : (!fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> i1
diff --git a/flang/test/HLFIR/extents-of-shape-of.f90 b/flang/test/HLFIR/extents-of-shape-of.f90
index d807f8b70302b6..1168004597d191 100644
--- a/flang/test/HLFIR/extents-of-shape-of.f90
+++ b/flang/test/HLFIR/extents-of-shape-of.f90
@@ -31,18 +31,17 @@ elemental subroutine elem_sub(x)
 ! CHECK-HLFIR-NEXT:    hlfir.destroy %[[MUL]]
 
 ! ...
+! CHECK-FIR-DAG:       %[[C0:.*]] = arith.constant 0 : index
+! CHECK-FIR-DAG:       %[[C1:.*]] = arith.constant 1 : index
+! CHECK-FIR-DAG:       %[[C2:.*]] = arith.constant 2 : index
 ! CHECK-FIR:           fir.call @_FortranAMatmul
 ! CHECK-FIR-NEXT:      %[[MUL:.*]] = fir.load %[[MUL_BOX:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
-! CHECK-FIR-NEXT:      %[[C0:.*]] = arith.constant 0 : index
 ! CHECK-FIR-NEXT:      %[[DIMS0:.*]]:3 = fir.box_dims %[[MUL]], %[[C0]]
-! CHECK-FIR-NEXT:      %[[C1:.*]] = arith.constant 1 : index
 ! CHECK-FIR-NEXT:      %[[DIMS1:.*]]:3 = fir.box_dims %[[MUL]], %[[C1]]
 ! ...
 ! CHECK-FIR:          ...
[truncated]

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I'm thinking of adding additional options to GreedyRewriteConfig to deactivate folding and/or CSE'ing/moving of constants. That should make changes like this one easier in the future and also generally gives more control over what's happening during a greedy pattern rewrite.

// CHECK-NEXT: %[[V0:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xi32>>>
// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
// CHECK-DAG: %[[FALSE:.*]] = arith.constant false
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent

@tblah tblah merged commit 44c0bdb into llvm:main Mar 1, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants