Skip to content

Commit

Permalink
Make OpBuilder::insert virtual instead of OpBuilder::createOperation.
Browse files Browse the repository at this point in the history
It is sometimes useful to create operations separately from the builder before insertion as it may be easier to erase them in isolation if necessary. One example use case for this is folding, as we will only want to insert newly generated constant operations on success. This has the added benefit of fixing some silent PatternRewriter failures related to cloning, as the OpBuilder 'clone' methods don't call createOperation.

PiperOrigin-RevId: 285086242
  • Loading branch information
River707 authored and tensorflower-gardener committed Dec 12, 2019
1 parent 9dfa84a commit 851a851
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 57 deletions.
29 changes: 11 additions & 18 deletions mlir/include/mlir/IR/Builders.h
Expand Up @@ -281,6 +281,9 @@ class OpBuilder : public Builder {
/// Returns the current insertion point of the builder.
Block::iterator getInsertionPoint() const { return insertPoint; }

/// Insert the given operation at the current insertion point and return it.
virtual Operation *insert(Operation *op);

/// Add new block and set the insertion point to the end of it. The block is
/// inserted at the provided insertion point of 'parent'.
Block *createBlock(Region *parent, Region::iterator insertPt = {});
Expand All @@ -293,7 +296,7 @@ class OpBuilder : public Builder {
Block *getBlock() const { return block; }

/// Creates an operation given the fields represented as an OperationState.
virtual Operation *createOperation(const OperationState &state);
Operation *createOperation(const OperationState &state);

/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
Expand Down Expand Up @@ -346,38 +349,28 @@ class OpBuilder : public Builder {
/// cloned sub-operations to the corresponding operation that is copied,
/// and adds those mappings to the map.
Operation *clone(Operation &op, BlockAndValueMapping &mapper) {
Operation *cloneOp = op.clone(mapper);
insert(cloneOp);
return cloneOp;
}
Operation *clone(Operation &op) {
Operation *cloneOp = op.clone();
insert(cloneOp);
return cloneOp;
return insert(op.clone(mapper));
}
Operation *clone(Operation &op) { return insert(op.clone()); }

/// Creates a deep copy of this operation but keep the operation regions
/// empty. Operands are remapped using `mapper` (if present), and `mapper` is
/// updated to contain the results.
Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) {
Operation *cloneOp = op.cloneWithoutRegions(mapper);
insert(cloneOp);
return cloneOp;
return insert(op.cloneWithoutRegions(mapper));
}
Operation *cloneWithoutRegions(Operation &op) {
Operation *cloneOp = op.cloneWithoutRegions();
insert(cloneOp);
return cloneOp;
return insert(op.cloneWithoutRegions());
}
template <typename OpT> OpT cloneWithoutRegions(OpT op) {
return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
}

private:
/// Attempts to fold the given operation and places new results within
/// 'results'.
void tryFold(Operation *op, SmallVectorImpl<Value *> &results);

/// Insert the given operation at the current insertion point.
void insert(Operation *op);

Block *block = nullptr;
Block::iterator insertPoint;
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/IR/PatternMatch.h
Expand Up @@ -302,9 +302,9 @@ class PatternRewriter : public OpBuilder {
return OpTy();
}

/// This is implemented to create the specified operations and serves as a
/// This is implemented to insert the specified operation and serves as a
/// notification hook for rewriters that want to know about new operations.
virtual Operation *createOperation(const OperationState &state) = 0;
virtual Operation *insert(Operation *op) = 0;

/// Move the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be different. The caller
Expand Down
10 changes: 2 additions & 8 deletions mlir/include/mlir/Transforms/DialectConversion.h
Expand Up @@ -332,12 +332,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument *from, Value *to);

/// Clone the given operation without cloning its regions.
Operation *cloneWithoutRegions(Operation *op);
template <typename OpT> OpT cloneWithoutRegions(OpT op) {
return cast<OpT>(cloneWithoutRegions(op.getOperation()));
}

/// Return the converted value that replaces 'key'. Return 'key' if there is
/// no such a converted value.
Value *getRemappedValue(Value *key);
Expand Down Expand Up @@ -376,8 +370,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
BlockAndValueMapping &mapping) override;
using PatternRewriter::cloneRegionBefore;

/// PatternRewriter hook for creating a new operation.
Operation *createOperation(const OperationState &state) override;
/// PatternRewriter hook for inserting a new operation.
Operation *insert(Operation *op) override;

/// PatternRewriter hook for updating the root operation in-place.
void notifyRootUpdated(Operation *op) override;
Expand Down
18 changes: 8 additions & 10 deletions mlir/lib/IR/Builders.cpp
Expand Up @@ -306,6 +306,13 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {

OpBuilder::~OpBuilder() {}

/// Insert the given operation at the current insertion point and return it.
Operation *OpBuilder::insert(Operation *op) {
if (block)
block->getOperations().insert(insertPoint, op);
return op;
}

/// Add new block and set the insertion point to the end of it. The block is
/// inserted at the provided insertion point of 'parent'.
Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
Expand All @@ -328,10 +335,7 @@ Block *OpBuilder::createBlock(Block *insertBefore) {

/// Create an operation given the fields represented as an OperationState.
Operation *OpBuilder::createOperation(const OperationState &state) {
assert(block && "createOperation() called without setting builder's block");
auto *op = Operation::create(state);
insert(op);
return op;
return insert(Operation::create(state));
}

/// Attempts to fold the given operation and places new results within
Expand Down Expand Up @@ -359,9 +363,3 @@ void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
[](OpFoldResult result) { return result.get<Value *>(); });
op->erase();
}

/// Insert the given operation at the current insertion point.
void OpBuilder::insert(Operation *op) {
if (block)
block->getOperations().insert(insertPoint, op);
}
18 changes: 5 additions & 13 deletions mlir/lib/Transforms/DialectConversion.cpp
Expand Up @@ -802,13 +802,6 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument *from,
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}

/// Clone the given operation without cloning its regions.
Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
Operation *newOp = OpBuilder::cloneWithoutRegions(*op);
impl->createdOps.push_back(newOp);
return newOp;
}

/// Return the converted value that replaces 'key'. Return 'key' if there is
/// no such a converted value.
Value *ConversionPatternRewriter::getRemappedValue(Value *key) {
Expand Down Expand Up @@ -854,12 +847,11 @@ void ConversionPatternRewriter::cloneRegionBefore(
}

/// PatternRewriter hook for creating a new operation.
Operation *
ConversionPatternRewriter::createOperation(const OperationState &state) {
LLVM_DEBUG(llvm::dbgs() << "** Creating operation : " << state.name << "\n");
auto *result = OpBuilder::createOperation(state);
impl->createdOps.push_back(result);
return result;
Operation *ConversionPatternRewriter::insert(Operation *op) {
LLVM_DEBUG(llvm::dbgs() << "** Inserting operation : " << op->getName()
<< "\n");
impl->createdOps.push_back(op);
return OpBuilder::insert(op);
}

/// PatternRewriter hook for updating the root operation in-place.
Expand Down
11 changes: 5 additions & 6 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Expand Up @@ -86,12 +86,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter {

// These are hooks implemented for PatternRewriter.
protected:
// Implement the hook for creating operations, and make sure that newly
// created ops are added to the worklist for processing.
Operation *createOperation(const OperationState &state) override {
auto *result = OpBuilder::createOperation(state);
addToWorklist(result);
return result;
// Implement the hook for inserting operations, and make sure that newly
// inserted ops are added to the worklist for processing.
Operation *insert(Operation *op) override {
addToWorklist(op);
return OpBuilder::insert(op);
}

// If an operation is about to be removed, make sure it is not in our
Expand Down

0 comments on commit 851a851

Please sign in to comment.