Skip to content

Commit

Permalink
[MLIR][Shape] Limit shape to SCF lowering patterns to their supported…
Browse files Browse the repository at this point in the history
… types

Differential Revision: https://reviews.llvm.org/D84444
  • Loading branch information
frgossen committed Jul 29, 2020
1 parent 1aaf8aa commit 5fc34fa
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
11 changes: 7 additions & 4 deletions mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
Expand Up @@ -121,7 +121,7 @@ LogicalResult
ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering is only defined on `tensor<?xindex>` operands.
if (!op.shape().getType().isa<RankedTensorType>())
if (op.shape().getType().isa<ShapeType>())
return failure();

auto loc = op.getLoc();
Expand Down Expand Up @@ -171,12 +171,15 @@ class ShapeOfOpConverter : public OpConversionPattern<ShapeOfOp> {
LogicalResult
ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
ShapeOfOp::Adaptor transformed(operands);
Value arg = transformed.arg();
Type argTy = arg.getType();
// For now, this lowering supports only error-free arguments.
if (op.getType().isa<ShapeType>())
return failure();

// For ranked tensors `shape_of` lowers to `std` and the pattern can be
// found in the corresponding pass.
ShapeOfOp::Adaptor transformed(operands);
Value arg = transformed.arg();
Type argTy = arg.getType();
if (argTy.isa<RankedTensorType>())
return failure();

Expand Down
33 changes: 22 additions & 11 deletions mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
Expand Up @@ -24,21 +24,32 @@ func @shape_reduce(%shape : tensor<?xindex>) -> index {

// -----

// Don't lower `shape_of` for result type of `shape.shape`.
// CHECK-LABEL: @shape_of
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @shape_of(%arg : tensor<*xf32>) {
// CHECK: shape.shape
%shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
return
}

// -----

// Lower `shape_of` for unranked tensors.
// CHECK-LABEL: @shape_of_unranked
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @shape_of_unranked(%arg : tensor<*xf32>) {
// CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
// CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
// CHECK-DAG: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
// CHECK-DAG: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
// CHECK-DAG: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
// CHECK: }
// CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
// CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
// CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
// CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
// CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
// CHECK: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
// CHECK: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
// CHECK: }
// CHECK: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
// CHECK: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
%shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
return
}
Expand Down

0 comments on commit 5fc34fa

Please sign in to comment.