Skip to content

Commit

Permalink
[shape] Add min and max ops
Browse files Browse the repository at this point in the history
These are element-wise operations that operates on shapes with equal ranks.
Also add missing printer/parser for join operator.

Differential Revision: https://reviews.llvm.org/D99986
  • Loading branch information
jpienaar committed Apr 7, 2021
1 parent 86175d5 commit e74e6af
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 3 deletions.
41 changes: 40 additions & 1 deletion mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
Expand Up @@ -387,13 +387,52 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
used to return an error to the user upon mismatch of dimensions.

```mlir
%c = shape.join %a, %b, error="<reason>" : !shape.shape
%c = shape.join %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
```
}];

let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrSizeType:$result);

let assemblyFormat = [{
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
type($arg0) `,` type($arg1) `->` type($result)
}];
}

def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
let summary = "Elementwise maximum";
let description = [{
Computes the elementwise maximum of two shapes with equal ranks. If either
operand is an error, then an error will be propagated to the result. If the
input types mismatch or the ranks do not match, then the result is an
error.
}];

let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
let results = (outs Shape_ShapeOrSizeType:$result);

let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];
}

def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
let summary = "Elementwise minimum";
let description = [{
Computes the elementwise maximum of two shapes with equal ranks. If either
operand is an error, then an error will be propagated to the result. If the
input types mismatch or the ranks do not match, then the result is an
error.
}];

let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
let results = (outs Shape_ShapeOrSizeType:$result);

let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];
}

def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
Expand Down
35 changes: 33 additions & 2 deletions mlir/test/Dialect/Shape/ops.mlir
Expand Up @@ -115,7 +115,7 @@ func @test_constraints() {
}

func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
%rhs : tensor<?xindex>) {
%rhs : tensor<?xindex>) {
%w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
return
}
Expand Down Expand Up @@ -183,7 +183,6 @@ func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
return %rank : index
}


func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
return %result : i1
Expand Down Expand Up @@ -289,3 +288,35 @@ func @is_broadcastable_on_shapes(%a : !shape.shape,
: !shape.shape, !shape.shape
return %result : i1
}

func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}

func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}

func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 5
%1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}

func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 9
%1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}

0 comments on commit e74e6af

Please sign in to comment.