Skip to content

Commit

Permalink
[MLIR][Shape] Allow shape.rank to operate on extent tensors
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D84429
  • Loading branch information
frgossen committed Jul 24, 2020
1 parent 8046220 commit 23a6564
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 37 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
Expand Up @@ -122,4 +122,6 @@ def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
Shape_ExtentTensorType],
"shape or extent tensor">;

def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">;

#endif // SHAPE_BASE_TD
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
Expand Up @@ -201,12 +201,13 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
}];

let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
let results = (outs Shape_SizeType:$rank);
let results = (outs Shape_SizeOrIndexType:$rank);

let assemblyFormat = "$shape `:` type($shape) attr-dict";
let assemblyFormat = "$shape `:` type($shape) `->` type($rank) attr-dict";

let hasFolder = 1;
let hasCanonicalizer = 1;
let verifier = [{ return ::verify(*this); }];
}

def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
Expand Down
31 changes: 24 additions & 7 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Expand Up @@ -8,6 +8,7 @@

#include "mlir/Dialect/Shape/IR/Shape.h"

#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
Expand Down Expand Up @@ -52,6 +53,8 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
if (type.isa<WitnessType>())
return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
if (type.isa<IndexType>())
return builder.create<ConstantOp>(loc, type, value);
return nullptr;
}

Expand Down Expand Up @@ -563,7 +566,17 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
// RankOp
//===----------------------------------------------------------------------===//

OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
static LogicalResult verify(shape::RankOp op) {
Type argTy = op.shape().getType();
Type resultTy = op.rank().getType();
if (argTy.isa<ShapeType>() && !resultTy.isa<SizeType>())
return op.emitOpError()
<< "if operand is of type `shape` then the result must be of type "
"`size` to propagate potential errors";
return success();
}

OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!shape)
return {};
Expand All @@ -587,10 +600,11 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
/// %rank = shape.const_size 3

namespace {
struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
using OpRewritePattern<RankOp>::OpRewritePattern;
struct RankShapeOfCanonicalizationPattern
: public OpRewritePattern<shape::RankOp> {
using OpRewritePattern<shape::RankOp>::OpRewritePattern;

LogicalResult matchAndRewrite(RankOp op,
LogicalResult matchAndRewrite(shape::RankOp op,
PatternRewriter &rewriter) const override {
auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>();
if (!shapeOfOp)
Expand All @@ -599,15 +613,18 @@ struct RankShapeOfCanonicalizationPattern : public OpRewritePattern<RankOp> {
shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
if (!rankedTensorType)
return failure();
assert(op.getType().isa<IndexType>() &&
"expected `rank(shape_of( ... )]` based on a shaped argument to "
"yield an index type");
int64_t rank = rankedTensorType.getRank();
rewriter.replaceOpWithNewOp<ConstSizeOp>(op.getOperation(), rank);
rewriter.replaceOpWithNewOp<ConstantIndexOp>(op.getOperation(), rank);
return success();
}
};
} // namespace

void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
MLIRContext *context) {
void shape::RankOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<RankShapeOfCanonicalizationPattern>(context);
}

Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Expand Up @@ -122,12 +122,12 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
// Convert `rank` to `dim` of the first dimension.
// CHECK-LABEL: @rank
// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
func @rank(%shape : !shape.shape) -> !shape.size {
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]]
// CHECK-DAG: return %[[RESULT]] : index
%rank = shape.rank %shape : !shape.shape
return %rank : !shape.size
func @rank(%shape : tensor<?xindex>) -> index {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]]
// CHECK: return %[[RESULT]] : index
%rank = shape.rank %shape : tensor<?xindex> -> index
return %rank : index
}

// -----
Expand Down
62 changes: 44 additions & 18 deletions mlir/test/Dialect/Shape/canonicalize.mlir
Expand Up @@ -496,10 +496,10 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
// Fold `rank` based on constant shape.
// CHECK-LABEL: @fold_rank
func @fold_rank() -> !shape.size {
// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5
// CHECK-DAG: return %[[RESULT]] : !shape.size
// CHECK: %[[RESULT:.*]] = shape.const_size 5
// CHECK: return %[[RESULT]] : !shape.size
%shape = shape.const_shape [3, 4, 5, 6, 7] : !shape.shape
%rank = shape.rank %shape : !shape.shape
%rank = shape.rank %shape : !shape.shape -> !shape.size
return %rank : !shape.size
}

Expand All @@ -509,38 +509,64 @@ func @fold_rank() -> !shape.size {
// CHECK-LABEL: @dont_fold_rank
// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) -> !shape.size
func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
// CHECK-DAG: %[[RESULT:.*]] = shape.rank %[[SHAPE]]
// CHECK-DAG: return %[[RESULT]] : !shape.size
%rank = shape.rank %shape : !shape.shape
// CHECK: %[[RESULT:.*]] = shape.rank %[[SHAPE]]
// CHECK: return %[[RESULT]] : !shape.size
%rank = shape.rank %shape : !shape.shape -> !shape.size
return %rank : !shape.size
}

// -----

// Fold `rank` based on constant extent tensor.
// CHECK-LABEL: @fold_rank
func @fold_rank() -> index {
// CHECK: %[[RESULT:.*]] = constant 5 : index
// CHECK: return %[[RESULT]] : index
%shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<?xindex>
%rank = shape.rank %shape : tensor<?xindex> -> index
return %rank : index
}

// -----

// Do not fold `rank` for non-constant extent tensors.
// CHECK-LABEL: @dont_fold_rank
// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
func @dont_fold_rank(%shape : tensor<?xindex>) -> index {
// CHECK: %[[RESULT:.*]] = shape.rank %[[SHAPE]] : tensor<?xindex> -> index
// CHECK: return %[[RESULT]] : index
%rank = shape.rank %shape : tensor<?xindex> -> index
return %rank : index
}

// -----

// Canonicalize `rank` when shape is derived from ranked tensor.
// CHECK-LABEL: @canonicalize_rank
func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
// CHECK-DAG: return %[[RESULT]] : !shape.size
func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> index {
// CHECK: %[[RESULT:.*]] = constant 3 : index
// CHECK: return %[[RESULT]] : index
%shape = shape.shape_of %arg : tensor<1x2x?xf32> -> tensor<?xindex>
%rank = shape.rank %shape : tensor<?xindex>
return %rank : !shape.size
%rank = shape.rank %shape : tensor<?xindex> -> index
return %rank : index
}

// -----

// Do not canonicalize `rank` when shape is derived from unranked tensor.
// CHECK-LABEL: @dont_canonicalize_rank
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor<?xindex>
// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
// CHECK-DAG: return %[[SIZE]] : !shape.size
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> index
func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> index {
// CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
// CHECK: return %[[SIZE]] : index
%shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
%rank = shape.rank %shape : tensor<?xindex>
return %rank : !shape.size
%rank = shape.rank %shape : tensor<?xindex> -> index
return %rank : index
}

// -----

// Canonicalize redundant conversion from `index` to `size` and back.
// CHECK-LABEL: @index_to_size_to_index
// CHECK-SAME: (%[[IDX:.*]]: index) -> index
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/Dialect/Shape/invalid.mlir
Expand Up @@ -95,3 +95,10 @@ func @shape_of(%value_arg : !shape.value_shape,
%1 = shape.shape_of %shaped_arg : tensor<?x3x4xf32> -> !shape.shape
}

// -----

func @rank(%arg : !shape.shape) {
// expected-error@+1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}}
%0 = shape.rank %arg : !shape.shape -> index
}

8 changes: 4 additions & 4 deletions mlir/test/Dialect/Shape/ops.mlir
Expand Up @@ -137,13 +137,13 @@ func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape {
}

func @rank(%shape : !shape.shape) -> !shape.size {
%rank = shape.rank %shape : !shape.shape
%rank = shape.rank %shape : !shape.shape -> !shape.size
return %rank : !shape.size
}

func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> !shape.size {
%rank = shape.rank %shape : tensor<?xindex>
return %rank : !shape.size
func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
%rank = shape.rank %shape : tensor<?xindex> -> index
return %rank : index
}


Expand Down

0 comments on commit 23a6564

Please sign in to comment.