diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index e9fdcbdc15837..4fc7ce81d9821 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -205,148 +205,142 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { constCheckers.emplace_back(checkConstantOperandNegate); } - bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_KERNEL) { - op->emitOpError() << "failed level check: " << checkDesc; - return false; - } - return true; + LogicalResult levelCheckKernel(Operation *op, int32_t v, + const StringRef checkDesc) { + if (v > tosaLevel.MAX_KERNEL) + return op->emitOpError() << "failed level check: " << checkDesc; + return success(); } - bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_STRIDE) { - op->emitOpError() << "failed level check: " << checkDesc; - return false; - } - return true; + LogicalResult levelCheckStride(Operation *op, int32_t v, + const StringRef checkDesc) { + if (v > tosaLevel.MAX_STRIDE) + return op->emitOpError() << "failed level check: " << checkDesc; + return success(); } - bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_SCALE) { - op->emitOpError() << "failed level check: " << checkDesc; - return false; - } - return true; + LogicalResult levelCheckScale(Operation *op, int32_t v, + const StringRef checkDesc) { + if (v > tosaLevel.MAX_SCALE) + return op->emitOpError() << "failed level check: " << checkDesc; + return success(); } - bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) { - op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: " - << checkDesc; - return false; - } - return true; + LogicalResult levelCheckListSize(Operation *op, int32_t v, + const StringRef checkDesc) { + if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) + return op->emitOpError() + << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc; + return success(); } // Perform the Level Rank check on the tensor type. - bool levelCheckRank(Operation *op, const Type typeToCheck, - const StringRef operandOrResult, int32_t highest_rank) { + LogicalResult levelCheckRank(Operation *op, const Type typeToCheck, + const StringRef operandOrResult, + int32_t highest_rank) { if (ShapedType type = dyn_cast(typeToCheck)) { - if (!type.hasRank()) { - op->emitOpError() << "failed level check: unranked tensor"; - return false; - } - if (type.getRank() > highest_rank) { - op->emitOpError() << "failed level check: " << operandOrResult - << " rank(shape) <= MAX_RANK"; - return false; - } + if (!type.hasRank()) + return op->emitOpError() << "failed level check: unranked tensor"; + if (type.getRank() > highest_rank) + return op->emitOpError() << "failed level check: " << operandOrResult + << " rank(shape) <= MAX_RANK"; } - return true; + return success(); } // Perform the Level Rank check on the tensor value. - bool levelCheckRank(Operation *op, const Value &v, - const StringRef operandOrResult, int32_t highest_rank) { + LogicalResult levelCheckRank(Operation *op, const Value &v, + const StringRef operandOrResult, + int32_t highest_rank) { return levelCheckRank(op, v.getType(), operandOrResult, highest_rank); } // Perform the Level tensor size check on the tensor type. - bool levelCheckSize(Operation *op, const Type &typeToCheck, - const StringRef operandOrResult); + LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck, + const StringRef operandOrResult); // Perform the Level tensor size check on the tensor value. - bool levelCheckSize(Operation *op, const Value &v, - const StringRef operandOrResult) { + LogicalResult levelCheckSize(Operation *op, const Value &v, + const StringRef operandOrResult) { return levelCheckSize(op, v.getType(), operandOrResult); } // Level check sizes of all operands and results of the operation. template - bool levelCheckSizes(T tosaOp) { + LogicalResult levelCheckSizes(T tosaOp) { auto op = tosaOp.getOperation(); for (auto v : op->getOperands()) { - if (!levelCheckSize(op, v, "operand")) - return false; + if (failed(levelCheckSize(op, v, "operand"))) + return failure(); } for (auto v : op->getResults()) { - if (!levelCheckSize(op, v, "result")) - return false; + if (failed(levelCheckSize(op, v, "result"))) + return failure(); } - return true; + return success(); } // Level check ranks of all operands, attribute and results of the operation. template - bool levelCheckRanks(T tosaOp) { + LogicalResult levelCheckRanks(T tosaOp) { auto op = tosaOp.getOperation(); for (auto v : op->getOperands()) { - if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)) - return false; + if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))) + return failure(); } for (auto v : op->getResults()) { - if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)) - return false; + if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))) + return failure(); } - return true; + return success(); } // Level check ranks and sizes. - bool levelCheckRanksAndSizes(Operation *op); + LogicalResult levelCheckRanksAndSizes(Operation *op); // Pool Op: level check kernel/stride/pad values template - bool levelCheckPool(Operation *op) { + LogicalResult levelCheckPool(Operation *op) { if (auto poolOp = dyn_cast(op)) { for (auto k : poolOp.getKernel()) { - if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) { - return false; + if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) { + return failure(); } } for (auto s : poolOp.getStride()) { - if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { - return false; + if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) { + return failure(); } } for (auto p : poolOp.getPad()) { - if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { - return false; + if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) { + return failure(); } } } - return true; + return success(); } // Conv Op: level check dilation/stride/pad values template - bool levelCheckConv(Operation *op) { + LogicalResult levelCheckConv(Operation *op) { if (auto convOp = dyn_cast(op)) { for (auto k : convOp.getDilation()) { - if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) { - return false; + if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) { + return failure(); } } for (auto p : convOp.getPad()) { - if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { - return false; + if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) { + return failure(); } } for (auto s : convOp.getStride()) { - if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { - return false; + if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) { + return failure(); } } auto dilation = convOp.getDilation(); @@ -356,100 +350,100 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { if (isa(op)) { assert(shape.size() == 4); assert(dilation.size() == 2); - if (!levelCheckKernel(op, dilation[0] * shape[1], - "dilation_y * KH <= MAX_KERNEL)") || - !levelCheckKernel(op, dilation[1] * shape[2], - "dilation_x * KW <= MAX_KERNEL)")) - return false; + if (failed(levelCheckKernel(op, dilation[0] * shape[1], + "dilation_y * KH <= MAX_KERNEL)")) || + failed(levelCheckKernel(op, dilation[1] * shape[2], + "dilation_x * KW <= MAX_KERNEL)"))) + return failure(); } else if (isa(op)) { assert(shape.size() == 5); assert(dilation.size() == 3); - if (!levelCheckKernel(op, dilation[0] * shape[1], - "dilation_d * KD <= MAX_KERNEL)") || - !levelCheckKernel(op, dilation[1] * shape[2], - "dilation_y * KH <= MAX_KERNEL)") || - !levelCheckKernel(op, dilation[2] * shape[3], - "dilation_x * KW <= MAX_KERNEL)")) - return false; + if (failed(levelCheckKernel(op, dilation[0] * shape[1], + "dilation_d * KD <= MAX_KERNEL)")) || + failed(levelCheckKernel(op, dilation[1] * shape[2], + "dilation_y * KH <= MAX_KERNEL)")) || + failed(levelCheckKernel(op, dilation[2] * shape[3], + "dilation_x * KW <= MAX_KERNEL)"))) + return failure(); } else if (isa(op)) { assert(shape.size() == 4); assert(dilation.size() == 2); - if (!levelCheckKernel(op, dilation[0] * shape[0], - "dilation_y * KH <= MAX_KERNEL)") || - !levelCheckKernel(op, dilation[1] * shape[1], - "dilation_x * KW <= MAX_KERNEL)")) - return false; + if (failed(levelCheckKernel(op, dilation[0] * shape[0], + "dilation_y * KH <= MAX_KERNEL)")) || + failed(levelCheckKernel(op, dilation[1] * shape[1], + "dilation_x * KW <= MAX_KERNEL)"))) + return failure(); } } } - return true; + return success(); } // FFT op: level check H, W in input shape [N,H,W] template - bool levelCheckFFT(Operation *op) { + LogicalResult levelCheckFFT(Operation *op) { if (isa(op)) { for (auto v : op->getOperands()) { if (ShapedType type = dyn_cast(v.getType())) { auto shape = type.getShape(); assert(shape.size() == 3); - if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") || - !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) { - return false; + if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) || + failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) { + return failure(); } } } } - return true; + return success(); } // TransposeConv2d op: level check kH/kW, outpad, and stride - bool levelCheckTransposeConv2d(Operation *op) { + LogicalResult levelCheckTransposeConv2d(Operation *op) { if (auto transpose = dyn_cast(op)) { if (ShapedType filterType = dyn_cast(transpose.getWeight().getType())) { auto shape = filterType.getShape(); assert(shape.size() == 4); // level check kernel sizes for kH and KW - if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") || - !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) { - return false; + if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) || + failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) { + return failure(); } } for (auto p : transpose.getOutPad()) { - if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { - return false; + if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) { + return failure(); } } for (auto s : transpose.getStride()) { - if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { - return false; + if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) { + return failure(); } } } - return true; + return success(); } // Resize op: level check max scales - bool levelCheckResize(Operation *op) { + LogicalResult levelCheckResize(Operation *op) { if (auto resize = dyn_cast(op)) { SmallVector scale; if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) { - return false; + return failure(); } const int64_t scaleYN = scale[0]; const int64_t scaleYD = scale[1]; const int64_t scaleXN = scale[2]; const int64_t scaleXD = scale[3]; - if (!levelCheckScale(op, scaleYN / scaleYD, - "scale_y_n/scale_y_d <= MAX_SCALE") || - !levelCheckScale(op, scaleXN / scaleXD, - "scale_x_n/scale_x_d <= MAX_SCALE")) { - return false; + if (failed(levelCheckScale(op, scaleYN / scaleYD, + "scale_y_n/scale_y_d <= MAX_SCALE")) || + failed(levelCheckScale(op, scaleXN / scaleXD, + "scale_x_n/scale_x_d <= MAX_SCALE"))) { + return failure(); } } - return true; + return success(); } // Recursively perform a bottom-up search to determine the maximum nesting @@ -468,62 +462,65 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { getMaxNestedDepth(op, depth); } - bool levelCheckMaxNesting(Operation *op) { + LogicalResult levelCheckMaxNesting(Operation *op) { int32_t maxNestedDepth = 0; getMaxNestedDepth(op, maxNestedDepth); if (maxNestedDepth >= tosaLevel.MAX_NESTING) { op->emitOpError() << "failed level check: " << maxNestedDepth << " >= MAX_NESTING"; - return false; + return failure(); } - return true; + return success(); } - bool levelCheckListSize(Operation *op) { + LogicalResult levelCheckListSize(Operation *op) { if (auto concat = dyn_cast(op)) { return levelCheckListSize(op, concat.getInput1().size(), "input1"); } if (auto custom = dyn_cast(op)) { - if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") || - !levelCheckListSize(op, custom.getOutputList().size(), - "output_list")) { - return false; + if (failed(levelCheckListSize(op, custom.getInputList().size(), + "input_list")) || + failed(levelCheckListSize(op, custom.getOutputList().size(), + "output_list"))) { + return failure(); } } if (auto condIf = dyn_cast(op)) { - if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") || - !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) { - return false; + if (failed( + levelCheckListSize(op, condIf.getInputList().size(), "inputs")) || + failed(levelCheckListSize(op, condIf.getOutputList().size(), + "outputs"))) { + return failure(); } } if (auto w = dyn_cast(op)) { - if (!levelCheckListSize(op, w.getInputList().size(), "inputs") || - !levelCheckListSize(op, w.getOutputList().size(), "outputs")) { - return false; + if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) || + failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) { + return failure(); } } - return true; + return success(); } - bool attributeCheckRescale(Operation *op) { + LogicalResult attributeCheckRescale(Operation *op) { if (auto rescale = dyn_cast(op)) { if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !targetEnv.allows(Extension::doubleround)) { op->emitOpError() << "failed attribute check: rounding_mode = DOUBLE_ROUND " << "requires extension [doubleround]"; - return false; + return failure(); } if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND && !targetEnv.allows(Extension::inexactround)) { op->emitOpError() << "failed attribute check: rounding_mode = INEXACT_ROUND " << "requires extension [inexactround]"; - return false; + return failure(); } } - return true; + return success(); } // configure profile and level values from pass options profileName and @@ -563,8 +560,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { } } - bool CheckVariable(Operation *op); - bool CheckVariableReadOrWrite(Operation *op); + LogicalResult CheckVariable(Operation *op); + LogicalResult CheckVariableReadOrWrite(Operation *op); bool isValidElementType(Type type, const bool allowUnsigned = false); SmallVector< @@ -577,62 +574,66 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { }; template <> -bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) { +LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) { auto *op = tosaOp.getOperation(); - if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)) - return false; + if (failed( + levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))) + return failure(); // rank(output) = rank(input) - 1 - if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1)) - return false; + if (failed(levelCheckRank(op, tosaOp.getOutput(), "result", + tosaLevel.MAX_RANK - 1))) + return failure(); - return true; + return success(); } template <> -bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { +LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { auto *op = tosaOp.getOperation(); // Only the condition input has rank limitation. - if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK)) - return false; + if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand", + tosaLevel.MAX_RANK))) + return failure(); - return true; + return success(); } template <> -bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { +LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { auto *op = tosaOp.getOperation(); auto variableType = getVariableType(tosaOp); - if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK)) - return false; + if (failed(levelCheckRank(op, variableType, "variable type", + tosaLevel.MAX_RANK))) + return failure(); - return true; + return success(); } template <> -bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) { +LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) { auto *op = tosaOp.getOperation(); auto variableType = getVariableType(tosaOp); - if (!levelCheckSize(op, variableType, "variable type")) - return false; + if (failed(levelCheckSize(op, variableType, "variable type"))) + return failure(); - return true; + return success(); } -bool TosaValidation::levelCheckRanksAndSizes(Operation *op) { +LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { #define CHECK_RANKS_AND_SIZES(tosaOp) \ if (isa(op)) { \ - if (!levelCheckRanks(cast(op))) \ - return false; \ - if (!levelCheckSizes(cast(op))) \ - return false; \ + if (failed(levelCheckRanks(cast(op)))) \ + return failure(); \ + if (failed(levelCheckSizes(cast(op)))) \ + return failure(); \ } #define CHECK_SIZES(tosaOp) \ if (isa(op)) { \ - if (!levelCheckSizes(cast(op))) \ - return false; \ + if (failed(levelCheckSizes(cast(op)))) \ + return failure(); \ } // Tensor Operators @@ -735,24 +736,21 @@ bool TosaValidation::levelCheckRanksAndSizes(Operation *op) { #undef CHECK_RANKS_AND_SIZES #undef CHECK_SIZES - return true; + return success(); } // Perform the Level tensor size check on the tensor type. -bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck, - const StringRef operandOrResult) { +LogicalResult TosaValidation::levelCheckSize(Operation *op, + const Type &typeToCheck, + const StringRef operandOrResult) { if (ShapedType type = dyn_cast(typeToCheck)) { - if (!type.hasRank()) { - op->emitOpError() << "failed level check: unranked tensor"; - return false; - } + if (!type.hasRank()) + return op->emitOpError() << "failed level check: unranked tensor"; auto shape = type.getShape(); for (auto dim : shape) { - if (mlir::ShapedType::isDynamic(dim)) { - op->emitOpError() << "failed level check: " << operandOrResult - << " shape dimension cannot be dynamic"; - return false; - } + if (mlir::ShapedType::isDynamic(dim)) + return op->emitOpError() << "failed level check: " << operandOrResult + << " shape dimension cannot be dynamic"; } int64_t element_bits = type.getElementTypeBitWidth(); @@ -765,14 +763,12 @@ bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck, // For each tensor, the number of tensor elements multiplied by the // element size in bytes must be representable as a tensor_size_t. const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1; - if (size > max_size) { - op->emitOpError() - << "failed level check: " << operandOrResult - << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)"; - return false; - } + if (size > max_size) + return op->emitOpError() + << "failed level check: " << operandOrResult + << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)"; } - return true; + return success(); } LogicalResult TosaValidation::applyLevelCheck(Operation *op) { @@ -782,28 +778,28 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) { } // check rank and sizes early so later checks can assume shaped operands - if (!levelCheckRanksAndSizes(op)) + if (failed(levelCheckRanksAndSizes(op))) return failure(); // additional level checks from spec 0.70 - if (!levelCheckPool(op) || - !levelCheckConv(op) || - !levelCheckConv(op) || - !levelCheckConv(op) || - !levelCheckFFT(op) || - !levelCheckPool(op) || - !levelCheckFFT(op) || !levelCheckTransposeConv2d(op) || - !levelCheckResize(op)) { + if (failed(levelCheckPool(op)) || + failed(levelCheckConv(op)) || + failed(levelCheckConv(op)) || + failed(levelCheckConv(op)) || + failed(levelCheckFFT(op)) || + failed(levelCheckPool(op)) || + failed(levelCheckFFT(op)) || + failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) { return failure(); } // level check MAX_TENSOR_LIST_SIZE - if (!levelCheckListSize(op)) { + if (failed(levelCheckListSize(op))) { return failure(); } if (isa(op) || isa(op)) { - if (!levelCheckMaxNesting(op)) { + if (failed(levelCheckMaxNesting(op))) { return failure(); } } @@ -812,7 +808,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) { } LogicalResult TosaValidation::applyAttributeCheck(Operation *op) { - if (!attributeCheckRescale(op)) + if (failed(attributeCheckRescale(op))) return failure(); return success(); } @@ -823,14 +819,12 @@ inline bool CompatibleTypes(const mlir::Type &type, return type == declaredType; } -bool TosaValidation::CheckVariable(Operation *op) { +LogicalResult TosaValidation::CheckVariable(Operation *op) { if (auto variableOp = dyn_cast(op)) { mlir::StringAttr nameAttr = variableOp.getNameAttr(); - if (variablesMap.count(nameAttr)) { - op->emitOpError() << "name has already been declared"; - return false; - } + if (variablesMap.count(nameAttr)) + return op->emitOpError() << "name has already been declared"; auto elementType = variableOp.getType(); DenseIntElementsAttr varShapeAttr = variableOp.getVarShape(); @@ -841,51 +835,44 @@ bool TosaValidation::CheckVariable(Operation *op) { variablesMap[nameAttr] = variableType; } - return true; + return success(); } -bool TosaValidation::CheckVariableReadOrWrite(Operation *op) { +LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) { if (isa(op) || isa(op)) { mlir::StringAttr nameAttr = cast(op->getAttr("name")); - if (!variablesMap.count(nameAttr)) { - op->emitOpError() << "name has not been declared"; - return false; - } + if (!variablesMap.count(nameAttr)) + return op->emitOpError() << "name has not been declared"; auto varType = variablesMap[nameAttr]; for (auto v : op->getOperands()) { auto type = v.getType(); - if (!CompatibleTypes(type, varType)) { - op->emitOpError() << "operand type does not equal variable type"; - return false; - } + if (!CompatibleTypes(type, varType)) + return op->emitOpError() << "operand type does not equal variable type"; } for (auto v : op->getResults()) { auto type = v.getType(); - if (!CompatibleTypes(type, varType)) { - op->emitOpError() << "result type does not equal variable type"; - return false; - } + if (!CompatibleTypes(type, varType)) + return op->emitOpError() << "result type does not equal variable type"; } } - return true; + return success(); } LogicalResult TosaValidation::applyVariableCheck(Operation *op) { - if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) { + if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op))) return failure(); - } return success(); } -bool checkErrorIfResize(Operation *op) { +LogicalResult checkErrorIfResize(Operation *op) { auto resize = dyn_cast(op); if (!resize) - return true; + return success(); const Value input = resize.getInput(); const Value output = resize.getOutput(); @@ -894,10 +881,8 @@ bool checkErrorIfResize(Operation *op) { const RankedTensorType outputType = llvm::dyn_cast(output.getType()); - if (!inputType || !outputType) { - op->emitOpError("expect ranked input/output tensor"); - return false; - } + if (!inputType || !outputType) + return op->emitOpError("expect ranked input/output tensor"); // Ensure the image size is supported by GPU APIs and that for integer // implementations, position * stride does not overflow int32_t. @@ -906,17 +891,15 @@ bool checkErrorIfResize(Operation *op) { outputType.getDimSize(1), outputType.getDimSize(2), inputType.getDimSize(1), inputType.getDimSize(2)}; const int64_t *maxDim = llvm::max_element(sizes); - if (maxDim != sizes.end() && *maxDim >= 16384) { - op->emitOpError("expect input/output height/width dims to be < 16384, ") - << "got [OH, OW, IH, IW] = " << sizes; - return false; - } + if (maxDim != sizes.end() && *maxDim >= 16384) + return op->emitOpError( + "expect input/output height/width dims to be < 16384, ") + << "got [OH, OW, IH, IW] = " << sizes; } SmallVector scale; - if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) { - return false; - } + if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) + return failure(); const int64_t scaleYN = scale[0]; const int64_t scaleYD = scale[1]; @@ -924,57 +907,45 @@ bool checkErrorIfResize(Operation *op) { const int64_t scaleXD = scale[3]; // Ensure scale values don't overflow int32 accumulator - if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) { - op->emitOpError("expect all scale numerator values to be <= (1 << 11), " - "got scale_y_n=") - << scaleYN << ", scale_x_n=" << scaleXN; - return false; - } + if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) + return op->emitOpError( + "expect all scale numerator values to be <= (1 << 11), " + "got scale_y_n=") + << scaleYN << ", scale_x_n=" << scaleXN; - if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) { - op->emitOpError("expect a downscale ratio larger than 1/16, got y=") - << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD; - return false; - } + if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) + return op->emitOpError("expect a downscale ratio larger than 1/16, got y=") + << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD; SmallVector offset; SmallVector border; if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) || - !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border)) { - return false; - } + !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border)) + return failure(); const int64_t offsetY = offset[0]; const int64_t offsetX = offset[1]; // Set a consistent lower limit of 1/16 downscale to simplify // implementations - if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) { - op->emitOpError( - "expect offsetY / scaleYNumerator to be in range [-1, 16), got ") - << offsetY << "/" << scaleYN; - return false; - } - if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) { - op->emitOpError( - "expect offsetX / scaleXNumerator to be in range [-1, 16), got ") - << offsetX << "/" << scaleXN; - return false; - } + if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) + return op->emitOpError( + "expect offsetY / scaleYNumerator to be in range [-1, 16), got ") + << offsetY << "/" << scaleYN; + if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) + return op->emitOpError( + "expect offsetX / scaleXNumerator to be in range [-1, 16), got ") + << offsetX << "/" << scaleXN; const int64_t borderY = border[0]; const int64_t borderX = border[1]; - if (borderY < -16 * scaleYN || borderY >= scaleYN) { - op->emitOpError( - "expect borderY / scaleYNumerator to be in range [-16, 1), got ") - << borderY << "/" << scaleYN; - return false; - } - if (borderX < -16 * scaleXN || borderX >= scaleXN) { - op->emitOpError( - "expect borderX / scaleXNumerator to be in range [-16, 1), got ") - << borderX << "/" << scaleXN; - return false; - } + if (borderY < -16 * scaleYN || borderY >= scaleYN) + return op->emitOpError( + "expect borderY / scaleYNumerator to be in range [-16, 1), got ") + << borderY << "/" << scaleYN; + if (borderX < -16 * scaleXN || borderX >= scaleXN) + return op->emitOpError( + "expect borderX / scaleXNumerator to be in range [-16, 1), got ") + << borderX << "/" << scaleXN; // The following section of code is mostly duplicated with ResizeOp::verify(). // @@ -1001,81 +972,72 @@ bool checkErrorIfResize(Operation *op) { if (ih != ShapedType::kDynamic) { const std::optional calculatedOutHeightMinusOne = idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD); - if (!calculatedOutHeightMinusOne.has_value()) { - op->emitOpError("expected (input_height - 1) * scale_y_n - offset_y + " - "border_y ") - << "to be wholly divisible by scale_y_d, got ((" << ih << " - 1) * " - << scaleYN << " - " << offsetY << " + " << borderY << ") / " - << scaleYD; - return false; - } + if (!calculatedOutHeightMinusOne.has_value()) + return op->emitOpError( + "expected (input_height - 1) * scale_y_n - offset_y + " + "border_y ") + << "to be wholly divisible by scale_y_d, got ((" << ih + << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY + << ") / " << scaleYD; const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1; - if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) { - op->emitOpError("calculated output height did not match expected: ") - << "calculated=" << calculatedOutHeight << ", expected=" << oh; - return false; - } + if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) + return op->emitOpError( + "calculated output height did not match expected: ") + << "calculated=" << calculatedOutHeight << ", expected=" << oh; } if (iw != ShapedType::kDynamic) { const std::optional calculatedOutWidthMinusOne = idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD); - if (!calculatedOutWidthMinusOne.has_value()) { - op->emitOpError("expected (input_width - 1) * scale_x_n - offset_x + " - "border_x ") - << "to be wholly divisible by scale_x_d, got ((" << iw << " - 1) * " - << scaleXN << " - " << offsetX << " + " << borderX << ") / " - << scaleXD; - return false; - } + if (!calculatedOutWidthMinusOne.has_value()) + return op->emitOpError( + "expected (input_width - 1) * scale_x_n - offset_x + " + "border_x ") + << "to be wholly divisible by scale_x_d, got ((" << iw + << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX + << ") / " << scaleXD; const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1; - if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) { - op->emitOpError("calculated output width did not match expected: ") - << "calculated=" << calculatedOutWidth << ", expected=" << ow; - return false; - } + if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) + return op->emitOpError("calculated output width did not match expected: ") + << "calculated=" << calculatedOutWidth << ", expected=" << ow; } - return true; + return success(); } -bool checkErrorIfMul(Operation *op) { +LogicalResult checkErrorIfMul(Operation *op) { auto mul = dyn_cast(op); if (!mul) - return true; + return success(); // REQUIRE(0 <= shift && shift <= 63); // REQUIRE(is_same() || shift == 0); ElementsAttr shift_elem; - if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) { - return true; - } + if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) + return success(); int32_t shift = shift_elem.getValues()[0].getInt(); auto inputElemType = getElementTypeOrSelf(mul.getInput1()); if (inputElemType.isInteger(32)) { // 0 <= shift <= 63 for int32_t type - if (shift < 0 || shift > 63) { - op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: " - << shift; - return false; - } + if (shift < 0 || shift > 63) + return op->emitOpError() + << "requires 0 <= shift && shift <= 63, but got: " << shift; } else { // shift must be 0 for all other types - if (shift != 0) { - op->emitOpError() << "requires shift = 0 for all input data types that " - "are not int32_t, but got: " - << shift; - return false; - } + if (shift != 0) + return op->emitOpError() + << "requires shift = 0 for all input data types that " + "are not int32_t, but got: " + << shift; } - return true; + return success(); } -bool checkErrorIfTable(Operation *op) { +LogicalResult checkErrorIfTable(Operation *op) { auto table = dyn_cast(op); if (!table) - return true; + return success(); // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513 const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType()); @@ -1084,26 +1046,24 @@ bool checkErrorIfTable(Operation *op) { const ShapeAdaptor tableShape(table.getTable().getType()); if (tableShape.hasStaticShape()) { const auto numElements = tableShape.getNumElements(); - if (numElements != tableSize) { - op->emitOpError() << "requires table size of " << tableSize << ", got " - << numElements; - return false; - } + if (numElements != tableSize) + return op->emitOpError() << "requires table size of " << tableSize + << ", got " << numElements; } - return true; + return success(); } -bool checkErrorIfRescale(Operation *op) { +LogicalResult checkErrorIfRescale(Operation *op) { auto rescale = dyn_cast(op); if (!rescale) - return true; + return success(); auto inputType = llvm::dyn_cast(rescale.getInput().getType()); auto outputType = llvm::dyn_cast(rescale.getOutput().getType()); if (!inputType || !outputType || !inputType.getElementType().isInteger() || !outputType.getElementType().isInteger()) - return true; + return success(); auto inElemType = inputType.getElementType(); auto outElemType = outputType.getElementType(); @@ -1117,81 +1077,65 @@ bool checkErrorIfRescale(Operation *op) { auto roundingMode = rescale.getRoundingMode(); // ERROR_IF(scale32 && is_same()) - if (scale32 && inWidth == 48) { - op->emitOpError() << "scale32 is not allowed with 48-bit input."; - return false; - } + if (scale32 && inWidth == 48) + return op->emitOpError() << "scale32 is not allowed with 48-bit input."; // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND)) - if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND) { - op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true."; - return false; - } + if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND) + return op->emitOpError() + << "DOUBLE_ROUND is only allowed with scale32=true."; // ERROR_IF(input_unsigned && output_unsigned) - if (inputUnsigned && outputUnsigned) { - op->emitOpError() << "input and output cannot be both unsigned."; - return false; - } + if (inputUnsigned && outputUnsigned) + return op->emitOpError() << "input and output cannot be both unsigned."; // ERROR_IF(is_same() && input_unsigned) - if (outWidth == 32 && inputUnsigned) { - op->emitOpError() << "i32 output type is not allowed with unsigned input."; - return false; - } + if (outWidth == 32 && inputUnsigned) + return op->emitOpError() + << "i32 output type is not allowed with unsigned input."; // ERROR_IF(is_same() && output_unsigned) - if (inWidth == 32 && outputUnsigned) { - op->emitOpError() << "i32 input type is not allowed with unsigned output."; - return false; - } + if (inWidth == 32 && outputUnsigned) + return op->emitOpError() + << "i32 input type is not allowed with unsigned output."; // ERROR_IF(is_same() && output_unsigned) - if (inWidth == 48 && outputUnsigned) { - op->emitOpError() << "i48 input type is not allowed with unsigned output."; - return false; - } + if (inWidth == 48 && outputUnsigned) + return op->emitOpError() + << "i48 input type is not allowed with unsigned output."; // ERROR_IF(is_same && input_unsigned) - if (inWidth == 48 && inputUnsigned) { - op->emitOpError() << "i48 input type cannot be unsigned."; - return false; - } + if (inWidth == 48 && inputUnsigned) + return op->emitOpError() << "i48 input type cannot be unsigned."; // ERROR_IF(is_same && input_unsigned) - if (inWidth == 32 && inputUnsigned) { - op->emitOpError() << "i32 input type cannot be unsigned."; - return false; - } + if (inWidth == 32 && inputUnsigned) + return op->emitOpError() << "i32 input type cannot be unsigned."; // ERROR_IF(is_same && output_unsigned) - if (outWidth == 32 && outputUnsigned) { - op->emitOpError() << "i32 output type cannot be unsigned."; - return false; - } + if (outWidth == 32 && outputUnsigned) + return op->emitOpError() << "i32 output type cannot be unsigned."; - return true; + return success(); } -bool checkErrorIfPad(Operation *op) { +LogicalResult checkErrorIfPad(Operation *op) { auto pad = dyn_cast(op); if (!pad) - return true; + return success(); DenseIntElementsAttr paddingAttr; if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr))) // Pad verifier will catch this - return true; + return success(); for (const APInt &val : paddingAttr.getValues()) { - if (val.getSExtValue() < 0) { - op->emitOpError() << "padding value must all be non-negative, got " - << val.getSExtValue(); - return false; - } + if (val.getSExtValue() < 0) + return op->emitOpError() << "padding value must all be non-negative, got " + << val.getSExtValue(); } - return true; + return success(); } static bool isOpIsolatedWithinRegion(Operation *op, Region *region) { @@ -1201,7 +1145,7 @@ static bool isOpIsolatedWithinRegion(Operation *op, Region *region) { }); } -static bool isRegionIsolatedFromAbove(Region ®ionToCheck) { +static LogicalResult isRegionIsolatedFromAbove(Region ®ionToCheck) { bool noLiveInValue = true; regionToCheck.walk([&noLiveInValue, ®ionToCheck](Operation *op) { if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) { @@ -1210,23 +1154,22 @@ static bool isRegionIsolatedFromAbove(Region ®ionToCheck) { } return WalkResult::advance(); }); - return noLiveInValue; + return noLiveInValue ? success() : failure(); } LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck, StringRef regionName) { - if (isRegionIsolatedFromAbove(regionToCheck)) + if (succeeded(isRegionIsolatedFromAbove(regionToCheck))) return success(); - op->emitOpError() - << "is not conformant to the TOSA specification. It requires the '" - << regionName << "' region is isolated from above.\n"; - return failure(); + return op->emitOpError() + << "is not conformant to the TOSA specification. It requires the '" + << regionName << "' region is isolated from above.\n"; } -bool checkErrorIfCondIf(Operation *op) { +LogicalResult checkErrorIfCondIf(Operation *op) { auto ifOp = dyn_cast(op); if (!ifOp) - return true; + return success(); // Currently the dialect supports declaring cond_if operations that // have then/else regions that reference values from outside these @@ -1257,49 +1200,53 @@ bool checkErrorIfCondIf(Operation *op) { // tosa.yield %arg4 // } - return succeeded(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) && - succeeded(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")); + if (failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) || + failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"))) + return failure(); + return success(); } -bool checkErrorIfWhileLoop(Operation *op) { +LogicalResult checkErrorIfWhileLoop(Operation *op) { auto whileOp = dyn_cast(op); if (!whileOp) - return true; + return success(); - return succeeded(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) && - succeeded(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")); + if (failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) || + failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"))) + return failure(); + return success(); } -bool checkErrorIfScatter(Operation *op) { +LogicalResult checkErrorIfScatter(Operation *op) { auto scatterOp = dyn_cast(op); if (!scatterOp) - return true; + return success(); // for constant indices, check that there are no duplicate values DenseIntElementsAttr indicesAttr; if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr))) - return true; + return success(); auto const indicesType = dyn_cast(scatterOp.getIndices().getType()); if (!indicesType || !indicesType.hasRank()) { op->emitOpError("expect ranked indices tensor"); - return false; + return failure(); } if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) { op->emitOpError("indices values contain duplicates"); - return false; + return failure(); } - return true; + return success(); } LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { - if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || - !checkErrorIfTable(op) || !checkErrorIfRescale(op) || - !checkErrorIfPad(op) || !checkErrorIfCondIf(op) || - !checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op)) + if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) || + failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) || + failed(checkErrorIfPad(op)) || failed(checkErrorIfCondIf(op)) || + failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op))) return failure(); return success(); }