diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 3cafb199d2db3..63e99814bf17b 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1951,11 +1951,13 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents( } LogicalResult tosa::TableOp::verify() { - TensorType inputType = getInput1().getType(); - TensorType outputType = getOutput().getType(); + const TensorType inputType = getInput1().getType(); + const TensorType outputType = getOutput().getType(); + + if (!inputType.hasRank() || !outputType.hasRank()) + return success(); - if (inputType.hasRank() && outputType.hasRank() && - inputType.getRank() != outputType.getRank()) + if (inputType.getRank() != outputType.getRank()) return emitOpError() << "expected input tensor rank to equal result tensor rank"; diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index cbe0056bafe22..9abe2daa10830 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -177,6 +177,15 @@ func.func @test_table_rank_invalid(%arg0: tensor<1x1x1x1x1x1x64xi16>, %arg1: ten // ----- +func.func @test_table_unranked_tensor(%arg0: tensor<*xi8>) -> (tensor<*xi8>) { + %0 = "tosa.const"() <{values = dense<"0x47CE492BAE8FF8AC700D8903ECF3BC45BC865CA9C35C14DBD3A9E4D1B4AEB8B6A1F20F03486D513ABFC212A4E07118ADFEA5D6B736D4510F7685692B88FFA19B0F7414B0B56635237B48E95B048E96A36001B3388971683E82E6BC40C69D6B6218F6576AF384396BC16F1D437174EA1FB5466AD719344BB8E21BE628893039F831BA1A39C30C413D3C6AA60F91F4D70F1F20473DBAC203C66FC02CBB2E9F11FB2352DD5D6A7F85CFA2F7B697489D9738E1B3C91CB4D2A59B0757C39A2C52619290B43F47340806FD6E0F7400C9373DA037E2FE35967B5D025F29D98AD5EE58BF41EB0C4E49EF73ED167BCE66D58596181DF78F8194D258B51807CAB4A4020239"> : tensor<256xi8>}> : () -> tensor<256xi8> + // expected-error@+1 {{'tosa.table' op failed level check: unranked tensor}} + %1 = tosa.table %arg0, %0 : (tensor<*xi8>, tensor<256xi8>) -> tensor<*xi8> + return %1 : tensor<*xi8> +} + +// ----- + func.func @test_abs_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { // expected-error@+1 {{'tosa.abs' op failed level check: operand rank(shape) <= MAX_RANK}} %0 = tosa.abs %arg0 : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>