Skip to content

Commit

Permalink
Add the inline interface to the shape dialect
Browse files Browse the repository at this point in the history
This patch also fixes a minor issue that shape.rank should allow
returning !shape.size. The dialect doc has such an example for
shape.rank.

Differential Revision: https://reviews.llvm.org/D85556
  • Loading branch information
fengliu committed Aug 8, 2020
1 parent b6d9add commit 5c9c4ad
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
39 changes: 35 additions & 4 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/raw_ostream.h"

Expand Down Expand Up @@ -59,13 +60,40 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
return success();
}

//===----------------------------------------------------------------------===//
// InlinerInterface
//===----------------------------------------------------------------------===//

namespace {
/// This class defines the interface for inlining shape dialect ops.
struct ShapeInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;

// Returns true if the given region 'src' can be inlined into the region
// 'dest' that is attached to an operation registered to the current dialect.
bool isLegalToInline(Region *dest, Region *src,
BlockAndValueMapping &) const final {
return true;
}

// Returns true if the given operation 'op', that is registered to this
// dialect, can be inlined into the region 'dest' that is attached to an
// operation registered to the current dialect.
bool isLegalToInline(Operation *op, Region *dest,
BlockAndValueMapping &) const final {
return true;
}
};
} // namespace

void ShapeDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
>();
addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
WitnessType>();
addInterfaces<ShapeInlinerInterface>();
// Allow unknown operations during prototyping and testing. As the dialect is
// still evolving it makes it simple to start with an unregistered ops and
// try different variants before actually defining the op.
Expand Down Expand Up @@ -640,11 +668,14 @@ struct RankShapeOfCanonicalizationPattern
shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
if (!rankedTensorType)
return failure();
assert(op.getType().isa<IndexType>() &&
"expected `rank(shape_of( ... )]` based on a shaped argument to "
"yield an index type");
int64_t rank = rankedTensorType.getRank();
rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
if (op.getType().isa<IndexType>()) {
rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
} else if (op.getType().isa<shape::SizeType>()) {
rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
} else {
return failure();
}
return success();
}
};
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Shape/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,18 @@ func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> index {

// -----

// Canonicalize `rank` when shape is derived from ranked tensor.
// CHECK-LABEL: @canonicalize_rank
func @canonicalize_rank_size(%arg : tensor<1x2x?xf32>) -> !shape.size {
// CHECK: %[[RESULT:.*]] = shape.const_size 3
// CHECK: return %[[RESULT]] : !shape.size
%shape = shape.shape_of %arg : tensor<1x2x?xf32> -> !shape.shape
%rank = shape.rank %shape : !shape.shape -> !shape.size
return %rank : !shape.size
}

// -----

// Do not canonicalize `rank` when shape is derived from unranked tensor.
// CHECK-LABEL: @dont_canonicalize_rank
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> index
Expand Down

0 comments on commit 5c9c4ad

Please sign in to comment.