-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][tosa] Use LogicalResult
in validation functions
#160052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesThis commit replaces functions that previously returned Note: this PR also contains the contents of #159754, so shouldn't be merged before #159754. Patch is 51.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160052.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 790bbf77877bc..6ea4e7736f78c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -205,148 +205,142 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
constCheckers.emplace_back(checkConstantOperandNegate);
}
- bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_KERNEL) {
- op->emitOpError() << "failed level check: " << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckKernel(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_KERNEL)
+ return op->emitOpError() << "failed level check: " << checkDesc;
+ return success();
}
- bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_STRIDE) {
- op->emitOpError() << "failed level check: " << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckStride(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_STRIDE)
+ return op->emitOpError() << "failed level check: " << checkDesc;
+ return success();
}
- bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_SCALE) {
- op->emitOpError() << "failed level check: " << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckScale(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_SCALE)
+ return op->emitOpError() << "failed level check: " << checkDesc;
+ return success();
}
- bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
- op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
- << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckListSize(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+ return op->emitOpError()
+ << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
+ return success();
}
// Perform the Level Rank check on the tensor type.
- bool levelCheckRank(Operation *op, const Type typeToCheck,
- const StringRef operandOrResult, int32_t highest_rank) {
+ LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
+ const StringRef operandOrResult,
+ int32_t highest_rank) {
if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
- if (!type.hasRank()) {
- op->emitOpError() << "failed level check: unranked tensor";
- return false;
- }
- if (type.getRank() > highest_rank) {
- op->emitOpError() << "failed level check: " << operandOrResult
- << " rank(shape) <= MAX_RANK";
- return false;
- }
+ if (!type.hasRank())
+ return op->emitOpError() << "failed level check: unranked tensor";
+ if (type.getRank() > highest_rank)
+ return op->emitOpError() << "failed level check: " << operandOrResult
+ << " rank(shape) <= MAX_RANK";
}
- return true;
+ return success();
}
// Perform the Level Rank check on the tensor value.
- bool levelCheckRank(Operation *op, const Value &v,
- const StringRef operandOrResult, int32_t highest_rank) {
+ LogicalResult levelCheckRank(Operation *op, const Value &v,
+ const StringRef operandOrResult,
+ int32_t highest_rank) {
return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
}
// Perform the Level tensor size check on the tensor type.
- bool levelCheckSize(Operation *op, const Type &typeToCheck,
- const StringRef operandOrResult);
+ LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
+ const StringRef operandOrResult);
// Perform the Level tensor size check on the tensor value.
- bool levelCheckSize(Operation *op, const Value &v,
- const StringRef operandOrResult) {
+ LogicalResult levelCheckSize(Operation *op, const Value &v,
+ const StringRef operandOrResult) {
return levelCheckSize(op, v.getType(), operandOrResult);
}
// Level check sizes of all operands and results of the operation.
template <typename T>
- bool levelCheckSizes(T tosaOp) {
+ LogicalResult levelCheckSizes(T tosaOp) {
auto op = tosaOp.getOperation();
for (auto v : op->getOperands()) {
- if (!levelCheckSize(op, v, "operand"))
- return false;
+ if (failed(levelCheckSize(op, v, "operand")))
+ return failure();
}
for (auto v : op->getResults()) {
- if (!levelCheckSize(op, v, "result"))
- return false;
+ if (failed(levelCheckSize(op, v, "result")))
+ return failure();
}
- return true;
+ return success();
}
// Level check ranks of all operands, attribute and results of the operation.
template <typename T>
- bool levelCheckRanks(T tosaOp) {
+ LogicalResult levelCheckRanks(T tosaOp) {
auto op = tosaOp.getOperation();
for (auto v : op->getOperands()) {
- if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
+ return failure();
}
for (auto v : op->getResults()) {
- if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
+ return failure();
}
- return true;
+ return success();
}
// Level check ranks and sizes.
- bool levelCheckRanksAndSizes(Operation *op);
+ LogicalResult levelCheckRanksAndSizes(Operation *op);
// Pool Op: level check kernel/stride/pad values
template <typename T>
- bool levelCheckPool(Operation *op) {
+ LogicalResult levelCheckPool(Operation *op) {
if (auto poolOp = dyn_cast<T>(op)) {
for (auto k : poolOp.getKernel()) {
- if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto s : poolOp.getStride()) {
- if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
- return false;
+ if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+ return failure();
}
}
for (auto p : poolOp.getPad()) {
- if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+ return failure();
}
}
}
- return true;
+ return success();
}
// Conv Op: level check dilation/stride/pad values
template <typename T>
- bool levelCheckConv(Operation *op) {
+ LogicalResult levelCheckConv(Operation *op) {
if (auto convOp = dyn_cast<T>(op)) {
for (auto k : convOp.getDilation()) {
- if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto p : convOp.getPad()) {
- if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto s : convOp.getStride()) {
- if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
- return false;
+ if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+ return failure();
}
}
auto dilation = convOp.getDilation();
@@ -356,100 +350,100 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
if (isa<tosa::Conv2DOp>(op)) {
assert(shape.size() == 4);
assert(dilation.size() == 2);
- if (!levelCheckKernel(op, dilation[0] * shape[1],
- "dilation_y * KH <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[1] * shape[2],
- "dilation_x * KW <= MAX_KERNEL)"))
- return false;
+ if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+ "dilation_y * KH <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[1] * shape[2],
+ "dilation_x * KW <= MAX_KERNEL)")))
+ return failure();
} else if (isa<tosa::Conv3DOp>(op)) {
assert(shape.size() == 5);
assert(dilation.size() == 3);
- if (!levelCheckKernel(op, dilation[0] * shape[1],
- "dilation_d * KD <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[1] * shape[2],
- "dilation_y * KH <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[2] * shape[3],
- "dilation_x * KW <= MAX_KERNEL)"))
- return false;
+ if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+ "dilation_d * KD <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[1] * shape[2],
+ "dilation_y * KH <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[2] * shape[3],
+ "dilation_x * KW <= MAX_KERNEL)")))
+ return failure();
} else if (isa<tosa::DepthwiseConv2DOp>(op)) {
assert(shape.size() == 4);
assert(dilation.size() == 2);
- if (!levelCheckKernel(op, dilation[0] * shape[0],
- "dilation_y * KH <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[1] * shape[1],
- "dilation_x * KW <= MAX_KERNEL)"))
- return false;
+ if (failed(levelCheckKernel(op, dilation[0] * shape[0],
+ "dilation_y * KH <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[1] * shape[1],
+ "dilation_x * KW <= MAX_KERNEL)")))
+ return failure();
}
}
}
- return true;
+ return success();
}
// FFT op: level check H, W in input shape [N,H,W]
template <typename T>
- bool levelCheckFFT(Operation *op) {
+ LogicalResult levelCheckFFT(Operation *op) {
if (isa<T>(op)) {
for (auto v : op->getOperands()) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
auto shape = type.getShape();
assert(shape.size() == 3);
- if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
- !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
+ failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
+ return failure();
}
}
}
}
- return true;
+ return success();
}
// TransposeConv2d op: level check kH/kW, outpad, and stride
- bool levelCheckTransposeConv2d(Operation *op) {
+ LogicalResult levelCheckTransposeConv2d(Operation *op) {
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
if (ShapedType filterType =
dyn_cast<ShapedType>(transpose.getWeight().getType())) {
auto shape = filterType.getShape();
assert(shape.size() == 4);
// level check kernel sizes for kH and KW
- if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
- !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
+ failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto p : transpose.getOutPad()) {
- if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto s : transpose.getStride()) {
- if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
- return false;
+ if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+ return failure();
}
}
}
- return true;
+ return success();
}
// Resize op: level check max scales
- bool levelCheckResize(Operation *op) {
+ LogicalResult levelCheckResize(Operation *op) {
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
SmallVector<int64_t> scale;
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
scale)) {
- return false;
+ return failure();
}
const int64_t scaleYN = scale[0];
const int64_t scaleYD = scale[1];
const int64_t scaleXN = scale[2];
const int64_t scaleXD = scale[3];
- if (!levelCheckScale(op, scaleYN / scaleYD,
- "scale_y_n/scale_y_d <= MAX_SCALE") ||
- !levelCheckScale(op, scaleXN / scaleXD,
- "scale_x_n/scale_x_d <= MAX_SCALE")) {
- return false;
+ if (failed(levelCheckScale(op, scaleYN / scaleYD,
+ "scale_y_n/scale_y_d <= MAX_SCALE")) ||
+ failed(levelCheckScale(op, scaleXN / scaleXD,
+ "scale_x_n/scale_x_d <= MAX_SCALE"))) {
+ return failure();
}
}
- return true;
+ return success();
}
// Recursively perform a bottom-up search to determine the maximum nesting
@@ -468,62 +462,65 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
getMaxNestedDepth(op, depth);
}
- bool levelCheckMaxNesting(Operation *op) {
+ LogicalResult levelCheckMaxNesting(Operation *op) {
int32_t maxNestedDepth = 0;
getMaxNestedDepth(op, maxNestedDepth);
if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
op->emitOpError() << "failed level check: " << maxNestedDepth
<< " >= MAX_NESTING";
- return false;
+ return failure();
}
- return true;
+ return success();
}
- bool levelCheckListSize(Operation *op) {
+ LogicalResult levelCheckListSize(Operation *op) {
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
return levelCheckListSize(op, concat.getInput1().size(), "input1");
}
if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
- if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
- !levelCheckListSize(op, custom.getOutputList().size(),
- "output_list")) {
- return false;
+ if (failed(levelCheckListSize(op, custom.getInputList().size(),
+ "input_list")) ||
+ failed(levelCheckListSize(op, custom.getOutputList().size(),
+ "output_list"))) {
+ return failure();
}
}
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
- if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
- !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
- return false;
+ if (failed(
+ levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
+ failed(levelCheckListSize(op, condIf.getOutputList().size(),
+ "outputs"))) {
+ return failure();
}
}
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
- if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
- !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
- return false;
+ if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
+ failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
+ return failure();
}
}
- return true;
+ return success();
}
- bool attributeCheckRescale(Operation *op) {
+ LogicalResult attributeCheckRescale(Operation *op) {
if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
!targetEnv.allows(Extension::doubleround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = DOUBLE_ROUND "
<< "requires extension [doubleround]";
- return false;
+ return failure();
}
if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
!targetEnv.allows(Extension::inexactround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = INEXACT_ROUND "
<< "requires extension [inexactround]";
- return false;
+ return failure();
}
}
- return true;
+ return success();
}
// configure profile and level values from pass options profileName and
@@ -563,8 +560,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
}
- bool CheckVariable(Operation *op);
- bool CheckVariableReadOrWrite(Operation *op);
+ LogicalResult CheckVariable(Operation *op);
+ LogicalResult CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type, const bool allowUnsigned = false);
SmallVector<
@@ -577,62 +574,66 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
};
template <>
-bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
auto op = tosaOp.getOperation();
- if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
- return false;
+ if (failed(
+ levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)))
+ return failure();
// rank(output) = rank(input) - 1
- if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
- return false;
+ if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
+ tosaLevel.MAX_RANK - 1)))
+ return failure();
- return true;
+ return success();
}
template <>
-bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
auto op = tosaOp.getOperation();
// Only the condition input has rank limitation.
- if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
+ tosaLevel.MAX_RANK)))
+ return failure();
- return true;
+ return success();
}
template <>
-bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
auto op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
- if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, variableType, "variable type",
+ tosaLevel.MAX_RANK)))
+ return failure();
- return true;
+ return success();
}
template <>
-bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
auto op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
- if (!levelCheckSize(op, variableType, "variable type"))
- return false;
+ if (failed(levelCheckSize(op, variableType, "variable type")))
+ ret...
[truncated]
|
@llvm/pr-subscribers-mlir-tosa Author: Luke Hutton (lhutton1) ChangesThis commit replaces functions that previously returned Note: this PR also contains the contents of #159754, so shouldn't be merged before #159754. Patch is 51.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160052.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 790bbf77877bc..6ea4e7736f78c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -205,148 +205,142 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
constCheckers.emplace_back(checkConstantOperandNegate);
}
- bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_KERNEL) {
- op->emitOpError() << "failed level check: " << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckKernel(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_KERNEL)
+ return op->emitOpError() << "failed level check: " << checkDesc;
+ return success();
}
- bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_STRIDE) {
- op->emitOpError() << "failed level check: " << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckStride(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_STRIDE)
+ return op->emitOpError() << "failed level check: " << checkDesc;
+ return success();
}
- bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_SCALE) {
- op->emitOpError() << "failed level check: " << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckScale(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_SCALE)
+ return op->emitOpError() << "failed level check: " << checkDesc;
+ return success();
}
- bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
- if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
- op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
- << checkDesc;
- return false;
- }
- return true;
+ LogicalResult levelCheckListSize(Operation *op, int32_t v,
+ const StringRef checkDesc) {
+ if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+ return op->emitOpError()
+ << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
+ return success();
}
// Perform the Level Rank check on the tensor type.
- bool levelCheckRank(Operation *op, const Type typeToCheck,
- const StringRef operandOrResult, int32_t highest_rank) {
+ LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
+ const StringRef operandOrResult,
+ int32_t highest_rank) {
if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
- if (!type.hasRank()) {
- op->emitOpError() << "failed level check: unranked tensor";
- return false;
- }
- if (type.getRank() > highest_rank) {
- op->emitOpError() << "failed level check: " << operandOrResult
- << " rank(shape) <= MAX_RANK";
- return false;
- }
+ if (!type.hasRank())
+ return op->emitOpError() << "failed level check: unranked tensor";
+ if (type.getRank() > highest_rank)
+ return op->emitOpError() << "failed level check: " << operandOrResult
+ << " rank(shape) <= MAX_RANK";
}
- return true;
+ return success();
}
// Perform the Level Rank check on the tensor value.
- bool levelCheckRank(Operation *op, const Value &v,
- const StringRef operandOrResult, int32_t highest_rank) {
+ LogicalResult levelCheckRank(Operation *op, const Value &v,
+ const StringRef operandOrResult,
+ int32_t highest_rank) {
return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
}
// Perform the Level tensor size check on the tensor type.
- bool levelCheckSize(Operation *op, const Type &typeToCheck,
- const StringRef operandOrResult);
+ LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
+ const StringRef operandOrResult);
// Perform the Level tensor size check on the tensor value.
- bool levelCheckSize(Operation *op, const Value &v,
- const StringRef operandOrResult) {
+ LogicalResult levelCheckSize(Operation *op, const Value &v,
+ const StringRef operandOrResult) {
return levelCheckSize(op, v.getType(), operandOrResult);
}
// Level check sizes of all operands and results of the operation.
template <typename T>
- bool levelCheckSizes(T tosaOp) {
+ LogicalResult levelCheckSizes(T tosaOp) {
auto op = tosaOp.getOperation();
for (auto v : op->getOperands()) {
- if (!levelCheckSize(op, v, "operand"))
- return false;
+ if (failed(levelCheckSize(op, v, "operand")))
+ return failure();
}
for (auto v : op->getResults()) {
- if (!levelCheckSize(op, v, "result"))
- return false;
+ if (failed(levelCheckSize(op, v, "result")))
+ return failure();
}
- return true;
+ return success();
}
// Level check ranks of all operands, attribute and results of the operation.
template <typename T>
- bool levelCheckRanks(T tosaOp) {
+ LogicalResult levelCheckRanks(T tosaOp) {
auto op = tosaOp.getOperation();
for (auto v : op->getOperands()) {
- if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
+ return failure();
}
for (auto v : op->getResults()) {
- if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
+ return failure();
}
- return true;
+ return success();
}
// Level check ranks and sizes.
- bool levelCheckRanksAndSizes(Operation *op);
+ LogicalResult levelCheckRanksAndSizes(Operation *op);
// Pool Op: level check kernel/stride/pad values
template <typename T>
- bool levelCheckPool(Operation *op) {
+ LogicalResult levelCheckPool(Operation *op) {
if (auto poolOp = dyn_cast<T>(op)) {
for (auto k : poolOp.getKernel()) {
- if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto s : poolOp.getStride()) {
- if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
- return false;
+ if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+ return failure();
}
}
for (auto p : poolOp.getPad()) {
- if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+ return failure();
}
}
}
- return true;
+ return success();
}
// Conv Op: level check dilation/stride/pad values
template <typename T>
- bool levelCheckConv(Operation *op) {
+ LogicalResult levelCheckConv(Operation *op) {
if (auto convOp = dyn_cast<T>(op)) {
for (auto k : convOp.getDilation()) {
- if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto p : convOp.getPad()) {
- if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto s : convOp.getStride()) {
- if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
- return false;
+ if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+ return failure();
}
}
auto dilation = convOp.getDilation();
@@ -356,100 +350,100 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
if (isa<tosa::Conv2DOp>(op)) {
assert(shape.size() == 4);
assert(dilation.size() == 2);
- if (!levelCheckKernel(op, dilation[0] * shape[1],
- "dilation_y * KH <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[1] * shape[2],
- "dilation_x * KW <= MAX_KERNEL)"))
- return false;
+ if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+ "dilation_y * KH <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[1] * shape[2],
+ "dilation_x * KW <= MAX_KERNEL)")))
+ return failure();
} else if (isa<tosa::Conv3DOp>(op)) {
assert(shape.size() == 5);
assert(dilation.size() == 3);
- if (!levelCheckKernel(op, dilation[0] * shape[1],
- "dilation_d * KD <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[1] * shape[2],
- "dilation_y * KH <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[2] * shape[3],
- "dilation_x * KW <= MAX_KERNEL)"))
- return false;
+ if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+ "dilation_d * KD <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[1] * shape[2],
+ "dilation_y * KH <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[2] * shape[3],
+ "dilation_x * KW <= MAX_KERNEL)")))
+ return failure();
} else if (isa<tosa::DepthwiseConv2DOp>(op)) {
assert(shape.size() == 4);
assert(dilation.size() == 2);
- if (!levelCheckKernel(op, dilation[0] * shape[0],
- "dilation_y * KH <= MAX_KERNEL)") ||
- !levelCheckKernel(op, dilation[1] * shape[1],
- "dilation_x * KW <= MAX_KERNEL)"))
- return false;
+ if (failed(levelCheckKernel(op, dilation[0] * shape[0],
+ "dilation_y * KH <= MAX_KERNEL)")) ||
+ failed(levelCheckKernel(op, dilation[1] * shape[1],
+ "dilation_x * KW <= MAX_KERNEL)")))
+ return failure();
}
}
}
- return true;
+ return success();
}
// FFT op: level check H, W in input shape [N,H,W]
template <typename T>
- bool levelCheckFFT(Operation *op) {
+ LogicalResult levelCheckFFT(Operation *op) {
if (isa<T>(op)) {
for (auto v : op->getOperands()) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
auto shape = type.getShape();
assert(shape.size() == 3);
- if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
- !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
+ failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
+ return failure();
}
}
}
}
- return true;
+ return success();
}
// TransposeConv2d op: level check kH/kW, outpad, and stride
- bool levelCheckTransposeConv2d(Operation *op) {
+ LogicalResult levelCheckTransposeConv2d(Operation *op) {
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
if (ShapedType filterType =
dyn_cast<ShapedType>(transpose.getWeight().getType())) {
auto shape = filterType.getShape();
assert(shape.size() == 4);
// level check kernel sizes for kH and KW
- if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
- !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
+ failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto p : transpose.getOutPad()) {
- if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
- return false;
+ if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+ return failure();
}
}
for (auto s : transpose.getStride()) {
- if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
- return false;
+ if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+ return failure();
}
}
}
- return true;
+ return success();
}
// Resize op: level check max scales
- bool levelCheckResize(Operation *op) {
+ LogicalResult levelCheckResize(Operation *op) {
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
SmallVector<int64_t> scale;
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
scale)) {
- return false;
+ return failure();
}
const int64_t scaleYN = scale[0];
const int64_t scaleYD = scale[1];
const int64_t scaleXN = scale[2];
const int64_t scaleXD = scale[3];
- if (!levelCheckScale(op, scaleYN / scaleYD,
- "scale_y_n/scale_y_d <= MAX_SCALE") ||
- !levelCheckScale(op, scaleXN / scaleXD,
- "scale_x_n/scale_x_d <= MAX_SCALE")) {
- return false;
+ if (failed(levelCheckScale(op, scaleYN / scaleYD,
+ "scale_y_n/scale_y_d <= MAX_SCALE")) ||
+ failed(levelCheckScale(op, scaleXN / scaleXD,
+ "scale_x_n/scale_x_d <= MAX_SCALE"))) {
+ return failure();
}
}
- return true;
+ return success();
}
// Recursively perform a bottom-up search to determine the maximum nesting
@@ -468,62 +462,65 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
getMaxNestedDepth(op, depth);
}
- bool levelCheckMaxNesting(Operation *op) {
+ LogicalResult levelCheckMaxNesting(Operation *op) {
int32_t maxNestedDepth = 0;
getMaxNestedDepth(op, maxNestedDepth);
if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
op->emitOpError() << "failed level check: " << maxNestedDepth
<< " >= MAX_NESTING";
- return false;
+ return failure();
}
- return true;
+ return success();
}
- bool levelCheckListSize(Operation *op) {
+ LogicalResult levelCheckListSize(Operation *op) {
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
return levelCheckListSize(op, concat.getInput1().size(), "input1");
}
if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
- if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
- !levelCheckListSize(op, custom.getOutputList().size(),
- "output_list")) {
- return false;
+ if (failed(levelCheckListSize(op, custom.getInputList().size(),
+ "input_list")) ||
+ failed(levelCheckListSize(op, custom.getOutputList().size(),
+ "output_list"))) {
+ return failure();
}
}
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
- if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
- !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
- return false;
+ if (failed(
+ levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
+ failed(levelCheckListSize(op, condIf.getOutputList().size(),
+ "outputs"))) {
+ return failure();
}
}
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
- if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
- !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
- return false;
+ if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
+ failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
+ return failure();
}
}
- return true;
+ return success();
}
- bool attributeCheckRescale(Operation *op) {
+ LogicalResult attributeCheckRescale(Operation *op) {
if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
!targetEnv.allows(Extension::doubleround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = DOUBLE_ROUND "
<< "requires extension [doubleround]";
- return false;
+ return failure();
}
if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
!targetEnv.allows(Extension::inexactround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = INEXACT_ROUND "
<< "requires extension [inexactround]";
- return false;
+ return failure();
}
}
- return true;
+ return success();
}
// configure profile and level values from pass options profileName and
@@ -563,8 +560,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
}
- bool CheckVariable(Operation *op);
- bool CheckVariableReadOrWrite(Operation *op);
+ LogicalResult CheckVariable(Operation *op);
+ LogicalResult CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type, const bool allowUnsigned = false);
SmallVector<
@@ -577,62 +574,66 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
};
template <>
-bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
auto op = tosaOp.getOperation();
- if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
- return false;
+ if (failed(
+ levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)))
+ return failure();
// rank(output) = rank(input) - 1
- if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
- return false;
+ if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
+ tosaLevel.MAX_RANK - 1)))
+ return failure();
- return true;
+ return success();
}
template <>
-bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
auto op = tosaOp.getOperation();
// Only the condition input has rank limitation.
- if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
+ tosaLevel.MAX_RANK)))
+ return failure();
- return true;
+ return success();
}
template <>
-bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
auto op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
- if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
- return false;
+ if (failed(levelCheckRank(op, variableType, "variable type",
+ tosaLevel.MAX_RANK)))
+ return failure();
- return true;
+ return success();
}
template <>
-bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
auto op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
- if (!levelCheckSize(op, variableType, "variable type"))
- return false;
+ if (failed(levelCheckSize(op, variableType, "variable type")))
+ ret...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I find this much better :)
798879f
to
9e67b51
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
LGTM |
This commit replaces functions that previously returned `bool` to indicate validation success or failure with `LogicalResult`. Change-Id: Iec3b54e3cc5462e981e1e9eb8639608c62a128ed
9e67b51
to
32e5a11
Compare
This commit replaces functions that previously returned `bool` to indicate validation success or failure with `LogicalResult`.
This commit replaces functions that previously returned
bool
to indicate validation success or failure withLogicalResult
.Note: this PR also contains the contents of #159754, so shouldn't be merged before #159754.