diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 78b86bfb22aa5..a6f579cec5057 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -155,7 +155,7 @@ def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> { let hasFolder = 1; } -def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> { +def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> { let summary = "Creates a shape from a tensor of extents"; let description = [{ Creates a shape from a 1D integral tensor of extents. The rank of the @@ -165,26 +165,25 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> { let arguments = (ins IndexTensor:$input); let results = (outs Shape_ShapeType:$result); + + let assemblyFormat = "attr-dict $input `:` type($input)"; } -def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", []> { +def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { let summary = "Creates a dimension tensor from a shape"; - // TODO: Think more about the error situation. Perhaps factor out the - // error detection into a separate op so downstream consumers can control - // their error behavior. Then this op would assume that the input has - // been properly checked to not be an error (and could thus be a - // NoSideEffect op). let description = [{ Converts a shape to a 1D integral tensor of extents. The number of elements in the tensor equals the rank of the shape, and the elements equal the extents of the shape. - If the shape represents an error, then this op currently aborts the program. + If the shape represents an error, this op's behavior is undefined. }]; let arguments = (ins Shape_ShapeType:$input); let results = (outs IndexTensor:$result); + let assemblyFormat = "attr-dict $input `:` type($result)"; + let hasFolder = 1; } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 292579d98f8e3..a56a8f9861de5 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -89,7 +89,7 @@ func @f() -> !shape.shape { func @f() -> tensor<2xindex> { // CHECK: constant dense<[0, 1]> : tensor<2xindex> %cs = shape.const_shape [0, 1] - %0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex> + %0 = shape.to_extent_tensor %cs : tensor<2xindex> return %0 : tensor<2xindex> } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir index a6718c73a7792..2e40211e5a638 100644 --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -101,3 +101,13 @@ func @const_size() { %2 = shape.const_size 2 return } + +func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> { + %0 = shape.to_extent_tensor %arg : tensor<3xindex> + return %0 : tensor<3xindex> +} + +func @test_from_extent_tensor(%arg: tensor) -> !shape.shape { + %0 = shape.from_extent_tensor %arg : tensor + return %0 : !shape.shape +}