Skip to content

Commit

Permalink
[flang] hlfir.elemental codegen
Browse files Browse the repository at this point in the history
Without any optimization or when it cannot be optimized before
bufferization, an hlfir.elemental lowers to an array temporary.
Its codegen consists in:
- allocating a temp given the type, shape, and length parameter arguments.
- generating a loop nest given the elemental shape
- inlining the body of the elemental inside the loops, and replacing the
  yield_element by an assignment to an element of the temp.

Differential Revision: https://reviews.llvm.org/D140093
  • Loading branch information
jeanPerier committed Dec 16, 2022
1 parent 95ec1a6 commit c2e3cb3
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 14 deletions.
17 changes: 12 additions & 5 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Expand Up @@ -182,6 +182,13 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
return createTemporary(loc, type, name, {}, {}, attrs);
}

/// Create a temporary on the heap.
mlir::Value
createHeapTemporary(mlir::Location loc, mlir::Type type,
llvm::StringRef name = {}, mlir::ValueRange shape = {},
mlir::ValueRange lenParams = {},
llvm::ArrayRef<mlir::NamedAttribute> attrs = {});

/// Create a global value.
fir::GlobalOp createGlobal(mlir::Location loc, mlir::Type type,
llvm::StringRef name,
Expand Down Expand Up @@ -425,16 +432,16 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
/// Dump the current function. (debug)
LLVM_DUMP_METHOD void dumpFunc();

private:
/// Set attributes (e.g. FastMathAttr) to \p op operation
/// based on the current attributes setting.
void setCommonAttributes(mlir::Operation *op) const;

/// FirOpBuilder hook for creating new operation.
void notifyOperationInserted(mlir::Operation *op) override {
setCommonAttributes(op);
}

private:
/// Set attributes (e.g. FastMathAttr) to \p op operation
/// based on the current attributes setting.
void setCommonAttributes(mlir::Operation *op) const;

const KindMapping &kindMap;

/// FastMathFlags that need to be set for operations that support
Expand Down
18 changes: 18 additions & 0 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Expand Up @@ -14,6 +14,7 @@
#define FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H

#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FortranVariableInterface.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"

Expand All @@ -25,6 +26,7 @@ namespace hlfir {

class AssociateOp;
class ElementalOp;
class YieldElementOp;

/// Is this an SSA value type for the value of a Fortran expression?
inline bool isFortranValueType(mlir::Type type) {
Expand Down Expand Up @@ -253,6 +255,22 @@ hlfir::ElementalOp genElementalOp(mlir::Location loc,
mlir::ValueRange typeParams,
const ElementalKernelGenerator &genKernel);

/// Generate a fir.do_loop nest looping from 1 to extents[i].
/// Return the inner fir.do_loop and the indices of the loops.
std::pair<fir::DoLoopOp, llvm::SmallVector<mlir::Value>>
genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange extents);

/// Inline the body of an hlfir.elemental at the current insertion point
/// given a list of one based indices. This generates the computation
/// of one element of the elemental expression. Return the YieldElementOp
/// whose value argument is the element value.
/// The original hlfir::ElementalOp is left untouched.
hlfir::YieldElementOp inlineElementalOp(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::ElementalOp elemental,
mlir::ValueRange oneBasedIndices);

} // namespace hlfir

#endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H
14 changes: 14 additions & 0 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Expand Up @@ -232,6 +232,20 @@ fir::FirOpBuilder::createTemporary(mlir::Location loc, mlir::Type type,
return ae;
}

mlir::Value fir::FirOpBuilder::createHeapTemporary(
mlir::Location loc, mlir::Type type, llvm::StringRef name,
mlir::ValueRange shape, mlir::ValueRange lenParams,
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
llvm::SmallVector<mlir::Value> dynamicShape =
elideExtentsAlreadyInType(type, shape);
llvm::SmallVector<mlir::Value> dynamicLength =
elideLengthsAlreadyInType(type, lenParams);

assert(!type.isa<fir::ReferenceType>() && "cannot be a reference");
return create<fir::AllocMemOp>(loc, type, /*unique_name=*/llvm::StringRef{},
name, dynamicLength, dynamicShape, attrs);
}

/// Create a global variable in the (read-only) data section. A global variable
/// must have a unique name to identify and reference it.
fir::GlobalOp fir::FirOpBuilder::createGlobal(mlir::Location loc,
Expand Down
40 changes: 40 additions & 0 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Expand Up @@ -15,6 +15,7 @@
#include "flang/Optimizer/Builder/MutableBox.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "mlir/IR/BlockAndValueMapping.h"

// Return explicit extents. If the base is a fir.box, this won't read it to
// return the extents and will instead return an empty vector.
Expand Down Expand Up @@ -484,3 +485,42 @@ hlfir::genElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
builder.restoreInsertionPoint(insertPt);
return elementalOp;
}

hlfir::YieldElementOp
hlfir::inlineElementalOp(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::ElementalOp elemental,
mlir::ValueRange oneBasedIndices) {
// hlfir.elemental region is a SizedRegion<1>.
assert(elemental.getRegion().hasOneBlock() &&
"expect elemental region to have one block");
mlir::BlockAndValueMapping mapper;
mapper.map(elemental.getIndices(), oneBasedIndices);
mlir::Operation *newOp;
for (auto &op : elemental.getRegion().back().getOperations())
newOp = builder.clone(op, mapper);
auto yield = mlir::dyn_cast_or_null<hlfir::YieldElementOp>(newOp);
assert(yield && "last ElementalOp operation must be am hlfir.yield_element");
return yield;
}

std::pair<fir::DoLoopOp, llvm::SmallVector<mlir::Value>>
hlfir::genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange extents) {
assert(!extents.empty() && "must have at least one extent");
auto insPt = builder.saveInsertionPoint();
llvm::SmallVector<mlir::Value> indices(extents.size());
// Build loop nest from column to row.
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
mlir::Type indexType = builder.getIndexType();
unsigned dim = extents.size() - 1;
fir::DoLoopOp innerLoop;
for (auto extent : llvm::reverse(extents)) {
auto ub = builder.createConvert(loc, indexType, extent);
innerLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one);
builder.setInsertionPointToStart(innerLoop.getBody());
// Reverse the indices so they are in column-major order.
indices[dim--] = innerLoop.getInductionVar();
}
builder.restoreInsertionPoint(insPt);
return {innerLoop, indices};
}
137 changes: 128 additions & 9 deletions flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Expand Up @@ -95,6 +95,26 @@ static mlir::Value getBufferizedExprMustFreeFlag(mlir::Value bufferizedExpr) {
TODO(bufferizedExpr.getLoc(), "general extract storage case");
}

static llvm::SmallVector<mlir::Value>
getIndexExtents(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Value shape) {
llvm::SmallVector<mlir::Value> extents;
if (auto s = shape.getDefiningOp<fir::ShapeOp>()) {
auto e = s.getExtents();
extents.insert(extents.end(), e.begin(), e.end());
} else if (auto s = shape.getDefiningOp<fir::ShapeShiftOp>()) {
auto e = s.getExtents();
extents.insert(extents.end(), e.begin(), e.end());
} else {
// TODO: add fir.get_extent ops on fir.shape<> ops.
TODO(loc, "get extents from fir.shape without fir::ShapeOp parent op");
}
mlir::Type indexType = builder.getIndexType();
for (auto &extent : extents)
extent = builder.createConvert(loc, indexType, extent);
return extents;
}

static std::pair<hlfir::Entity, mlir::Value>
createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity mold) {
Expand All @@ -113,6 +133,21 @@ createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder,
return {hlfir::Entity{declareOp.getBase()}, falseVal};
}

static std::pair<hlfir::Entity, mlir::Value>
createArrayTemp(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Type exprType, mlir::Value shape,
mlir::ValueRange extents, mlir::ValueRange lenParams) {
mlir::Type sequenceType = hlfir::getFortranElementOrSequenceType(exprType);
llvm::StringRef tmpName{".tmp.array"};
mlir::Value allocmem = builder.createHeapTemporary(loc, sequenceType, tmpName,
extents, lenParams);
auto declareOp =
builder.create<hlfir::DeclareOp>(loc, allocmem, tmpName, shape, lenParams,
fir::FortranVariableFlagsAttr{});
mlir::Value trueVal = builder.createBool(loc, true);
return {hlfir::Entity{declareOp.getBase()}, trueVal};
}

struct AsExprOpConversion : public mlir::OpConversionPattern<hlfir::AsExprOp> {
using mlir::OpConversionPattern<hlfir::AsExprOp>::OpConversionPattern;
explicit AsExprOpConversion(mlir::MLIRContext *ctx)
Expand Down Expand Up @@ -236,11 +271,20 @@ struct EndAssociateOpConversion
matchAndRewrite(hlfir::EndAssociateOp endAssociate, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Value mustFree = adaptor.getMustFree();
if (auto cstMustFree = fir::factory::getIntIfConstant(mustFree))
if (*cstMustFree == 0) {
rewriter.eraseOp(endAssociate);
return mlir::success(); // nothing to do.
}
mlir::Location loc = endAssociate->getLoc();
rewriter.eraseOp(endAssociate);
auto genFree = [&]() {
mlir::Value var = adaptor.getVar();
if (var.getType().isa<fir::BaseBoxType>())
TODO(loc, "unbox");
rewriter.create<fir::FreeMemOp>(loc, var);
};
if (auto cstMustFree = fir::factory::getIntIfConstant(mustFree)) {
if (*cstMustFree != 0)
genFree();
// else, nothing to do.
return mlir::success();
}
TODO(endAssociate.getLoc(), "conditional free");
}
};
Expand All @@ -259,6 +303,79 @@ struct NoReassocOpConversion
}
};

/// This Listener allows setting both the builder and the rewriter as
/// listeners. This is required when a pattern uses a firBuilder helper that
/// may create illegal operations that will need to be translated and requires
/// notifying the rewriter.
struct HLFIRListener : public mlir::OpBuilder::Listener {
HLFIRListener(fir::FirOpBuilder &builder,
mlir::ConversionPatternRewriter &rewriter)
: builder{builder}, rewriter{rewriter} {}
void notifyOperationInserted(mlir::Operation *op) override {
builder.notifyOperationInserted(op);
rewriter.notifyOperationInserted(op);
}
virtual void notifyBlockCreated(mlir::Block *block) override {
builder.notifyBlockCreated(block);
rewriter.notifyBlockCreated(block);
}
fir::FirOpBuilder &builder;
mlir::ConversionPatternRewriter &rewriter;
};

struct ElementalOpConversion
: public mlir::OpConversionPattern<hlfir::ElementalOp> {
using mlir::OpConversionPattern<hlfir::ElementalOp>::OpConversionPattern;
explicit ElementalOpConversion(mlir::MLIRContext *ctx)
: mlir::OpConversionPattern<hlfir::ElementalOp>{ctx} {}
mlir::LogicalResult
matchAndRewrite(hlfir::ElementalOp elemental, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = elemental->getLoc();
auto module = elemental->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
// The body of the elemental op may contain operation that will require
// to be translated. Notify the rewriter about the cloned operations.
HLFIRListener listener{builder, rewriter};
builder.setListener(&listener);

mlir::Value shape = adaptor.getShape();
auto extents = getIndexExtents(loc, builder, shape);
auto [temp, cleanup] =
createArrayTemp(loc, builder, elemental.getType(), shape, extents,
adaptor.getTypeparams());
// Generate a loop nest looping around the fir.elemental shape and clone
// fir.elemental region inside the inner loop.
auto [innerLoop, oneBasedLoopIndices] =
hlfir::genLoopNest(loc, builder, extents);
auto insPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(innerLoop.getBody());
auto yield =
hlfir::inlineElementalOp(loc, builder, elemental, oneBasedLoopIndices);
hlfir::Entity elementValue(yield.getElementValue());
// Skip final AsExpr if any. It would create an element temporary,
// which is no needed since the element will be assigned right away in
// the array temporary. An hlfir.as_expr may have been added if the
// elemental is a "view" over a variable (e.g parentheses or transpose).
if (auto asExpr = elementValue.getDefiningOp<hlfir::AsExprOp>()) {
elementValue = hlfir::Entity{asExpr.getVar()};
if (asExpr->hasOneUse())
rewriter.eraseOp(asExpr);
}
rewriter.eraseOp(yield);
// Assign the element value to the temp element for this iteration.
auto tempElement =
hlfir::getElementAt(loc, builder, temp, oneBasedLoopIndices);
builder.create<hlfir::AssignOp>(loc, elementValue, tempElement);
builder.restoreInsertionPoint(insPt);

mlir::Value bufferizedExpr =
packageBufferizedExpr(loc, builder, temp, cleanup);
rewriter.replaceOp(elemental, bufferizedExpr);
return mlir::success();
}
};

class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
public:
void runOnOperation() override {
Expand All @@ -272,11 +389,13 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
auto module = this->getOperation();
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.insert<AsExprOpConversion, AssignOpConversion,
AssociateOpConversion, ConcatOpConversion,
EndAssociateOpConversion, NoReassocOpConversion>(context);
patterns
.insert<AsExprOpConversion, AssignOpConversion, AssociateOpConversion,
ConcatOpConversion, ElementalOpConversion,
EndAssociateOpConversion, NoReassocOpConversion>(context);
mlir::ConversionTarget target(*context);
target.addIllegalOp<hlfir::AssociateOp, hlfir::EndAssociateOp>();
target.addIllegalOp<hlfir::AssociateOp, hlfir::ElementalOp,
hlfir::EndAssociateOp, hlfir::YieldElementOp>();
target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) {
return llvm::all_of(
op->getResultTypes(),
Expand Down

0 comments on commit c2e3cb3

Please sign in to comment.