Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang] Any and All elemental lowering #75776

Merged
merged 1 commit into from Jan 10, 2024

Conversation

davemgreen
Copy link
Collaborator

This is an extension of #75774, with Any and All lowering added alongside Count.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Dec 18, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 18, 2023

@llvm/pr-subscribers-flang-fir-hlfir

Author: David Green (davemgreen)

Changes

This is an extension of #75774, with Any and All lowering added alongside Count.


Patch is 48.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75776.diff

4 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+151)
  • (added) flang/test/HLFIR/all-elemental.fir (+91)
  • (added) flang/test/HLFIR/any-elemental.fir (+190)
  • (added) flang/test/HLFIR/count-elemental.fir (+314)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 7abfa20493c736..a41796d256b6ab 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -659,6 +659,154 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
   return mlir::success();
 }
 
+using GenBodyFn =
+    std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
+                              const llvm::SmallVectorImpl<mlir::Value> &)>;
+static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
+                                         mlir::Location loc, mlir::Value init,
+                                         mlir::Value shape, GenBodyFn genBody) {
+  auto extents = hlfir::getIndexExtents(loc, builder, shape);
+  mlir::Value reduction = init;
+  mlir::IndexType idxTy = builder.getIndexType();
+  mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
+
+  // Create a reduction loop nest. We use one-based indices so that they can be
+  // passed to the elemental.
+  llvm::SmallVector<mlir::Value> indices;
+  for (unsigned i = 0; i < extents.size(); ++i) {
+    auto loop =
+        builder.create<fir::DoLoopOp>(loc, oneIdx, extents[i], oneIdx, false,
+                                      /*finalCountValue=*/false, reduction);
+    reduction = loop.getRegionIterArgs()[0];
+    indices.push_back(loop.getInductionVar());
+    // Set insertion point to the loop body so that the next loop
+    // is inserted inside the current one.
+    builder.setInsertionPointToStart(loop.getBody());
+  }
+
+  // Generate the body
+  reduction = genBody(builder, loc, reduction, indices);
+
+  // Unwind the loop nest.
+  for (unsigned i = 0; i < extents.size(); ++i) {
+    auto result = builder.create<fir::ResultOp>(loc, reduction);
+    auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
+    reduction = loop.getResult(0);
+    // Set insertion point after the loop operation that we have
+    // just processed.
+    builder.setInsertionPointAfter(loop.getOperation());
+  }
+
+  return reduction;
+}
+
+/// Given a reduction operation with an elemental mask, attempt to generate a
+/// do-loop to perform the operation inline.
+///   %e = hlfir.elemental %shape unordered
+///   %r = hlfir.count %e
+/// =>
+///   %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init) {
+///     <inline elemental>
+///     <reduce count>
+///   }
+template <typename Op>
+class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
+public:
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
+    mlir::Location loc = op.getLoc();
+    hlfir::ElementalOp elemental =
+        op.getMask().template getDefiningOp<hlfir::ElementalOp>();
+    if (!elemental || op.getDim())
+      return rewriter.notifyMatchFailure(op, "Did not find valid elemental");
+
+    fir::KindMapping kindMap =
+        fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
+    fir::FirOpBuilder builder{op, kindMap};
+
+    mlir::Value init;
+    GenBodyFn genBodyFn;
+    if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
+      init = builder.createIntegerConstant(loc, builder.getI1Type(), 0);
+      genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
+                              mlir::Value reduction,
+                              const llvm::SmallVectorImpl<mlir::Value> &indices)
+          -> mlir::Value {
+        // Inline the elemental and get the condition from it.
+        auto yield = inlineElementalOp(loc, builder, elemental, indices);
+        mlir::Value cond = builder.create<fir::ConvertOp>(
+            loc, builder.getI1Type(), yield.getElementValue());
+        yield->erase();
+
+        // Conditionally set the reduction variable.
+        return builder.create<mlir::arith::OrIOp>(loc, reduction, cond);
+      };
+    } else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
+      init = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
+      genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
+                              mlir::Value reduction,
+                              const llvm::SmallVectorImpl<mlir::Value> &indices)
+          -> mlir::Value {
+        // Inline the elemental and get the condition from it.
+        auto yield = inlineElementalOp(loc, builder, elemental, indices);
+        mlir::Value cond = builder.create<fir::ConvertOp>(
+            loc, builder.getI1Type(), yield.getElementValue());
+        yield->erase();
+
+        // Conditionally set the reduction variable.
+        return builder.create<mlir::arith::AndIOp>(loc, reduction, cond);
+      };
+    } else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
+      init = builder.createIntegerConstant(loc, op.getType(), 0);
+      genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
+                              mlir::Value reduction,
+                              const llvm::SmallVectorImpl<mlir::Value> &indices)
+          -> mlir::Value {
+        // Inline the elemental and get the condition from it.
+        auto yield = inlineElementalOp(loc, builder, elemental, indices);
+        mlir::Value cond = builder.create<fir::ConvertOp>(
+            loc, builder.getI1Type(), yield.getElementValue());
+        yield->erase();
+
+        // Conditionally add one to the current value
+        mlir::Value one =
+            builder.createIntegerConstant(loc, reduction.getType(), 1);
+        mlir::Value add1 =
+            builder.create<mlir::arith::AddIOp>(loc, reduction, one);
+        return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
+                                                     reduction);
+      };
+    } else {
+      static_assert("Expected Op to be handled");
+      return mlir::failure();
+    }
+
+    mlir::Value res = generateReductionLoop(builder, loc, init,
+                                            elemental.getOperand(0), genBodyFn);
+    if (res.getType() != op.getType())
+      res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
+
+    // Check if the op was the only user of the elemental (apart from a
+    // destroy), and remove it if so.
+    mlir::Operation::user_range elemUsers = elemental->getUsers();
+    hlfir::DestroyOp elemDestroy;
+    if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
+      elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
+      if (!elemDestroy)
+        elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
+    }
+
+    rewriter.replaceOp(op, res);
+    if (elemDestroy) {
+      rewriter.eraseOp(elemDestroy);
+      rewriter.eraseOp(elemental);
+    }
+    return mlir::success();
+  }
+};
+
 class OptimizedBufferizationPass
     : public hlfir::impl::OptimizedBufferizationBase<
           OptimizedBufferizationPass> {
@@ -681,6 +829,9 @@ class OptimizedBufferizationPass
     patterns.insert<ElementalAssignBufferization>(context);
     patterns.insert<BroadcastAssignBufferization>(context);
     patterns.insert<VariableAssignBufferization>(context);
+    patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
+    patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
+    patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
 
     if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
             func, std::move(patterns), config))) {
diff --git a/flang/test/HLFIR/all-elemental.fir b/flang/test/HLFIR/all-elemental.fir
new file mode 100644
index 00000000000000..1ba8bb1b7a5fb4
--- /dev/null
+++ b/flang/test/HLFIR/all-elemental.fir
@@ -0,0 +1,91 @@
+// RUN: fir-opt %s -opt-bufferization | FileCheck %s
+
+func.func @_QFPtest(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.logical<4> {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c7 = arith.constant 7 : index
+  %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.alloca !fir.logical<4> {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
+  %5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %6 = fir.load %2#0 : !fir.ref<i32>
+  %7 = fir.convert %6 : (i32) -> i64
+  %8 = fir.shape %c7 : (index) -> !fir.shape<1>
+  %9 = hlfir.designate %1#0 (%7, %c1:%c7:%c1)  shape %8 : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+  %10 = fir.load %5#0 : !fir.ref<i32>
+  %11 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7x!fir.logical<4>> {
+  ^bb0(%arg3: index):
+    %14 = hlfir.designate %9 (%arg3)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+    %15 = fir.load %14 : !fir.ref<i32>
+    %16 = arith.cmpi sge, %15, %10 : i32
+    %17 = fir.convert %16 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %17 : !fir.logical<4>
+  }
+  %12 = hlfir.all %11 : (!hlfir.expr<7x!fir.logical<4>>) -> !fir.logical<4>
+  hlfir.assign %12 to %4#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
+  hlfir.destroy %11 : !hlfir.expr<7x!fir.logical<4>>
+  %13 = fir.load %4#1 : !fir.ref<!fir.logical<4>>
+  return %13 : !fir.logical<4>
+}
+// CHECK-LABEL:  func.func @_QFPtest(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.logical<4> {
+// CHECK-NEXT:     %true = arith.constant true
+// CHECK-NEXT:     %c1 = arith.constant 1 : index
+// CHECK-NEXT:     %c4 = arith.constant 4 : index
+// CHECK-NEXT:     %c7 = arith.constant 7 : index
+// CHECK-NEXT:     %[[V0:.*]] = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+// CHECK-NEXT:     %[[V1:.*]]:2 = hlfir.declare %arg0(%[[V0]])
+// CHECK-NEXT:     %[[V2:.*]]:2 = hlfir.declare %arg1
+// CHECK-NEXT:     %[[V3:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:     %[[V4:.*]]:2 = hlfir.declare %[[V3]]
+// CHECK-NEXT:     %[[V5:.*]]:2 = hlfir.declare %arg2
+// CHECK-NEXT:     %[[V6:.*]] = fir.load %[[V2]]#0 : !fir.ref<i32>
+// CHECK-NEXT:     %[[V7:.*]] = fir.convert %[[V6]] : (i32) -> i64
+// CHECK-NEXT:     %[[V8:.*]] = fir.shape %c7 : (index) -> !fir.shape<1>
+// CHECK-NEXT:     %[[V9:.*]] = hlfir.designate %[[V1]]#0 (%[[V7]], %c1:%c7:%c1)  shape %[[V8]] : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+// CHECK-NEXT:     %[[V10:.*]] = fir.load %[[V5]]#0 : !fir.ref<i32>
+// CHECK-NEXT:     %[[V11:.*]] = fir.do_loop %arg3 = %c1 to %c7 step %c1 iter_args(%arg4 = %true) -> (i1) {
+// CHECK-NEXT:       %[[V14:.*]] = hlfir.designate %[[V9]] (%arg3)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:       %[[V15:.*]] = fir.load %[[V14]] : !fir.ref<i32>
+// CHECK-NEXT:       %[[V16:.*]] = arith.cmpi sge, %[[V15]], %[[V10]] : i32
+// CHECK-NEXT:       %[[V17:.*]] = arith.andi %arg4, %[[V16]] : i1
+// CHECK-NEXT:       fir.result %[[V17]] : i1
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %[[V12:.*]] = fir.convert %[[V11]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:     hlfir.assign %[[V12]] to %[[V4]]#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:     %[[V13:.*]] = fir.load %[[V4]]#1 : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:     return %[[V13]] : !fir.logical<4>
+
+
+func.func @_QFPtest_dim(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.array<4x!fir.logical<4>> {
+  %c2_i32 = arith.constant 2 : i32
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c7 = arith.constant 7 : index
+  %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.alloca !fir.array<4x!fir.logical<4>> {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+  %4 = fir.shape %c4 : (index) -> !fir.shape<1>
+  %5:2 = hlfir.declare %3(%4) {uniq_name = "_QFFtestEtest"} : (!fir.ref<!fir.array<4x!fir.logical<4>>>, !fir.shape<1>) -> (!fir.ref<!fir.array<4x!fir.logical<4>>>, !fir.ref<!fir.array<4x!fir.logical<4>>>)
+  %6:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %7 = hlfir.designate %1#0 (%c1:%c4:%c1, %c1:%c7:%c1)  shape %0 : (!fir.ref<!fir.array<4x7xi32>>, index, index, index, index, index, index, !fir.shape<2>) -> !fir.ref<!fir.array<4x7xi32>>
+  %8 = fir.load %6#0 : !fir.ref<i32>
+  %9 = hlfir.elemental %0 unordered : (!fir.shape<2>) -> !hlfir.expr<4x7x!fir.logical<4>> {
+  ^bb0(%arg3: index, %arg4: index):
+    %12 = hlfir.designate %7 (%arg3, %arg4)  : (!fir.ref<!fir.array<4x7xi32>>, index, index) -> !fir.ref<i32>
+    %13 = fir.load %12 : !fir.ref<i32>
+    %14 = arith.cmpi sge, %13, %8 : i32
+    %15 = fir.convert %14 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %15 : !fir.logical<4>
+  }
+  %10 = hlfir.all %9 dim %c2_i32 : (!hlfir.expr<4x7x!fir.logical<4>>, i32) -> !hlfir.expr<4x!fir.logical<4>>
+  hlfir.assign %10 to %5#0 : !hlfir.expr<4x!fir.logical<4>>, !fir.ref<!fir.array<4x!fir.logical<4>>>
+  hlfir.destroy %10 : !hlfir.expr<4x!fir.logical<4>>
+  hlfir.destroy %9 : !hlfir.expr<4x7x!fir.logical<4>>
+  %11 = fir.load %5#1 : !fir.ref<!fir.array<4x!fir.logical<4>>>
+  return %11 : !fir.array<4x!fir.logical<4>>
+}
+// CHECK-LABEL:  func.func @_QFPtest_dim(
+// CHECK: %10 = hlfir.all %9 dim %c2_i32
\ No newline at end of file
diff --git a/flang/test/HLFIR/any-elemental.fir b/flang/test/HLFIR/any-elemental.fir
new file mode 100644
index 00000000000000..6e233068d2e9b5
--- /dev/null
+++ b/flang/test/HLFIR/any-elemental.fir
@@ -0,0 +1,190 @@
+// RUN: fir-opt %s -opt-bufferization | FileCheck %s
+
+func.func @_QFPtest(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.logical<4> {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c7 = arith.constant 7 : index
+  %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.alloca !fir.logical<4> {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
+  %5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %6 = fir.load %2#0 : !fir.ref<i32>
+  %7 = fir.convert %6 : (i32) -> i64
+  %8 = fir.shape %c7 : (index) -> !fir.shape<1>
+  %9 = hlfir.designate %1#0 (%7, %c1:%c7:%c1)  shape %8 : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+  %10 = fir.load %5#0 : !fir.ref<i32>
+  %11 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7x!fir.logical<4>> {
+  ^bb0(%arg3: index):
+    %14 = hlfir.designate %9 (%arg3)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+    %15 = fir.load %14 : !fir.ref<i32>
+    %16 = arith.cmpi sge, %15, %10 : i32
+    %17 = fir.convert %16 : (i1) -> !fir.logical<4>
+    hlfir.yield_element %17 : !fir.logical<4>
+  }
+  %12 = hlfir.any %11 : (!hlfir.expr<7x!fir.logical<4>>) -> !fir.logical<4>
+  hlfir.assign %12 to %4#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
+  hlfir.destroy %11 : !hlfir.expr<7x!fir.logical<4>>
+  %13 = fir.load %4#1 : !fir.ref<!fir.logical<4>>
+  return %13 : !fir.logical<4>
+}
+// CHECK-LABEL:  func.func @_QFPtest(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.logical<4> {
+// CHECK-NEXT:     %false = arith.constant false
+// CHECK-NEXT:     %c1 = arith.constant 1 : index
+// CHECK-NEXT:     %c4 = arith.constant 4 : index
+// CHECK-NEXT:     %c7 = arith.constant 7 : index
+// CHECK-NEXT:     %[[V0:.*]] = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+// CHECK-NEXT:     %[[V1:.*]]:2 = hlfir.declare %arg0(%[[V0]])
+// CHECK-NEXT:     %[[V2:.*]]:2 = hlfir.declare %arg1
+// CHECK-NEXT:     %[[V3:.*]] = fir.alloca !fir.logical<4>
+// CHECK-NEXT:     %[[V4:.*]]:2 = hlfir.declare %[[V3]]
+// CHECK-NEXT:     %[[V5:.*]]:2 = hlfir.declare %arg2
+// CHECK-NEXT:     %[[V6:.*]] = fir.load %[[V2]]#0 : !fir.ref<i32>
+// CHECK-NEXT:     %[[V7:.*]] = fir.convert %[[V6]] : (i32) -> i64
+// CHECK-NEXT:     %[[V8:.*]] = fir.shape %c7 : (index) -> !fir.shape<1>
+// CHECK-NEXT:     %[[V9:.*]] = hlfir.designate %[[V1]]#0 (%[[V7]], %c1:%c7:%c1)  shape %[[V8]] : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
+// CHECK-NEXT:     %[[V10:.*]] = fir.load %[[V5]]#0 : !fir.ref<i32>
+// CHECK-NEXT:     %[[V11:.*]] = fir.do_loop %arg3 = %c1 to %c7 step %c1 iter_args(%arg4 = %false) -> (i1) {
+// CHECK-NEXT:       %[[V14:.*]] = hlfir.designate %[[V9]] (%arg3)  : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
+// CHECK-NEXT:       %[[V15:.*]] = fir.load %[[V14]] : !fir.ref<i32>
+// CHECK-NEXT:       %[[V16:.*]] = arith.cmpi sge, %[[V15]], %[[V10]] : i32
+// CHECK-NEXT:       %[[V17:.*]] = arith.ori %arg4, %[[V16]] : i1
+// CHECK-NEXT:       fir.result %[[V17]] : i1
+// CHECK-NEXT:     }
+// CHECK-NEXT:     %[[V12:.*]] = fir.convert %[[V11]] : (i1) -> !fir.logical<4>
+// CHECK-NEXT:     hlfir.assign %[[V12]] to %[[V4]]#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:     %[[V13:.*]] = fir.load %[[V4]]#1 : !fir.ref<!fir.logical<4>>
+// CHECK-NEXT:     return %[[V13]] : !fir.logical<4>
+
+
+func.func @_QFPtest_dim(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.array<4x!fir.logical<4>> {
+  %c2_i32 = arith.constant 2 : i32
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c7 = arith.constant 7 : index
+  %0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
+  %1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
+  %2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.alloca !fir.array<4x!fir.logical<4>> {bindc_name = "test", uniq_name = "_QFFtestEtest"}
+  %4 = fir.shape %c4 : (index) -> !fir.shape<1>
+  %5:2 = hlfir.declare %3(%4) {uniq_name = "_QFFtestEtest"} : (!fir.ref<!fir.array<4x!fir.logical<4>>>, !fir.shape<1>) -> (!fir.ref<!fir.array<4x!fir.logical<4>>>, !fir.ref<!fir.array<4x!fir.logical<4>>>)
+  %6:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %7 = hlfir.designate %1#0 (%c1:%c4:%c1, %c1:%c7:%c1)  shape %0 : (!fir.ref<!fir.array<4x7xi32>>, index, index, index, index, index, index, !fir.shape<2>) -> !fir.ref<!fir.array<4x7xi32>>
+  %8 = fir.load %6#0 : !fir.ref<i32>
+  %9 = hlfir.elemental %0 unordered : (!fir.shape<2>) -> !hlfir.expr<4x7x!fir.logical<4>> {
+  ^bb0(%arg3: index, %arg4: index):
+    %12 = hlfir.designate %7 (%arg3, %arg4)  : (!fir.ref<!fir.array<4x7xi32>>, index, index) -> !fir.ref<i3...
[truncated]

Copy link

github-actions bot commented Dec 18, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@davemgreen davemgreen changed the title [Flang] And and All elemental lowering [Flang] Any and All elemental lowering Dec 18, 2023
This is an extension of llvm#75774, with Any and All lowering added along side Count.
Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for this!

One could rewrite this with less code duplication, but I think it is much easier to read the way you have done it.

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@davemgreen
Copy link
Collaborator Author

Thanks

@davemgreen davemgreen merged commit e22cb93 into llvm:main Jan 10, 2024
4 checks passed
@davemgreen davemgreen deleted the gh-flang-anyallele branch January 10, 2024 09:52
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
This is an extension of llvm#75774,
with Any and All lowering added alongside Count.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants