diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td index 15b92385a7720..142a70c639127 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -317,6 +317,27 @@ def hlfir_ConcatOp : hlfir_Op<"concat", []> { let hasVerifier = 1; } +def hlfir_AllOp : hlfir_Op<"all", []> { + let summary = "ALL transformational intrinsic"; + let description = [{ + Takes a logical array MASK as argument, optionally along a particular dimension, + and returns true if all elements of MASK are true. + }]; + + let arguments = (ins + AnyFortranLogicalArrayObject:$mask, + Optional:$dim + ); + + let results = (outs AnyFortranValue); + + let assemblyFormat = [{ + $mask (`dim` $dim^)? attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + def hlfir_AnyOp : hlfir_Op<"any", []> { let summary = "ANY transformational intrinsic"; let description = [{ diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index 4547c4247241e..adf8b72993e4c 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -442,16 +442,19 @@ mlir::LogicalResult hlfir::ParentComponentOp::verify() { } //===----------------------------------------------------------------------===// -// AnyOp +// LogicalReductionOp //===----------------------------------------------------------------------===// -mlir::LogicalResult hlfir::AnyOp::verify() { - mlir::Operation *op = getOperation(); +template +static mlir::LogicalResult +verifyLogicalReductionOp(LogicalReductionOp reductionOp) { + mlir::Operation *op = reductionOp->getOperation(); auto results = op->getResultTypes(); assert(results.size() == 1); - mlir::Value mask = getMask(); - mlir::Value dim = getDim(); + mlir::Value mask = reductionOp->getMask(); + mlir::Value dim = reductionOp->getDim(); + fir::SequenceType maskTy = hlfir::getFortranElementOrSequenceType(mask.getType()) .cast(); @@ -462,7 +465,7 @@ mlir::LogicalResult hlfir::AnyOp::verify() { if (mlir::isa(resultType)) { // Result is of the same type as MASK if (resultType != logicalTy) - return emitOpError( + return reductionOp->emitOpError( "result must have the same element type as MASK argument"); } else if (auto resultExpr = @@ -470,25 +473,42 @@ mlir::LogicalResult hlfir::AnyOp::verify() { // Result should only be in hlfir.expr form if it is an array if (maskShape.size() > 1 && dim != nullptr) { if (!resultExpr.isArray()) - return emitOpError("result must be an array"); + return reductionOp->emitOpError("result must be an array"); if (resultExpr.getEleTy() != logicalTy) - return emitOpError( + return reductionOp->emitOpError( "result must have the same element type as MASK argument"); llvm::ArrayRef resultShape = resultExpr.getShape(); // Result has rank n-1 if (resultShape.size() != (maskShape.size() - 1)) - return emitOpError("result rank must be one less than MASK"); + return reductionOp->emitOpError( + "result rank must be one less than MASK"); } else { - return emitOpError("result must be of logical type"); + return reductionOp->emitOpError("result must be of logical type"); } } else { - return emitOpError("result must be of logical type"); + return reductionOp->emitOpError("result must be of logical type"); } return mlir::success(); } +//===----------------------------------------------------------------------===// +// AllOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult hlfir::AllOp::verify() { + return verifyLogicalReductionOp(this); +} + +//===----------------------------------------------------------------------===// +// AnyOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult hlfir::AnyOp::verify() { + return verifyLogicalReductionOp(this); +} + //===----------------------------------------------------------------------===// // ConcatOp //===----------------------------------------------------------------------===// @@ -537,11 +557,12 @@ void hlfir::ConcatOp::build(mlir::OpBuilder &builder, } //===----------------------------------------------------------------------===// -// ReductionOp +// NumericalReductionOp //===----------------------------------------------------------------------===// -template -static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) { +template +static mlir::LogicalResult +verifyNumericalReductionOp(NumericalReductionOp reductionOp) { mlir::Operation *op = reductionOp->getOperation(); auto results = op->getResultTypes(); @@ -619,7 +640,7 @@ static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) { //===----------------------------------------------------------------------===// mlir::LogicalResult hlfir::ProductOp::verify() { - return verifyReductionOp(this); + return verifyNumericalReductionOp(this); } //===----------------------------------------------------------------------===// @@ -645,7 +666,7 @@ void hlfir::SetLengthOp::build(mlir::OpBuilder &builder, //===----------------------------------------------------------------------===// mlir::LogicalResult hlfir::SumOp::verify() { - return verifyReductionOp(this); + return verifyNumericalReductionOp(this); } //===----------------------------------------------------------------------===// diff --git a/flang/test/HLFIR/all.fir b/flang/test/HLFIR/all.fir new file mode 100644 index 0000000000000..00ce1b3a5fbae --- /dev/null +++ b/flang/test/HLFIR/all.fir @@ -0,0 +1,113 @@ +// Test hlfir.all operation parse, verify (no errors), and unparse + +// RUN: fir-opt %s | fir-opt | FileCheck %s + +// mask is an expression of known shape +func.func @all0(%arg0: !hlfir.expr<2x!fir.logical<4>>) { + %all = hlfir.all %arg0 : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4> + return +} +// CHECK: func.func @all0(%[[ARRAY:.*]]: !hlfir.expr<2x!fir.logical<4>>) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// mask is an expression of assumed shape +func.func @all1(%arg0: !hlfir.expr>) { + %all = hlfir.all %arg0 : (!hlfir.expr>) -> !fir.logical<4> + return +} +// CHECK: func.func @all1(%[[ARRAY:.*]]: !hlfir.expr>) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr>) -> !fir.logical<4> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// mask is a boxed array +func.func @all2(%arg0: !fir.box>>) { + %all = hlfir.all %arg0 : (!fir.box>>) -> !fir.logical<4> + return +} +// CHECK: func.func @all2(%[[ARRAY:.*]]: !fir.box>>) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box>>) -> !fir.logical<4> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// mask is an assumed shape boxed array +func.func @all3(%arg0: !fir.box>>){ + %all = hlfir.all %arg0 : (!fir.box>>) -> !fir.logical<4> + return +} +// CHECK: func.func @all3(%[[ARRAY:.*]]: !fir.box>>) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box>>) -> !fir.logical<4> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// mask is a 2-dimensional array +func.func @all4(%arg0: !fir.box>>){ + %all = hlfir.all %arg0 : (!fir.box>>) -> !fir.logical<4> + return +} +// CHECK: func.func @all4(%[[ARRAY:.*]]: !fir.box>>) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box>>) -> !fir.logical<4> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// mask and dim argument +func.func @all5(%arg0: !fir.box>>, %arg1: i32) { + %all = hlfir.all %arg0 dim %arg1 : (!fir.box>>, i32) -> !fir.logical<4> + return +} +// CHECK: func.func @all5(%[[ARRAY:.*]]: !fir.box>>, %[[DIM:.*]]: i32) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box>>, i32) -> !fir.logical<4> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// hlfir.all with dim argument with an unusual type +func.func @all6(%arg0: !fir.box>>, %arg1: index) { + %all = hlfir.all %arg0 dim %arg1 : (!fir.box>>, index) ->!fir.logical<4> + return +} +// CHECK: func.func @all6(%[[ARRAY:.*]]: !fir.box>>, %[[DIM:.*]]: index) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box>>, index) -> !fir.logical<4> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// mask is a 2 dimensional array with dim +func.func @all7(%arg0: !fir.box>>, %arg1: i32) { + %all = hlfir.all %arg0 dim %arg1 : (!fir.box>>, i32) -> !hlfir.expr> + return +} +// CHECK: func.func @all7(%[[ARRAY:.*]]: !fir.box>>, %[[DIM:.*]]: i32) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box>>, i32) -> !hlfir.expr> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// known shape expr return +func.func @all8(%arg0: !fir.box>>, %arg1: i32) { + %all = hlfir.all %arg0 dim %arg1 : (!fir.box>>, i32) -> !hlfir.expr<2x!fir.logical<4>> + return +} +// CHECK: func.func @all8(%[[ARRAY:.*]]: !fir.box>>, %[[DIM:.*]]: i32) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box>>, i32) -> !hlfir.expr<2x!fir.logical<4>> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// hlfir.all with mask argument of ref> type +func.func @all9(%arg0: !fir.ref>>) { + %all = hlfir.all %arg0 : (!fir.ref>>) -> !fir.logical<4> + return +} +// CHECK: func.func @all9(%[[ARRAY:.*]]: !fir.ref>>) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.ref>>) -> !fir.logical<4> +// CHECK-NEXT: return +// CHECK-NEXT: } + +// hlfir.all with fir.logical<8> type +func.func @all10(%arg0: !fir.box>>) { + %all = hlfir.all %arg0 : (!fir.box>>) -> !fir.logical<8> + return +} +// CHECK: func.func @all10(%[[ARRAY:.*]]: !fir.box>>) { +// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box>>) -> !fir.logical<8> +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir index e1c95c1046dc4..8dc5679346bc1 100644 --- a/flang/test/HLFIR/invalid.fir +++ b/flang/test/HLFIR/invalid.fir @@ -332,6 +332,42 @@ func.func @bad_any6(%arg0: !hlfir.expr>) { %0 = hlfir.any %arg0 : (!hlfir.expr>) -> !hlfir.expr> } +// ----- +func.func @bad_all1(%arg0: !hlfir.expr>) { + // expected-error@+1 {{'hlfir.all' op result must have the same element type as MASK argument}} + %0 = hlfir.all %arg0 : (!hlfir.expr>) -> !fir.logical<8> +} + +// ----- +func.func @bad_all2(%arg0: !hlfir.expr>, %arg1: i32) { + // expected-error@+1 {{'hlfir.all' op result must have the same element type as MASK argument}} + %0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr>, i32) -> !hlfir.expr> +} + +// ----- +func.func @bad_all3(%arg0: !hlfir.expr>, %arg1: i32){ + // expected-error@+1 {{'hlfir.all' op result rank must be one less than MASK}} + %0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr>, i32) -> !hlfir.expr> +} + +// ----- +func.func @bad_all4(%arg0: !hlfir.expr>, %arg1: i32) { + // expected-error@+1 {{'hlfir.all' op result must be an array}} + %0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr>, i32) -> !hlfir.expr> +} + +// ----- +func.func @bad_all5(%arg0: !hlfir.expr>) { + // expected-error@+1 {{'hlfir.all' op result must be of logical type}} + %0 = hlfir.all %arg0 : (!hlfir.expr>) -> i32 +} + +// ----- +func.func @bad_all6(%arg0: !hlfir.expr>) { + // expected-error@+1 {{'hlfir.all' op result must be of logical type}} + %0 = hlfir.all %arg0 : (!hlfir.expr>) -> !hlfir.expr> +} + // ----- func.func @bad_product1(%arg0: !hlfir.expr, %arg1: i32, %arg2: !fir.box>) { // expected-error@+1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}