diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 59a979c74a7d51..62b4e022408aa6 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -827,14 +827,10 @@ bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l, Type lhs = l.front(); Type rhs = r.front(); - if (lhs == rhs) - return true; - if (lhs.isa() || rhs.isa()) // Shape type is compatible with all other valid return types. return true; - - return succeeded(verifyCompatibleShapes(lhs, rhs)); + return lhs == rhs; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td index 8f847b1b28c563..7460dc5f3d33d9 100644 --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -12,6 +12,10 @@ def HasSingleElement : Constraint>; +def HasStaticShape : Constraint().hasStaticShape() +}]>>; + // Canonicalization patterns. def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args), @@ -37,4 +41,5 @@ def SizeToIndexToSizeCanonicalization : Pat< // Fold tensor.cast(const_shape) to const_shape. This changes the type of // const_shape to the destination type of the cast. def TensorCastConstShape : Pat < - (Tensor_CastOp (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg)>; + (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg), + [(HasStaticShape $res)]>; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index b32b7ac9052cb1..b0c2181b5b7bac 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1,10 +1,10 @@ // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize %s | FileCheck %s // CHECK-LABEL: func @f -func @f(%arg0: tensor<2x3x4xf32>) -> tensor { - // CHECK: shape.const_shape [2, 3, 4] : tensor - %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor - return %0 : tensor +func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> { + // CHECK: shape.const_shape [2, 3, 4] : tensor<3xindex> + %0 = shape.shape_of %arg0 : tensor<2x3x4xf32> -> tensor<3xindex> + return %0 : tensor<3xindex> } // ----- @@ -62,13 +62,13 @@ func @f() -> !shape.shape { // Basic case including extent tensors. // CHECK-LABEL: @broadcast -func @broadcast() -> tensor { - // CHECK: shape.const_shape [7, 2] : tensor - %0 = shape.const_shape [1, 2] : tensor - %1 = shape.const_shape [7, 1] : tensor +func @broadcast() -> tensor<2xindex> { + // CHECK: shape.const_shape [7, 2] : tensor<2xindex> + %0 = shape.const_shape [1, 2] : tensor<2xindex> + %1 = shape.const_shape [7, 1] : tensor<2xindex> %2 = shape.broadcast %0, %1 - : tensor, tensor -> tensor - return %2 : tensor + : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> + return %2 : tensor<2xindex> } // ----- @@ -77,9 +77,9 @@ func @broadcast() -> tensor { // CHECK-LABEL: @broadcast func @broadcast() -> !shape.shape { // CHECK: shape.const_shape [7, 2] : !shape.shape - %0 = shape.const_shape [1, 2] : tensor - %1 = shape.const_shape [7, 1] : tensor - %2 = shape.broadcast %0, %1 : tensor, tensor -> !shape.shape + %0 = shape.const_shape [1, 2] : tensor<2xindex> + %1 = shape.const_shape [7, 1] : tensor<2xindex> + %2 = shape.broadcast %0, %1 : tensor<2xindex>, tensor<2xindex> -> !shape.shape return %2 : !shape.shape } @@ -317,9 +317,9 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size { // CHECK-LABEL: func @basic func @basic() -> index { // CHECK: constant 2 : index - %0 = shape.const_shape [0, 1, 2] : tensor + %0 = shape.const_shape [0, 1, 2] : tensor<3xindex> %c2 = constant 2 : index - %1 = shape.get_extent %0, %c2 : tensor, index -> index + %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index return %1 : index } @@ -330,9 +330,9 @@ func @basic() -> index { func @out_of_bounds() -> index { // CHECK: shape.const_shape // CHECK: shape.get_extent - %0 = shape.const_shape [0, 1, 2] : tensor + %0 = shape.const_shape [0, 1, 2] : tensor<3xindex> %c3 = constant 3 : index - %1 = shape.get_extent %0, %c3 : tensor, index -> index + %1 = shape.get_extent %0, %c3 : tensor<3xindex>, index -> index return %1 : index } @@ -559,12 +559,12 @@ func @f(%arg : !shape.shape) -> !shape.shape { // any can be replaced with a constant input if it has one. // CHECK-LABEL: func @f -func @f(%arg : tensor) -> tensor { - // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor - // CHECK-NEXT: return %[[CS]] : tensor - %0 = shape.const_shape [2, 3, 4] : tensor - %1 = shape.any %0, %arg : tensor, tensor -> tensor - return %1 : tensor +func @f(%arg : tensor) -> tensor<3xindex> { + // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<3xindex> + // CHECK-NEXT: return %[[CS]] : tensor<3xindex> + %0 = shape.const_shape [2, 3, 4] : tensor<3xindex> + %1 = shape.any %0, %arg : tensor<3xindex>, tensor -> tensor<3xindex> + return %1 : tensor<3xindex> } // ----- @@ -837,8 +837,8 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size { func @fold_rank() -> index { // CHECK: %[[RESULT:.*]] = constant 5 : index // CHECK: return %[[RESULT]] : index - %shape = shape.const_shape [3, 4, 5, 6, 7] : tensor - %rank = shape.rank %shape : tensor -> index + %shape = shape.const_shape [3, 4, 5, 6, 7] : tensor<5xindex> + %rank = shape.rank %shape : tensor<5xindex> -> index return %rank : index } @@ -971,9 +971,9 @@ func @shape_eq_fold_1() -> i1 { // CHECK: %[[RESULT:.*]] = constant true // CHECK: return %[[RESULT]] : i1 %a = shape.const_shape [1, 2, 3] : !shape.shape - %b = shape.const_shape [1, 2, 3] : tensor - %c = shape.const_shape [1, 2, 3] : tensor - %result = shape.shape_eq %a, %b, %c : !shape.shape, tensor, tensor + %b = shape.const_shape [1, 2, 3] : tensor<3xindex> + %c = shape.const_shape [1, 2, 3] : tensor<3xindex> + %result = shape.shape_eq %a, %b, %c : !shape.shape, tensor<3xindex>, tensor<3xindex> return %result : i1 } @@ -984,10 +984,10 @@ func @shape_eq_fold_1() -> i1 { func @shape_eq_fold_0() -> i1 { // CHECK: %[[RESULT:.*]] = constant false // CHECK: return %[[RESULT]] : i1 - %a = shape.const_shape [1, 2, 3] : tensor - %b = shape.const_shape [4, 5, 6] : tensor - %c = shape.const_shape [4, 5, 6] : tensor - %result = shape.shape_eq %a, %b, %c : tensor, tensor, tensor + %a = shape.const_shape [1, 2, 3] : tensor<3xindex> + %b = shape.const_shape [4, 5, 6] : tensor<3xindex> + %c = shape.const_shape [4, 5, 6] : tensor<3xindex> + %result = shape.shape_eq %a, %b, %c : tensor<3xindex>, tensor<3xindex>, tensor<3xindex> return %result : i1 } @@ -1161,18 +1161,17 @@ func @fold_assuming_all_single_element(%arg: tensor) { func @fold_tensor.cast_of_const_shape_returned(%arg: i1) -> tensor<1xindex> { // CHECK: shape.const_shape [2] : tensor<1xindex> // CHECK-NOT: tensor.cast - %0 = shape.const_shape [2] : tensor - %1 = tensor.cast %0 : tensor to tensor<1xindex> + %0 = shape.const_shape [2] : tensor<1xindex> + %1 = tensor.cast %0 : tensor<1xindex> to tensor<1xindex> return %1 : tensor<1xindex> } // ----- -// Verify that tensor.cast folding uses the correct type -// CHECK-LABEL: @fold_tensor.cast_of_const_shape_returned_dynamic -func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor { - // CHECK: shape.const_shape [2] : tensor - // CHECK-NOT: tensor.cast +// CHECK-LABEL: @dont_fold_tensor.cast_of_const_shape_returned_dynamic +func @dont_fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor { + // CHECK: %[[CONST_SHAPE:.*]] = shape.const_shape [2] : tensor<1xindex> + // CHECK: tensor.cast %[[CONST_SHAPE]] : tensor<1xindex> to tensor %0 = shape.const_shape [2] : tensor<1xindex> %1 = tensor.cast %0 : tensor<1xindex> to tensor return %1 : tensor diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir index e7b501e8e23523..a41e7b5936e13d 100644 --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -35,7 +35,6 @@ func @test_shape_num_elements_unknown() { func @const_shape() { %0 = shape.const_shape [1, 2, 3] : !shape.shape - %1 = shape.const_shape [4, 5, 6] : tensor %2 = shape.const_shape [4, 5, 6] : tensor<3xindex> return } @@ -55,11 +54,11 @@ func @test_broadcast_fixed() { return } -func @test_broadcast_extents() -> tensor { - %0 = shape.const_shape [10, 1, 57, 92] : tensor - %1 = shape.const_shape [4, 57, 92] : tensor - %2 = shape.broadcast %0, %1 : tensor, tensor -> tensor - return %2 : tensor +func @test_broadcast_extents() -> tensor<4xindex> { + %0 = shape.const_shape [10, 1, 57, 92] : tensor<4xindex> + %1 = shape.const_shape [4, 57, 92] : tensor<3xindex> + %2 = shape.broadcast %0, %1 : tensor<4xindex>, tensor<3xindex> -> tensor<4xindex> + return %2 : tensor<4xindex> } func @test_shape_any_fixed() { @@ -89,7 +88,7 @@ func @test_shape_any_fixed_mismatch() { func @test_parse_const_shape() { %0 = shape.const_shape [] : !shape.shape %1 = shape.const_shape [1, 2, 3] : !shape.shape - %2 = shape.const_shape [1, 2, 3] : tensor + %2 = shape.const_shape [1, 2, 3] : tensor<3xindex> return } @@ -222,9 +221,9 @@ func @any() { %0 = shape.const_shape [1, 2, 3] : !shape.shape %1 = shape.const_shape [4, 5, 6] : !shape.shape %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape - %3 = shape.const_shape [1, 2, 3] : tensor - %4 = shape.const_shape [4, 5, 6] : tensor - %5 = "shape.any"(%3, %4) : (tensor, tensor) -> tensor + %3 = shape.const_shape [1, 2, 3] : tensor<3xindex> + %4 = shape.const_shape [4, 5, 6] : tensor<3xindex> + %5 = "shape.any"(%3, %4) : (tensor<3xindex>, tensor<3xindex>) -> tensor<3xindex> return }