Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,11 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let results = (outs FloatIntegerIndexOrOpaqueType);
}

def EmitC_ExpressionOp : EmitC_Op<"expression",
[HasOnlyGraphRegion, OpAsmOpInterface,
SingleBlockImplicitTerminator<"emitc::YieldOp">, NoRegionArguments]> {
def EmitC_ExpressionOp
: EmitC_Op<
"expression", [HasOnlyGraphRegion, OpAsmOpInterface,
IsolatedFromAbove,
SingleBlockImplicitTerminator<"emitc::YieldOp">]> {
let summary = "Expression operation";
let description = [{
The `emitc.expression` operation returns a single SSA value which is yielded by
Expand Down Expand Up @@ -494,12 +496,13 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
at its use.
}];

let arguments = (ins UnitAttr:$do_not_inline);
let arguments = (ins Variadic<AnyTypeOf<[EmitCType, EmitC_LValueType]>>:$defs,
UnitAttr:$do_not_inline);
let results = (outs EmitCType:$result);
let regions = (region SizedRegion<1>:$region);

let hasVerifier = 1;
let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region";
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
bool hasSideEffects() {
Expand All @@ -510,6 +513,13 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
return llvm::any_of(getRegion().front().without_terminator(), predicate);
};
Operation *getRootOp();
Block &createBody() {
assert(getRegion().empty() && "expression already has a body");
Block &block = getRegion().emplaceBlock();
for (auto operand : getOperands())
block.addArgument(operand.getType(), operand.getLoc());
return block;
}

//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
Expand Down
17 changes: 10 additions & 7 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,16 +610,19 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));

emitc::ExpressionOp ternary = emitc::ExpressionOp::create(
rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false);
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
emitc::ExpressionOp ternary =
emitc::ExpressionOp::create(rewriter, op.getLoc(), arithmeticType,
ValueRange({lhs, rhs, excessCheck, poison}),
/*do_not_inline=*/false);
Block &bodyBlock = ternary.createBody();
auto currentPoint = rewriter.getInsertionPoint();
rewriter.setInsertionPointToStart(&bodyBlock);
Value arithmeticResult =
EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
Value resultOrPoison =
emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType,
excessCheck, arithmeticResult, poison);
EmitCOp::create(rewriter, op.getLoc(), arithmeticType,
bodyBlock.getArgument(0), bodyBlock.getArgument(1));
Value resultOrPoison = emitc::ConditionalOp::create(
rewriter, op.getLoc(), arithmeticType, bodyBlock.getArgument(2),
arithmeticResult, bodyBlock.getArgument(3));
emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
rewriter.setInsertionPoint(op->getBlock(), currentPoint);

Expand Down
46 changes: 46 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,52 @@ OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
// ExpressionOp
//===----------------------------------------------------------------------===//

ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand> operands;
if (parser.parseOperandList(operands))
return parser.emitError(parser.getCurrentLocation()) << "expected operands";
if (succeeded(parser.parseOptionalKeyword("noinline")))
result.addAttribute(ExpressionOp::getDoNotInlineAttrName(result.name),
parser.getBuilder().getUnitAttr());
Type type;
if (parser.parseColonType(type))
return parser.emitError(parser.getCurrentLocation(),
"expected function type");
auto fnType = llvm::dyn_cast<FunctionType>(type);
if (!fnType)
return parser.emitError(parser.getCurrentLocation(),
"expected function type");
if (parser.resolveOperands(operands, fnType.getInputs(),
parser.getCurrentLocation(), result.operands))
return failure();
if (fnType.getNumResults() != 1)
return parser.emitError(parser.getCurrentLocation(),
"expected single return type");
result.addTypes(fnType.getResults());
Region *body = result.addRegion();
SmallVector<OpAsmParser::Argument> argsInfo;
for (auto [unresolvedOperand, operandType] :
llvm::zip(operands, fnType.getInputs())) {
OpAsmParser::Argument argInfo;
argInfo.ssaName = unresolvedOperand;
argInfo.type = operandType;
argsInfo.push_back(argInfo);
}
if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
return failure();
return success();
}

void emitc::ExpressionOp::print(OpAsmPrinter &p) {
p << ' ';
p.printOperands(getDefs());
p << " : ";
p.printFunctionalType(getOperation());
p.shadowRegionArgs(getRegion(), getDefs());
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}

Operation *ExpressionOp::getRootOp() {
auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
Value yieldedValue = yieldOp.getResult();
Expand Down
150 changes: 99 additions & 51 deletions mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"

namespace mlir {
namespace emitc {
Expand All @@ -24,20 +26,24 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
Location loc = op->getLoc();

builder.setInsertionPointAfter(op);
auto expressionOp = emitc::ExpressionOp::create(builder, loc, resultType);
auto expressionOp =
emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands());

// Replace all op's uses with the new expression's result.
result.replaceAllUsesWith(expressionOp.getResult());

// Create an op to yield op's value.
Region &region = expressionOp.getRegion();
Block &block = region.emplaceBlock();
Block &block = expressionOp.createBody();
IRMapping mapper;
for (auto [operand, arg] :
llvm::zip(expressionOp.getOperands(), block.getArguments()))
mapper.map(operand, arg);
builder.setInsertionPointToEnd(&block);
auto yieldOp = emitc::YieldOp::create(builder, loc, result);

// Move op into the new expression.
op->moveBefore(yieldOp);
Operation *rootOp = builder.clone(*op, mapper);
op->erase();

// Create an op to yield op's value.
emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]);
return expressionOp;
}

Expand All @@ -53,51 +59,93 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
using OpRewritePattern<ExpressionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpressionOp expressionOp,
PatternRewriter &rewriter) const override {
bool anythingFolded = false;
for (Operation &op : llvm::make_early_inc_range(
expressionOp.getBody()->without_terminator())) {
// Don't fold expressions whose result value has its address taken.
auto applyOp = dyn_cast<emitc::ApplyOp>(op);
if (applyOp && applyOp.getApplicableOperator() == "&")
continue;

for (Value operand : op.getOperands()) {
auto usedExpression = operand.getDefiningOp<ExpressionOp>();
if (!usedExpression)
continue;

// Don't fold expressions with multiple users: assume any
// re-materialization was done separately.
if (!usedExpression.getResult().hasOneUse())
continue;

// Don't fold expressions with side effects.
if (usedExpression.hasSideEffects())
continue;

// Fold the used expression into this expression by cloning all
// instructions in the used expression just before the operation using
// its value.
rewriter.setInsertionPoint(&op);
IRMapping mapper;
for (Operation &opToClone :
usedExpression.getBody()->without_terminator()) {
Operation *clone = rewriter.clone(opToClone, mapper);
mapper.map(&opToClone, clone);
}

Operation *expressionRoot = usedExpression.getRootOp();
Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
assert(clonedExpressionRootOp &&
"Expected cloned expression root to be in mapper");
assert(clonedExpressionRootOp->getNumResults() == 1 &&
"Expected cloned root to have a single result");

rewriter.replaceOp(usedExpression, clonedExpressionRootOp);
anythingFolded = true;
}
Block *expressionBody = expressionOp.getBody();
ExpressionOp usedExpression;
SetVector<Value> foldedOperands;

auto takesItsOperandsAddress = [](Operation *user) {
auto applyOp = dyn_cast<emitc::ApplyOp>(user);
return applyOp && applyOp.getApplicableOperator() == "&";
};

// Select as expression to fold the first operand expression that
// - doesn't have its result value's address taken,
// - has a single user: assume any re-materialization was done separately,
// - has no side effects,
// and save all other operands to be used later as operands in the folded
// expression.
for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(),
expressionBody->getArguments())) {
ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>();
if (usedExpression || !operandExpression ||
llvm::any_of(arg.getUsers(), takesItsOperandsAddress) ||
!operandExpression.getResult().hasOneUse() ||
operandExpression.hasSideEffects())
foldedOperands.insert(operand);
else
usedExpression = operandExpression;
}
return anythingFolded ? success() : failure();

// If no operand expression was selected, bail out.
if (!usedExpression)
return failure();

// Collect additional operands from the folded expression.
for (Value operand : usedExpression.getOperands())
foldedOperands.insert(operand);

// Create a new expression to hold the folding result.
rewriter.setInsertionPointAfter(expressionOp);
auto foldedExpression = emitc::ExpressionOp::create(
rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
foldedOperands.getArrayRef(), expressionOp.getDoNotInline());
Block &foldedExpressionBody = foldedExpression.createBody();

// Map each operand of the new expression to its matching block argument.
IRMapping mapper;
for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(),
foldedExpressionBody.getArguments()))
mapper.map(operand, arg);

// Prepare to fold the used expression and the matched expression into the
// newly created folded expression.
auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
bool withTerminator) {
Block *expressionToFoldBody = expressionToFold.getBody();
for (auto [operand, arg] :
llvm::zip(expressionToFold.getOperands(),
expressionToFoldBody->getArguments())) {
mapper.map(arg, mapper.lookup(operand));
}

for (Operation &opToClone : expressionToFoldBody->without_terminator())
rewriter.clone(opToClone, mapper);

if (withTerminator)
rewriter.clone(*expressionToFoldBody->getTerminator(), mapper);
};
rewriter.setInsertionPointToStart(&foldedExpressionBody);

// First, fold the used expression into the new expression and map its
// result to the clone of its root operation within the new expression.
foldExpression(usedExpression, /*withTerminator=*/false);
Operation *expressionRoot = usedExpression.getRootOp();
Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
assert(clonedExpressionRootOp &&
"Expected cloned expression root to be in mapper");
assert(clonedExpressionRootOp->getNumResults() == 1 &&
"Expected cloned root to have a single result");
mapper.map(usedExpression.getResult(),
clonedExpressionRootOp->getResults()[0]);

// Now fold the matched expression into the new expression.
foldExpression(expressionOp, /*withTerminator=*/true);

// Complete the rewrite.
rewriter.replaceOp(expressionOp, foldedExpression);
rewriter.eraseOp(usedExpression);

return success();
}
};

Expand Down
22 changes: 19 additions & 3 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/Cpp/CppEmitter.h"
Expand Down Expand Up @@ -364,9 +365,10 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
if (hasDeferredEmission(user))
return false;

// Do not inline expressions used by ops with the CExpressionInterface. If
// this was intended, the user could have been merged into the expression op.
return !isa<emitc::CExpressionInterface>(*user);
// Do not inline expressions used by other expressions or by ops with the
// CExpressionInterface. If this was intended, the user could have been merged
// into the expression op.
return !isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user);
}

static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
Expand Down Expand Up @@ -1532,6 +1534,20 @@ LogicalResult CppEmitter::emitOperand(Value value) {
if (expressionOp && shouldBeInlined(expressionOp))
return emitExpression(expressionOp);

if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
// If this operand is a block argument of an expression, emit instead the
// matching expression parameter.
Operation *argOp = arg.getParentBlock()->getParentOp();
if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) {
// This scenario is only expected when one of the operations within the
// expression being emitted references one of the expression's block
// arguments.
assert(expressionOp == emittedExpression &&
"Expected expression being emitted");
value = expressionOp->getOperand(arg.getArgNumber());
}
}

os << getOrCreateName(value);
return success();
}
Expand Down
Loading