Skip to content

Commit

Permalink
[MLIR] Rename Shape dialect's join to meet.
Browse files Browse the repository at this point in the history
For the type lattice, we (now) use the "less specialized or equal" partial
order, leading to the bottom representing the empty set, and the top
representing any type.

This naming is more in line with the generally used conventions, where the top
of the lattice is the full set, and the bottom of the lattice is the empty set.
A typical example is the powerset of a finite set: generally, meet would be the
intersection, and join would be the union.

```
top:  {a,b,c}
     /   |   \
 {a,b} {a,c} {b,c}
   |  X     X  |
   {a} { b } {c}
      \  |  /
bottom: { }
```

This is in line with the examined lattice representations in LLVM:
* lattice for `BitTracker::BitValue` in `Hexagon/BitTracker.h`
* lattice for constant propagation in `HexagonConstPropagation.cpp`
* lattice in `VarLocBasedImpl.cpp`
* lattice for address space inference code in `InferAddressSpaces.cpp`

Reviewed By: silvas, jpienaar

Differential Revision: https://reviews.llvm.org/D110766
  • Loading branch information
Alexandre Rames committed Oct 6, 2021
1 parent 1301a8b commit fd96133
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 49 deletions.
76 changes: 38 additions & 38 deletions mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
Expand Up @@ -397,7 +397,34 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
let hasCanonicalizer = 1;
}

def Shape_JoinOp : Shape_Op<"join",
def Shape_MaxOp : Shape_Op<"max",
[Commutative, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Elementwise maximum";
let description = [{
Computes the elementwise maximum of two sizes or shapes with equal ranks.
If either operand is an error, then an error will be propagated to the
result. If the input types mismatch or the ranks do not match, then the
result is an error.
}];

let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
let results = (outs Shape_ShapeOrSizeType:$result);

let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];

let hasFolder = 1;

let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}

def Shape_MeetOp : Shape_Op<"meet",
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the least general shape.shape of its operands";
let description = [{
Expand All @@ -408,21 +435,21 @@ def Shape_JoinOp : Shape_Op<"join",
contradictory requirements. E.g., using pseudo code

```
shape.join([*], [*]) -> [*]
shape.join([*], [1, ?]) -> [1, ?]
shape.join([1, 2], [1, ?]) -> [1, 2]
shape.join([*], [1, 2]) -> [1, 2]
shape.join([], []) -> []
shape.join([], [*]) -> []
shape.join([], [?, ?]) -> [invalid]
shape.join([1, ?], [2, ?, ?]) -> [invalid]
shape.meet([*], [*]) -> [*]
shape.meet([*], [1, ?]) -> [1, ?]
shape.meet([1, 2], [1, ?]) -> [1, 2]
shape.meet([*], [1, 2]) -> [1, 2]
shape.meet([], []) -> []
shape.meet([], [*]) -> []
shape.meet([], [?, ?]) -> [invalid]
shape.meet([1, ?], [2, ?, ?]) -> [invalid]
```

`shape.join` also allows specifying an optional error string, that may be
`shape.meet` also allows specifying an optional error string, that may be
used to return an error to the user upon mismatch of dimensions.

```mlir
%c = shape.join %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
%c = shape.meet %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
```
}];

Expand All @@ -442,33 +469,6 @@ def Shape_JoinOp : Shape_Op<"join",
}];
}

def Shape_MaxOp : Shape_Op<"max",
[Commutative, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Elementwise maximum";
let description = [{
Computes the elementwise maximum of two sizes or shapes with equal ranks.
If either operand is an error, then an error will be propagated to the
result. If the input types mismatch or the ranks do not match, then the
result is an error.
}];

let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
let results = (outs Shape_ShapeOrSizeType:$result);

let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];

let hasFolder = 1;

let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}

def Shape_MinOp : Shape_Op<"min",
[Commutative, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Expand Up @@ -1177,18 +1177,18 @@ OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
}

//===----------------------------------------------------------------------===//
// JoinOp
// MeetOp
//===----------------------------------------------------------------------===//

LogicalResult mlir::shape::JoinOp::inferReturnTypes(
LogicalResult mlir::shape::MeetOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({operands[0].getType()});
return success();
}

bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
if (l == r)
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/Shape/ops.mlir
Expand Up @@ -65,23 +65,23 @@ func @test_broadcast_extents() -> tensor<?xindex> {
func @test_shape_any_fixed() {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}

func @test_shape_any_unknown() {
%0 = shape.const_shape [4, -1, 92] : !shape.shape
%1 = shape.const_shape [-1, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}

func @test_shape_any_fixed_mismatch() {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [2, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
Expand Down Expand Up @@ -243,7 +243,7 @@ func @num_elements_shape(%arg : !shape.shape) -> !shape.size {
func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
%0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
%1 = shape.shape_of %b : !shape.value_shape -> !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
return %2 : !shape.shape
}
func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
Expand Down Expand Up @@ -293,31 +293,31 @@ func @is_broadcastable_on_shapes(%a : !shape.shape,
func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
%2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}

func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
%2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}

func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 5
%1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
%2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}

func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 9
%1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
%2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}

0 comments on commit fd96133

Please sign in to comment.