Skip to content

Commit

Permalink
[mlir] Canonicalization of shape.assuming
Browse files Browse the repository at this point in the history
Summary:
This will inline the region to a shape.assuming in the case that the
input witness is found to be statically true.

Differential Revision: https://reviews.llvm.org/D80302
  • Loading branch information
tpopp committed Jun 5, 2020
1 parent 0a554e6 commit 655e08c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ def Shape_AssumingOp : Shape_Op<"assuming",

let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];

let hasCanonicalizer = 1;
}

def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,44 @@ static void print(OpAsmPrinter &p, AssumingOp op) {
p.printOptionalAttrDict(op.getAttrs());
}

namespace {
// Removes AssumingOp with a passing witness and inlines the region.
struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
using OpRewritePattern<AssumingOp>::OpRewritePattern;

LogicalResult matchAndRewrite(AssumingOp op,
PatternRewriter &rewriter) const override {
auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
if (!witness || !witness.passingAttr())
return failure();

auto *blockBeforeAssuming = rewriter.getInsertionBlock();
auto *assumingBlock = op.getBody();
auto initPosition = rewriter.getInsertionPoint();
auto *blockAfterAssuming =
rewriter.splitBlock(blockBeforeAssuming, initPosition);

// Remove the AssumingOp and AssumingYieldOp.
auto &yieldOp = assumingBlock->back();
rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
rewriter.replaceOp(op, yieldOp.getOperands());
rewriter.eraseOp(&yieldOp);

// Merge blocks together as there was no branching behavior from the
// AssumingOp.
rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
return success();
}
};
}; // namespace

void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
MLIRContext *context) {
// If taking a passing witness, inline region
patterns.insert<AssumingWithTrue>(context);
}

//===----------------------------------------------------------------------===//
// AssumingAllOp
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Dialect/Shape/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,42 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
return %1 : !shape.shape
}

// -----
// assuming with a known passing witness can be removed
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: source
// CHECK-NEXT: sink
// CHECK-NEXT: return
%0 = shape.const_witness true
%1 = shape.assuming %0 -> index {
%2 = "test.source"() : () -> (index)
shape.assuming_yield %2 : index
}
"test.sink"(%1) : (index) -> ()
return
}

// -----
// assuming without a known passing passing witness cannot be removed
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: test.source
// CHECK-NEXT: shape.assuming
// CHECK-NEXT: test.source
// CHECK-NEXT: shape.assuming_yield
// CHECK-NEXT: }
// CHECK-NEXT: test.sink
// CHECK-NEXT: return
%0 = "test.source"() : () -> (!shape.witness)
%1 = shape.assuming %0 -> index {
%2 = "test.source"() : () -> (index)
shape.assuming_yield %2 : index
}
"test.sink"(%1) : (index) -> ()
return
}

// -----
// Broadcastable with broadcastable constant shapes can be removed.
// CHECK-LABEL: func @f
Expand Down

0 comments on commit 655e08c

Please sign in to comment.