diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp index d49bc1e6bdff3..0317f83063f5e 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp @@ -141,7 +141,13 @@ class OrderedAssignmentRewriter { /// code, except the final yield, at the current execution point. /// If the value was saved in a previous run, this fetches the saved value /// from the temporary storage and returns the value. - mlir::Value generateYieldedScalarValue(mlir::Region ®ion); + /// Inside Forall, the value will be hoisted outside of the forall loops if + /// it does not depend on the forall indices. + /// An optional type can be provided to get a value from a specific type + /// (the cast will be hoisted if the computation is hoisted). + mlir::Value generateYieldedScalarValue( + mlir::Region ®ion, + std::optional castToType = std::nullopt); /// Generate an entity yielded by an ordered assignment tree region, and /// optionally return the (uncloned) yield if there is any clean-up that @@ -149,7 +155,8 @@ class OrderedAssignmentRewriter { /// this will return the saved value if the region was saved in a previous /// run. std::pair> - generateYieldedEntity(mlir::Region ®ion); + generateYieldedEntity(mlir::Region ®ion, + std::optional castToType = std::nullopt); /// If \p maybeYield is present and has a clean-up, generate the clean-up /// at the current insertion point (by cloning). @@ -215,20 +222,20 @@ void OrderedAssignmentRewriter::walk( void OrderedAssignmentRewriter::pre(hlfir::ForallOp forallOp) { /// Create a fir.do_loop given the hlfir.forall control values. - mlir::Value rawLowerBound = - generateYieldedScalarValue(forallOp.getLbRegion()); - mlir::Location loc = forallOp.getLoc(); mlir::Type idxTy = builder.getIndexType(); - mlir::Value lb = builder.createConvert(loc, idxTy, rawLowerBound); - mlir::Value rawUpperBound = - generateYieldedScalarValue(forallOp.getUbRegion()); - mlir::Value ub = builder.createConvert(loc, idxTy, rawUpperBound); + mlir::Location loc = forallOp.getLoc(); + mlir::Value lb = generateYieldedScalarValue(forallOp.getLbRegion(), idxTy); + mlir::Value ub = generateYieldedScalarValue(forallOp.getUbRegion(), idxTy); mlir::Value step; if (forallOp.getStepRegion().empty()) { + auto insertionPoint = builder.saveInsertionPoint(); + if (!constructStack.empty()) + builder.setInsertionPoint(constructStack[0]); step = builder.createIntegerConstant(loc, idxTy, 1); + if (!constructStack.empty()) + builder.restoreInsertionPoint(insertionPoint); } else { - step = generateYieldedScalarValue(forallOp.getStepRegion()); - step = builder.createConvert(loc, idxTy, step); + step = generateYieldedScalarValue(forallOp.getStepRegion(), idxTy); } auto doLoop = builder.create(loc, lb, ub, step); builder.setInsertionPointToStart(doLoop.getBody()); @@ -256,8 +263,8 @@ void OrderedAssignmentRewriter::pre(hlfir::ForallIndexOp forallIndexOp) { void OrderedAssignmentRewriter::pre(hlfir::ForallMaskOp forallMaskOp) { mlir::Location loc = forallMaskOp.getLoc(); - mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion()); - mask = builder.createConvert(loc, builder.getI1Type(), mask); + mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion(), + builder.getI1Type()); auto ifOp = builder.create(loc, std::nullopt, mask, false); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); constructStack.push_back(ifOp); @@ -350,35 +357,84 @@ void OrderedAssignmentRewriter::post(hlfir::ElseWhereOp elseWhereOp) { builder.setInsertionPointAfter(constructStack.pop_back_val()); } +/// Is this value a Forall index? +/// Forall index are block arguments of hlfir.forall body, or the result +/// of hlfir.forall_index. +static bool isForallIndex(mlir::Value value) { + if (auto blockArg = mlir::dyn_cast(value)) { + if (mlir::Block *block = blockArg.getOwner()) + return block->isEntryBlock() && + mlir::isa_and_nonnull(block->getParentOp()); + return false; + } + return value.getDefiningOp(); +} + std::pair> -OrderedAssignmentRewriter::generateYieldedEntity(mlir::Region ®ion) { +OrderedAssignmentRewriter::generateYieldedEntity( + mlir::Region ®ion, std::optional castToType) { // TODO: if the region was saved, use that instead of generating code again. if (whereLoopNest.has_value()) { mlir::Location loc = region.getParentOp()->getLoc(); return {generateMaskedEntity(loc, region), std::nullopt}; } assert(region.hasOneBlock() && "region must contain one block"); - // Clone all operations except the final hlfir.yield. + auto oldYield = mlir::dyn_cast_or_null( + region.back().getOperations().back()); + assert(oldYield && "region computing entities must end with a YieldOp"); mlir::Block::OpListType &ops = region.back().getOperations(); + + // Inside Forall, scalars that do not depend on forall indices can be hoisted + // here because their evaluation is required to only call pure procedures, and + // if they depend on a variable previously assigned to in a forall assignment, + // this assignment must have been scheduled in a previous run. Hoisting of + // scalars is done here to help creating simple temporary storage if needed. + // Inner forall bounds can often be hoisted, and this allows computing the + // total number of iterations to create temporary storages. + bool hoistComputation = false; + if (fir::isa_trivial(oldYield.getEntity().getType()) && + !constructStack.empty()) { + hoistComputation = true; + for (mlir::Operation &op : ops) + if (llvm::any_of(op.getOperands(), [](mlir::Value value) { + return isForallIndex(value); + })) { + hoistComputation = false; + break; + } + } + auto insertionPoint = builder.saveInsertionPoint(); + if (hoistComputation) + builder.setInsertionPoint(constructStack[0]); + + // Clone all operations except the final hlfir.yield. assert(!ops.empty() && "yield block cannot be empty"); auto end = ops.end(); for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt) (void)builder.clone(*opIt, mapper); - auto oldYield = mlir::dyn_cast_or_null( - region.back().getOperations().back()); - assert(oldYield && "region computing scalar must end with a YieldOp"); // Get the value for the yielded entity, it may be the result of an operation // that was cloned, or it may be the same as the previous value if the yield // operand was created before the ordered assignment tree. mlir::Value newEntity = mapper.lookupOrDefault(oldYield.getEntity()); + if (castToType.has_value()) + newEntity = + builder.createConvert(newEntity.getLoc(), *castToType, newEntity); + + if (hoistComputation) { + // Hoisted trivial scalars clean-up can be done right away, the value is + // in registers. + generateCleanupIfAny(oldYield); + builder.restoreInsertionPoint(insertionPoint); + return {newEntity, std::nullopt}; + } if (oldYield.getCleanup().empty()) return {newEntity, std::nullopt}; return {newEntity, oldYield}; } -mlir::Value -OrderedAssignmentRewriter::generateYieldedScalarValue(mlir::Region ®ion) { - auto [value, maybeYield] = generateYieldedEntity(region); +mlir::Value OrderedAssignmentRewriter::generateYieldedScalarValue( + mlir::Region ®ion, std::optional castToType) { + auto [value, maybeYield] = generateYieldedEntity(region, castToType); assert(fir::isa_trivial(value.getType()) && "not a trivial scalar value"); generateCleanupIfAny(maybeYield); return value; diff --git a/flang/test/HLFIR/order_assignments/forall-codegen-no-conflict.fir b/flang/test/HLFIR/order_assignments/forall-codegen-no-conflict.fir index dace9b2f245df..784367f4b05df 100644 --- a/flang/test/HLFIR/order_assignments/forall-codegen-no-conflict.fir +++ b/flang/test/HLFIR/order_assignments/forall-codegen-no-conflict.fir @@ -24,10 +24,10 @@ func.func @test_simple(%x: !fir.ref>) { // CHECK: %[[VAL_1:.*]] = arith.constant 1 : index // CHECK: %[[VAL_2:.*]] = arith.constant 10 : index // CHECK: %[[VAL_3:.*]] = arith.constant 1 : index -// CHECK: fir.do_loop %[[VAL_4:.*]] = %[[VAL_1]] to %[[VAL_2]] step %[[VAL_3]] { -// CHECK: %[[VAL_5:.*]] = arith.constant 42 : i32 -// CHECK: %[[VAL_6:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_4]]) : (!fir.ref>, index) -> !fir.ref -// CHECK: hlfir.assign %[[VAL_5]] to %[[VAL_6]] : i32, !fir.ref +// CHECK: %[[VAL_4:.*]] = arith.constant 42 : i32 +// CHECK: fir.do_loop %[[VAL_5:.*]] = %[[VAL_1]] to %[[VAL_2]] step %[[VAL_3]] { +// CHECK: %[[VAL_6:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_5]]) : (!fir.ref>, index) -> !fir.ref +// CHECK: hlfir.assign %[[VAL_4]] to %[[VAL_6]] : i32, !fir.ref // CHECK: } func.func @test_index(%x: !fir.ref>) { @@ -122,11 +122,11 @@ func.func @split_schedule(%arg0: !fir.box>, %arg1: !fir.box index // CHECK: %[[VAL_18:.*]] = fir.convert %[[VAL_4]] : (i64) -> index // CHECK: %[[VAL_19:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_5]] : (i64) -> index +// CHECK: %[[VAL_23:.*]] = fir.convert %[[VAL_4]] : (i64) -> index +// CHECK: %[[VAL_24:.*]] = arith.constant 1 : index // CHECK: fir.do_loop %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_19]] { // CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (index) -> i64 -// CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_5]] : (i64) -> index -// CHECK: %[[VAL_23:.*]] = fir.convert %[[VAL_4]] : (i64) -> index -// CHECK: %[[VAL_24:.*]] = arith.constant 1 : index // CHECK: fir.do_loop %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_23]] step %[[VAL_24]] { // CHECK: %[[VAL_26:.*]] = fir.convert %[[VAL_25]] : (index) -> i64 // CHECK: %[[VAL_27:.*]] = arith.subi %[[VAL_3]], %[[VAL_21]] : i64 @@ -181,15 +181,15 @@ func.func @test_mask(%arg0: !fir.box>, %arg1: !fir.box index // CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_3]] : (i64) -> index // CHECK: %[[VAL_10:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_4]] : (i64) -> index +// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index // CHECK: fir.do_loop %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_10]] { // CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_11]] : (index) -> i64 // CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_12]]) : (!fir.box>>, i64) -> !fir.ref> // CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref> // CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.logical<4>) -> i1 // CHECK: fir.if %[[VAL_15]] { -// CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_4]] : (i64) -> index // CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_12]] : (i64) -> index -// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index // CHECK: fir.do_loop %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_17]] step %[[VAL_18]] { // CHECK: %[[VAL_20:.*]] = fir.convert %[[VAL_19]] : (index) -> i64 // CHECK: %[[VAL_21:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_12]], %[[VAL_20]]) : (!fir.box>, i64, i64) -> !fir.ref