diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td index 792cfe1d005a9..de21464ae79ce 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -939,7 +939,7 @@ def hlfir_RegionAssignOp : hlfir_Op<"region_assign", [hlfir_OrderedAssignmentTre } def hlfir_YieldOp : hlfir_Op<"yield", [Terminator, ParentOneOf<["RegionAssignOp", - "ElementalAddrOp", "ForallOp"]>, + "ElementalAddrOp", "ForallOp", "ForallMaskOp", "WhereOp", "ElseWhereOp"]>, SingleBlockImplicitTerminator<"fir::FirEndOp">]> { let summary = "Yield a value or variable inside a forall, where or region assignment"; @@ -1114,4 +1114,123 @@ def hlfir_ForallOp : hlfir_Op<"forall", [hlfir_OrderedAssignmentTreeOpInterface] }]; } +/// Shared definition for hlfir.forall_mask and hlfir.where +/// that have the same structure and assembly format, but not the same +/// constraints. +class hlfir_AssignmentMaskOp : hlfir_Op { + let regions = (region SizedRegion<1>:$mask_region, + SizedRegion<1>:$body); + + let extraClassDeclaration = [{ + void getLeafRegions(llvm::SmallVectorImpl& regions) { + regions.push_back(&getMaskRegion()); + } + mlir::Region* getSubTreeRegion() { return &getBody(); } + }]; + + let assemblyFormat = [{ + $mask_region + attr-dict `do` + custom($body) + }]; +} + +def hlfir_ForallMaskOp : hlfir_AssignmentMaskOp<"forall_mask"> { + let summary = "Represent a Fortran forall mask"; + let description = [{ + Fortran Forall can have a scalar mask expression that depends on the + Forall index-name value. + hlfir.forall_mask allows representing this mask. The expression + evaluation is held in the mask region that must yield an i1 scalar + value. + An hlfir.forall_mask must be directly nested in the body region of + an hlfir.forall. It is a separate operation so that it can use the + index SSA value defined by the hlfir.forall body region. + + Example: "FORALL(I=1:10, SOME_CONDITION(I)) X(I) = FOO(I)" + ``` + hlfir.forall lb { + hlfir.yield %c1 : index + } ub { + hlfir.yield %c10 : index + } (%i : index) { + hlfir.forall_mask { + %mask = fir.call @some_condition(%i) : (index) -> i1 + hlfir.yield %mask : i1 + } do { + hlfir.region_assign { + %res = fir.call @foo(%i) : (index) -> f32 + hlfir.yield %res : f32 + } to { + %xi = hlfir.designate %x(%i) : (!fir.box>, index) -> !fir.ref + hlfir.yield %xi : !fir.ref + } + } + } + ``` + }]; + let hasVerifier = 1; +} + +def hlfir_WhereOp : hlfir_AssignmentMaskOp<"where"> { + let summary = "Represent a Fortran where construct or statement"; + let description = [{ + Represent Fortran "where" construct or statement. The mask + expression evaluation is held in the mask region that must yield + logical array that has the same shape as all the nested + hlfir.region_assign left-hand sides, and all the nested hlfir.where + or hlfir.elsewhere masks. + + The values of the where and elsewhere masks form a control mask that + controls all the nested hlfir.region_assign: only the array element for + which the related control mask value is true are assigned. Any right-hand + side elemental expression is only evaluated for elements where the control + mask is true. See Fortran standard 2018 section 10.2.3 for more detailed + about the control mask semantic. + + An hlfir.where must not contain any hlfir.forall but it may be contained + in such operation. This matches Fortran rules. + }]; + let hasVerifier = 1; +} + +def hlfir_ElseWhereOp : hlfir_Op<"elsewhere", [Terminator, + ParentOneOf<["WhereOp", "ElseWhereOp"]>, hlfir_OrderedAssignmentTreeOpInterface]> { + let summary = "Represent a Fortran elsewhere statement"; + + let description = [{ + Represent Fortran "elsewhere" construct or statement. + + It has an optional mask region to hold the evaluation of Fortran + optional elsewhere mask expressions. If this region is provided, + it must satisfy the same constraints as hlfir.where mask region. + + An hlfir.elsewhere must be the last operation of an hlfir.where or, + hlfir.elsewhere body, which is enforced by its terminator property. + + Like in Fortran, an hlfir.elsewhere negate the current control mask, + and if provided, adds the mask the resulting control mask (with a logical + AND). + }]; + + let regions = (region MaxSizedRegion<1>:$mask_region, + SizedRegion<1>:$body); + + let extraClassDeclaration = [{ + void getLeafRegions(llvm::SmallVectorImpl& regions) { + if (!getMaskRegion().empty()) + regions.push_back(&getMaskRegion()); + } + mlir::Region* getSubTreeRegion() { return &getBody(); } + }]; + + let assemblyFormat = [{ + (`mask` $mask_region^)? + attr-dict `do` + custom($body) + }]; + let hasVerifier = 1; +} + #endif // FORTRAN_DIALECT_HLFIR_OPS diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index c03b7c92248fb..7220c0860fc27 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -1151,6 +1151,81 @@ static bool yieldsIntegerOrEmpty(mlir::Region ®ion) { return yield && fir::isa_integer(yield.getEntity().getType()); } +//===----------------------------------------------------------------------===// +// ForallMaskOp +//===----------------------------------------------------------------------===// + +static mlir::ParseResult parseAssignmentMaskOpBody(mlir::OpAsmParser &parser, + mlir::Region &body) { + if (parser.parseRegion(body)) + return mlir::failure(); + ensureTerminator(body, parser.getBuilder(), + parser.getBuilder().getUnknownLoc()); + return mlir::success(); +} + +template +static void printAssignmentMaskOpBody(mlir::OpAsmPrinter &p, ConcreteOp, + mlir::Region &body) { + // ElseWhereOp is a WhereOp/ElseWhereOp terminator that should be printed. + bool printBlockTerminators = + !body.empty() && + mlir::isa_and_nonnull(body.back().getTerminator()); + p.printRegion(body, /*printEntryBlockArgs=*/false, printBlockTerminators); +} + +static bool yieldsLogical(mlir::Region ®ion, bool mustBeScalarI1) { + if (region.empty()) + return false; + auto yield = mlir::dyn_cast_or_null(getTerminator(region)); + if (!yield) + return false; + mlir::Type yieldType = yield.getEntity().getType(); + if (mustBeScalarI1) + return hlfir::isI1Type(yieldType); + return hlfir::isMaskArgument(yieldType) && + hlfir::getFortranElementOrSequenceType(yieldType) + .isa(); +} + +mlir::LogicalResult hlfir::ForallMaskOp::verify() { + if (!yieldsLogical(getMaskRegion(), /*mustBeScalarI1=*/true)) + return emitOpError("mask region must yield a scalar i1"); + mlir::Operation *op = getOperation(); + hlfir::ForallOp forallOp = + mlir::dyn_cast_or_null(op->getParentOp()); + if (!forallOp || op->getParentRegion() != &forallOp.getBody()) + return emitOpError("must be inside the body region of an hlfir.forall"); + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// WhereOp and ElseWhereOp +//===----------------------------------------------------------------------===// + +template +static mlir::LogicalResult verifyWhereAndElseWhereBody(ConcreteOp &concreteOp) { + for (mlir::Operation &op : concreteOp.getBody().front()) + if (mlir::isa(op)) + return concreteOp.emitOpError( + "body region must not contain hlfir.forall"); + return mlir::success(); +} + +mlir::LogicalResult hlfir::WhereOp::verify() { + if (!yieldsLogical(getMaskRegion(), /*mustBeScalarI1=*/false)) + return emitOpError("mask region must yield a logical array"); + return verifyWhereAndElseWhereBody(*this); +} + +mlir::LogicalResult hlfir::ElseWhereOp::verify() { + if (!getMaskRegion().empty()) + if (!yieldsLogical(getMaskRegion(), /*mustBeScalarI1=*/false)) + return emitOpError( + "mask region must yield a logical array when provided"); + return verifyWhereAndElseWhereBody(*this); +} + #include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc" #define GET_OP_CLASSES #include "flang/Optimizer/HLFIR/HLFIROps.cpp.inc" diff --git a/flang/test/HLFIR/elsewhere.fir b/flang/test/HLFIR/elsewhere.fir new file mode 100644 index 0000000000000..b0033a6dc5b4f --- /dev/null +++ b/flang/test/HLFIR/elsewhere.fir @@ -0,0 +1,82 @@ +// Test hlfir.elsewhere operation parse, verify (no errors), and unparse. +// RUN: fir-opt %s | fir-opt | FileCheck %s + +func.func @test_elsewhere(%mask: !fir.ref>>, %x: !fir.ref>, %y: !fir.box>) { + hlfir.where { + hlfir.yield %mask : !fir.ref>> + } do { + hlfir.elsewhere do { + hlfir.region_assign { + hlfir.yield %y : !fir.box> + } to { + hlfir.yield %x : !fir.ref> + } + } + } + return +} +// CHECK-LABEL: func.func @test_elsewhere( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>>, +// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref>, +// CHECK-SAME: %[[VAL_2:.*]]: !fir.box>) { +// CHECK: hlfir.where { +// CHECK: hlfir.yield %[[VAL_0]] : !fir.ref>> +// CHECK: } do { +// CHECK: hlfir.elsewhere do { +// CHECK: hlfir.region_assign { +// CHECK: hlfir.yield %[[VAL_2]] : !fir.box> +// CHECK: } to { +// CHECK: hlfir.yield %[[VAL_1]] : !fir.ref> +// CHECK: } +// CHECK: } +// CHECK: } + +func.func @test_masked_elsewhere(%mask: !fir.ref>>, %x: !fir.ref>, %y: !fir.box>) { + hlfir.where { + hlfir.yield %mask : !fir.ref>> + } do { + hlfir.elsewhere mask { + %other_mask = fir.call @get_mask() : () -> !fir.ptr>> + hlfir.yield %other_mask : !fir.ptr>> + } do { + hlfir.region_assign { + hlfir.yield %y : !fir.box> + } to { + hlfir.yield %x : !fir.ref> + } + hlfir.elsewhere do { + hlfir.region_assign { + hlfir.yield %x : !fir.ref> + } to { + hlfir.yield %y : !fir.box> + } + } + } + } + return +} +// CHECK-LABEL: func.func @test_masked_elsewhere( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>>, +// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref>, +// CHECK-SAME: %[[VAL_2:.*]]: !fir.box>) { +// CHECK: hlfir.where { +// CHECK: hlfir.yield %[[VAL_0]] : !fir.ref>> +// CHECK: } do { +// CHECK: hlfir.elsewhere mask { +// CHECK: %[[VAL_3:.*]] = fir.call @get_mask() : () -> !fir.ptr>> +// CHECK: hlfir.yield %[[VAL_3]] : !fir.ptr>> +// CHECK: } do { +// CHECK: hlfir.region_assign { +// CHECK: hlfir.yield %[[VAL_2]] : !fir.box> +// CHECK: } to { +// CHECK: hlfir.yield %[[VAL_1]] : !fir.ref> +// CHECK: } +// CHECK: hlfir.elsewhere do { +// CHECK: hlfir.region_assign { +// CHECK: hlfir.yield %[[VAL_1]] : !fir.ref> +// CHECK: } to { +// CHECK: hlfir.yield %[[VAL_2]] : !fir.box> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } diff --git a/flang/test/HLFIR/forall_mask.fir b/flang/test/HLFIR/forall_mask.fir new file mode 100644 index 0000000000000..b0ca270084e38 --- /dev/null +++ b/flang/test/HLFIR/forall_mask.fir @@ -0,0 +1,49 @@ +// Test hlfir.forall_mask operation parse, verify (no errors), and unparse. +// RUN: fir-opt %s | fir-opt | FileCheck %s + +func.func @forall_mask_test(%x: !fir.box>) { + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + hlfir.forall lb { + hlfir.yield %c1 : index + } ub { + hlfir.yield %c10 : index + } (%i : index) { + hlfir.forall_mask { + %mask = fir.call @some_condition(%i) : (index) -> i1 + hlfir.yield %mask : i1 + } do { + hlfir.region_assign { + %res = fir.call @foo(%i) : (index) -> f32 + hlfir.yield %res : f32 + } to { + %xi = hlfir.designate %x(%i) : (!fir.box>, index) -> !fir.ref + hlfir.yield %xi : !fir.ref + } + } + } + return +} + +// CHECK-LABEL: func.func @forall_mask_test( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 10 : index +// CHECK: hlfir.forall lb { +// CHECK: hlfir.yield %[[VAL_1]] : index +// CHECK: } ub { +// CHECK: hlfir.yield %[[VAL_2]] : index +// CHECK: } (%[[VAL_3:.*]]: index) { +// CHECK: hlfir.forall_mask { +// CHECK: %[[VAL_4:.*]] = fir.call @some_condition(%[[VAL_3]]) : (index) -> i1 +// CHECK: hlfir.yield %[[VAL_4]] : i1 +// CHECK: } do { +// CHECK: hlfir.region_assign { +// CHECK: %[[VAL_5:.*]] = fir.call @foo(%[[VAL_3]]) : (index) -> f32 +// CHECK: hlfir.yield %[[VAL_5]] : f32 +// CHECK: } to { +// CHECK: %[[VAL_6:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_3]]) : (!fir.box>, index) -> !fir.ref +// CHECK: hlfir.yield %[[VAL_6]] : !fir.ref +// CHECK: } +// CHECK: } +// CHECK: } diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir index c1bcdaf687c78..29b61f0487c21 100644 --- a/flang/test/HLFIR/invalid.fir +++ b/flang/test/HLFIR/invalid.fir @@ -642,3 +642,110 @@ func.func @bad_forall_2(%x : !fir.box>, %y: f32) { } return } + +// ----- +func.func @bad_forall_mask(%i: index) { + // expected-error@+1 {{'hlfir.forall_mask' op must be inside the body region of an hlfir.forall}} + hlfir.forall_mask { + %mask = fir.call @some_condition(%i) : (index) -> i1 + hlfir.yield %mask : i1 + } do { + } + return +} + +// ----- +func.func @bad_forall_mask_2(%mask: !fir.ref>>) { + %c1 = arith.constant 1 : index + hlfir.forall lb { + hlfir.yield %c1 : index + } ub { + hlfir.yield %c1 : index + } (%i: index) { + // expected-error@+1 {{'hlfir.forall_mask' op mask region must yield a scalar i1}} + hlfir.forall_mask { + hlfir.yield %mask : !fir.ref>> + } do { + } + } + return +} + +// ----- +func.func @bad_where_1(%bad_mask: !fir.ref>) { + // expected-error@+1 {{'hlfir.where' op mask region must yield a logical array}} + hlfir.where { + hlfir.yield %bad_mask : !fir.ref> + } do { + } + return +} + +// ----- +func.func @bad_where_2(%bad_mask: i1) { + // expected-error@+1 {{'hlfir.where' op mask region must yield a logical array}} + hlfir.where { + hlfir.yield %bad_mask : i1 + } do { + } + return +} + +// ----- +func.func @bad_where_3(%mask: !fir.ref>>, %n: index) { + // expected-error@+1 {{'hlfir.where' op body region must not contain hlfir.forall}} + hlfir.where { + hlfir.yield %mask : !fir.ref>> + } do { + hlfir.forall lb { + hlfir.yield %n : index + } ub { + hlfir.yield %n : index + } (%i: index) { + } + } + return +} + +// ----- +func.func @bad_elsewhere_1(%mask: !fir.ref>>, %bad_mask: i1) { + hlfir.where { + hlfir.yield %mask : !fir.ref>> + } do { + // expected-error@+1 {{'hlfir.elsewhere' op mask region must yield a logical array when provided}} + hlfir.elsewhere mask { + hlfir.yield %bad_mask : i1 + } do { + } + } + return +} + +// ----- +func.func @bad_elsewhere_2(%mask: !fir.ref>>) { + // expected-error@+1 {{'hlfir.elsewhere' op expects parent op to be one of 'hlfir.where, hlfir.elsewhere'}} + hlfir.elsewhere mask { + hlfir.yield %mask : !fir.ref>> + } do { + } + return +} + +// ----- +func.func @bad_elsewhere_3(%mask: !fir.ref>>, %x: !fir.ref>, %y: !fir.box>) { + hlfir.where { + hlfir.yield %mask : !fir.ref>> + } do { + // expected-error@+1 {{'hlfir.elsewhere' op must be the last operation in the parent block}} + hlfir.elsewhere mask { + hlfir.yield %mask : !fir.ref>> + } do { + } + hlfir.region_assign { + hlfir.yield %y : !fir.box> + } to { + hlfir.yield %x : !fir.ref> + } + } + return +} diff --git a/flang/test/HLFIR/where.fir b/flang/test/HLFIR/where.fir new file mode 100644 index 0000000000000..2aa242c9491df --- /dev/null +++ b/flang/test/HLFIR/where.fir @@ -0,0 +1,28 @@ +// Test hlfir.where operation parse, verify (no errors), and unparse. +// RUN: fir-opt %s | fir-opt | FileCheck %s + +func.func @test_where(%mask: !fir.ref>>, %x: !fir.ref>, %y: !fir.box>) { + hlfir.where { + hlfir.yield %mask : !fir.ref>> + } do { + hlfir.region_assign { + hlfir.yield %y : !fir.box> + } to { + hlfir.yield %x : !fir.ref> + } + } + return +} +// CHECK-LABEL: func.func @test_where( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>>, +// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref>, +// CHECK-SAME: %[[VAL_2:.*]]: !fir.box>) { +// CHECK: hlfir.where { +// CHECK: hlfir.yield %[[VAL_0]] : !fir.ref>> +// CHECK: } do { +// CHECK: hlfir.region_assign { +// CHECK: hlfir.yield %[[VAL_2]] : !fir.box> +// CHECK: } to { +// CHECK: hlfir.yield %[[VAL_1]] : !fir.ref> +// CHECK: } +// CHECK: }