Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][emitc] Add a structured for operation #68206

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions mlir/docs/Dialects/emitc.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,5 @@ translating the following operations:
* `func.constant`
* `func.func`
* `func.return`
* 'scf' Dialect
* `scf.for`
* `scf.yield`
* 'arith' Dialect
* `arith.constant`
64 changes: 63 additions & 1 deletion mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,67 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let results = (outs FloatIntegerIndexOrOpaqueType);
}

def EmitC_ForOp : EmitC_Op<"for",
[AllTypesMatch<["lowerBound", "upperBound", "step"]>,
SingleBlockImplicitTerminator<"emitc::YieldOp">,
RecursiveMemoryEffects]> {
let summary = "for operation";
let description = [{
The `emitc.for` operation represents a C loop of the following form:

```c++
for (T i = lb; i < ub; i += step) { /* ... */ } // where T is typeof(lb)
```

The operation takes 3 SSA values as operands that represent the lower bound,
upper bound and step respectively, and defines an SSA value for its
induction variable. It has one region capturing the loop body. The induction
variable is represented as an argument of this region. This SSA value is a
signless integer or index. The step is a value of same type.

This operation has no result. The body region must contain exactly one block
that terminates with `emitc.yield`. Calling ForOp::build will create such a
region and insert the terminator implicitly if none is defined, so will the
parsing even in cases when it is absent from the custom format. For example:

```mlir
// Index case.
emitc.for %iv = %lb to %ub step %step {
... // body
}
...
// Integer case.
emitc.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
... // body
}
```
}];
let arguments = (ins IntegerIndexOrOpaqueType:$lowerBound,
IntegerIndexOrOpaqueType:$upperBound,
IntegerIndexOrOpaqueType:$step);
let results = (outs);
let regions = (region SizedRegion<1>:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$lowerBound, "Value":$upperBound, "Value":$step,
CArg<"function_ref<void(OpBuilder &, Location, Value)>", "nullptr">)>
];

let extraClassDeclaration = [{
using BodyBuilderFn =
function_ref<void(OpBuilder &, Location, Value)>;
Value getInductionVar() { return getBody()->getArgument(0); }
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
void setStep(Value step) { getOperation()->setOperand(2, step); }
}];

let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasRegionVerifier = 1;
}

def EmitC_IncludeOp
: EmitC_Op<"include", [HasParent<"ModuleOp">]> {
let summary = "Include operation";
Expand Down Expand Up @@ -430,7 +491,8 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
let assemblyFormat = "$value `:` type($value) `to` $var `:` type($var) attr-dict";
}

def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]> {
def EmitC_YieldOp : EmitC_Op<"yield",
[Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> {
let summary = "block termination operation";
let description = [{
"yield" terminates blocks within EmitC control-flow operations. Since
Expand Down
123 changes: 99 additions & 24 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,100 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
void runOnOperation() override;
};

// Lower scf::if to emitc::if, implementing return values as emitc::variable's
// Lower scf::for to emitc::for, implementing result values using
// emitc::variable's updated within the loop body.
struct ForLowering : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override;
};

// Create an uninitialized emitc::variable op for each result of the given op.
template <typename T>
static SmallVector<Value> createVariablesForResults(T op,
PatternRewriter &rewriter) {
SmallVector<Value> resultVariables;

if (!op.getNumResults())
return resultVariables;

Location loc = op->getLoc();
MLIRContext *context = op.getContext();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);

for (OpResult result : op.getResults()) {
Type resultType = result.getType();
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
emitc::VariableOp var =
rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
resultVariables.push_back(var);
}

return resultVariables;
}

// Create a series of assign ops assigning given values to given variables at
// the current insertion point of given rewriter.
static void assignValues(ValueRange values, SmallVector<Value> &variables,
PatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
rewriter.create<emitc::AssignOp>(loc, var, value);
}

static void lowerYield(SmallVector<Value> &resultVariables,
PatternRewriter &rewriter, scf::YieldOp yield) {
Location loc = yield.getLoc();
ValueRange operands = yield.getOperands();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);

assignValues(operands, resultVariables, rewriter, loc);

rewriter.create<emitc::YieldOp>(loc);
rewriter.eraseOp(yield);
}

LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
Location loc = forOp.getLoc();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
SmallVector<Value> resultVariables =
createVariablesForResults(forOp, rewriter);
SmallVector<Value> iterArgsVariables =
createVariablesForResults(forOp, rewriter);

assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);

emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());

Block *loweredBody = loweredFor.getBody();

// Erase the auto-generated terminator for the lowered for op.
rewriter.eraseOp(loweredBody->getTerminator());

SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());

rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
lowerYield(iterArgsVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));

// Copy iterArgs into results after the for loop.
assignValues(iterArgsVariables, resultVariables, rewriter, loc);

rewriter.replaceOp(forOp, resultVariables);
return success();
}

// Lower scf::if to emitc::if, implementing result values as emitc::variable's
// updated within the then and else regions.
struct IfLowering : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
Expand All @@ -52,20 +145,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
Location loc = ifOp.getLoc();

SmallVector<Value> resultVariables;

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the then & else regions.
if (ifOp.getNumResults()) {
MLIRContext *context = ifOp.getContext();
rewriter.setInsertionPoint(ifOp);
for (OpResult result : ifOp.getResults()) {
Type resultType = result.getType();
auto noInit = emitc::OpaqueAttr::get(context, "");
auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
resultVariables.push_back(var);
}
}
SmallVector<Value> resultVariables =
createVariablesForResults(ifOp, rewriter);

// Utility function to lower the contents of an scf::if region to an emitc::if
// region. The contents of the scf::if regions is moved into the respective
Expand All @@ -76,16 +159,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
Location terminatorLoc = terminator->getLoc();
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(&loweredRegion.back());
for (auto value2Var : llvm::zip(terminatorOperands, resultVariables)) {
Value resultValue = std::get<0>(value2Var);
Value resultVar = std::get<1>(value2Var);
rewriter.create<emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
}
rewriter.create<emitc::YieldOp>(terminatorLoc);
rewriter.eraseOp(terminator);
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
};

Region &thenRegion = ifOp.getThenRegion();
Expand All @@ -109,6 +183,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
}

void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ForLowering>(patterns.getContext());
patterns.add<IfLowering>(patterns.getContext());
}

Expand All @@ -118,7 +193,7 @@ void SCFToEmitCPass::runOnOperation() {

// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
target.addIllegalOp<scf::IfOp>();
target.addIllegalOp<scf::ForOp, scf::IfOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
Expand Down
95 changes: 95 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,101 @@ LogicalResult emitc::ConstantOp::verify() {

OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }

//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//

void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
Value ub, Value step, BodyBuilderFn bodyBuilder) {
result.addOperands({lb, ub, step});
Type t = lb.getType();
Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block);
Block &bodyBlock = bodyRegion->front();
bodyBlock.addArgument(t, result.location);

// Create the default terminator if the builder is not provided.
if (!bodyBuilder) {
ForOp::ensureTerminator(*bodyRegion, builder, result.location);
} else {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(&bodyBlock);
bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
}
}

void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}

ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
Builder &builder = parser.getBuilder();
Type type;

OpAsmParser::Argument inductionVariable;
OpAsmParser::UnresolvedOperand lb, ub, step;

// Parse the induction variable followed by '='.
if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
// Parse loop bounds.
parser.parseOperand(lb) || parser.parseKeyword("to") ||
parser.parseOperand(ub) || parser.parseKeyword("step") ||
parser.parseOperand(step))
return failure();

// Parse the optional initial iteration arguments.
SmallVector<OpAsmParser::Argument, 4> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
regionArgs.push_back(inductionVariable);

// Parse optional type, else assume Index.
if (parser.parseOptionalColon())
type = builder.getIndexType();
else if (parser.parseType(type))
return failure();

// Resolve input operands.
regionArgs.front().type = type;
if (parser.resolveOperand(lb, type, result.operands) ||
parser.resolveOperand(ub, type, result.operands) ||
parser.resolveOperand(step, type, result.operands))
return failure();

// Parse the body region.
Region *body = result.addRegion();
if (parser.parseRegion(*body, regionArgs))
return failure();

ForOp::ensureTerminator(*body, builder, result.location);

// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();

return success();
}

void ForOp::print(OpAsmPrinter &p) {
p << " " << getInductionVar() << " = " << getLowerBound() << " to "
<< getUpperBound() << " step " << getStep();

p << ' ';
if (Type t = getInductionVar().getType(); !t.isIndex())
p << " : " << t << ' ';
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
p.printOptionalAttrDict((*this)->getAttrs());
}

LogicalResult ForOp::verifyRegions() {
// Check that the body defines as single block argument for the induction
// variable.
if (getInductionVar().getType() != getLowerBound().getType())
return emitOpError(
"expected induction variable to be same type as bounds and step");

return success();
}

//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//
Expand Down