Skip to content

Commit

Permalink
[mlir][EDSC] Make use of InsertGuard
Browse files Browse the repository at this point in the history
Summary:
This revision cleans up a layer of complexity in ScopedContext and uses InsertGuard instead of previously manual bookkeeping.
The method `getBuilder` is renamed to `getBuilderRef` and spurious copies of OpBuilder are tracked.

This results in some canonicalizations not happening anymore in the Linalg matmul to vector test. This test is retired because relying on DRRs for this has been shaky at best. The solution will be better support to write fused passes in C++ with more idiomatic pattern composition and application.

Differential Revision: https://reviews.llvm.org/D79208
  • Loading branch information
Nicolas Vasilache committed Apr 30, 2020
1 parent 45b7d44 commit 0d61dcf
Show file tree
Hide file tree
Showing 16 changed files with 55 additions and 188 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ struct FoldedValueBuilder {
// Builder-based
template <typename... Args>
FoldedValueBuilder(OperationFolder *folder, Args... args) {
value = folder ? folder->create<Op>(ScopedContext::getBuilder(),
value = folder ? folder->create<Op>(ScopedContext::getBuilderRef(),
ScopedContext::getLocation(), args...)
: ScopedContext::getBuilder().create<Op>(
: ScopedContext::getBuilderRef().create<Op>(
ScopedContext::getLocation(), args...);
}

Expand Down
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
set(LLVM_TARGET_DEFINITIONS LinalgTransformPatterns.td)
mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters)
mlir_tablegen(LinalgTransformPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRLinalgTransformPatternsIncGen)

Expand Down
26 changes: 10 additions & 16 deletions mlir/include/mlir/EDSC/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ class NestedBuilder;
/// setting and restoring of insertion points.
class ScopedContext {
public:
ScopedContext(OpBuilder &builder, Location location);
ScopedContext(OpBuilder &b, Location location);

/// Sets the insertion point of the builder to 'newInsertPt' for the duration
/// of the scope. The existing insertion point of the builder is restored on
/// destruction.
ScopedContext(OpBuilder &builder, OpBuilder::InsertPoint newInsertPt,
ScopedContext(OpBuilder &b, OpBuilder::InsertPoint newInsertPt,
Location location);
~ScopedContext();

static MLIRContext *getContext();
static OpBuilder &getBuilder();
static OpBuilder &getBuilderRef();
static Location getLocation();

private:
Expand All @@ -59,22 +59,19 @@ class ScopedContext {

/// Top level OpBuilder.
OpBuilder &builder;
/// The previous insertion point of the builder.
Optional<OpBuilder::InsertPoint> prevBuilderInsertPoint;
/// Guard to the previous insertion point.
OpBuilder::InsertionGuard guard;
/// Current location.
Location location;
/// Parent context we return into.
ScopedContext *enclosingScopedContext;
/// Defensively keeps track of the current NestedBuilder to ensure proper
/// scoping usage.
NestedBuilder *nestedBuilder;
};

template <typename Op>
struct ValueBuilder {
template <typename... Args>
ValueBuilder(Args... args) {
value = ScopedContext::getBuilder()
value = ScopedContext::getBuilderRef()
.create<Op>(ScopedContext::getLocation(), args...)
.getResult();
}
Expand All @@ -86,8 +83,8 @@ template <typename Op>
struct OperationBuilder {
template <typename... Args>
OperationBuilder(Args... args) {
op = ScopedContext::getBuilder().create<Op>(ScopedContext::getLocation(),
args...);
op = ScopedContext::getBuilderRef().create<Op>(ScopedContext::getLocation(),
args...);
}
operator Op() { return op; }
operator Operation *() { return op.getOperation(); }
Expand Down Expand Up @@ -122,22 +119,19 @@ class NestedBuilder {
/// let the escape.
void enter(mlir::Block *block) {
bodyScope = new ScopedContext(
ScopedContext::getBuilder(),
ScopedContext::getBuilderRef(),
OpBuilder::InsertPoint(block, std::prev(block->end())),
ScopedContext::getLocation());
if (!block->empty()) {
auto &termOp = block->back();
if (termOp.isKnownTerminator())
ScopedContext::getBuilder().setInsertionPoint(&termOp);
ScopedContext::getBuilderRef().setInsertionPoint(&termOp);
}
bodyScope->nestedBuilder = this;
}

/// Exit the current mlir::Block by explicitly deleting the dynamically
/// allocated OpBuilder and ScopedContext.
void exit() {
// Reclaim now to exit the scope.
bodyScope->nestedBuilder = nullptr;
delete bodyScope;
bodyScope = nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class BaseViewConversionHelper {
operator Value() { return d; }

private:
OpBuilder &rewriter() { return ScopedContext::getBuilder(); }
OpBuilder &rewriter() { return ScopedContext::getBuilderRef(); }
Location loc() { return ScopedContext::getLocation(); }

MemRefDescriptor d;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ void NDTransferOpHelper<ConcreteOp>::emitInBounds(
inBounds = inBounds && inBounds2;
}

auto ifOp = ScopedContext::getBuilder().create<loop::IfOp>(
auto ifOp = ScopedContext::getBuilderRef().create<loop::IfOp>(
ScopedContext::getLocation(), TypeRange{}, inBounds,
/*withElseRegion=*/std::is_same<ConcreteOp, TransferReadOp>());
BlockBuilder(&ifOp.thenRegion().front(),
Expand Down
29 changes: 16 additions & 13 deletions mlir/lib/Dialect/Affine/EDSC/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static Optional<Value> emitStaticFor(ArrayRef<Value> lbs, ArrayRef<Value> ubs,
auto ubConst = dyn_cast<ConstantIndexOp>(ubDef);
if (!lbConst || !ubConst)
return Optional<Value>();
return ScopedContext::getBuilder()
return ScopedContext::getBuilderRef()
.create<AffineForOp>(ScopedContext::getLocation(), lbConst.getValue(),
ubConst.getValue(), step)
.getInductionVar();
Expand All @@ -38,16 +38,19 @@ LoopBuilder mlir::edsc::makeAffineLoopBuilder(Value *iv, ArrayRef<Value> lbs,
ArrayRef<Value> ubs,
int64_t step) {
mlir::edsc::LoopBuilder result;
if (auto staticForIv = emitStaticFor(lbs, ubs, step)) {
if (auto staticForIv = emitStaticFor(lbs, ubs, step))
*iv = staticForIv.getValue();
} else {
auto b = ScopedContext::getBuilder();
*iv =
Value(b.create<AffineForOp>(ScopedContext::getLocation(), lbs,
b.getMultiDimIdentityMap(lbs.size()), ubs,
b.getMultiDimIdentityMap(ubs.size()), step)
.getInductionVar());
}
else
*iv = ScopedContext::getBuilderRef()
.create<AffineForOp>(
ScopedContext::getLocation(), lbs,
ScopedContext::getBuilderRef().getMultiDimIdentityMap(
lbs.size()),
ubs,
ScopedContext::getBuilderRef().getMultiDimIdentityMap(
ubs.size()),
step)
.getInductionVar();

auto *body = getForInductionVarOwner(*iv).getBody();
result.enter(body);
Expand Down Expand Up @@ -122,7 +125,7 @@ static Value createBinaryIndexHandle(

// TODO: createOrFold when available.
Operation *op =
makeComposedAffineApply(ScopedContext::getBuilder(),
makeComposedAffineApply(ScopedContext::getBuilderRef(),
ScopedContext::getLocation(), map, operands)
.getOperation();
assert(op->getNumResults() == 1 && "Expected single result AffineApply");
Expand Down Expand Up @@ -218,7 +221,7 @@ static Value createIComparisonExpr(CmpIPredicate predicate, Value lhs,
assert((lhsType.isa<IndexType>() || lhsType.isSignlessInteger()) &&
"only integer comparisons are supported");

return ScopedContext::getBuilder().create<CmpIOp>(
return ScopedContext::getBuilderRef().create<CmpIOp>(
ScopedContext::getLocation(), predicate, lhs, rhs);
}

Expand All @@ -231,7 +234,7 @@ static Value createFComparisonExpr(CmpFPredicate predicate, Value lhs,
assert(lhsType == rhsType && "cannot mix types in operators");
assert(lhsType.isa<FloatType>() && "only float comparisons are supported");

return ScopedContext::getBuilder().create<CmpFOp>(
return ScopedContext::getBuilderRef().create<CmpFOp>(
ScopedContext::getLocation(), predicate, lhs, rhs);
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Operation *mlir::edsc::makeGenericLinalgOp(
assert(!(outputs[i].getType().isa<RankedTensorType>() &&
outputs[i + 1].getType().isa<MemRefType>()) &&
"output tensors must be passed after output buffers");
auto &builder = edsc::ScopedContext::getBuilder();
auto &builder = edsc::ScopedContext::getBuilderRef();
auto *ctx = builder.getContext();
unsigned nInputs = inputs.size();
unsigned nOutputs = outputs.size();
Expand Down Expand Up @@ -157,7 +157,7 @@ Operation *mlir::edsc::makeGenericLinalgOp(
llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString));
// clang-format off
auto *op =
edsc::ScopedContext::getBuilder()
edsc::ScopedContext::getBuilderRef()
.create<linalg::GenericOp>(
edsc::ScopedContext::getLocation(),
types,
Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,

static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
Optional<AffineMap> permutation) {
return permutation ? applyMapToValues(ScopedContext::getBuilder(),
return permutation ? applyMapToValues(ScopedContext::getBuilderRef(),
ScopedContext::getLocation(),
permutation.getValue(), ivs)
: SmallVector<Value, 4>(ivs.begin(), ivs.end());
Expand Down Expand Up @@ -82,7 +82,7 @@ template <typename IndexedValueType, typename OpType>
static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
ArrayRef<SmallVector<Value, 8>> indexing,
ArrayRef<Value> outputBuffers) {
auto &b = ScopedContext::getBuilder();
auto &b = ScopedContext::getBuilderRef();
auto &block = op.region().front();
BlockAndValueMapping map;
map.map(block.getArguments(), indexedValues);
Expand Down Expand Up @@ -110,7 +110,7 @@ struct InputAndOutputIndices {
template <typename SingleInputPoolingOp>
static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
SingleInputPoolingOp op) {
auto &b = ScopedContext::getBuilder();
auto &b = ScopedContext::getBuilderRef();
auto loc = ScopedContext::getLocation();
auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>();
auto maps = llvm::to_vector<8>(
Expand Down Expand Up @@ -159,7 +159,7 @@ class LinalgScopedEmitter {
LinalgOpType linalgOp) {
assert(linalgOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto &b = ScopedContext::getBuilderRef();
auto loc = ScopedContext::getLocation();
unsigned nInputs = linalgOp.getNumInputs();
unsigned nOutputs = linalgOp.getNumOutputs();
Expand Down Expand Up @@ -331,7 +331,7 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
affine_max(dim.getType(), maxMap, ValueRange{dim}));
}

auto b = ScopedContext::getBuilder();
auto &b = ScopedContext::getBuilderRef();
Type type = convOp.input().getType().cast<MemRefType>().getElementType();
Value zero = std_constant(type, b.getZeroAttr(type));
Value readInput = im(clampedImIdx);
Expand All @@ -342,7 +342,7 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
assert(convOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto &b = ScopedContext::getBuilderRef();
auto loc = ScopedContext::getLocation();
auto mapsRange = convOp.indexing_maps().getAsRange<AffineMapAttr>();
auto maps = llvm::to_vector<8>(llvm::map_range(
Expand Down Expand Up @@ -445,7 +445,7 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
IndexedGenericOp indexedGenericOp) {
assert(indexedGenericOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto b = ScopedContext::getBuilder();
auto &b = ScopedContext::getBuilderRef();
auto loc = ScopedContext::getLocation();
unsigned nInputs = indexedGenericOp.getNumInputs();
unsigned nOutputs = indexedGenericOp.getNumOutputs();
Expand Down Expand Up @@ -606,7 +606,7 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,

SmallVector<Value, 4> allIvs(nLoops);
auto loopRanges =
emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
emitLoopRanges(scope.getBuilderRef(), scope.getLocation(), invertedMap,
getViewSizes(rewriter, linalgOp));
assert(loopRanges.size() == allIvs.size());
Impl::doit(linalgOp, loopRanges, allIvs);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
linalg_range(range.offset, range.size, range.stride));
}
GenericLoopNestRangeBuilder<LoopTy>(ivs, linalgRanges)([&] {
auto b = ScopedContext::getBuilder();
auto &b = ScopedContext::getBuilderRef();
auto loc = ScopedContext::getLocation();
SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end());

Expand Down
37 changes: 14 additions & 23 deletions mlir/lib/EDSC/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,25 @@
using namespace mlir;
using namespace mlir::edsc;

mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder, Location location)
: builder(builder), location(location),
enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
nestedBuilder(nullptr) {
mlir::edsc::ScopedContext::ScopedContext(OpBuilder &b, Location location)
: builder(b), guard(builder), location(location),
enclosingScopedContext(ScopedContext::getCurrentScopedContext()) {
getCurrentScopedContext() = this;
}

/// Sets the insertion point of the builder to 'newInsertPt' for the duration
/// of the scope. The existing insertion point of the builder is restored on
/// destruction.
mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder,
mlir::edsc::ScopedContext::ScopedContext(OpBuilder &b,
OpBuilder::InsertPoint newInsertPt,
Location location)
: builder(builder), prevBuilderInsertPoint(builder.saveInsertionPoint()),
location(location),
enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
nestedBuilder(nullptr) {
: builder(b), guard(builder), location(location),
enclosingScopedContext(ScopedContext::getCurrentScopedContext()) {
getCurrentScopedContext() = this;
builder.restoreInsertionPoint(newInsertPt);
}

mlir::edsc::ScopedContext::~ScopedContext() {
assert(!nestedBuilder &&
"Active NestedBuilder must have been exited at this point!");
if (prevBuilderInsertPoint)
builder.restoreInsertionPoint(*prevBuilderInsertPoint);
getCurrentScopedContext() = enclosingScopedContext;
}

Expand All @@ -49,7 +42,7 @@ ScopedContext *&mlir::edsc::ScopedContext::getCurrentScopedContext() {
return context;
}

OpBuilder &mlir::edsc::ScopedContext::getBuilder() {
OpBuilder &mlir::edsc::ScopedContext::getBuilderRef() {
assert(ScopedContext::getCurrentScopedContext() &&
"Unexpected Null ScopedContext");
return ScopedContext::getCurrentScopedContext()->builder;
Expand All @@ -62,15 +55,15 @@ Location mlir::edsc::ScopedContext::getLocation() {
}

MLIRContext *mlir::edsc::ScopedContext::getContext() {
return getBuilder().getContext();
return getBuilderRef().getContext();
}

BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) {
auto &currentB = ScopedContext::getBuilder();
auto &currentB = ScopedContext::getBuilderRef();
auto *ib = currentB.getInsertionBlock();
auto ip = currentB.getInsertionPoint();
BlockHandle res;
res.block = ScopedContext::getBuilder().createBlock(ib->getParent());
res.block = ScopedContext::getBuilderRef().createBlock(ib->getParent());
// createBlock sets the insertion point inside the block.
// We do not want this behavior when using declarative builders with nesting.
currentB.setInsertionPoint(ib, ip);
Expand All @@ -82,17 +75,15 @@ BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) {

BlockHandle mlir::edsc::BlockHandle::createInRegion(Region &region,
ArrayRef<Type> argTypes) {
auto &currentB = ScopedContext::getBuilder();
BlockHandle res;
region.push_back(new Block);
res.block = &region.back();
// createBlock sets the insertion point inside the block.
// We do not want this behavior when using declarative builders with nesting.
OpBuilder::InsertionGuard g(currentB);
currentB.setInsertionPoint(res.block, res.block->begin());
for (auto t : argTypes) {
res.block->addArgument(t);
}
OpBuilder::InsertionGuard g(ScopedContext::getBuilderRef());
ScopedContext::getBuilderRef().setInsertionPoint(res.block,
res.block->begin());
res.block->addArguments(argTypes);
return res;
}

Expand Down
16 changes: 0 additions & 16 deletions mlir/test/Dialect/Linalg/matmul-to-vector.mlir

This file was deleted.

6 changes: 0 additions & 6 deletions mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,3 @@ add_dependencies(MLIRTestLinalgTransformPatternsIncGen LinalgOdsGen)
set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td)
mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)

set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td)
mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen)
# Including Linalg in TableGen requires to depends on generated files
add_dependencies(MLIRTestLinalgTransformPatternsIncGen LinalgOdsGen)
Loading

0 comments on commit 0d61dcf

Please sign in to comment.