Skip to content

Commit 2ee5671

Browse files
[mlir][tosa] handle unranked tensors in tosa::table::verify (#156321)
Seen when running TOSA PRO-INT conformance tests in our SUT. This leads to verify being called with unranked tensors causing exception/error when trying to call getShape on them. Made some variables const for consistency with other verify functions in same file.
1 parent a434a7a commit 2ee5671

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,11 +2189,13 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
21892189
}
21902190

21912191
LogicalResult tosa::TableOp::verify() {
2192-
TensorType inputType = getInput1().getType();
2193-
TensorType outputType = getOutput().getType();
2192+
const TensorType inputType = getInput1().getType();
2193+
const TensorType outputType = getOutput().getType();
2194+
2195+
if (!inputType.hasRank() || !outputType.hasRank())
2196+
return success();
21942197

2195-
if (inputType.hasRank() && outputType.hasRank() &&
2196-
inputType.getRank() != outputType.getRank())
2198+
if (inputType.getRank() != outputType.getRank())
21972199
return emitOpError()
21982200
<< "expected input tensor rank to equal result tensor rank";
21992201

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,15 @@ func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: ten
177177

178178
// -----
179179

180+
func.func @test_table_unranked_tensor(%arg0: tensor<*xi8>) -> (tensor<*xi8>) {
181+
%0 = "tosa.const"() <{values = dense<"0x47CE492BAE8FF8AC700D8903ECF3BC45BC865CA9C35C14DBD3A9E4D1B4AEB8B6A1F20F03486D513ABFC212A4E07118ADFEA5D6B736D4510F7685692B88FFA19B0F7414B0B56635237B48E95B048E96A36001B3388971683E82E6BC40C69D6B6218F6576AF384396BC16F1D437174EA1FB5466AD719344BB8E21BE628893039F831BA1A39C30C413D3C6AA60F91F4D70F1F20473DBAC203C66FC02CBB2E9F11FB2352DD5D6A7F85CFA2F7B697489D9738E1B3C91CB4D2A59B0757C39A2C52619290B43F47340806FD6E0F7400C9373DA037E2FE35967B5D025F29D98AD5EE58BF41EB0C4E49EF73ED167BCE66D58596181DF78F8194D258B51807CAB4A4020239"> : tensor<256xi8>}> : () -> tensor<256xi8>
182+
// expected-error@+1 {{'tosa.table' op failed level check: unranked tensor}}
183+
%1 = tosa.table %arg0, %0 : (tensor<*xi8>, tensor<256xi8>) -> tensor<*xi8>
184+
return %1 : tensor<*xi8>
185+
}
186+
187+
// -----
188+
180189
func.func @test_abs_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
181190
// expected-error@+1 {{'tosa.abs' op failed level check: operand rank(shape) <= MAX_RANK}}
182191
%0 = tosa.abs %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>

0 commit comments

Comments
 (0)