diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index fa4bc120e9c1e..a09b06fec3928 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -662,6 +662,23 @@ static void printShapeToDiagnostic(InFlightDiagnostic &diag, llvm::interleaveComma(shape, diag, printDim); } +static LogicalResult +verifyOutputShapeCompatibleWithExpected(Operation *op, ShapedType outputType, + ArrayRef expectedShape, + StringRef outputName = "output") { + assert(outputType.hasRank() && "expected output type to be ranked"); + + if (succeeded(verifyCompatibleShape(outputType.getShape(), expectedShape))) + return success(); + + InFlightDiagnostic diag = op->emitOpError("expected "); + diag << outputName << " shape "; + printShapeToDiagnostic(diag, outputType.getShape()); + diag << " to be compatible with inferred shape "; + printShapeToDiagnostic(diag, expectedShape); + return diag; +} + LogicalResult verifyConvOutputSize( Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, @@ -2542,6 +2559,10 @@ LogicalResult tosa::MulOp::verify() { const bool aHasRank = aType.hasRank(); const bool bHasRank = bType.hasRank(); + + bool hasExpectedOutputShape = false; + SmallVector expectedOutputShape; + if (aHasRank && bHasRank) { const int64_t aRank = aType.getRank(); const int64_t bRank = bType.getRank(); @@ -2550,12 +2571,12 @@ LogicalResult tosa::MulOp::verify() { << aRank << " and " << bRank; // check for broadcast compatible shapes - SmallVector resultShape; if (!mlir::OpTrait::util::getBroadcastedShape( - aType.getShape(), bType.getShape(), resultShape)) + aType.getShape(), bType.getShape(), expectedOutputShape)) return emitOpError("a and b operands don't have broadcast-compatible " "shapes, got ") << aType << " and " << bType; + hasExpectedOutputShape = true; } ShapedType resultType = cast(output.getType()); @@ -2570,6 +2591,11 @@ LogicalResult tosa::MulOp::verify() { return emitOpError("result type has different rank than b, got ") << resultRank << " vs " << bType.getRank(); + if (hasExpectedOutputShape && + failed(verifyOutputShapeCompatibleWithExpected(getOperation(), resultType, + expectedOutputShape))) + return failure(); + return success(); } @@ -4846,12 +4872,7 @@ LogicalResult TransposeConv2DOp::verify() { } LogicalResult RescaleOp::verify() { - auto inputType = llvm::dyn_cast(getInput().getType()); - if (!inputType) { - emitOpError("expect shaped tensor for input, got ") << getInput().getType(); - return failure(); - } - + const auto inputType = llvm::cast(getInput().getType()); auto inputElementType = getStorageElementTypeOrSelf(inputType.getElementType()); if (!mlir::isa(inputElementType)) { @@ -4860,13 +4881,7 @@ LogicalResult RescaleOp::verify() { return failure(); } - auto outputType = llvm::dyn_cast(getOutput().getType()); - if (!outputType) { - emitOpError("expect shaped tensor for output, got ") - << getOutput().getType(); - return failure(); - } - + const auto outputType = llvm::cast(getOutput().getType()); auto outputElementType = getStorageElementTypeOrSelf(outputType.getElementType()); if (!mlir::isa(outputElementType)) { @@ -4891,19 +4906,7 @@ LogicalResult RescaleOp::verify() { if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed()) return failure(); - auto multiplierType = llvm::dyn_cast(getMultiplier().getType()); - if (!multiplierType) { - emitOpError("expect shaped tensor for multiplier, got ") - << getMultiplier().getType(); - return failure(); - } - - auto shiftType = llvm::dyn_cast(getShift().getType()); - if (!shiftType) { - emitOpError("expect shaped tensor for shift, got ") << getShift().getType(); - return failure(); - } - + const auto multiplierType = llvm::cast(getMultiplier().getType()); // multiplier element type must be i32 for scale32 = true if (getScale32() && !multiplierType.getElementType().isInteger(32)) { emitOpError("expect i32 element type for multiplier for scale32=true, got ") @@ -4936,28 +4939,34 @@ LogicalResult RescaleOp::verify() { numChannels = inputType.getDimSize(inputType.getRank() - 1); } - if (!multiplierType.hasRank()) - return success(); - - ArrayRef multiplierShape = multiplierType.getShape(); - // multiplier input has rank 1 by dialect definition - if (multiplierShape[0] != ShapedType::kDynamic && - multiplierShape[0] != numChannels) { - emitOpError("expect shape of { ") - << numChannels << " } for multiplier input, got { " - << multiplierShape[0] << " }"; - return failure(); + if (outputType.hasRank()) { + if (failed(verifyOutputShapeCompatibleWithExpected( + getOperation(), outputType, inputType.getShape()))) + return failure(); } - if (!shiftType.hasRank()) - return success(); + if (multiplierType.hasRank()) { + ArrayRef multiplierShape = multiplierType.getShape(); + // multiplier input has rank 1 by dialect definition + if (multiplierShape[0] != ShapedType::kDynamic && + multiplierShape[0] != numChannels) { + emitOpError("expect shape of { ") + << numChannels << " } for multiplier input, got { " + << multiplierShape[0] << " }"; + return failure(); + } + } - ArrayRef shiftShape = shiftType.getShape(); - // shift input has rank 1 by dialect definition - if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) { - emitOpError("expect shape of { ") - << numChannels << " } for shift input, got { " << shiftShape[0] << " }"; - return failure(); + const auto shiftType = llvm::cast(getShift().getType()); + if (shiftType.hasRank()) { + ArrayRef shiftShape = shiftType.getShape(); + // shift input has rank 1 by dialect definition + if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) { + emitOpError("expect shape of { ") + << numChannels << " } for shift input, got { " << shiftShape[0] + << " }"; + return failure(); + } } return success(); diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 1572df5357877..a5976b79adb65 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -423,6 +423,27 @@ func.func @test_error_scalar_input_with_per_channel(%arg0: tensor) -> tensor // ----- +func.func @test_rescale_invalid_static_output_shape(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x4xi8> { + %multiplier = "tosa.const"() <{values = dense<42> : tensor<1xi16>}> : () -> tensor<1xi16> + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.rescale' op expected output shape 13, 21, 4 to be compatible with inferred shape 13, 21, 3}} + %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x4xi8> + return %0 : tensor<13x21x4xi8> +} + +// ----- + +func.func @test_mul_invalid_static_output_shape(%arg0: tensor, %arg1: tensor) -> tensor { + %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op expected output shape ?, 21, 2 to be compatible with inferred shape ?, 21, 3}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor, tensor, tensor<1xi8>) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @test_gather_invalid_indices_N func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> { // expected-error@+1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}