Skip to content

Conversation

lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented Sep 22, 2025

This commit replaces functions that previously returned bool to indicate validation success or failure with LogicalResult.

Note: this PR also contains the contents of #159754, so shouldn't be merged before #159754.

@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2025

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit replaces functions that previously returned bool to indicate validation success or failure with LogicalResult.

Note: this PR also contains the contents of #159754, so shouldn't be merged before #159754.


Patch is 51.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160052.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+335-388)
  • (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (-33)
  • (added) mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir (+34)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 790bbf77877bc..6ea4e7736f78c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -205,148 +205,142 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     constCheckers.emplace_back(checkConstantOperandNegate);
   }
 
-  bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_KERNEL) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckKernel(Operation *op, int32_t v,
+                                 const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_KERNEL)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_STRIDE) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckStride(Operation *op, int32_t v,
+                                 const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_STRIDE)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_SCALE) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckScale(Operation *op, int32_t v,
+                                const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_SCALE)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
-      op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
-                        << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckListSize(Operation *op, int32_t v,
+                                   const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+      return op->emitOpError()
+             << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
+    return success();
   }
 
   // Perform the Level Rank check on the tensor type.
-  bool levelCheckRank(Operation *op, const Type typeToCheck,
-                      const StringRef operandOrResult, int32_t highest_rank) {
+  LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
+                               const StringRef operandOrResult,
+                               int32_t highest_rank) {
     if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
-      if (!type.hasRank()) {
-        op->emitOpError() << "failed level check: unranked tensor";
-        return false;
-      }
-      if (type.getRank() > highest_rank) {
-        op->emitOpError() << "failed level check: " << operandOrResult
-                          << " rank(shape) <= MAX_RANK";
-        return false;
-      }
+      if (!type.hasRank())
+        return op->emitOpError() << "failed level check: unranked tensor";
+      if (type.getRank() > highest_rank)
+        return op->emitOpError() << "failed level check: " << operandOrResult
+                                 << " rank(shape) <= MAX_RANK";
     }
-    return true;
+    return success();
   }
 
   // Perform the Level Rank check on the tensor value.
-  bool levelCheckRank(Operation *op, const Value &v,
-                      const StringRef operandOrResult, int32_t highest_rank) {
+  LogicalResult levelCheckRank(Operation *op, const Value &v,
+                               const StringRef operandOrResult,
+                               int32_t highest_rank) {
     return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
   }
 
   // Perform the Level tensor size check on the tensor type.
-  bool levelCheckSize(Operation *op, const Type &typeToCheck,
-                      const StringRef operandOrResult);
+  LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
+                               const StringRef operandOrResult);
 
   // Perform the Level tensor size check on the tensor value.
-  bool levelCheckSize(Operation *op, const Value &v,
-                      const StringRef operandOrResult) {
+  LogicalResult levelCheckSize(Operation *op, const Value &v,
+                               const StringRef operandOrResult) {
     return levelCheckSize(op, v.getType(), operandOrResult);
   }
 
   // Level check sizes of all operands and results of the operation.
   template <typename T>
-  bool levelCheckSizes(T tosaOp) {
+  LogicalResult levelCheckSizes(T tosaOp) {
     auto op = tosaOp.getOperation();
     for (auto v : op->getOperands()) {
-      if (!levelCheckSize(op, v, "operand"))
-        return false;
+      if (failed(levelCheckSize(op, v, "operand")))
+        return failure();
     }
 
     for (auto v : op->getResults()) {
-      if (!levelCheckSize(op, v, "result"))
-        return false;
+      if (failed(levelCheckSize(op, v, "result")))
+        return failure();
     }
-    return true;
+    return success();
   }
 
   // Level check ranks of all operands, attribute and results of the operation.
   template <typename T>
-  bool levelCheckRanks(T tosaOp) {
+  LogicalResult levelCheckRanks(T tosaOp) {
     auto op = tosaOp.getOperation();
     for (auto v : op->getOperands()) {
-      if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
-        return false;
+      if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
+        return failure();
     }
 
     for (auto v : op->getResults()) {
-      if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
-        return false;
+      if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
+        return failure();
     }
-    return true;
+    return success();
   }
 
   // Level check ranks and sizes.
-  bool levelCheckRanksAndSizes(Operation *op);
+  LogicalResult levelCheckRanksAndSizes(Operation *op);
 
   // Pool Op: level check kernel/stride/pad values
   template <typename T>
-  bool levelCheckPool(Operation *op) {
+  LogicalResult levelCheckPool(Operation *op) {
     if (auto poolOp = dyn_cast<T>(op)) {
       for (auto k : poolOp.getKernel()) {
-        if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : poolOp.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
       for (auto p : poolOp.getPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // Conv Op: level check dilation/stride/pad values
   template <typename T>
-  bool levelCheckConv(Operation *op) {
+  LogicalResult levelCheckConv(Operation *op) {
     if (auto convOp = dyn_cast<T>(op)) {
 
       for (auto k : convOp.getDilation()) {
-        if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto p : convOp.getPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : convOp.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
       auto dilation = convOp.getDilation();
@@ -356,100 +350,100 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         if (isa<tosa::Conv2DOp>(op)) {
           assert(shape.size() == 4);
           assert(dilation.size() == 2);
-          if (!levelCheckKernel(op, dilation[0] * shape[1],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[2],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[2],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         } else if (isa<tosa::Conv3DOp>(op)) {
           assert(shape.size() == 5);
           assert(dilation.size() == 3);
-          if (!levelCheckKernel(op, dilation[0] * shape[1],
-                                "dilation_d * KD <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[2],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[2] * shape[3],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+                                      "dilation_d * KD <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[2],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[2] * shape[3],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
           assert(shape.size() == 4);
           assert(dilation.size() == 2);
-          if (!levelCheckKernel(op, dilation[0] * shape[0],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[1],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[0],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[1],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // FFT op: level check H, W in input shape [N,H,W]
   template <typename T>
-  bool levelCheckFFT(Operation *op) {
+  LogicalResult levelCheckFFT(Operation *op) {
     if (isa<T>(op)) {
       for (auto v : op->getOperands()) {
         if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
           auto shape = type.getShape();
           assert(shape.size() == 3);
-          if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
-              !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
-            return false;
+          if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
+              failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
+            return failure();
           }
         }
       }
     }
-    return true;
+    return success();
   }
 
   // TransposeConv2d op: level check kH/kW, outpad, and stride
-  bool levelCheckTransposeConv2d(Operation *op) {
+  LogicalResult levelCheckTransposeConv2d(Operation *op) {
     if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
       if (ShapedType filterType =
               dyn_cast<ShapedType>(transpose.getWeight().getType())) {
         auto shape = filterType.getShape();
         assert(shape.size() == 4);
         // level check kernel sizes for kH and KW
-        if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
-            !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
+            failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto p : transpose.getOutPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : transpose.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // Resize op: level check max scales
-  bool levelCheckResize(Operation *op) {
+  LogicalResult levelCheckResize(Operation *op) {
     if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
       SmallVector<int64_t> scale;
       if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
                                      scale)) {
-        return false;
+        return failure();
       }
       const int64_t scaleYN = scale[0];
       const int64_t scaleYD = scale[1];
       const int64_t scaleXN = scale[2];
       const int64_t scaleXD = scale[3];
-      if (!levelCheckScale(op, scaleYN / scaleYD,
-                           "scale_y_n/scale_y_d <= MAX_SCALE") ||
-          !levelCheckScale(op, scaleXN / scaleXD,
-                           "scale_x_n/scale_x_d <= MAX_SCALE")) {
-        return false;
+      if (failed(levelCheckScale(op, scaleYN / scaleYD,
+                                 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
+          failed(levelCheckScale(op, scaleXN / scaleXD,
+                                 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
   // Recursively perform a bottom-up search to determine the maximum nesting
@@ -468,62 +462,65 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     getMaxNestedDepth(op, depth);
   }
 
-  bool levelCheckMaxNesting(Operation *op) {
+  LogicalResult levelCheckMaxNesting(Operation *op) {
     int32_t maxNestedDepth = 0;
     getMaxNestedDepth(op, maxNestedDepth);
 
     if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
       op->emitOpError() << "failed level check: " << maxNestedDepth
                         << " >= MAX_NESTING";
-      return false;
+      return failure();
     }
-    return true;
+    return success();
   }
 
-  bool levelCheckListSize(Operation *op) {
+  LogicalResult levelCheckListSize(Operation *op) {
     if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
       return levelCheckListSize(op, concat.getInput1().size(), "input1");
     }
     if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
-      if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
-          !levelCheckListSize(op, custom.getOutputList().size(),
-                              "output_list")) {
-        return false;
+      if (failed(levelCheckListSize(op, custom.getInputList().size(),
+                                    "input_list")) ||
+          failed(levelCheckListSize(op, custom.getOutputList().size(),
+                                    "output_list"))) {
+        return failure();
       }
     }
     if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
-      if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
-          !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
-        return false;
+      if (failed(
+              levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
+          failed(levelCheckListSize(op, condIf.getOutputList().size(),
+                                    "outputs"))) {
+        return failure();
       }
     }
     if (auto w = dyn_cast<tosa::WhileOp>(op)) {
-      if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
-          !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
-        return false;
+      if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
+          failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
-  bool attributeCheckRescale(Operation *op) {
+  LogicalResult attributeCheckRescale(Operation *op) {
     if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
       if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
           !targetEnv.allows(Extension::doubleround)) {
         op->emitOpError()
             << "failed attribute check: rounding_mode = DOUBLE_ROUND "
             << "requires extension [doubleround]";
-        return false;
+        return failure();
       }
       if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
           !targetEnv.allows(Extension::inexactround)) {
         op->emitOpError()
             << "failed attribute check: rounding_mode = INEXACT_ROUND "
             << "requires extension [inexactround]";
-        return false;
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
   // configure profile and level values from pass options profileName and
@@ -563,8 +560,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     }
   }
 
-  bool CheckVariable(Operation *op);
-  bool CheckVariableReadOrWrite(Operation *op);
+  LogicalResult CheckVariable(Operation *op);
+  LogicalResult CheckVariableReadOrWrite(Operation *op);
   bool isValidElementType(Type type, const bool allowUnsigned = false);
 
   SmallVector<
@@ -577,62 +574,66 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 };
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
   auto op = tosaOp.getOperation();
-  if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(
+          levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)))
+    return failure();
 
   // rank(output) = rank(input) - 1
-  if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
-    return false;
+  if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
+                            tosaLevel.MAX_RANK - 1)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
   auto op = tosaOp.getOperation();
 
   // Only the condition input has rank limitation.
-  if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
+                            tosaLevel.MAX_RANK)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
   auto op = tosaOp.getOperation();
   auto variableType = getVariableType(tosaOp);
-  if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(levelCheckRank(op, variableType, "variable type",
+                            tosaLevel.MAX_RANK)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
   auto op = tosaOp.getOperation();
   auto variableType = getVariableType(tosaOp);
-  if (!levelCheckSize(op, variableType, "variable type"))
-    return false;
+  if (failed(levelCheckSize(op, variableType, "variable type")))
+    ret...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

This commit replaces functions that previously returned bool to indicate validation success or failure with LogicalResult.

Note: this PR also contains the contents of #159754, so shouldn't be merged before #159754.


Patch is 51.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160052.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+335-388)
  • (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (-33)
  • (added) mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir (+34)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 790bbf77877bc..6ea4e7736f78c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -205,148 +205,142 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     constCheckers.emplace_back(checkConstantOperandNegate);
   }
 
-  bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_KERNEL) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckKernel(Operation *op, int32_t v,
+                                 const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_KERNEL)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_STRIDE) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckStride(Operation *op, int32_t v,
+                                 const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_STRIDE)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_SCALE) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckScale(Operation *op, int32_t v,
+                                const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_SCALE)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
-      op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
-                        << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckListSize(Operation *op, int32_t v,
+                                   const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+      return op->emitOpError()
+             << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
+    return success();
   }
 
   // Perform the Level Rank check on the tensor type.
-  bool levelCheckRank(Operation *op, const Type typeToCheck,
-                      const StringRef operandOrResult, int32_t highest_rank) {
+  LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
+                               const StringRef operandOrResult,
+                               int32_t highest_rank) {
     if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
-      if (!type.hasRank()) {
-        op->emitOpError() << "failed level check: unranked tensor";
-        return false;
-      }
-      if (type.getRank() > highest_rank) {
-        op->emitOpError() << "failed level check: " << operandOrResult
-                          << " rank(shape) <= MAX_RANK";
-        return false;
-      }
+      if (!type.hasRank())
+        return op->emitOpError() << "failed level check: unranked tensor";
+      if (type.getRank() > highest_rank)
+        return op->emitOpError() << "failed level check: " << operandOrResult
+                                 << " rank(shape) <= MAX_RANK";
     }
-    return true;
+    return success();
   }
 
   // Perform the Level Rank check on the tensor value.
-  bool levelCheckRank(Operation *op, const Value &v,
-                      const StringRef operandOrResult, int32_t highest_rank) {
+  LogicalResult levelCheckRank(Operation *op, const Value &v,
+                               const StringRef operandOrResult,
+                               int32_t highest_rank) {
     return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
   }
 
   // Perform the Level tensor size check on the tensor type.
-  bool levelCheckSize(Operation *op, const Type &typeToCheck,
-                      const StringRef operandOrResult);
+  LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
+                               const StringRef operandOrResult);
 
   // Perform the Level tensor size check on the tensor value.
-  bool levelCheckSize(Operation *op, const Value &v,
-                      const StringRef operandOrResult) {
+  LogicalResult levelCheckSize(Operation *op, const Value &v,
+                               const StringRef operandOrResult) {
     return levelCheckSize(op, v.getType(), operandOrResult);
   }
 
   // Level check sizes of all operands and results of the operation.
   template <typename T>
-  bool levelCheckSizes(T tosaOp) {
+  LogicalResult levelCheckSizes(T tosaOp) {
     auto op = tosaOp.getOperation();
     for (auto v : op->getOperands()) {
-      if (!levelCheckSize(op, v, "operand"))
-        return false;
+      if (failed(levelCheckSize(op, v, "operand")))
+        return failure();
     }
 
     for (auto v : op->getResults()) {
-      if (!levelCheckSize(op, v, "result"))
-        return false;
+      if (failed(levelCheckSize(op, v, "result")))
+        return failure();
     }
-    return true;
+    return success();
   }
 
   // Level check ranks of all operands, attribute and results of the operation.
   template <typename T>
-  bool levelCheckRanks(T tosaOp) {
+  LogicalResult levelCheckRanks(T tosaOp) {
     auto op = tosaOp.getOperation();
     for (auto v : op->getOperands()) {
-      if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
-        return false;
+      if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
+        return failure();
     }
 
     for (auto v : op->getResults()) {
-      if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
-        return false;
+      if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
+        return failure();
     }
-    return true;
+    return success();
   }
 
   // Level check ranks and sizes.
-  bool levelCheckRanksAndSizes(Operation *op);
+  LogicalResult levelCheckRanksAndSizes(Operation *op);
 
   // Pool Op: level check kernel/stride/pad values
   template <typename T>
-  bool levelCheckPool(Operation *op) {
+  LogicalResult levelCheckPool(Operation *op) {
     if (auto poolOp = dyn_cast<T>(op)) {
       for (auto k : poolOp.getKernel()) {
-        if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : poolOp.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
       for (auto p : poolOp.getPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // Conv Op: level check dilation/stride/pad values
   template <typename T>
-  bool levelCheckConv(Operation *op) {
+  LogicalResult levelCheckConv(Operation *op) {
     if (auto convOp = dyn_cast<T>(op)) {
 
       for (auto k : convOp.getDilation()) {
-        if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto p : convOp.getPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : convOp.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
       auto dilation = convOp.getDilation();
@@ -356,100 +350,100 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         if (isa<tosa::Conv2DOp>(op)) {
           assert(shape.size() == 4);
           assert(dilation.size() == 2);
-          if (!levelCheckKernel(op, dilation[0] * shape[1],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[2],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[2],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         } else if (isa<tosa::Conv3DOp>(op)) {
           assert(shape.size() == 5);
           assert(dilation.size() == 3);
-          if (!levelCheckKernel(op, dilation[0] * shape[1],
-                                "dilation_d * KD <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[2],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[2] * shape[3],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+                                      "dilation_d * KD <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[2],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[2] * shape[3],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
           assert(shape.size() == 4);
           assert(dilation.size() == 2);
-          if (!levelCheckKernel(op, dilation[0] * shape[0],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[1],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[0],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[1],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // FFT op: level check H, W in input shape [N,H,W]
   template <typename T>
-  bool levelCheckFFT(Operation *op) {
+  LogicalResult levelCheckFFT(Operation *op) {
     if (isa<T>(op)) {
       for (auto v : op->getOperands()) {
         if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
           auto shape = type.getShape();
           assert(shape.size() == 3);
-          if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
-              !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
-            return false;
+          if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
+              failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
+            return failure();
           }
         }
       }
     }
-    return true;
+    return success();
   }
 
   // TransposeConv2d op: level check kH/kW, outpad, and stride
-  bool levelCheckTransposeConv2d(Operation *op) {
+  LogicalResult levelCheckTransposeConv2d(Operation *op) {
     if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
       if (ShapedType filterType =
               dyn_cast<ShapedType>(transpose.getWeight().getType())) {
         auto shape = filterType.getShape();
         assert(shape.size() == 4);
         // level check kernel sizes for kH and KW
-        if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
-            !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
+            failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto p : transpose.getOutPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : transpose.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // Resize op: level check max scales
-  bool levelCheckResize(Operation *op) {
+  LogicalResult levelCheckResize(Operation *op) {
     if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
       SmallVector<int64_t> scale;
       if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
                                      scale)) {
-        return false;
+        return failure();
       }
       const int64_t scaleYN = scale[0];
       const int64_t scaleYD = scale[1];
       const int64_t scaleXN = scale[2];
       const int64_t scaleXD = scale[3];
-      if (!levelCheckScale(op, scaleYN / scaleYD,
-                           "scale_y_n/scale_y_d <= MAX_SCALE") ||
-          !levelCheckScale(op, scaleXN / scaleXD,
-                           "scale_x_n/scale_x_d <= MAX_SCALE")) {
-        return false;
+      if (failed(levelCheckScale(op, scaleYN / scaleYD,
+                                 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
+          failed(levelCheckScale(op, scaleXN / scaleXD,
+                                 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
   // Recursively perform a bottom-up search to determine the maximum nesting
@@ -468,62 +462,65 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     getMaxNestedDepth(op, depth);
   }
 
-  bool levelCheckMaxNesting(Operation *op) {
+  LogicalResult levelCheckMaxNesting(Operation *op) {
     int32_t maxNestedDepth = 0;
     getMaxNestedDepth(op, maxNestedDepth);
 
     if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
       op->emitOpError() << "failed level check: " << maxNestedDepth
                         << " >= MAX_NESTING";
-      return false;
+      return failure();
     }
-    return true;
+    return success();
   }
 
-  bool levelCheckListSize(Operation *op) {
+  LogicalResult levelCheckListSize(Operation *op) {
     if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
       return levelCheckListSize(op, concat.getInput1().size(), "input1");
     }
     if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
-      if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
-          !levelCheckListSize(op, custom.getOutputList().size(),
-                              "output_list")) {
-        return false;
+      if (failed(levelCheckListSize(op, custom.getInputList().size(),
+                                    "input_list")) ||
+          failed(levelCheckListSize(op, custom.getOutputList().size(),
+                                    "output_list"))) {
+        return failure();
       }
     }
     if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
-      if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
-          !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
-        return false;
+      if (failed(
+              levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
+          failed(levelCheckListSize(op, condIf.getOutputList().size(),
+                                    "outputs"))) {
+        return failure();
       }
     }
     if (auto w = dyn_cast<tosa::WhileOp>(op)) {
-      if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
-          !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
-        return false;
+      if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
+          failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
-  bool attributeCheckRescale(Operation *op) {
+  LogicalResult attributeCheckRescale(Operation *op) {
     if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
       if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
           !targetEnv.allows(Extension::doubleround)) {
         op->emitOpError()
             << "failed attribute check: rounding_mode = DOUBLE_ROUND "
             << "requires extension [doubleround]";
-        return false;
+        return failure();
       }
       if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
           !targetEnv.allows(Extension::inexactround)) {
         op->emitOpError()
             << "failed attribute check: rounding_mode = INEXACT_ROUND "
             << "requires extension [inexactround]";
-        return false;
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
   // configure profile and level values from pass options profileName and
@@ -563,8 +560,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     }
   }
 
-  bool CheckVariable(Operation *op);
-  bool CheckVariableReadOrWrite(Operation *op);
+  LogicalResult CheckVariable(Operation *op);
+  LogicalResult CheckVariableReadOrWrite(Operation *op);
   bool isValidElementType(Type type, const bool allowUnsigned = false);
 
   SmallVector<
@@ -577,62 +574,66 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 };
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
   auto op = tosaOp.getOperation();
-  if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(
+          levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)))
+    return failure();
 
   // rank(output) = rank(input) - 1
-  if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
-    return false;
+  if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
+                            tosaLevel.MAX_RANK - 1)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
   auto op = tosaOp.getOperation();
 
   // Only the condition input has rank limitation.
-  if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
+                            tosaLevel.MAX_RANK)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
   auto op = tosaOp.getOperation();
   auto variableType = getVariableType(tosaOp);
-  if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(levelCheckRank(op, variableType, "variable type",
+                            tosaLevel.MAX_RANK)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
   auto op = tosaOp.getOperation();
   auto variableType = getVariableType(tosaOp);
-  if (!levelCheckSize(op, variableType, "variable type"))
-    return false;
+  if (failed(levelCheckSize(op, variableType, "variable type")))
+    ret...
[truncated]

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I find this much better :)

@lhutton1 lhutton1 force-pushed the bool-to-logical-result branch 2 times, most recently from 798879f to 9e67b51 Compare September 24, 2025 08:33
Copy link

github-actions bot commented Sep 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Tai78641
Copy link
Contributor

LGTM

This commit replaces functions that previously returned `bool`
to indicate validation success or failure with `LogicalResult`.

Change-Id: Iec3b54e3cc5462e981e1e9eb8639608c62a128ed
@lhutton1 lhutton1 force-pushed the bool-to-logical-result branch from 9e67b51 to 32e5a11 Compare September 24, 2025 17:00
@lhutton1 lhutton1 merged commit 4b99547 into llvm:main Sep 24, 2025
9 checks passed
@lhutton1 lhutton1 deleted the bool-to-logical-result branch September 24, 2025 18:31
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
This commit replaces functions that previously returned `bool` to
indicate validation success or failure with `LogicalResult`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants