diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp index 824db2685a777..0101c9e7fdc01 100644 --- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp +++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp @@ -121,7 +121,7 @@ LogicalResult ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands. - if (!op.shape().getType().isa()) + if (op.shape().getType().isa()) return failure(); auto loc = op.getLoc(); @@ -171,12 +171,15 @@ class ShapeOfOpConverter : public OpConversionPattern { LogicalResult ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - ShapeOfOp::Adaptor transformed(operands); - Value arg = transformed.arg(); - Type argTy = arg.getType(); + // For now, this lowering supports only error-free arguments. + if (op.getType().isa()) + return failure(); // For ranked tensors `shape_of` lowers to `std` and the pattern can be // found in the corresponding pass. + ShapeOfOp::Adaptor transformed(operands); + Value arg = transformed.arg(); + Type argTy = arg.getType(); if (argTy.isa()) return failure(); diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir index 441b2e92cc3d9..97d2bce5a0948 100644 --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -24,21 +24,32 @@ func @shape_reduce(%shape : tensor) -> index { // ----- +// Don't lower `shape_of` for result type of `shape.shape`. +// CHECK-LABEL: @shape_of +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) +func @shape_of(%arg : tensor<*xf32>) { + // CHECK: shape.shape + %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape + return +} + +// ----- + // Lower `shape_of` for unranked tensors. // CHECK-LABEL: @shape_of_unranked // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) func @shape_of_unranked(%arg : tensor<*xf32>) { - // CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> - // CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref - // CHECK-DAG: %[[C0:.*]] = constant 0 : index - // CHECK-DAG: %[[C1:.*]] = constant 1 : index - // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] { - // CHECK-DAG: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32> - // CHECK-DAG: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64 - // CHECK-DAG: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref - // CHECK: } - // CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref - // CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor to tensor + // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> + // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] { + // CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32> + // CHECK: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64 + // CHECK: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref + // CHECK: } + // CHECK: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref + // CHECK: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor to tensor %shape = shape.shape_of %arg : tensor<*xf32> -> tensor return }