Skip to content

Commit

Permalink
[mlir] SCF: provide function_ref builders for IfOp
Browse files Browse the repository at this point in the history
Now that OpBuilder is available in `build` functions, it becomes possible to
populate the "then" and "else" regions directly when building the "if"
operation. This is desirable in more structured forms of builders, especially
in when conditionals are mixed with loops. Provide new `build` APIs taking
callbacks for body constructors, similarly to scf::ForOp, and replace more
clunky edsc::BlockBuilder uses with these. The original APIs remain available
and go through the new implementation.

Differential Revision: https://reviews.llvm.org/D80527
  • Loading branch information
ftynse committed May 27, 2020
1 parent 5ee902b commit cadb7cc
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 46 deletions.
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/SCF/EDSC/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ scf::ValueVector loopNestBuilder(
Value lb, Value ub, Value step, ValueRange iterArgInitValues,
function_ref<scf::ValueVector(Value, ValueRange)> fun = nullptr);

/// Adapters for building if conditions using the builder and the location
/// stored in ScopedContext. 'thenBody' is mandatory, 'elseBody' can be omitted
/// if the condition should not have an 'else' part.
ValueRange
conditionBuilder(TypeRange results, Value condition,
function_ref<scf::ValueVector()> thenBody,
function_ref<scf::ValueVector()> elseBody = nullptr);
ValueRange conditionBuilder(Value condition, function_ref<void()> thenBody,
function_ref<void()> elseBody = nullptr);

} // namespace edsc
} // namespace mlir

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/SCF/SCF.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
namespace mlir {
namespace scf {

void buildTerminatedBody(OpBuilder &builder, Location loc);

#include "mlir/Dialect/SCF/SCFOpsDialect.h.inc"

#define GET_OP_CLASSES
Expand Down
13 changes: 12 additions & 1 deletion mlir/include/mlir/Dialect/SCF/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,18 @@ def IfOp : SCF_Op<"if",
OpBuilder<"OpBuilder &builder, OperationState &result, "
"Value cond, bool withElseRegion">,
OpBuilder<"OpBuilder &builder, OperationState &result, "
"TypeRange resultTypes, Value cond, bool withElseRegion">
"TypeRange resultTypes, Value cond, bool withElseRegion">,
OpBuilder<
"OpBuilder &builder, OperationState &result, TypeRange resultTypes, "
"Value cond, "
"function_ref<void(OpBuilder &, Location)> thenBuilder "
" = buildTerminatedBody, "
"function_ref<void(OpBuilder &, Location)> elseBuilder = nullptr">,
OpBuilder<
"OpBuilder &builder, OperationState &result, Value cond, "
"function_ref<void(OpBuilder &, Location)> thenBuilder "
" = buildTerminatedBody, "
"function_ref<void(OpBuilder &, Location)> elseBuilder = nullptr">
];

let extraClassDeclaration = [{
Expand Down
70 changes: 33 additions & 37 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,39 +235,38 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
SmallVector<Type, 1> resultType;
if (options.unroll)
resultType.push_back(vectorType);
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), resultType, inBoundsCondition,
/*withElseRegion=*/true);

// 3.a. If in-bounds, progressively lower to a 1-D transfer read.
BlockBuilder(&ifOp.thenRegion().front(), Append())([&] {
Value vector = load1DVector(majorIvsPlusOffsets);
// 3.a.i. If `options.unroll` is true, insert the 1-D vector in the
// aggregate. We must yield and merge with the `else` branch.
if (options.unroll) {
vector = vector_insert(vector, result, majorIvs);
(loop_yield(vector));
return;
}
// 3.a.ii. Otherwise, just go through the temporary `alloc`.
std_store(vector, alloc, majorIvs);
});

// 3.b. If not in-bounds, splat a 1-D vector.
BlockBuilder(&ifOp.elseRegion().front(), Append())([&] {
Value vector = std_splat(minorVectorType, xferOp.padding());
// 3.a.i. If `options.unroll` is true, insert the 1-D vector in the
// aggregate. We must yield and merge with the `then` branch.
if (options.unroll) {
vector = vector_insert(vector, result, majorIvs);
(loop_yield(vector));
return;
}
// 3.b.ii. Otherwise, just go through the temporary `alloc`.
std_store(vector, alloc, majorIvs);
});

// 3. If in-bounds, progressively lower to a 1-D transfer read, otherwise
// splat a 1-D vector.
ValueRange ifResults = conditionBuilder(
resultType, inBoundsCondition,
[&]() -> scf::ValueVector {
Value vector = load1DVector(majorIvsPlusOffsets);
// 3.a. If `options.unroll` is true, insert the 1-D vector in the
// aggregate. We must yield and merge with the `else` branch.
if (options.unroll) {
vector = vector_insert(vector, result, majorIvs);
return {vector};
}
// 3.b. Otherwise, just go through the temporary `alloc`.
std_store(vector, alloc, majorIvs);
return {};
},
[&]() -> scf::ValueVector {
Value vector = std_splat(minorVectorType, xferOp.padding());
// 3.c. If `options.unroll` is true, insert the 1-D vector in the
// aggregate. We must yield and merge with the `then` branch.
if (options.unroll) {
vector = vector_insert(vector, result, majorIvs);
return {vector};
}
// 3.d. Otherwise, just go through the temporary `alloc`.
std_store(vector, alloc, majorIvs);
return {};
});

if (!resultType.empty())
result = *ifOp.results().begin();
result = *ifResults.begin();
} else {
// 4. Guaranteed in-bounds, progressively lower to a 1-D transfer read.
Value loaded1D = load1DVector(majorIvsPlusOffsets);
Expand Down Expand Up @@ -336,11 +335,8 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
if (inBoundsCondition) {
// 2.a. If the condition is not null, we need an IfOp, to write
// conditionally. Progressively lower to a 1-D transfer write.
auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), TypeRange{}, inBoundsCondition,
/*withElseRegion=*/false);
BlockBuilder(&ifOp.thenRegion().front(),
Append())([&] { emitTransferWrite(majorIvsPlusOffsets); });
conditionBuilder(inBoundsCondition,
[&] { emitTransferWrite(majorIvsPlusOffsets); });
} else {
// 2.b. Guaranteed in-bounds. Progressively lower to a 1-D transfer write.
emitTransferWrite(majorIvsPlusOffsets);
Expand Down
48 changes: 48 additions & 0 deletions mlir/lib/Dialect/SCF/EDSC/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,51 @@ mlir::scf::ValueVector mlir::edsc::loopNestBuilder(
iterArgInitValues.end());
});
}

static std::function<void(OpBuilder &, Location)>
wrapIfBody(function_ref<scf::ValueVector()> body, TypeRange expectedTypes) {
(void)expectedTypes;
return [=](OpBuilder &builder, Location loc) {
ScopedContext context(builder, loc);
scf::ValueVector returned = body();
assert(ValueRange(returned).getTypes() == expectedTypes &&
"'if' body builder returned values of unexpected type");
builder.create<scf::YieldOp>(loc, returned);
};
}

ValueRange
mlir::edsc::conditionBuilder(TypeRange results, Value condition,
function_ref<scf::ValueVector()> thenBody,
function_ref<scf::ValueVector()> elseBody) {
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
assert(thenBody && "thenBody is mandatory");

auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), results, condition,
wrapIfBody(thenBody, results), wrapIfBody(elseBody, results));
return ifOp.getResults();
}

static std::function<void(OpBuilder &, Location)>
wrapZeroResultIfBody(function_ref<void()> body) {
return [=](OpBuilder &builder, Location loc) {
ScopedContext context(builder, loc);
body();
builder.create<scf::YieldOp>(loc);
};
}

ValueRange mlir::edsc::conditionBuilder(Value condition,
function_ref<void()> thenBody,
function_ref<void()> elseBody) {
assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
assert(thenBody && "thenBody is mandatory");

ScopedContext::getBuilderRef().create<scf::IfOp>(
ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody),
elseBody ? llvm::function_ref<void(OpBuilder &, Location)>(
wrapZeroResultIfBody(elseBody))
: llvm::function_ref<void(OpBuilder &, Location)>(nullptr));
return {};
}
44 changes: 36 additions & 8 deletions mlir/lib/Dialect/SCF/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ SCFDialect::SCFDialect(MLIRContext *context)
>();
}

/// Default callback for IfOp builders. Inserts a yield without arguments.
void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc);
}

//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -338,20 +343,43 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,

void IfOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value cond, bool withElseRegion) {
auto addTerminator = [&](OpBuilder &nested, Location loc) {
if (resultTypes.empty())
IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
loc);
};

build(builder, result, resultTypes, cond, addTerminator,
withElseRegion ? addTerminator
: function_ref<void(OpBuilder &, Location)>());
}

void IfOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, Value cond,
function_ref<void(OpBuilder &, Location)> thenBuilder,
function_ref<void(OpBuilder &, Location)> elseBuilder) {
assert(thenBuilder && "the builder callback for 'then' must be present");

result.addOperands(cond);
result.addTypes(resultTypes);

OpBuilder::InsertionGuard guard(builder);
Region *thenRegion = result.addRegion();
thenRegion->push_back(new Block());
if (resultTypes.empty())
IfOp::ensureTerminator(*thenRegion, builder, result.location);
builder.createBlock(thenRegion);
thenBuilder(builder, result.location);

Region *elseRegion = result.addRegion();
if (withElseRegion) {
elseRegion->push_back(new Block());
if (resultTypes.empty())
IfOp::ensureTerminator(*elseRegion, builder, result.location);
}
if (!elseBuilder)
return;

builder.createBlock(elseRegion);
elseBuilder(builder, result.location);
}

void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
function_ref<void(OpBuilder &, Location)> thenBuilder,
function_ref<void(OpBuilder &, Location)> elseBuilder) {
build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
}

static LogicalResult verify(IfOp op) {
Expand Down

0 comments on commit cadb7cc

Please sign in to comment.