Skip to content

Commit c618df8

Browse files
committed
[mlir][emitc] Isolate expressions from above
The expression op is currently not isolated from above. This served its original usage as an optional, translation-oriented op, but is becoming less convenient now that expressions appear earlier in the emitc compilation flow and are gaining use as components of other emitc ops. This patch therefore adds the isolated-from-above trait to expressions. Syntactically, the only change is in the expression's signature which now includes the values being used in the expression as arguments and their types. The region's argument's names shadow the used values to keep the def-use relations clear.
1 parent d7d8703 commit c618df8

File tree

13 files changed

+256
-122
lines changed

13 files changed

+256
-122
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,11 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
455455
let results = (outs FloatIntegerIndexOrOpaqueType);
456456
}
457457

458-
def EmitC_ExpressionOp : EmitC_Op<"expression",
459-
[HasOnlyGraphRegion, OpAsmOpInterface,
460-
SingleBlockImplicitTerminator<"emitc::YieldOp">, NoRegionArguments]> {
458+
def EmitC_ExpressionOp
459+
: EmitC_Op<
460+
"expression", [HasOnlyGraphRegion, OpAsmOpInterface,
461+
IsolatedFromAbove,
462+
SingleBlockImplicitTerminator<"emitc::YieldOp">]> {
461463
let summary = "Expression operation";
462464
let description = [{
463465
The `emitc.expression` operation returns a single SSA value which is yielded by
@@ -494,12 +496,13 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
494496
at its use.
495497
}];
496498

497-
let arguments = (ins UnitAttr:$do_not_inline);
499+
let arguments = (ins Variadic<AnyTypeOf<[EmitCType, EmitC_LValueType]>>:$defs,
500+
UnitAttr:$do_not_inline);
498501
let results = (outs EmitCType:$result);
499502
let regions = (region SizedRegion<1>:$region);
500503

501504
let hasVerifier = 1;
502-
let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region";
505+
let hasCustomAssemblyFormat = 1;
503506

504507
let extraClassDeclaration = [{
505508
bool hasSideEffects() {
@@ -510,6 +513,13 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
510513
return llvm::any_of(getRegion().front().without_terminator(), predicate);
511514
};
512515
Operation *getRootOp();
516+
Block &createBody() {
517+
assert(getRegion().empty() && "expression already has a body");
518+
Block &block = getRegion().emplaceBlock();
519+
for (auto operand : getOperands())
520+
block.addArgument(operand.getType(), operand.getLoc());
521+
return block;
522+
}
513523

514524
//===------------------------------------------------------------------===//
515525
// OpAsmOpInterface Methods

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -610,16 +610,19 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
610610
? rewriter.getIntegerAttr(arithmeticType, 0)
611611
: rewriter.getIndexAttr(0)));
612612

613-
emitc::ExpressionOp ternary = emitc::ExpressionOp::create(
614-
rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false);
615-
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
613+
emitc::ExpressionOp ternary =
614+
emitc::ExpressionOp::create(rewriter, op.getLoc(), arithmeticType,
615+
ValueRange({lhs, rhs, excessCheck, poison}),
616+
/*do_not_inline=*/false);
617+
Block &bodyBlock = ternary.createBody();
616618
auto currentPoint = rewriter.getInsertionPoint();
617619
rewriter.setInsertionPointToStart(&bodyBlock);
618620
Value arithmeticResult =
619-
EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
620-
Value resultOrPoison =
621-
emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType,
622-
excessCheck, arithmeticResult, poison);
621+
EmitCOp::create(rewriter, op.getLoc(), arithmeticType,
622+
bodyBlock.getArgument(0), bodyBlock.getArgument(1));
623+
Value resultOrPoison = emitc::ConditionalOp::create(
624+
rewriter, op.getLoc(), arithmeticType, bodyBlock.getArgument(2),
625+
arithmeticResult, bodyBlock.getArgument(3));
623626
emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
624627
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
625628

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,52 @@ OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
381381
// ExpressionOp
382382
//===----------------------------------------------------------------------===//
383383

384+
ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
385+
SmallVector<OpAsmParser::UnresolvedOperand> operands;
386+
if (parser.parseOperandList(operands))
387+
return parser.emitError(parser.getCurrentLocation()) << "expected operands";
388+
if (succeeded(parser.parseOptionalKeyword("noinline")))
389+
result.addAttribute(ExpressionOp::getDoNotInlineAttrName(result.name),
390+
parser.getBuilder().getUnitAttr());
391+
Type type;
392+
if (parser.parseColonType(type))
393+
return parser.emitError(parser.getCurrentLocation(),
394+
"expected function type");
395+
auto fnType = llvm::dyn_cast<FunctionType>(type);
396+
if (!fnType)
397+
return parser.emitError(parser.getCurrentLocation(),
398+
"expected function type");
399+
if (parser.resolveOperands(operands, fnType.getInputs(),
400+
parser.getCurrentLocation(), result.operands))
401+
return failure();
402+
if (fnType.getNumResults() != 1)
403+
return parser.emitError(parser.getCurrentLocation(),
404+
"expected single return type");
405+
result.addTypes(fnType.getResults());
406+
Region *body = result.addRegion();
407+
SmallVector<OpAsmParser::Argument> argsInfo;
408+
for (auto [unresolvedOperand, operandType] :
409+
llvm::zip(operands, fnType.getInputs())) {
410+
OpAsmParser::Argument argInfo;
411+
argInfo.ssaName = unresolvedOperand;
412+
argInfo.type = operandType;
413+
argsInfo.push_back(argInfo);
414+
}
415+
if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
416+
return failure();
417+
return success();
418+
}
419+
420+
void emitc::ExpressionOp::print(OpAsmPrinter &p) {
421+
p << ' ';
422+
p.printOperands(getDefs());
423+
p << " : ";
424+
p.printFunctionalType(getOperation());
425+
p.shadowRegionArgs(getRegion(), getDefs());
426+
p << ' ';
427+
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
428+
}
429+
384430
Operation *ExpressionOp::getRootOp() {
385431
auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
386432
Value yieldedValue = yieldOp.getResult();

mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp

Lines changed: 99 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
1010
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1111
#include "mlir/IR/IRMapping.h"
12+
#include "mlir/IR/Location.h"
1213
#include "mlir/IR/PatternMatch.h"
14+
#include "llvm/ADT/STLExtras.h"
1315

1416
namespace mlir {
1517
namespace emitc {
@@ -24,20 +26,24 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
2426
Location loc = op->getLoc();
2527

2628
builder.setInsertionPointAfter(op);
27-
auto expressionOp = emitc::ExpressionOp::create(builder, loc, resultType);
29+
auto expressionOp =
30+
emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands());
2831

2932
// Replace all op's uses with the new expression's result.
3033
result.replaceAllUsesWith(expressionOp.getResult());
3134

32-
// Create an op to yield op's value.
33-
Region &region = expressionOp.getRegion();
34-
Block &block = region.emplaceBlock();
35+
Block &block = expressionOp.createBody();
36+
IRMapping mapper;
37+
for (auto [operand, arg] :
38+
llvm::zip(expressionOp.getOperands(), block.getArguments()))
39+
mapper.map(operand, arg);
3540
builder.setInsertionPointToEnd(&block);
36-
auto yieldOp = emitc::YieldOp::create(builder, loc, result);
3741

38-
// Move op into the new expression.
39-
op->moveBefore(yieldOp);
42+
Operation *rootOp = builder.clone(*op, mapper);
43+
op->erase();
4044

45+
// Create an op to yield op's value.
46+
emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]);
4147
return expressionOp;
4248
}
4349

@@ -53,51 +59,93 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
5359
using OpRewritePattern<ExpressionOp>::OpRewritePattern;
5460
LogicalResult matchAndRewrite(ExpressionOp expressionOp,
5561
PatternRewriter &rewriter) const override {
56-
bool anythingFolded = false;
57-
for (Operation &op : llvm::make_early_inc_range(
58-
expressionOp.getBody()->without_terminator())) {
59-
// Don't fold expressions whose result value has its address taken.
60-
auto applyOp = dyn_cast<emitc::ApplyOp>(op);
61-
if (applyOp && applyOp.getApplicableOperator() == "&")
62-
continue;
63-
64-
for (Value operand : op.getOperands()) {
65-
auto usedExpression = operand.getDefiningOp<ExpressionOp>();
66-
if (!usedExpression)
67-
continue;
68-
69-
// Don't fold expressions with multiple users: assume any
70-
// re-materialization was done separately.
71-
if (!usedExpression.getResult().hasOneUse())
72-
continue;
73-
74-
// Don't fold expressions with side effects.
75-
if (usedExpression.hasSideEffects())
76-
continue;
77-
78-
// Fold the used expression into this expression by cloning all
79-
// instructions in the used expression just before the operation using
80-
// its value.
81-
rewriter.setInsertionPoint(&op);
82-
IRMapping mapper;
83-
for (Operation &opToClone :
84-
usedExpression.getBody()->without_terminator()) {
85-
Operation *clone = rewriter.clone(opToClone, mapper);
86-
mapper.map(&opToClone, clone);
87-
}
88-
89-
Operation *expressionRoot = usedExpression.getRootOp();
90-
Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
91-
assert(clonedExpressionRootOp &&
92-
"Expected cloned expression root to be in mapper");
93-
assert(clonedExpressionRootOp->getNumResults() == 1 &&
94-
"Expected cloned root to have a single result");
95-
96-
rewriter.replaceOp(usedExpression, clonedExpressionRootOp);
97-
anythingFolded = true;
98-
}
62+
Block *expressionBody = expressionOp.getBody();
63+
ExpressionOp usedExpression;
64+
SetVector<Value> foldedOperands;
65+
66+
auto takesItsOperandsAddress = [](Operation *user) {
67+
auto applyOp = dyn_cast<emitc::ApplyOp>(user);
68+
return applyOp && applyOp.getApplicableOperator() == "&";
69+
};
70+
71+
// Select as expression to fold the first operand expression that
72+
// - doesn't have its result value's address taken,
73+
// - has a single user: assume any re-materialization was done separately,
74+
// - has no side effects,
75+
// and save all other operands to be used later as operands in the folded
76+
// expression.
77+
for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(),
78+
expressionBody->getArguments())) {
79+
ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>();
80+
if (usedExpression || !operandExpression ||
81+
llvm::any_of(arg.getUsers(), takesItsOperandsAddress) ||
82+
!operandExpression.getResult().hasOneUse() ||
83+
operandExpression.hasSideEffects())
84+
foldedOperands.insert(operand);
85+
else
86+
usedExpression = operandExpression;
9987
}
100-
return anythingFolded ? success() : failure();
88+
89+
// If no operand expression was selected, bail out.
90+
if (!usedExpression)
91+
return failure();
92+
93+
// Collect additional operands from the folded expression.
94+
for (Value operand : usedExpression.getOperands())
95+
foldedOperands.insert(operand);
96+
97+
// Create a new expression to hold the folding result.
98+
rewriter.setInsertionPointAfter(expressionOp);
99+
auto foldedExpression = emitc::ExpressionOp::create(
100+
rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
101+
foldedOperands.getArrayRef(), expressionOp.getDoNotInline());
102+
Block &foldedExpressionBody = foldedExpression.createBody();
103+
104+
// Map each operand of the new expression to its matching block argument.
105+
IRMapping mapper;
106+
for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(),
107+
foldedExpressionBody.getArguments()))
108+
mapper.map(operand, arg);
109+
110+
// Prepare to fold the used expression and the matched expression into the
111+
// newly created folded expression.
112+
auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
113+
bool withTerminator) {
114+
Block *expressionToFoldBody = expressionToFold.getBody();
115+
for (auto [operand, arg] :
116+
llvm::zip(expressionToFold.getOperands(),
117+
expressionToFoldBody->getArguments())) {
118+
mapper.map(arg, mapper.lookup(operand));
119+
}
120+
121+
for (Operation &opToClone : expressionToFoldBody->without_terminator())
122+
rewriter.clone(opToClone, mapper);
123+
124+
if (withTerminator)
125+
rewriter.clone(*expressionToFoldBody->getTerminator(), mapper);
126+
};
127+
rewriter.setInsertionPointToStart(&foldedExpressionBody);
128+
129+
// First, fold the used expression into the new expression and map its
130+
// result to the clone of its root operation within the new expression.
131+
foldExpression(usedExpression, /*withTerminator=*/false);
132+
Operation *expressionRoot = usedExpression.getRootOp();
133+
Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
134+
assert(clonedExpressionRootOp &&
135+
"Expected cloned expression root to be in mapper");
136+
assert(clonedExpressionRootOp->getNumResults() == 1 &&
137+
"Expected cloned root to have a single result");
138+
mapper.map(usedExpression.getResult(),
139+
clonedExpressionRootOp->getResults()[0]);
140+
141+
// Now fold the matched expression into the new expression.
142+
foldExpression(expressionOp, /*withTerminator=*/true);
143+
144+
// Complete the rewrite.
145+
rewriter.replaceOp(expressionOp, foldedExpression);
146+
rewriter.eraseOp(usedExpression);
147+
148+
return success();
101149
}
102150
};
103151

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/Dialect.h"
1515
#include "mlir/IR/Operation.h"
1616
#include "mlir/IR/SymbolTable.h"
17+
#include "mlir/IR/Value.h"
1718
#include "mlir/Support/IndentedOstream.h"
1819
#include "mlir/Support/LLVM.h"
1920
#include "mlir/Target/Cpp/CppEmitter.h"
@@ -364,9 +365,10 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
364365
if (hasDeferredEmission(user))
365366
return false;
366367

367-
// Do not inline expressions used by ops with the CExpressionInterface. If
368-
// this was intended, the user could have been merged into the expression op.
369-
return !isa<emitc::CExpressionInterface>(*user);
368+
// Do not inline expressions used by other expressions or by ops with the
369+
// CExpressionInterface. If this was intended, the user could have been merged
370+
// into the expression op.
371+
return !isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user);
370372
}
371373

372374
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
@@ -1532,6 +1534,20 @@ LogicalResult CppEmitter::emitOperand(Value value) {
15321534
if (expressionOp && shouldBeInlined(expressionOp))
15331535
return emitExpression(expressionOp);
15341536

1537+
if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
1538+
// If this operand is a block argument of an expression, emit instead the
1539+
// matching expression parameter.
1540+
Operation *argOp = arg.getParentBlock()->getParentOp();
1541+
if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) {
1542+
// This scenario is only expected when one of the operations within the
1543+
// expression being emitted references one of the expression's block
1544+
// arguments.
1545+
assert(expressionOp == emittedExpression &&
1546+
"Expected expression being emitted");
1547+
value = expressionOp->getOperand(arg.getArgNumber());
1548+
}
1549+
}
1550+
15351551
os << getOrCreateName(value);
15361552
return success();
15371553
}

0 commit comments

Comments
 (0)