diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index f3413ac6d0d29..546201484e8cb 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1879,24 +1879,14 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents( } LogicalResult MatMulOp::verify() { - auto aType = llvm::dyn_cast(getA().getType()); - auto bType = llvm::dyn_cast(getB().getType()); + const ShapeAdaptor aShape(getA().getType()); + const ShapeAdaptor bShape(getB().getType()); + const Type aElementType = aShape.getElementType(); + const Type bElementType = bShape.getElementType(); - // Must be shaped tensor types - if (!aType) - return emitOpError("expect a shaped tensor for input a, got ") - << getA().getType(); - - if (!bType) - return emitOpError("expect a shaped tensor for input b, got ") - << getB().getType(); - - auto aElementType = aType.getElementType(); - auto bElementType = bType.getElementType(); - - auto aQuantizedEType = + const auto aQuantizedEType = llvm::dyn_cast(aElementType); - auto bQuantizedEType = + const auto bQuantizedEType = llvm::dyn_cast(bElementType); if (aQuantizedEType || bQuantizedEType) { @@ -1915,21 +1905,19 @@ LogicalResult MatMulOp::verify() { } // check a_zp and b_zp - auto aEType = getStorageElementTypeOrSelf(aType); + auto aEType = getStorageElementTypeOrSelf(aElementType); auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType()); - if (aEType != aZpEType) { + if (aEType != aZpEType) return emitOpError("expect input a and a_zp have the same " "element type, got ") << aEType << " and " << aZpEType; - } - auto bEType = getStorageElementTypeOrSelf(bType); - auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType()); - if (bEType != bZpEType) { + const Type bEType = getStorageElementTypeOrSelf(bElementType); + const Type bZpEType = getStorageElementTypeOrSelf(getBZp().getType()); + if (bEType != bZpEType) return emitOpError("expect input b and b_zp have the same " "element type, got ") << bEType << " and " << bZpEType; - } FailureOr maybeAZp = getAZeroPoint(); if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed()) @@ -1939,6 +1927,45 @@ LogicalResult MatMulOp::verify() { if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed()) return failure(); + // Verify input/output shapes + int64_t N = ShapedType::kDynamic; + int64_t H = ShapedType::kDynamic; + int64_t W = ShapedType::kDynamic; + int64_t C = ShapedType::kDynamic; + + if (aShape.hasRank()) { + N = aShape.getDimSize(0); + H = aShape.getDimSize(1); + C = aShape.getDimSize(2); + } + + if (bShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(*this, N, bShape.getDimSize(0), "b", + "batch")) || + failed(tryUpdateDimOrFailure(*this, C, bShape.getDimSize(1), "b", + "channels"))) + return failure(); + W = bShape.getDimSize(2); + } + + const SmallVector expectedOutputShape = {N, H, W}; + const auto outputType = cast(getResult().getType()); + if (outputType.hasRank() && + failed( + verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) { + InFlightDiagnostic opError = emitOpError("expected output shape "); + auto stringifyDim = [&](int64_t d) { + if (ShapedType::isDynamic(d)) + opError << "?"; + else + opError << d; + }; + llvm::interleaveComma(outputType.getShape(), opError, stringifyDim); + opError << " to be compatible with expected output shape "; + llvm::interleaveComma(expectedOutputShape, opError, stringifyDim); + return opError; + } + return success(); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index b7334fb4246a7..e18cc40b78e8b 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1698,46 +1698,6 @@ func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tenso return %0 : tensor<1xi16> } -// ----- -// CHECK-LABEL: test_matmul_a_zp_same_element_type -func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { -%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> -%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> -// expected-error@+1 {{'tosa.matmul' op expect input a and a_zp have the same element type, got 'f32' and 'f16'}} -%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf16>, tensor<1xf32>) -> tensor<1x14x28xf32> - return %0 : tensor<1x14x28xf32> -} - -// ----- -// CHECK-LABEL: test_matmul_b_zp_same_element_type -func.func @test_matmul_b_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { -%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> -%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> -// expected-error@+1 {{'tosa.matmul' op expect input b and b_zp have the same element type, got 'f32' and 'f16'}} -%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf16>) -> tensor<1x14x28xf32> - return %0 : tensor<1x14x28xf32> -} - -// ----- -// CHECK-LABEL: test_matmul_a_zp_non_zero -func.func @test_matmul_a_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { -%azp0 = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : () -> tensor<1xf32> -%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> -// expected-error@+1 {{'tosa.matmul' op a zero point must be zero for non-int8 integer types}} -%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> - return %0 : tensor<1x14x28xf32> -} - -// ----- -// CHECK-LABEL: test_matmul_b_zp_non_zero -func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { -%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> -%bzp0 = "tosa.const"() <{values = dense<-1.0> : tensor<1xf32>}> : () -> tensor<1xf32> -// expected-error@+1 {{'tosa.matmul' op b zero point must be zero for non-int8 integer types}} -%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> - return %0 : tensor<1x14x28xf32> -} - // ----- // CHECK-LABEL: test_negate_same_element_type diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 80d5bca039909..931659fb435cc 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -1147,6 +1147,143 @@ func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32 // ----- +func.func @test_matmul_output_batch_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<5x4x6xf32>) -> tensor<2x3x6xf32> { + %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.matmul' op expected batch of b to match size 2, got 5}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<5x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x3x6xf32> + return %0 : tensor<2x3x6xf32> +} + +// ----- + +func.func @test_matmul_output_channel_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x7x6xf32>) -> tensor<2x3x6xf32> { + %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.matmul' op expected channels of b to match size 4, got 7}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<2x7x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x3x6xf32> + return %0 : tensor<2x3x6xf32> +} + +// ----- + +func.func @test_matmul_output_shape_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x4x6xf32>) -> tensor<2x5x6xf32> { + %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.matmul' op expected output shape 2, 5, 6 to be compatible with expected output shape 2, 3, 6}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<2x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x5x6xf32> + return %0 : tensor<2x5x6xf32> +} + +// ----- + + +func.func @test_matmul_dynamic_batch_mismatch(%arg0: tensor<2x?x4xf32>, %arg1: tensor<5x4x6xf32>) -> tensor<2x?x6xf32> { + %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.matmul' op expected batch of b to match size 2, got 5}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x?x4xf32>, tensor<5x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x?x6xf32> + return %0 : tensor<2x?x6xf32> +} + +// ----- + +func.func @test_matmul_dynamic_channel_mismatch(%arg0: tensor, %arg1: tensor) -> tensor { + %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.matmul' op expected channels of b to match size 4, got 7}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor, tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + return %0 : tensor +} + +// ----- + +func.func @test_matmul_dynamic_output_shape_mismatch(%arg0: tensor, %arg1: tensor<2x4x6xf32>) -> tensor<5x3x6xf32> { + %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.matmul' op expected output shape 5, 3, 6 to be compatible with expected output shape 2, 3, 6}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor, tensor<2x4x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x3x6xf32> + return %0 : tensor<5x3x6xf32> +} + +// ----- + + +func.func @test_matmul_unranked_b_output_shape_mismatch(%arg0: tensor<2x3x4xf32>, %arg1: tensor<*xf32>) -> tensor<2x5x?xf32> { + %azp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.matmul' op expected output shape 2, 5, ? to be compatible with expected output shape 2, 3, ?}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4xf32>, tensor<*xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x5x?xf32> + return %0 : tensor<2x5x?xf32> +} + +// ----- + + +func.func @test_matmul_quantized_mixed_operands(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x4x6xf32>) -> tensor<2x3x6xi32> { + %azp0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %bzp0 = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + // expected-error@+1 {{'tosa.matmul' op expect operands to be both quantized or both not quantized, got '!quant.uniform' and 'f32'}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4x!quant.uniform>, tensor<2x4x6xf32>, tensor<1xi8>, tensor<1xf32>) -> tensor<2x3x6xi32> + return %0 : tensor<2x3x6xi32> +} + +// ----- + +func.func @test_matmul_quantized_width_mismatch(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x4x6x!quant.uniform>) -> tensor<2x3x6xi32> { + %azp0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %bzp0 = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16> + // expected-error@+1 {{'tosa.matmul' op expect quantized operands to have same widths, got 8 and 16}} + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<2x3x4x!quant.uniform>, tensor<2x4x6x!quant.uniform>, tensor<1xi8>, tensor<1xi16>) -> tensor<2x3x6xi32> + return %0 : tensor<2x3x6xi32> +} + +// ----- + +// CHECK-LABEL: test_matmul_a_zp_same_element_type +func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> +%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> +// expected-error@+1 {{'tosa.matmul' op expect input a and a_zp have the same element type, got 'f32' and 'f16'}} +%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf16>, tensor<1xf32>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- + +// CHECK-LABEL: test_matmul_b_zp_same_element_type +func.func @test_matmul_b_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> +%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16> +// expected-error@+1 {{'tosa.matmul' op expect input b and b_zp have the same element type, got 'f32' and 'f16'}} +%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf16>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- + +// CHECK-LABEL: test_matmul_a_zp_non_zero +func.func @test_matmul_a_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +%azp0 = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : () -> tensor<1xf32> +%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> +// expected-error@+1 {{'tosa.matmul' op a zero point must be zero for non-int8 integer types}} +%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- + +// CHECK-LABEL: test_matmul_b_zp_non_zero +func.func @test_matmul_b_zp_non_zero(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { +%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> +%bzp0 = "tosa.const"() <{values = dense<-1.0> : tensor<1xf32>}> : () -> tensor<1xf32> +// expected-error@+1 {{'tosa.matmul' op b zero point must be zero for non-int8 integer types}} +%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf32>, tensor<1x19x28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> +} + +// ----- + func.func @test_matmul_t_block_scaled_data_mismatch(%arg0: tensor<4x8x32xf8E4M3FN>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32xf8E5M2>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> { // expected-error@+1 {{'tosa.matmul_t_block_scaled' op expect A_data and B_data to have same element type, got 'f8E4M3FN' and 'f8E5M2'}} %0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size : i32} : (tensor<4x8x32xf8E4M3FN>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32xf8E5M2>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>