From e74e6afcf13aeb7d0a30e55b2eda89f5910d6e68 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 6 Apr 2021 17:58:12 -0700 Subject: [PATCH] [shape] Add min and max ops These are element-wise operations that operates on shapes with equal ranks. Also add missing printer/parser for join operator. Differential Revision: https://reviews.llvm.org/D99986 --- .../include/mlir/Dialect/Shape/IR/ShapeOps.td | 41 ++++++++++++++++++- mlir/test/Dialect/Shape/ops.mlir | 35 +++++++++++++++- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 34a12275eabeb..0b8c26dc91565 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -387,13 +387,52 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> { used to return an error to the user upon mismatch of dimensions. ```mlir - %c = shape.join %a, %b, error="" : !shape.shape + %c = shape.join %a, %b, error="" : !shape.shape, !shape.shape -> !shape.shape ``` }]; let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1, OptionalAttr:$error); let results = (outs Shape_ShapeOrSizeType:$result); + + let assemblyFormat = [{ + $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:` + type($arg0) `,` type($arg1) `->` type($result) + }]; +} + +def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> { + let summary = "Elementwise maximum"; + let description = [{ + Computes the elementwise maximum of two shapes with equal ranks. If either + operand is an error, then an error will be propagated to the result. If the + input types mismatch or the ranks do not match, then the result is an + error. + }]; + + let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); + let results = (outs Shape_ShapeOrSizeType:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; +} + +def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { + let summary = "Elementwise minimum"; + let description = [{ + Computes the elementwise maximum of two shapes with equal ranks. If either + operand is an error, then an error will be propagated to the result. If the + input types mismatch or the ranks do not match, then the result is an + error. + }]; + + let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); + let results = (outs Shape_ShapeOrSizeType:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; } def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir index ca838e7f8dc7b..b9ae301d55799 100644 --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -115,7 +115,7 @@ func @test_constraints() { } func @eq_on_extent_tensors(%lhs : tensor, - %rhs : tensor) { + %rhs : tensor) { %w0 = shape.cstr_eq %lhs, %rhs : tensor, tensor return } @@ -183,7 +183,6 @@ func @rank_on_extent_tensor(%shape : tensor) -> index { return %rank : index } - func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 { %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape return %result : i1 @@ -289,3 +288,35 @@ func @is_broadcastable_on_shapes(%a : !shape.shape, : !shape.shape, !shape.shape return %result : i1 } + +func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape { + %0 = shape.const_shape [4, 57, 92] : !shape.shape + %1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape + %2 = shape.join %0, %1, error="exceeded element-wise upper bound" : + !shape.shape, !shape.shape -> !shape.shape + return %2 : !shape.shape +} + +func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape { + %0 = shape.const_shape [4, 57, 92] : !shape.shape + %1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape + %2 = shape.join %0, %1, error="lower bound element-wise exceeded" : + !shape.shape, !shape.shape -> !shape.shape + return %2 : !shape.shape +} + +func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size { + %0 = shape.const_size 5 + %1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size + %2 = shape.join %0, %1, error="exceeded element-wise upper bound" : + !shape.size, !shape.size -> !shape.size + return %2 : !shape.size +} + +func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size { + %0 = shape.const_size 9 + %1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size + %2 = shape.join %0, %1, error="lower bound element-wise exceeded" : + !shape.size, !shape.size -> !shape.size + return %2 : !shape.size +}