diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h index 9d631dba9412c..6d73ebc3a7e1d 100644 --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -374,14 +374,15 @@ using ElementalKernelGenerator = std::function; /// Generate an hlfir.elementalOp given call back to generate the element /// value at for each iteration. -/// If exprType is specified, this will be the return type of the elemental op -hlfir::ElementalOp genElementalOp(mlir::Location loc, - fir::FirOpBuilder &builder, - mlir::Type elementType, mlir::Value shape, - mlir::ValueRange typeParams, - const ElementalKernelGenerator &genKernel, - bool isUnordered = false, - mlir::Type exprType = mlir::Type{}); +/// If exprType is specified, this will be the return type of the elemental op. +/// If exprType is not specified, the resulting expression type is computed +/// from the given \p elementType and \p shape, and the type is polymorphic +/// if \p polymorphicMold is present. +hlfir::ElementalOp genElementalOp( + mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType, + mlir::Value shape, mlir::ValueRange typeParams, + const ElementalKernelGenerator &genKernel, bool isUnordered = false, + mlir::Value polymorphicMold = {}, mlir::Type exprType = mlir::Type{}); /// Structure to describe a loop nest. struct LoopNest { diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h index d080286f0e092..b76063fb7c535 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h +++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h @@ -87,6 +87,7 @@ bool isPassByRefOrIntegerType(mlir::Type); bool isI1Type(mlir::Type); // scalar i1 or logical, or sequence of logical (via (boxed?) array or expr) bool isMaskArgument(mlir::Type); +bool isPolymorphicObject(mlir::Type); /// If an expression's extents are known at compile time, generate a fir.shape /// for this expression. Otherwise return {} diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td index 324689d22d4cb..018e187ed46e6 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td @@ -149,6 +149,11 @@ def IsFortranLogicalArrayPred def AnyFortranLogicalArrayObject : Type; +def IsPolymorphicObjectPred + : CPred<"::hlfir::isPolymorphicObject($_self)">; +def AnyPolymorphicObject : Type; + def hlfir_CharExtremumPredicateAttr : I32EnumAttr< "CharExtremumPredicate", "", [ diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td index d5114ec3de9b7..24c2dad497fd2 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -740,7 +740,7 @@ def hlfir_ElementalOpInterface : OpInterface<"ElementalOpInterface"> { let cppNamespace = "hlfir"; } -def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_ElementalOpInterface]> { +def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_ElementalOpInterface, AttrSizedOperandSegments]> { let summary = "elemental expression"; let description = [{ Represent an elemental expression as a function of the indices. @@ -753,6 +753,12 @@ def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_Ele The shape and typeparams operands represent the extents and type parameters of the resulting array value. + The optional mold is an entity carrying the information about + the dynamic type of the polymorphic result. Note that the shape + of the mold does not necessarily match the shape of the result, + for example, the result of `merge(poly_scalar1, poly_scalar2, mask_array)` + will have the shape of `mask_array` and the dynamic type of `poly_scalar*`. + The unordered attribute can be set to allow out of order processing of the indices. This is safe only if the operations in the body of the elemental do not have side effects. @@ -775,6 +781,7 @@ def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_Ele let arguments = (ins AnyShapeType:$shape, + Optional:$mold, Variadic:$typeparams, OptionalAttr:$unordered ); @@ -783,7 +790,8 @@ def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_Ele let regions = (region SizedRegion<1>:$region); let assemblyFormat = [{ - $shape (`typeparams` $typeparams^)? (`unordered` $unordered^)? + $shape (`mold` $mold^)? (`typeparams` $typeparams^)? + (`unordered` $unordered^)? attr-dict `:` functional-type(operands, results) $region }]; @@ -808,10 +816,12 @@ def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_Ele let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape, + CArg<"mlir::Value", "{}">:$mold, CArg<"mlir::ValueRange", "{}">:$typeparams, CArg<"bool", "false">:$isUnordered)> ]; + let hasVerifier = 1; } def hlfir_YieldElementOp : hlfir_Op<"yield_element", [Terminator, HasParent<"ElementalOp">, Pure]> { diff --git a/flang/lib/Lower/ConvertArrayConstructor.cpp b/flang/lib/Lower/ConvertArrayConstructor.cpp index 2ef500ecf22db..24aa9beba6bf4 100644 --- a/flang/lib/Lower/ConvertArrayConstructor.cpp +++ b/flang/lib/Lower/ConvertArrayConstructor.cpp @@ -214,9 +214,9 @@ class AsElementalStrategy : public StrategyBase { assert(!elementalOp && "expected only one implied-do"); mlir::Value one = builder.createIntegerConstant(loc, builder.getIndexType(), 1); - elementalOp = - builder.create(loc, exprType, shape, lengthParams, - /*isUnordered=*/true); + elementalOp = builder.create( + loc, exprType, shape, + /*mold=*/nullptr, lengthParams, /*isUnordered=*/true); builder.setInsertionPointToStart(elementalOp.getBody()); // implied-do-index = lower+((i-1)*stride) mlir::Value diff = builder.create( diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index 3934609491164..dd62aa0e37012 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -737,16 +737,15 @@ static hlfir::ExprType getArrayExprType(mlir::Type elementType, isPolymorphic); } -hlfir::ElementalOp -hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder, - mlir::Type elementType, mlir::Value shape, - mlir::ValueRange typeParams, - const ElementalKernelGenerator &genKernel, - bool isUnordered, mlir::Type exprType) { +hlfir::ElementalOp hlfir::genElementalOp( + mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType, + mlir::Value shape, mlir::ValueRange typeParams, + const ElementalKernelGenerator &genKernel, bool isUnordered, + mlir::Value polymorphicMold, mlir::Type exprType) { if (!exprType) - exprType = getArrayExprType(elementType, shape, false); + exprType = getArrayExprType(elementType, shape, !!polymorphicMold); auto elementalOp = builder.create( - loc, exprType, shape, typeParams, isUnordered); + loc, exprType, shape, polymorphicMold, typeParams, isUnordered); auto insertPt = builder.saveInsertionPoint(); builder.setInsertionPointToStart(elementalOp.getBody()); mlir::Value elementResult = genKernel(loc, builder, elementalOp.getIndices()); diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp index 1f4f62f29e3db..7ca6108a31acb 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp @@ -181,6 +181,13 @@ bool hlfir::isMaskArgument(mlir::Type type) { return mlir::isa(elementType) || isI1Type(elementType); } +bool hlfir::isPolymorphicObject(mlir::Type type) { + if (auto exprType = mlir::dyn_cast(type)) + return exprType.isPolymorphic(); + + return fir::isPolymorphicType(type); +} + mlir::Value hlfir::genExprShape(mlir::OpBuilder &builder, const mlir::Location &loc, const hlfir::ExprType &expr) { diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index a74d6f94f4df1..0b4b9c1588efa 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -1036,10 +1036,17 @@ void hlfir::AsExprOp::build(mlir::OpBuilder &builder, void hlfir::ElementalOp::build(mlir::OpBuilder &builder, mlir::OperationState &odsState, mlir::Type resultType, mlir::Value shape, - mlir::ValueRange typeparams, bool isUnordered) { + mlir::Value mold, mlir::ValueRange typeparams, + bool isUnordered) { odsState.addOperands(shape); + if (mold) + odsState.addOperands(mold); odsState.addOperands(typeparams); odsState.addTypes(resultType); + odsState.addAttribute( + getOperandSegmentSizesAttrName(odsState.name), + builder.getDenseI32ArrayAttr({/*shape=*/1, (mold ? 1 : 0), + static_cast(typeparams.size())})); if (isUnordered) odsState.addAttribute(getUnorderedAttrName(odsState.name), isUnordered ? builder.getUnitAttr() : nullptr); @@ -1057,6 +1064,16 @@ mlir::Value hlfir::ElementalOp::getElementEntity() { return mlir::cast(getBody()->back()).getElementValue(); } +mlir::LogicalResult hlfir::ElementalOp::verify() { + mlir::Value mold = getMold(); + hlfir::ExprType resultType = mlir::cast(getType()); + if (!!mold != resultType.isPolymorphic()) + return emitOpError("result must be polymorphic when mold is present " + "and vice versa"); + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // ApplyOp //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index 6206deee411c3..5f065056bac00 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -58,7 +58,8 @@ class TransposeAsElementalConversion }; hlfir::ElementalOp elementalOp = hlfir::genElementalOp( loc, builder, elementType, resultShape, typeParams, genKernel, - /*isUnordered=*/true, transpose.getResult().getType()); + /*isUnordered=*/true, /*polymorphicMold=*/nullptr, + transpose.getResult().getType()); // it wouldn't be safe to replace block arguments with a different // hlfir.expr type. Types can differ due to differing amounts of shape diff --git a/flang/test/HLFIR/elemental.fir b/flang/test/HLFIR/elemental.fir index d4cef6705b176..174c39b99b372 100644 --- a/flang/test/HLFIR/elemental.fir +++ b/flang/test/HLFIR/elemental.fir @@ -99,3 +99,45 @@ func.func @unordered() { // CHECK: } // CHECK: return // CHECK: } + +func.func @polymorphic_mold_var(%arg0: !fir.class>>, %shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> + %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !fir.class>>) -> !hlfir.expr?> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr?> + hlfir.yield_element %6 : !hlfir.expr?> + } + return +} +// CHECK-LABEL: func.func @polymorphic_mold_var( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.class>>, %[[VAL_1:.*]]: index) { +// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_3:.*]] = hlfir.elemental %[[VAL_2]] mold %[[VAL_0]] unordered : (!fir.shape<1>, !fir.class>>) -> !hlfir.expr?> { +// CHECK: ^bb0(%[[VAL_4:.*]]: index): +// CHECK: %[[VAL_5:.*]] = fir.undefined !hlfir.expr?> +// CHECK: hlfir.yield_element %[[VAL_5]] : !hlfir.expr?> +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @polymorphic_mold_expr(%shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> + %mold = fir.undefined !hlfir.expr?> + %4 = hlfir.elemental %3 mold %mold unordered : (!fir.shape<1>, !hlfir.expr?>) -> !hlfir.expr?> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr?> + hlfir.yield_element %6 : !hlfir.expr?> + } + return +} +// CHECK-LABEL: func.func @polymorphic_mold_expr( +// CHECK-SAME: %[[VAL_0:.*]]: index) { +// CHECK: %[[VAL_1:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_2:.*]] = fir.undefined !hlfir.expr?> +// CHECK: %[[VAL_3:.*]] = hlfir.elemental %[[VAL_1]] mold %[[VAL_2]] unordered : (!fir.shape<1>, !hlfir.expr?>) -> !hlfir.expr?> { +// CHECK: ^bb0(%[[VAL_4:.*]]: index): +// CHECK: %[[VAL_5:.*]] = fir.undefined !hlfir.expr?> +// CHECK: hlfir.yield_element %[[VAL_5]] : !hlfir.expr?> +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir index bbfe543a3427e..db16cb29c6238 100644 --- a/flang/test/HLFIR/invalid.fir +++ b/flang/test/HLFIR/invalid.fir @@ -961,3 +961,51 @@ func.func @bad_get_length_3(%arg0: !hlfir.expr>) { %1 = hlfir.get_length %arg0 : (!hlfir.expr>) -> index return } + +// ----- +func.func @elemental_poly_1(%arg0: !fir.box>>, %shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> + // expected-error@+1 {{'hlfir.elemental' op operand #1 must be any polymorphic object, but got '!fir.box>>'}} + %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !fir.box>>) -> !hlfir.expr?> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr?> + hlfir.yield_element %6 : !hlfir.expr?> + } + return +} + +// ----- +func.func @elemental_poly_2(%arg0: !hlfir.expr>, %shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> + // expected-error@+1 {{'hlfir.elemental' op operand #1 must be any polymorphic object, but got '!hlfir.expr>'}} + %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !hlfir.expr>) -> !hlfir.expr?> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr?> + hlfir.yield_element %6 : !hlfir.expr?> + } + return +} + +// ----- +func.func @elemental_poly_3(%arg0: !hlfir.expr?>, %shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> +// expected-error@+1 {{'hlfir.elemental' op result must be polymorphic when mold is present and vice versa}} + %4 = hlfir.elemental %3 mold %arg0 unordered : (!fir.shape<1>, !hlfir.expr?>) -> !hlfir.expr> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr> + hlfir.yield_element %6 : !hlfir.expr> + } + return +} + +// ----- +func.func @elemental_poly_4(%shape : index) { + %3 = fir.shape %shape : (index) -> !fir.shape<1> +// expected-error@+1 {{'hlfir.elemental' op result must be polymorphic when mold is present and vice versa}} + %4 = hlfir.elemental %3 unordered : (!fir.shape<1>) -> !hlfir.expr?> { + ^bb0(%arg2: index): + %6 = fir.undefined !hlfir.expr?> + hlfir.yield_element %6 : !hlfir.expr?> + } + return +}