Skip to content

Commit

Permalink
[flang][hlfir] Set/propagate 'unordered' attribute for elementals.
Browse files Browse the repository at this point in the history
This patch adds 'unordered' attribute handling the HLFIR elementals'
builders and fixes the attribute handling in lowering and transformations.

Depends on D154031, D154032

Reviewed By: jeanPerier, tblah

Differential Revision: https://reviews.llvm.org/D154035
  • Loading branch information
vzakhari committed Jun 29, 2023
1 parent 65379d4 commit 7b4aa95
Show file tree
Hide file tree
Showing 25 changed files with 101 additions and 84 deletions.
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ hlfir::ElementalOp genElementalOp(mlir::Location loc,
mlir::Type elementType, mlir::Value shape,
mlir::ValueRange typeParams,
const ElementalKernelGenerator &genKernel,
bool isUnordered = false,
mlir::Type exprType = mlir::Type{});

/// Structure to describe a loop nest.
Expand Down
6 changes: 4 additions & 2 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,9 @@ def hlfir_ElementalOp : hlfir_Op<"elemental", [RecursiveMemoryEffects, hlfir_Ele
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape,
CArg<"mlir::ValueRange", "{}">:$typeparams)>];
CArg<"mlir::ValueRange", "{}">:$typeparams,
CArg<"bool", "false">:$isUnordered)>
];

}

Expand Down Expand Up @@ -1216,7 +1218,7 @@ def hlfir_ElementalAddrOp : hlfir_Op<"elemental_addr", [Terminator, HasParent<"R
MaxSizedRegion<1>:$cleanup);

let builders = [
OpBuilder<(ins "mlir::Value":$shape)>
OpBuilder<(ins "mlir::Value":$shape, CArg<"bool", "false">:$isUnordered)>
];

let assemblyFormat = [{
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3165,7 +3165,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return hlfir::EntityWithAttributes{builder.createConvert(loc, toTy, val)};
};
mlir::Value convertedRhs = hlfir::genElementalOp(
loc, builder, toTy, shape, /*typeParams=*/{}, genKernel);
loc, builder, toTy, shape, /*typeParams=*/{}, genKernel,
/*isUnordered=*/true);
fir::FirOpBuilder *bldr = &builder;
stmtCtx.attachCleanup([loc, bldr, convertedRhs]() {
bldr->create<hlfir::DestroyOp>(loc, convertedRhs);
Expand Down
10 changes: 6 additions & 4 deletions flang/lib/Lower/ConvertArrayConstructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ class AsElementalStrategy : public StrategyBase {
mlir::Value one =
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
elementalOp =
builder.create<hlfir::ElementalOp>(loc, exprType, shape, lengthParams);
builder.create<hlfir::ElementalOp>(loc, exprType, shape, lengthParams,
/*isUnordered=*/true);
builder.setInsertionPointToStart(elementalOp.getBody());
// implied-do-index = lower+((i-1)*stride)
mlir::Value diff = builder.create<mlir::arith::SubIOp>(
Expand Down Expand Up @@ -686,9 +687,10 @@ static ArrayCtorLoweringStrategy selectArrayCtorLoweringStrategy(
loc, builder, stmtCtx, symMap, declaredType,
extent ? std::optional<mlir::Value>(extent) : std::nullopt, lengths,
needToEvaluateOneExprToGetLengthParameters);
// Note: array constructors containing impure ac-value expr are currently not
// rewritten to hlfir.elemental because impure expressions should be evaluated
// in order, and hlfir.elemental currently misses a way to indicate that.
// Note: the generated hlfir.elemental is always unordered, thus,
// AsElementalStrategy can only be used for array constructors without
// impure ac-value expressions. If/when this changes, make sure
// the 'unordered' attribute is set accordingly for the hlfir.elemental.
if (analysis.isSingleImpliedDoWithOneScalarPureExpr())
return AsElementalStrategy(loc, builder, stmtCtx, symMap, declaredType,
extent, lengths);
Expand Down
7 changes: 3 additions & 4 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1612,8 +1612,6 @@ class ElementalCallBuilder {
return std::nullopt;
}
// Function case: generate call inside hlfir.elemental
if (mustBeOrdered)
TODO(loc, "ordered elemental calls in HLFIR");
mlir::Type elementType =
hlfir::getFortranElementType(*callContext.resultType);
// Get result length parameters.
Expand Down Expand Up @@ -1645,8 +1643,9 @@ class ElementalCallBuilder {
// use.
return res;
};
mlir::Value elemental = hlfir::genElementalOp(loc, builder, elementType,
shape, typeParams, genKernel);
mlir::Value elemental =
hlfir::genElementalOp(loc, builder, elementType, shape, typeParams,
genKernel, !mustBeOrdered);
fir::FirOpBuilder *bldr = &builder;
callContext.stmtCtx.attachCleanup(
[=]() { bldr->create<hlfir::DestroyOp>(loc, elemental); });
Expand Down
10 changes: 7 additions & 3 deletions flang/lib/Lower/ConvertExprToHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,9 @@ class HlfirDesignatorBuilder {
// of the whole designator (not the ones of the vector subscripted part).
// These are not yet known and will be added when finalizing the designator
// lowering.
auto elementalAddrOp = builder.create<hlfir::ElementalAddrOp>(loc, shape);
auto elementalAddrOp =
builder.create<hlfir::ElementalAddrOp>(loc, shape,
/*isUnordered=*/true);
setVectorSubscriptElementAddrOp(elementalAddrOp);
builder.setInsertionPointToEnd(&elementalAddrOp.getBody().front());
mlir::Region::BlockArgListType indices = elementalAddrOp.getIndices();
Expand Down Expand Up @@ -1512,7 +1514,8 @@ class HlfirBuilder {
return unaryOp.gen(l, b, op.derived(), leftVal);
};
mlir::Value elemental = hlfir::genElementalOp(loc, builder, elementType,
shape, typeParams, genKernel);
shape, typeParams, genKernel,
/*isUnordered=*/true);
fir::FirOpBuilder *bldr = &builder;
getStmtCtx().attachCleanup(
[=]() { bldr->create<hlfir::DestroyOp>(loc, elemental); });
Expand Down Expand Up @@ -1557,7 +1560,8 @@ class HlfirBuilder {
return binaryOp.gen(l, b, op.derived(), leftVal, rightVal);
};
mlir::Value elemental = hlfir::genElementalOp(loc, builder, elementType,
shape, typeParams, genKernel);
shape, typeParams, genKernel,
/*isUnordered=*/true);
fir::FirOpBuilder *bldr = &builder;
getStmtCtx().attachCleanup(
[=]() { bldr->create<hlfir::DestroyOp>(loc, elemental); });
Expand Down
16 changes: 9 additions & 7 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,14 +732,16 @@ static hlfir::ExprType getArrayExprType(mlir::Type elementType,
isPolymorphic);
}

hlfir::ElementalOp hlfir::genElementalOp(
mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type elementType,
mlir::Value shape, mlir::ValueRange typeParams,
const ElementalKernelGenerator &genKernel, mlir::Type exprType) {
hlfir::ElementalOp
hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Type elementType, mlir::Value shape,
mlir::ValueRange typeParams,
const ElementalKernelGenerator &genKernel,
bool isUnordered, mlir::Type exprType) {
if (!exprType)
exprType = getArrayExprType(elementType, shape, false);
auto elementalOp =
builder.create<hlfir::ElementalOp>(loc, exprType, shape, typeParams);
auto elementalOp = builder.create<hlfir::ElementalOp>(
loc, exprType, shape, typeParams, isUnordered);
auto insertPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(elementalOp.getBody());
mlir::Value elementResult = genKernel(loc, builder, elementalOp.getIndices());
Expand Down Expand Up @@ -1013,5 +1015,5 @@ hlfir::cloneToElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Type elementType = scalarAddress.getFortranElementType();
return hlfir::genElementalOp(loc, builder, elementType,
elementalAddrOp.getShape(), typeParams,
genKernel);
genKernel, !elementalAddrOp.isOrdered());
}
10 changes: 8 additions & 2 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1029,10 +1029,13 @@ void hlfir::AsExprOp::build(mlir::OpBuilder &builder,
void hlfir::ElementalOp::build(mlir::OpBuilder &builder,
mlir::OperationState &odsState,
mlir::Type resultType, mlir::Value shape,
mlir::ValueRange typeparams) {
mlir::ValueRange typeparams, bool isUnordered) {
odsState.addOperands(shape);
odsState.addOperands(typeparams);
odsState.addTypes(resultType);
if (isUnordered)
odsState.addAttribute(getUnorderedAttrName(odsState.name),
isUnordered ? builder.getUnitAttr() : nullptr);
mlir::Region *bodyRegion = odsState.addRegion();
bodyRegion->push_back(new mlir::Block{});
if (auto exprType = resultType.dyn_cast<hlfir::ExprType>()) {
Expand Down Expand Up @@ -1264,8 +1267,11 @@ static void printYieldOpCleanup(mlir::OpAsmPrinter &p, YieldOp yieldOp,

void hlfir::ElementalAddrOp::build(mlir::OpBuilder &builder,
mlir::OperationState &odsState,
mlir::Value shape) {
mlir::Value shape, bool isUnordered) {
odsState.addOperands(shape);
if (isUnordered)
odsState.addAttribute(getUnorderedAttrName(odsState.name),
isUnordered ? builder.getUnitAttr() : nullptr);
mlir::Region *bodyRegion = odsState.addRegion();
bodyRegion->push_back(new mlir::Block{});
if (auto shapeType = shape.getType().dyn_cast<fir::ShapeType>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TransposeAsElementalConversion
};
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
loc, builder, elementType, resultShape, typeParams, genKernel,
transpose.getResult().getType());
/*isUnordered=*/true, transpose.getResult().getType());

// it wouldn't be safe to replace block arguments with a different
// hlfir.expr type. Types can differ due to differing amounts of shape
Expand Down
10 changes: 5 additions & 5 deletions flang/test/HLFIR/simplify-hlfir-intrinsics.fir
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func.func @transpose0(%arg0: !fir.box<!fir.array<1x2xi32>>) {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C1]] : (index, index) -> !fir.shape<2>
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x1xi32> {
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] unordered : (!fir.shape<2>) -> !hlfir.expr<2x1xi32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] : (!fir.box<!fir.array<1x2xi32>>, index) -> (index, index, index)
Expand Down Expand Up @@ -38,7 +38,7 @@ func.func @transpose1(%arg0: !hlfir.expr<1x2xi32>) {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[C1]] : (index, index) -> !fir.shape<2>
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x1xi32> {
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] unordered : (!fir.shape<2>) -> !hlfir.expr<2x1xi32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELEMENT:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<1x2xi32>, index, index) -> i32
// CHECK: hlfir.yield_element %[[ELEMENT]] : i32
Expand All @@ -57,7 +57,7 @@ func.func @transpose2(%arg0: !fir.box<!fir.array<?x2xi32>>) {
// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] : (!fir.box<!fir.array<?x2xi32>>, index) -> (index, index, index)
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[SHAPE:.*]] = fir.shape %[[C2]], %[[DIMS0]]#1 : (index, index) -> !fir.shape<2>
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x?xi32> {
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[SHAPE]] unordered : (!fir.shape<2>) -> !hlfir.expr<2x?xi32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0_1]] : (!fir.box<!fir.array<?x2xi32>>, index) -> (index, index, index)
Expand Down Expand Up @@ -86,7 +86,7 @@ func.func @transpose3(%arg0: !hlfir.expr<?x2xi32>) {
// CHECK: %[[EXTENT0:.*]] = hlfir.get_extent %[[IN_SHAPE]] {dim = 0 : index} : (!fir.shape<2>) -> index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[OUT_SHAPE:.*]] = fir.shape %[[C2]], %[[EXTENT0]] : (index, index) -> !fir.shape<2>
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[OUT_SHAPE]] : (!fir.shape<2>) -> !hlfir.expr<2x?xi32> {
// CHECK: %[[EXPR:.*]] = hlfir.elemental %[[OUT_SHAPE]] unordered : (!fir.shape<2>) -> !hlfir.expr<2x?xi32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELEMENT:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<?x2xi32>, index, index) -> i32
// CHECK: hlfir.yield_element %[[ELEMENT]] : i32
Expand All @@ -113,7 +113,7 @@ func.func @transpose4(%arg0: !hlfir.expr<2x2xf32>, %arg1: !fir.ref<!fir.box<!fir
// CHECK-SAME: %[[ARG0:.*]]: !hlfir.expr<2x2xf32>
// CHECK-SAME: %[[ARG1:.*]]:
// CHECK: %[[SHAPE0:.*]] = fir.shape
// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0]] : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
// CHECK: %[[TRANSPOSE:.*]] = hlfir.elemental %[[SHAPE0]] unordered : (!fir.shape<2>) -> !hlfir.expr<2x2xf32> {
// CHECK: ^bb0(%[[I:.*]]: index, %[[J:.*]]: index):
// CHECK: %[[ELE:.*]] = hlfir.apply %[[ARG0]], %[[J]], %[[I]] : (!hlfir.expr<2x2xf32>, index, index) -> f32
// CHECK: hlfir.yield_element %[[ELE]] : f32
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/HLFIR/allocatables-and-pointers.f90
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ subroutine elemental_expr(x)
! CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_7:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_6]] : (!fir.box<!fir.ptr<!fir.array<?x?xi32>>>, index) -> (index, index, index)
! CHECK: %[[VAL_8:.*]] = fir.shape %[[VAL_5]]#1, %[[VAL_7]]#1 : (index, index) -> !fir.shape<2>
! CHECK: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_8]] : (!fir.shape<2>) -> !hlfir.expr<?x?xi32> {
! CHECK: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_8]] unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xi32> {
! CHECK: ^bb0(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index):
! CHECK: %[[VAL_12:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_2]], %[[VAL_12]] : (!fir.box<!fir.ptr<!fir.array<?x?xi32>>>, index) -> (index, index, index)
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/HLFIR/array-ctor-as-elemental-nested.f90
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFtestEpi"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
! CHECK: %[[VAL_12:.*]] = arith.constant 2 : index
! CHECK: %[[VAL_13:.*]] = fir.shape %[[VAL_12]] : (index) -> !fir.shape<1>
! CHECK: %[[VAL_14:.*]] = hlfir.elemental %[[VAL_13]] : (!fir.shape<1>) -> !hlfir.expr<2xf32> {
! CHECK: %[[VAL_14:.*]] = hlfir.elemental %[[VAL_13]] unordered : (!fir.shape<1>) -> !hlfir.expr<2xf32> {
! CHECK: ^bb0(%[[VAL_15:.*]]: index):
! CHECK: %[[VAL_16:.*]] = arith.constant 2 : index
! CHECK: %[[VAL_17:.*]] = fir.shape %[[VAL_16]] : (index) -> !fir.shape<1>
! CHECK: %[[VAL_18:.*]] = hlfir.elemental %[[VAL_17]] : (!fir.shape<1>) -> !hlfir.expr<2xf32> {
! CHECK: %[[VAL_18:.*]] = hlfir.elemental %[[VAL_17]] unordered : (!fir.shape<1>) -> !hlfir.expr<2xf32> {
! CHECK: ^bb0(%[[VAL_19:.*]]: index):
! CHECK: %[[VAL_20:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref<f32>
! CHECK: hlfir.yield_element %[[VAL_20]] : f32
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Lower/HLFIR/array-ctor-as-elemental.f90
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ subroutine test_as_simple_elemental(n)
! CHECK: %[[VAL_6:.*]] = arith.constant 1 : i64
! CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_6]] : (i64) -> index
! CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_3]] : (!fir.shape<1>) -> !hlfir.expr<4xi32> {
! CHECK: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_3]] unordered : (!fir.shape<1>) -> !hlfir.expr<4xi32> {
! CHECK: ^bb0(%[[VAL_10:.*]]: index):
! CHECK: %[[VAL_11:.*]] = arith.subi %[[VAL_10]], %[[VAL_8]] : index
! CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_7]] : index
Expand Down Expand Up @@ -63,7 +63,7 @@ subroutine test_as_strided_elemental(lb, ub, stride)
! CHECK: %[[VAL_22:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<i64>
! CHECK: %[[VAL_23:.*]] = fir.convert %[[VAL_22]] : (i64) -> index
! CHECK: %[[VAL_24:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_25:.*]] = hlfir.elemental %[[VAL_19]] : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
! CHECK: %[[VAL_25:.*]] = hlfir.elemental %[[VAL_19]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xi32> {
! CHECK: ^bb0(%[[VAL_26:.*]]: index):
! CHECK: %[[VAL_27:.*]] = arith.subi %[[VAL_26]], %[[VAL_24]] : index
! CHECK: %[[VAL_28:.*]] = arith.muli %[[VAL_27]], %[[VAL_23]] : index
Expand Down Expand Up @@ -99,7 +99,7 @@ integer pure function foo(i)
! CHECK: %[[VAL_6:.*]] = arith.constant 1 : i64
! CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_6]] : (i64) -> index
! CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_3]] : (!fir.shape<1>) -> !hlfir.expr<4xi32> {
! CHECK: %[[VAL_9:.*]] = hlfir.elemental %[[VAL_3]] unordered : (!fir.shape<1>) -> !hlfir.expr<4xi32> {
! CHECK: ^bb0(%[[VAL_10:.*]]: index):
! CHECK: %[[VAL_11:.*]] = arith.subi %[[VAL_10]], %[[VAL_8]] : index
! CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_7]] : index
Expand Down

0 comments on commit 7b4aa95

Please sign in to comment.