Skip to content

Commit

Permalink
[mlir][tosa] Add verifier for tosa.tile, fix shape inference crash (#…
Browse files Browse the repository at this point in the history
…70972)

This patch adds an verifier to `tosa.tile` which checks input/output
ranks and the length of the `multiples` array. The patch also fixes a
crash in the shape inference when an invalid `multiples` array is
supplied.

Fix #70415
  • Loading branch information
ubfx committed Nov 2, 2023
1 parent 9e0a5be commit b6d67af
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 1 deletion.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,7 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
);

let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
21 changes: 20 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,8 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
outputShape.resize(multiples.size(), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
} else if (inputShape.getRank() != multiples.size())
return failure();

// Any non dynamic dimension can be multiplied to a known size.
outputShape.reserve(multiples.size());
Expand All @@ -878,6 +879,24 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
return success();
}

LogicalResult tosa::TileOp::verify() {
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
ShapedType outputType = llvm::cast<ShapedType>(getType());
auto multiples = getMultiples();

if (inputType.hasRank()) {
if (inputType.getRank() != multiples.size())
return emitOpError("expect 'multiples' array to have length ")
<< inputType.getRank() << " but got " << multiples.size() << ".";
if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
return emitOpError("expect same input and output tensor rank.");
} else if (outputType.hasRank() && outputType.getRank() != multiples.size())
return emitOpError("expect 'multiples' array to have length ")
<< outputType.getRank() << " but got " << multiples.size() << ".";

return success();
}

bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != r.size() || l.size() != 1)
return false;
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -603,3 +603,13 @@ func.func nested @fold_reduce_rank_zero() {
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
return
}

// -----

// CHECK-LABEL: @fold_tile_rank_zero
func.func nested @fold_tile_rank_zero() -> tensor<i32> {
// CHECK-NOT: tosa.tile
%0 = tensor.empty() : tensor<i32>
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
9 changes: 9 additions & 0 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,12 @@ func.func @test_slice_invalid_size() {
%1 = tosa.slice %0 {size = array<i64: 1>, start = array<i64: 1, 1, 1>} : (tensor<4x31x31xf32>) -> tensor<*xf32>
return
}

// -----

func.func @test_tile_invalid_multiples() {
%0 = tensor.empty() : tensor<4x31x31xf32>
// expected-error@+1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}}
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
return
}

0 comments on commit b6d67af

Please sign in to comment.