From fd9613324d305f036d8177469e4f1d0ed521b3a7 Mon Sep 17 00:00:00 2001 From: Alexandre Rames Date: Tue, 5 Oct 2021 10:53:02 -0700 Subject: [PATCH] [MLIR] Rename Shape dialect's `join` to `meet`. For the type lattice, we (now) use the "less specialized or equal" partial order, leading to the bottom representing the empty set, and the top representing any type. This naming is more in line with the generally used conventions, where the top of the lattice is the full set, and the bottom of the lattice is the empty set. A typical example is the powerset of a finite set: generally, meet would be the intersection, and join would be the union. ``` top: {a,b,c} / | \ {a,b} {a,c} {b,c} | X X | {a} { b } {c} \ | / bottom: { } ``` This is in line with the examined lattice representations in LLVM: * lattice for `BitTracker::BitValue` in `Hexagon/BitTracker.h` * lattice for constant propagation in `HexagonConstPropagation.cpp` * lattice in `VarLocBasedImpl.cpp` * lattice for address space inference code in `InferAddressSpaces.cpp` Reviewed By: silvas, jpienaar Differential Revision: https://reviews.llvm.org/D110766 --- .../include/mlir/Dialect/Shape/IR/ShapeOps.td | 76 +++++++++---------- mlir/lib/Dialect/Shape/IR/Shape.cpp | 6 +- mlir/test/Dialect/Shape/ops.mlir | 16 ++-- 3 files changed, 49 insertions(+), 49 deletions(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index a9c82bf41849c..606f906290e3a 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -397,7 +397,34 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> { let hasCanonicalizer = 1; } -def Shape_JoinOp : Shape_Op<"join", +def Shape_MaxOp : Shape_Op<"max", + [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "Elementwise maximum"; + let description = [{ + Computes the elementwise maximum of two sizes or shapes with equal ranks. + If either operand is an error, then an error will be propagated to the + result. If the input types mismatch or the ranks do not match, then the + result is an error. + }]; + + let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); + let results = (outs Shape_ShapeOrSizeType:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + + let hasFolder = 1; + + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; +} + +def Shape_MeetOp : Shape_Op<"meet", [Commutative, DeclareOpInterfaceMethods]> { let summary = "Returns the least general shape.shape of its operands"; let description = [{ @@ -408,21 +435,21 @@ def Shape_JoinOp : Shape_Op<"join", contradictory requirements. E.g., using pseudo code ``` - shape.join([*], [*]) -> [*] - shape.join([*], [1, ?]) -> [1, ?] - shape.join([1, 2], [1, ?]) -> [1, 2] - shape.join([*], [1, 2]) -> [1, 2] - shape.join([], []) -> [] - shape.join([], [*]) -> [] - shape.join([], [?, ?]) -> [invalid] - shape.join([1, ?], [2, ?, ?]) -> [invalid] + shape.meet([*], [*]) -> [*] + shape.meet([*], [1, ?]) -> [1, ?] + shape.meet([1, 2], [1, ?]) -> [1, 2] + shape.meet([*], [1, 2]) -> [1, 2] + shape.meet([], []) -> [] + shape.meet([], [*]) -> [] + shape.meet([], [?, ?]) -> [invalid] + shape.meet([1, ?], [2, ?, ?]) -> [invalid] ``` - `shape.join` also allows specifying an optional error string, that may be + `shape.meet` also allows specifying an optional error string, that may be used to return an error to the user upon mismatch of dimensions. ```mlir - %c = shape.join %a, %b, error="" : !shape.shape, !shape.shape -> !shape.shape + %c = shape.meet %a, %b, error="" : !shape.shape, !shape.shape -> !shape.shape ``` }]; @@ -442,33 +469,6 @@ def Shape_JoinOp : Shape_Op<"join", }]; } -def Shape_MaxOp : Shape_Op<"max", - [Commutative, NoSideEffect, - DeclareOpInterfaceMethods]> { - let summary = "Elementwise maximum"; - let description = [{ - Computes the elementwise maximum of two sizes or shapes with equal ranks. - If either operand is an error, then an error will be propagated to the - result. If the input types mismatch or the ranks do not match, then the - result is an error. - }]; - - let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); - let results = (outs Shape_ShapeOrSizeType:$result); - - let assemblyFormat = [{ - $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) - }]; - - let hasFolder = 1; - - let extraClassDeclaration = [{ - // Returns when two result types are compatible for this op; method used by - // InferTypeOpInterface - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); - }]; -} - def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect, DeclareOpInterfaceMethods]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index d3229ea319b16..59a979c74a7d5 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1177,10 +1177,10 @@ OpFoldResult IsBroadcastableOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// -// JoinOp +// MeetOp //===----------------------------------------------------------------------===// -LogicalResult mlir::shape::JoinOp::inferReturnTypes( +LogicalResult mlir::shape::MeetOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { @@ -1188,7 +1188,7 @@ LogicalResult mlir::shape::JoinOp::inferReturnTypes( return success(); } -bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { +bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != 1 || r.size() != 1) return false; if (l == r) diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir index 24e7c2a6a2559..e7b501e8e2352 100644 --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -65,7 +65,7 @@ func @test_broadcast_extents() -> tensor { func @test_shape_any_fixed() { %0 = shape.const_shape [4, 57, 92] : !shape.shape %1 = shape.const_shape [4, 57, 92] : !shape.shape - %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } @@ -73,7 +73,7 @@ func @test_shape_any_fixed() { func @test_shape_any_unknown() { %0 = shape.const_shape [4, -1, 92] : !shape.shape %1 = shape.const_shape [-1, 57, 92] : !shape.shape - %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } @@ -81,7 +81,7 @@ func @test_shape_any_unknown() { func @test_shape_any_fixed_mismatch() { %0 = shape.const_shape [4, 57, 92] : !shape.shape %1 = shape.const_shape [2, 57, 92] : !shape.shape - %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } @@ -243,7 +243,7 @@ func @num_elements_shape(%arg : !shape.shape) -> !shape.size { func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape %1 = shape.shape_of %b : !shape.value_shape -> !shape.shape - %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape return %2 : !shape.shape } func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { @@ -293,7 +293,7 @@ func @is_broadcastable_on_shapes(%a : !shape.shape, func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape { %0 = shape.const_shape [4, 57, 92] : !shape.shape %1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape - %2 = shape.join %0, %1, error="exceeded element-wise upper bound" : + %2 = shape.meet %0, %1, error="exceeded element-wise upper bound" : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } @@ -301,7 +301,7 @@ func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape { func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape { %0 = shape.const_shape [4, 57, 92] : !shape.shape %1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape - %2 = shape.join %0, %1, error="lower bound element-wise exceeded" : + %2 = shape.meet %0, %1, error="lower bound element-wise exceeded" : !shape.shape, !shape.shape -> !shape.shape return %2 : !shape.shape } @@ -309,7 +309,7 @@ func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape { func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size { %0 = shape.const_size 5 %1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size - %2 = shape.join %0, %1, error="exceeded element-wise upper bound" : + %2 = shape.meet %0, %1, error="exceeded element-wise upper bound" : !shape.size, !shape.size -> !shape.size return %2 : !shape.size } @@ -317,7 +317,7 @@ func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size { func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size { %0 = shape.const_size 9 %1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size - %2 = shape.join %0, %1, error="lower bound element-wise exceeded" : + %2 = shape.meet %0, %1, error="lower bound element-wise exceeded" : !shape.size, !shape.size -> !shape.size return %2 : !shape.size }