Skip to content

Commit

Permalink
[mlir] tosa.reshape - Add InferTensorType interface
Browse files Browse the repository at this point in the history
When this interface is used, a call to inferReturnTypeComponents()
is generated on creation and verification of the op.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D148498
  • Loading branch information
AviadCo committed Apr 22, 2023
1 parent 6fb4c9f commit 2dd396c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
13 changes: 8 additions & 5 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1441,8 +1441,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
// Operator: concat
//===----------------------------------------------------------------------===//
def Tosa_ConcatOp : Tosa_Op<"concat", [
InferTensorType,
Pure]> {
InferTensorType, Pure]> {
let summary = "Concatenates tensors along one dimension.";

let description = [{
Expand Down Expand Up @@ -1503,9 +1502,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [
// Operator: reshape
//===----------------------------------------------------------------------===//
def Tosa_ReshapeOp: Tosa_Op<"reshape", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
Pure]> {
InferTensorType, Pure]> {
let summary = "Reshape operator";

let description = [{
Expand All @@ -1526,6 +1523,12 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
let results = (outs
Tosa_RankedTensor:$output
);

let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
/// Method used by InferTypeOpInterface.
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}

//===----------------------------------------------------------------------===//
Expand Down
13 changes: 11 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,19 +674,27 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
return success();
}

bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != r.size() || l.size() != 1)
return false;
return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
}

LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ReshapeOpAdaptor adaptor(operands, attributes);
ShapeAdaptor inputShape = operands.getShape(0);
Type inputType = getElementTypeOrSelf(operands.getType()[0]);
llvm::SmallVector<int64_t> newShapeValue =
convertToMlirShape(adaptor.getNewShape());

// We cannot infer from the total number of elements so we must take the
// shape attribute as exact.
if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
inferredReturnShapes.push_back(
ShapedTypeComponents(newShapeValue, inputType));
return success();
}

Expand All @@ -707,7 +715,8 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
val = numElements / staticMul;
}

inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
inferredReturnShapes.push_back(
ShapedTypeComponents(newShapeValue, inputType));
return success();
}

Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,11 @@ func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
%0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
return
}

// -----

func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
// expected-error@+1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3, 1>} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32>
return
}

0 comments on commit 2dd396c

Please sign in to comment.