diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 11eb0d969d78b..ef9d27f8df0ad 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1012,8 +1012,30 @@ bool checkErrorIfMul(Operation *op) { return true; } +bool checkErrorIfTable(Operation *op) { + auto table = dyn_cast(op); + if (!table) + return true; + + // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513 + const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType()); + const int tableSize = inputElemType.isInteger(8) ? 256 : 513; + + const ShapeAdaptor tableShape(table.getTable().getType()); + if (tableShape.hasStaticShape()) { + const auto numElements = tableShape.getNumElements(); + if (numElements != tableSize) { + op->emitOpError() << "requires table size of " << tableSize << ", got " + << numElements; + return false; + } + } + + return true; +} + LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { - if (!checkErrorIfResize(op) || !checkErrorIfMul(op)) + if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op)) return failure(); return success(); } diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir index 0ec46022157d7..25e1aa195c3a0 100644 --- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir +++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir @@ -13,8 +13,8 @@ func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8 // ----- -func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () { - %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8> +func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>) -> () { + %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<256xi8>) -> tensor<4x5xi8> return } diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index f7ca0faa8bc9e..65a69be91e0c8 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -113,3 +113,19 @@ func.func @test_mul_non_zero_shift(%arg0: tensor<1x8x8x8xi16>, %arg1: tensor<1x8 %mul = tosa.mul %arg0, %arg1, %shift : (tensor<1x8x8x8xi16>, tensor<1x8x8x8xi16>, tensor<1xi8>) -> tensor<1x8x8x8xi32> return %mul : tensor<1x8x8x8xi32> } + +// ----- +// CHECK-LABEL: test_i16_table_size +func.func @test_i16_table_size(%arg0: tensor<2x64xi16>, %arg1: tensor<256xi16>) -> tensor<2x64xi32> { + // expected-error@+1 {{'tosa.table' op requires table size of 513, got 256}} + %0 = tosa.table %arg0, %arg1 : (tensor<2x64xi16>, tensor<256xi16>) -> tensor<2x64xi32> + return %0 : tensor<2x64xi32> +} + +// ----- +// CHECK-LABEL: test_i8_table_size +func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) -> tensor<2x64xi8> { + // expected-error@+1 {{'tosa.table' op requires table size of 256, got 513}} + %0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8> + return %0 : tensor<2x64xi8> +} diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 241e603e91c61..7386b1ba9df99 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -497,9 +497,9 @@ func.func @test_mul_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x1x3xi8 // ----- -func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () { +func.func @test_table_non_const(%arg0 : tensor<4x5xi8>, %arg1 : tensor<256xi8>) -> () { // expected-error@+1 {{'tosa.table' op expected compile time resolvable constant, but got variable value for operand #1}} - %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<513xi8>) -> tensor<4x5xi8> + %0 = tosa.table %arg0, %arg1 : (tensor<4x5xi8>, tensor<256xi8>) -> tensor<4x5xi8> return }