Skip to content

Commit

Permalink
TOSA-to-Linalg lowering for element-wise ops
Browse files Browse the repository at this point in the history
- Wrote complete documentation for the `Broadcastable` op trait. This is mostly meant as a thorough description of its previous behavior, with the exception of minor feature updates.

- Restricted legality criteria for a `Broadcastable` op in order to simplify current and future lowering passes and increase efficiency of code generated by those passes. New restriction are: 1) A dynamic dimension in an inferred result is not compatible with a static dimension in the actual result. 2) Broadcast semantics are restricted to input operands and not supported between inferred and actual result shapes.

- Implemented TOSA-to-Linalg lowering support for unary, binary, tertiary element-wise ops. This support is complete for all legal cases described in the `Broadcastable` trait documentation.

- Added unit tests for `tosa.abs`, `tosa.add`, and `tosa.select` as examples of unary, binary, and tertiary ops.

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D153291
  • Loading branch information
rafaelubalmw authored and eric-k256 committed Jul 21, 2023
1 parent cbf2a6c commit b2d76a0
Show file tree
Hide file tree
Showing 8 changed files with 969 additions and 285 deletions.
197 changes: 197 additions & 0 deletions mlir/docs/Traits/Broadcastable.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# The `Broadcastable` Trait

[TOC]

## Description

The `Broadcastable` trait enforces the following properties on an operation:

- The operation has at least one input operand.

- The operation has exactly one result.

- All input operands and result are of type `tensor` or `vector`.

- A shape inference mechanism is able to compute the result shape solely based on input operand shapes.

- Input operands have broadcast-compatible shapes, according to the verification rules presented below.

- The operation's result shape is compatible with —though not necessarily identical to— the shape inferred from its input operands, according to the verification rules presented below.


## Dimension inference

Given an operation with two input operands, the size of dimension `i` of its result can be inferred from dimension `i` of the operands according to the table below. Here, `dim0` and `dim1` represent dimension `i` of the input operands in an interchangeable order, while `inferredDim` represents the inferred size for dimension `i` of the operation result. Dimensions are classified in three categories: dynamic ("?"), static equal to 1 ("1"), and static greater than 1 (">1").


| `dim0` | `dim1` | `inferredDim` | Notes |
| -------- | -------- | ------------- | ----- |
| ? | ? | ? | If `RuntimeSize(dim0)` is 1, dimension `dim0` is broadcast to `RuntimeSize(dim1)`. If `RuntimeSize(dim1)` is 1, dimension `dim1` is broadcast to `RuntimeSize(dim0)`. The operation produces undefined behavior if both runtime sizes are greater than 1 and not equal. |
| ? | 1 | ? | Dimension `dim1` is broadcast to `RuntimeSize(dim0)`. |
| ? | >1 | `dim1` | If `RuntimeSize(dim0)` is 1, `dim0` is broadcast to `dim1`. The operation produces undefined behavior if `RuntimeSize(dim0)` is greater than 1 and not equal to `dim1`. |
| 1 | 1 | 1 | |
| 1 | >1 | `dim1` | Dimension `dim0` is broadcast to `dim1`. |
| >1 | >1 | `dim0` | The operation verifier produces a compile-time error if `dim0` != `dim1`. |


The following pseudo-function is a formal representation of the dimension inference process:

```python
InferDim(dim0, dim1):
switch (dim0, dim1):
case (?, ?):
case (?, 1):
case (1, 1):
case (>1, ?):
case (>1, 1):
return dim0
case (?, >1):
case (1, ?):
case (1, >1):
return dim1
case (>1, >1):
ERROR_IF(dim0 != dim1)
return dim0
```

## Shape inference

The shape inference process begins by correcting rank differences in input operands. A shape is expanded by adding additional dimensions of size 1 on its left until the desired rank is reached, as shown here:

```python
ExpandRank(shape, rank):
while len(shape) < rank:
shape.prepend(1)
```
Given the shapes of two ranked input operands, the result's shape is inferred by equalizing input ranks and inferring individual dimensions, as shown here:

```python
InferShape(shape0, shape1):

# Equalize ranks
rank = max(GetRank(shape0), GetRank(shape1))
ExpandRank(shape0, rank)
ExpandRank(shape1, rank)

# Infer shape
inferredShape = []
for (dim0, dim1) in zip(shape0, shape1):
inferredDim = InferDim(dim0, dim1)
inferredShape.append(inferredDim)
return inferredShape
```

The result shape for an operation with an arbitrary number of input operands is then inferred by discarding unranked operands, applying shape inference on the first ranked operand pair, and updating the inferred shape with each additional ranked operand. If the operation has no ranked operands, the result shape cannot be inferred. If the operation has exactly one ranked operand, its shape is directly provided as the inferred result shape. Formally:

```python
InferResultShape(op):

# Filter ranked operands
rankedOperands = filter(op.operands, IsRanked)
if len(rankedOperands) == 0:
return None

# Infer result shape
inferredShape = GetShape(rankedOperands[0])
for operand in rankedOperands[1:]:
inferredShape = InferShape(inferredShape, GetShape(operand))
return inferredShape
```

## Verification

The legality of an operation with the `Broadcastable` trait is verified by first running the shape inference process. If a failure occurs during shape inference, it is concluded that input operands are not broadcast-compatible, and verification fails. If shape inference succeeds, verification continues.

If either the result is unranked or all input operands are unranked, no further verification steps are needed, and the process ends here successfully. If, on the contrary, both the result and at least one input operand are ranked, verification continues by checking for a matching rank between the previously inferred shape and the result.

Once a rank match is guaranteed, each dimension of the inferred shape is compared with the corresponding dimension of the actual result shape according to the following table table:


| `inferredDim` | `actualDim` | Verification outcome |
| ------------- | ----------- | -------------------- |
| ? | ? | **OK** |
| ? | static | **Error** <br> An inferred dimension being dynamic indicates that its size cannot be inferred at compile time from its input operands. The presence of a static dimension in the actual result is counterintuitive and is therefore not allowed. |
| static | ? | **OK** <br> The actual result dimension may be dynamic even when a static size can be inferred at compile time. The programmer may choose to relax the specificity of the result dimension for forward compatibility of the result type. |
| static | static | **OK if equal** <br> When both the inferred and actual dimensions are static, they must be set to the same size. |


The full verification process can be formally specified as follows:

```python
Verify(op):

# Run shape inference
inferredShape = InferResultShape(op.operands)

# Done if result is unranked or all operands are unranked
if not IsRanked(op.result) or inferredShape is None:
return

# Rank must match
actualShape = GetShape(op.result):
ERROR_IF(len(inferredShape) != len(actualShape))

# Verify
for (inferredDim, actualDim) in zip(inferredShape, actualShape):
ERROR_IF(IsDynamic(inferredDim) and IsStatic(actualDim))
ERROR_IF(IsStatic(actualDim) and inferredDim != actualDim)
```
## Examples

The following are correct uses of broadcastable ops:

```mlir
// Exact match of static sizes.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<1x2xi32) -> tensor<1x2xi32>
// Dynamic sizes match. The programmer must guarantee that the runtime sizes of
// %arg0 and %arg1 are equal at runtime.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32) -> tensor<?xi32>
// The shape of %arg0 is broadcast from tensor<1xi32> to tensor<4xi32>.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<4xi32) -> tensor<4xi32>
// The shape of %result is inferred as tensor<4xi32>, while the actual result
// type is tensor<?xi32>. The inferred shape is compatible with the actual shape.
%result = "test.broadcastable"(%arg0) : (tensor<4xi32) -> tensor<?xi32>
// The shape of %arg0 is first expanded to tensor<1x1x4xi32> and then broadcast
// to tensor<2x3x4xi32>.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<2x3x4xi32) -> tensor<2x3x4xi32>
// Input and results tensors have different element types (i1, i32, i64). The
// 'Broadcastable' trait has no restrictions on element types.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi1>, tensor<2xi32) -> tensor<2xi64>
// No result shape verification is needed when the result is unranked.
%result = "test.broadcastable"(%arg0) : (tensor<2xi32>) -> tensor<*xi32>
// No result shape verification needed when all inputs are unranked.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32>
```


The following are incorrect uses of broadcastable ops:

```mlir
// Dimension 0 of input operands is static but not equal.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32) -> tensor<?xi32>
// The inferred result shape is tensor<3xi32>, but the actual result shape is
// tensor<1x3xi32>. Inferred and actual shapes differ in rank.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32) -> tensor<1x3xi32>
// The inferred result shape is tensor<?xi32>, but the actual shape is
// tensor<4xi32>. The inferred shape is not compatible with the actual shape.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32) -> tensor<4xi32>
// The inferred result shape is tensor<2xi32>, but the actual result shape is
// tensor<4xi32>, which is not compatible.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32) -> tensor<4xi32>
// The inferred result shape is tensor<1xi32>, but the actual result shape is
// tensor<4xi32>. Broadcast semantics are not applicable for results.
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32) -> tensor<4xi32>
```
11 changes: 1 addition & 10 deletions mlir/docs/Traits.md → mlir/docs/Traits/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,16 +241,7 @@ that has the trait AutomaticAllocationScope.

This trait adds the property that the operation is known to have
[broadcast-compatible](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
operands and its result types' shape is the broadcast compatible with the shape
of the broadcasted operands. Specifically, starting from the most varying
dimension, each dimension pair of the two operands' shapes should either be the
same or one of them is one. Also, the result shape should have the corresponding
dimension equal to the larger one, if known. Shapes are checked partially if
ranks or dimensions are not known. For example, an op with `tensor<?x2xf32>` and
`tensor<2xf32>` as operand types and `tensor<3x2xf32>` as the result type is
broadcast-compatible.

This trait requires that the operands are either vector or tensor types.
operands and that its result type is compatible with the inferred broadcast shape. See [The `Broadcastable` Trait](Traits/Broadcastable.md) for details.

### Commutative

Expand Down
12 changes: 3 additions & 9 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,18 +215,12 @@ class Tosa_Op<string mnemonic, list<Trait> traits = []> :
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface])> {
}

class Tosa_ElemWiseUnaryOp<string mnemonic, list<Trait> traits = []> :
class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
Tosa_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure, SameOperandsAndResultElementType])> {
}

class Tosa_ElemWiseBinaryOp<string mnemonic, list<Trait> traits = []> :
Tosa_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
ResultsBroadcastableShape, Pure, SameOperandsAndResultElementType])> {
ResultsBroadcastableShape,
Pure])> {
}

#endif // TOSA_OP_BASE

0 comments on commit b2d76a0

Please sign in to comment.