Skip to content

Commit

Permalink
[Flang] Generate inline reduction loops for elemental count intrinsics (
Browse files Browse the repository at this point in the history
#75774)

This adds a ReductionElementalConversion transform to
OptimizedBufferizationPass, taking hlfir::count(hlfir::elemental) and
generating the inline loop to perform the count of true elements. This
lets us generate a single loop instead of ending up as two plus a
temporary.

Any and All should be able to share the same code with a different
function/initial value.
  • Loading branch information
davemgreen committed Jan 9, 2024
1 parent 6eb372e commit 810c291
Show file tree
Hide file tree
Showing 2 changed files with 434 additions and 0 deletions.
120 changes: 120 additions & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Expand Up @@ -659,6 +659,125 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
return mlir::success();
}

using GenBodyFn =
std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
const llvm::SmallVectorImpl<mlir::Value> &)>;
static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Value init,
mlir::Value shape, GenBodyFn genBody) {
auto extents = hlfir::getIndexExtents(loc, builder, shape);
mlir::Value reduction = init;
mlir::IndexType idxTy = builder.getIndexType();
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);

// Create a reduction loop nest. We use one-based indices so that they can be
// passed to the elemental, and reverse the order so that they can be
// generated in column-major order for better performance.
llvm::SmallVector<mlir::Value> indices(extents.size(), mlir::Value{});
for (unsigned i = 0; i < extents.size(); ++i) {
auto loop = builder.create<fir::DoLoopOp>(
loc, oneIdx, extents[extents.size() - i - 1], oneIdx, false,
/*finalCountValue=*/false, reduction);
reduction = loop.getRegionIterArgs()[0];
indices[extents.size() - i - 1] = loop.getInductionVar();
// Set insertion point to the loop body so that the next loop
// is inserted inside the current one.
builder.setInsertionPointToStart(loop.getBody());
}

// Generate the body
reduction = genBody(builder, loc, reduction, indices);

// Unwind the loop nest.
for (unsigned i = 0; i < extents.size(); ++i) {
auto result = builder.create<fir::ResultOp>(loc, reduction);
auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
reduction = loop.getResult(0);
// Set insertion point after the loop operation that we have
// just processed.
builder.setInsertionPointAfter(loop.getOperation());
}

return reduction;
}

/// Given a reduction operation with an elemental mask, attempt to generate a
/// do-loop to perform the operation inline.
/// %e = hlfir.elemental %shape unordered
/// %r = hlfir.count %e
/// =>
/// %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
/// %i = <inline elemental>
/// %c = <reduce count> %i
/// fir.result %c
template <typename Op>
class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
public:
using mlir::OpRewritePattern<Op>::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = op.getLoc();
hlfir::ElementalOp elemental =
op.getMask().template getDefiningOp<hlfir::ElementalOp>();
if (!elemental || op.getDim())
return rewriter.notifyMatchFailure(op, "Did not find valid elemental");

fir::KindMapping kindMap =
fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
fir::FirOpBuilder builder{op, kindMap};

mlir::Value init;
GenBodyFn genBodyFn;
if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
init = builder.createIntegerConstant(loc, op.getType(), 0);
genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices)
-> mlir::Value {
// Inline the elemental and get the condition from it.
auto yield = inlineElementalOp(loc, builder, elemental, indices);
mlir::Value cond = builder.create<fir::ConvertOp>(
loc, builder.getI1Type(), yield.getElementValue());
yield->erase();

// Conditionally add one to the current value
mlir::Value one =
builder.createIntegerConstant(loc, reduction.getType(), 1);
mlir::Value add1 =
builder.create<mlir::arith::AddIOp>(loc, reduction, one);
return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
reduction);
};
} else {
static_assert("Expected Op to be handled");
return mlir::failure();
}

mlir::Value res = generateReductionLoop(builder, loc, init,
elemental.getOperand(0), genBodyFn);
if (res.getType() != op.getType())
res = builder.create<fir::ConvertOp>(loc, op.getType(), res);

// Check if the op was the only user of the elemental (apart from a
// destroy), and remove it if so.
mlir::Operation::user_range elemUsers = elemental->getUsers();
hlfir::DestroyOp elemDestroy;
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
if (!elemDestroy)
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
}

rewriter.replaceOp(op, res);
if (elemDestroy) {
rewriter.eraseOp(elemDestroy);
rewriter.eraseOp(elemental);
}
return mlir::success();
}
};

class OptimizedBufferizationPass
: public hlfir::impl::OptimizedBufferizationBase<
OptimizedBufferizationPass> {
Expand All @@ -681,6 +800,7 @@ class OptimizedBufferizationPass
patterns.insert<ElementalAssignBufferization>(context);
patterns.insert<BroadcastAssignBufferization>(context);
patterns.insert<VariableAssignBufferization>(context);
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
func, std::move(patterns), config))) {
Expand Down

0 comments on commit 810c291

Please sign in to comment.