-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Add options to generate-runtime-verification to enable faster pass running #160331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLIR] Add options to generate-runtime-verification to enable faster pass running #160331
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-linalg Author: Hanchenng Wu (HanchengWu) ChangesThe pass generate-runtime-verification generates additional runtime op verification checks. Currently, the pass is extremely expensive. For example, with a mobilenet v2 ssd network(converted to mlir), running this pass alone will take 30 minutes. The same observation has been made to other networks as small as 5 Mb. The culprit is this line "op->print(stream, flags);" in function "RuntimeVerifiableOpInterface::generateErrorMessage" in File mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp. As we are printing the op with all the names of the operands in the middle end, we are constructing a new SSANameState for each op->print(...) call. Thus, we are doing a new SSA analysis for each error message printed. Perf profiling shows that 98% percent of the time is spent in the constructor of SSANameState. This change add verbose options to generate-runtime-verification pass. verbose 2 is the current behavior and is very expensive. I still keep the default as verbose 2. When we switch from verbose 2 to verbose 0/1, we see below improvements. For mlir imported from mobileNet v2 ssd, the running time of the pass is reduced from 32 mintues to 21 seconds. Full diff: https://github.com/llvm/llvm-project/pull/160331.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
index 6fd0df59d9d2e..e5c9336c8d8dc 100644
--- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
+++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
@@ -32,14 +32,16 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
/*retTy=*/"void",
/*methodName=*/"generateRuntimeVerification",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
- "::mlir::Location":$loc)
+ "::mlir::Location":$loc,
+ "unsigned":$verboseLevel)
>,
];
let extraClassDeclaration = [{
/// Generate the error message that will be printed to the user when
/// verification fails.
- static std::string generateErrorMessage(Operation *op, const std::string &msg);
+ static std::string generateErrorMessage(Operation *op, const std::string &msg,
+ unsigned verboseLevel = 0);
}];
}
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..58ba0892df113 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..3d643d8a168db 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -271,8 +271,16 @@ def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
passes that are suspected to introduce faulty IR.
}];
let constructor = "mlir::createGenerateRuntimeVerificationPass()";
+ let options = [
+ Option<"verboseLevel", "verbose-level", "unsigned", /*default=*/"2",
+ "Verbosity level for runtime verification messages: "
+ "0 = Minimum (only source location), "
+ "1 = Basic (include operation type and operand type), "
+ "2 = Detailed (include full operation details, names, types, shapes, etc.)">
+ ];
}
+
def Inliner : Pass<"inline"> {
let summary = "Inline function calls";
let constructor = "mlir::createInlinerPass()";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index b30182dc84079..608a6801af267 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -32,7 +32,7 @@ struct StructuredOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
StructuredOpInterface<T>, T> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
@@ -73,7 +73,8 @@ struct StructuredOpInterface
auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
linalgOp, "unexpected negative result on dimension #" +
std::to_string(dim) + " of input/output operand #" +
- std::to_string(opOperand.getOperandNumber()));
+ std::to_string(opOperand.getOperandNumber()),
+ verboseLevel);
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
// Generate:
@@ -104,7 +105,8 @@ struct StructuredOpInterface
linalgOp, "dimension #" + std::to_string(dim) +
" of input/output operand #" +
std::to_string(opOperand.getOperandNumber()) +
- " is incompatible with inferred dimension size");
+ " is incompatible with inferred dimension size",
+ verboseLevel);
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index cd92026562da9..d8a7a89a3fbe7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -39,7 +39,7 @@ struct AssumeAlignmentOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto assumeOp = cast<AssumeAlignmentOp>(op);
Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
loc, assumeOp.getMemref());
@@ -53,7 +53,8 @@ struct AssumeAlignmentOpInterface
loc, isAligned,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "memref is not aligned to " +
- std::to_string(assumeOp.getAlignment())));
+ std::to_string(assumeOp.getAlignment()),
+ verboseLevel));
}
};
@@ -61,7 +62,7 @@ struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto castOp = cast<CastOp>(op);
auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
@@ -79,8 +80,8 @@ struct CastOpInterface
loc, arith::CmpIPredicate::eq, srcRank, resultRank);
builder.create<cf::AssertOp>(
loc, isSameRank,
- RuntimeVerifiableOpInterface::generateErrorMessage(op,
- "rank mismatch"));
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "rank mismatch", verboseLevel));
}
// Get source offset and strides. We do not have an op to get offsets and
@@ -119,7 +120,8 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameSz,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "size mismatch of dim " + std::to_string(it.index())));
+ op, "size mismatch of dim " + std::to_string(it.index()),
+ verboseLevel));
}
// Get result offset and strides.
@@ -139,7 +141,7 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameOffset,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "offset mismatch"));
+ op, "offset mismatch", verboseLevel));
}
// Check strides.
@@ -157,7 +159,8 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameStride,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "stride mismatch of dim " + std::to_string(it.index())));
+ op, "stride mismatch of dim " + std::to_string(it.index()),
+ verboseLevel));
}
}
};
@@ -166,7 +169,7 @@ struct CopyOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
CopyOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto copyOp = cast<CopyOp>(op);
BaseMemRefType sourceType = copyOp.getSource().getType();
BaseMemRefType targetType = copyOp.getTarget().getType();
@@ -201,7 +204,7 @@ struct CopyOpInterface
loc, sameDimSize,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size of " + std::to_string(i) +
- "-th source/target dim does not match"));
+ "-th source/target dim does not match", verboseLevel));
}
}
};
@@ -210,14 +213,14 @@ struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto dimOp = cast<DimOp>(op);
Value rank = builder.create<RankOp>(loc, dimOp.getSource());
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
builder.create<cf::AssertOp>(
loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "index is out of bounds"));
+ op, "index is out of bounds", verboseLevel));
}
};
@@ -228,7 +231,7 @@ struct LoadStoreOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto loadStoreOp = cast<LoadStoreOp>(op);
auto memref = loadStoreOp.getMemref();
@@ -251,7 +254,7 @@ struct LoadStoreOpInterface
builder.create<cf::AssertOp>(
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "out-of-bounds access"));
+ op, "out-of-bounds access", verboseLevel));
}
};
@@ -295,7 +298,7 @@ struct ReinterpretCastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ReinterpretCastOpInterface, ReinterpretCastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto reinterpretCast = cast<ReinterpretCastOp>(op);
auto baseMemref = reinterpretCast.getSource();
auto resultMemref =
@@ -323,7 +326,8 @@ struct ReinterpretCastOpInterface
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
op,
- "result of reinterpret_cast is out-of-bounds of the base memref"));
+ "result of reinterpret_cast is out-of-bounds of the base memref",
+ verboseLevel));
}
};
@@ -331,7 +335,7 @@ struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto subView = cast<SubViewOp>(op);
MemRefType sourceType = subView.getSource().getType();
@@ -357,7 +361,7 @@ struct SubViewOpInterface
builder.create<cf::AssertOp>(
loc, offsetInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "offset " + std::to_string(i) + " is out-of-bounds"));
+ op, "offset " + std::to_string(i) + " is out-of-bounds", verboseLevel));
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
@@ -371,7 +375,7 @@ struct SubViewOpInterface
loc, lastPosInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "subview runs out-of-bounds along dimension " +
- std::to_string(i)));
+ std::to_string(i), verboseLevel));
}
}
};
@@ -380,7 +384,7 @@ struct ExpandShapeOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
ExpandShapeOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto expandShapeOp = cast<ExpandShapeOp>(op);
// Verify that the expanded dim sizes are a product of the collapsed dim
@@ -414,7 +418,7 @@ struct ExpandShapeOpInterface
loc, isModZero,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "static result dims in reassoc group do not "
- "divide src dim evenly"));
+ "divide src dim evenly", verboseLevel));
}
}
};
diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
index 8aa194befb420..8b54ed1dc3780 100644
--- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
+++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
@@ -15,7 +15,7 @@ class OpBuilder;
/// Generate an error message string for the given op and the specified error.
std::string
RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
- const std::string &msg) {
+ const std::string &msg, unsigned verboseLevel) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
@@ -26,9 +26,25 @@ RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
flags.skipRegions();
flags.useLocalScope();
stream << "ERROR: Runtime op verification failed\n";
- op->print(stream, flags);
- stream << "\n^ " << msg;
- stream << "\nLocation: ";
+ if (verboseLevel == 2){
+ // print full op including operand names, very expensive
+ op->print(stream, flags);
+ stream << "\n " << msg;
+ }else if (verboseLevel == 1){
+ // print op name and operand types
+ stream << "Op: " << op->getName().getStringRef() << "\n";
+ stream << "Operand Types:";
+ for (const auto &operand : op->getOpOperands()) {
+ stream << " " << operand.get().getType();
+ }
+ stream << "\n" << msg;
+ stream << "Result Types:";
+ for (const auto &result : op->getResults()) {
+ stream << " " << result.getType();
+ }
+ stream << "\n" << msg;
+ }
+ stream << "^\nLocation: ";
op->getLoc().print(stream);
return buffer;
}
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index a40bc2b3272fc..7a54ce667c6ad 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -28,6 +28,14 @@ struct GenerateRuntimeVerificationPass
} // namespace
void GenerateRuntimeVerificationPass::runOnOperation() {
+ // Check verboseLevel is in range [0, 2].
+ if (verboseLevel > 2) {
+ getOperation()->emitError(
+ "generate-runtime-verification pass: set verboseLevel to 0, 1 or 2");
+ signalPassFailure();
+ return;
+ }
+
// The implementation of the RuntimeVerifiableOpInterface may create ops that
// can be verified. We don't want to generate verification for IR that
// performs verification, so gather all runtime-verifiable ops first.
@@ -39,7 +47,8 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
OpBuilder builder(getOperation()->getContext());
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
builder.setInsertionPoint(verifiableOp);
- verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+ verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc(),
+ verboseLevel);
};
}
diff --git a/mlir/test/Dialect/Linalg/runtime-verification.mlir b/mlir/test/Dialect/Linalg/runtime-verification.mlir
index a4f29d8457e58..238169adf496e 100644
--- a/mlir/test/Dialect/Linalg/runtime-verification.mlir
+++ b/mlir/test/Dialect/Linalg/runtime-verification.mlir
@@ -1,13 +1,25 @@
// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s
+// RUN: mlir-opt %s --generate-runtime-verification="verbose-level=1" | FileCheck %s --check-prefix=VERBOSE1
+// RUN: mlir-opt %s --generate-runtime-verification="verbose-level=0" | FileCheck %s --check-prefix=VERBOSE0
// Most of the tests for linalg runtime-verification are implemented as integration tests.
#identity = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @static_dims
+// VERBOSE1-LABEL: @static_dims
+// VERBOSE0-LABEL: @static_dims
func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
// CHECK: %[[TRUE:.*]] = index.bool.constant true
// CHECK: cf.assert %[[TRUE]]
+ // VERBOSE1: %[[TRUE:.*]] = index.bool.constant true
+ // VERBOSE1: cf.assert %[[TRUE]]
+ // VERBOSE1: Operand Types: tensor<5xf32> tensor<5xf32> tensor<5xf32>
+ // VERBOSE1: Result Types
+ // VERBOSE1: Location: loc
+ // VERBOSE0-NOT: Operand Types: tensor<5xf32> tensor<5xf32> tensor<5xf32>
+ // VERBOSE0-NOT: Result Types
+ // VERBOSE0: Location: loc
%result = tensor.empty() : tensor<5xf32>
%0 = linalg.generic {
indexing_maps = [#identity, #identity, #identity],
@@ -26,9 +38,11 @@ func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5x
#map = affine_map<() -> ()>
// CHECK-LABEL: @scalars
+// VERBOSE1-LABEL: @scalars
func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
// No runtime checks are required if the operands are all scalars
// CHECK-NOT: cf.assert
+ // VERBOSE1-NOT: cf.assert
%result = tensor.empty() : tensor<f32>
%0 = linalg.generic {
indexing_maps = [#map, #map, #map],
|
@llvm/pr-subscribers-mlir-core Author: Hanchenng Wu (HanchengWu) ChangesThe pass generate-runtime-verification generates additional runtime op verification checks. Currently, the pass is extremely expensive. For example, with a mobilenet v2 ssd network(converted to mlir), running this pass alone will take 30 minutes. The same observation has been made to other networks as small as 5 Mb. The culprit is this line "op->print(stream, flags);" in function "RuntimeVerifiableOpInterface::generateErrorMessage" in File mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp. As we are printing the op with all the names of the operands in the middle end, we are constructing a new SSANameState for each op->print(...) call. Thus, we are doing a new SSA analysis for each error message printed. Perf profiling shows that 98% percent of the time is spent in the constructor of SSANameState. This change add verbose options to generate-runtime-verification pass. verbose 2 is the current behavior and is very expensive. I still keep the default as verbose 2. When we switch from verbose 2 to verbose 0/1, we see below improvements. For mlir imported from mobileNet v2 ssd, the running time of the pass is reduced from 32 mintues to 21 seconds. Full diff: https://github.com/llvm/llvm-project/pull/160331.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
index 6fd0df59d9d2e..e5c9336c8d8dc 100644
--- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
+++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
@@ -32,14 +32,16 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
/*retTy=*/"void",
/*methodName=*/"generateRuntimeVerification",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
- "::mlir::Location":$loc)
+ "::mlir::Location":$loc,
+ "unsigned":$verboseLevel)
>,
];
let extraClassDeclaration = [{
/// Generate the error message that will be printed to the user when
/// verification fails.
- static std::string generateErrorMessage(Operation *op, const std::string &msg);
+ static std::string generateErrorMessage(Operation *op, const std::string &msg,
+ unsigned verboseLevel = 0);
}];
}
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..58ba0892df113 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..3d643d8a168db 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -271,8 +271,16 @@ def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
passes that are suspected to introduce faulty IR.
}];
let constructor = "mlir::createGenerateRuntimeVerificationPass()";
+ let options = [
+ Option<"verboseLevel", "verbose-level", "unsigned", /*default=*/"2",
+ "Verbosity level for runtime verification messages: "
+ "0 = Minimum (only source location), "
+ "1 = Basic (include operation type and operand type), "
+ "2 = Detailed (include full operation details, names, types, shapes, etc.)">
+ ];
}
+
def Inliner : Pass<"inline"> {
let summary = "Inline function calls";
let constructor = "mlir::createInlinerPass()";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index b30182dc84079..608a6801af267 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -32,7 +32,7 @@ struct StructuredOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
StructuredOpInterface<T>, T> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
@@ -73,7 +73,8 @@ struct StructuredOpInterface
auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
linalgOp, "unexpected negative result on dimension #" +
std::to_string(dim) + " of input/output operand #" +
- std::to_string(opOperand.getOperandNumber()));
+ std::to_string(opOperand.getOperandNumber()),
+ verboseLevel);
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
// Generate:
@@ -104,7 +105,8 @@ struct StructuredOpInterface
linalgOp, "dimension #" + std::to_string(dim) +
" of input/output operand #" +
std::to_string(opOperand.getOperandNumber()) +
- " is incompatible with inferred dimension size");
+ " is incompatible with inferred dimension size",
+ verboseLevel);
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index cd92026562da9..d8a7a89a3fbe7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -39,7 +39,7 @@ struct AssumeAlignmentOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto assumeOp = cast<AssumeAlignmentOp>(op);
Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
loc, assumeOp.getMemref());
@@ -53,7 +53,8 @@ struct AssumeAlignmentOpInterface
loc, isAligned,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "memref is not aligned to " +
- std::to_string(assumeOp.getAlignment())));
+ std::to_string(assumeOp.getAlignment()),
+ verboseLevel));
}
};
@@ -61,7 +62,7 @@ struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto castOp = cast<CastOp>(op);
auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
@@ -79,8 +80,8 @@ struct CastOpInterface
loc, arith::CmpIPredicate::eq, srcRank, resultRank);
builder.create<cf::AssertOp>(
loc, isSameRank,
- RuntimeVerifiableOpInterface::generateErrorMessage(op,
- "rank mismatch"));
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "rank mismatch", verboseLevel));
}
// Get source offset and strides. We do not have an op to get offsets and
@@ -119,7 +120,8 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameSz,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "size mismatch of dim " + std::to_string(it.index())));
+ op, "size mismatch of dim " + std::to_string(it.index()),
+ verboseLevel));
}
// Get result offset and strides.
@@ -139,7 +141,7 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameOffset,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "offset mismatch"));
+ op, "offset mismatch", verboseLevel));
}
// Check strides.
@@ -157,7 +159,8 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameStride,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "stride mismatch of dim " + std::to_string(it.index())));
+ op, "stride mismatch of dim " + std::to_string(it.index()),
+ verboseLevel));
}
}
};
@@ -166,7 +169,7 @@ struct CopyOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
CopyOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto copyOp = cast<CopyOp>(op);
BaseMemRefType sourceType = copyOp.getSource().getType();
BaseMemRefType targetType = copyOp.getTarget().getType();
@@ -201,7 +204,7 @@ struct CopyOpInterface
loc, sameDimSize,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size of " + std::to_string(i) +
- "-th source/target dim does not match"));
+ "-th source/target dim does not match", verboseLevel));
}
}
};
@@ -210,14 +213,14 @@ struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto dimOp = cast<DimOp>(op);
Value rank = builder.create<RankOp>(loc, dimOp.getSource());
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
builder.create<cf::AssertOp>(
loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "index is out of bounds"));
+ op, "index is out of bounds", verboseLevel));
}
};
@@ -228,7 +231,7 @@ struct LoadStoreOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto loadStoreOp = cast<LoadStoreOp>(op);
auto memref = loadStoreOp.getMemref();
@@ -251,7 +254,7 @@ struct LoadStoreOpInterface
builder.create<cf::AssertOp>(
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "out-of-bounds access"));
+ op, "out-of-bounds access", verboseLevel));
}
};
@@ -295,7 +298,7 @@ struct ReinterpretCastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ReinterpretCastOpInterface, ReinterpretCastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto reinterpretCast = cast<ReinterpretCastOp>(op);
auto baseMemref = reinterpretCast.getSource();
auto resultMemref =
@@ -323,7 +326,8 @@ struct ReinterpretCastOpInterface
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
op,
- "result of reinterpret_cast is out-of-bounds of the base memref"));
+ "result of reinterpret_cast is out-of-bounds of the base memref",
+ verboseLevel));
}
};
@@ -331,7 +335,7 @@ struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto subView = cast<SubViewOp>(op);
MemRefType sourceType = subView.getSource().getType();
@@ -357,7 +361,7 @@ struct SubViewOpInterface
builder.create<cf::AssertOp>(
loc, offsetInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "offset " + std::to_string(i) + " is out-of-bounds"));
+ op, "offset " + std::to_string(i) + " is out-of-bounds", verboseLevel));
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
@@ -371,7 +375,7 @@ struct SubViewOpInterface
loc, lastPosInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "subview runs out-of-bounds along dimension " +
- std::to_string(i)));
+ std::to_string(i), verboseLevel));
}
}
};
@@ -380,7 +384,7 @@ struct ExpandShapeOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
ExpandShapeOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto expandShapeOp = cast<ExpandShapeOp>(op);
// Verify that the expanded dim sizes are a product of the collapsed dim
@@ -414,7 +418,7 @@ struct ExpandShapeOpInterface
loc, isModZero,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "static result dims in reassoc group do not "
- "divide src dim evenly"));
+ "divide src dim evenly", verboseLevel));
}
}
};
diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
index 8aa194befb420..8b54ed1dc3780 100644
--- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
+++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
@@ -15,7 +15,7 @@ class OpBuilder;
/// Generate an error message string for the given op and the specified error.
std::string
RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
- const std::string &msg) {
+ const std::string &msg, unsigned verboseLevel) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
@@ -26,9 +26,25 @@ RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
flags.skipRegions();
flags.useLocalScope();
stream << "ERROR: Runtime op verification failed\n";
- op->print(stream, flags);
- stream << "\n^ " << msg;
- stream << "\nLocation: ";
+ if (verboseLevel == 2){
+ // print full op including operand names, very expensive
+ op->print(stream, flags);
+ stream << "\n " << msg;
+ }else if (verboseLevel == 1){
+ // print op name and operand types
+ stream << "Op: " << op->getName().getStringRef() << "\n";
+ stream << "Operand Types:";
+ for (const auto &operand : op->getOpOperands()) {
+ stream << " " << operand.get().getType();
+ }
+ stream << "\n" << msg;
+ stream << "Result Types:";
+ for (const auto &result : op->getResults()) {
+ stream << " " << result.getType();
+ }
+ stream << "\n" << msg;
+ }
+ stream << "^\nLocation: ";
op->getLoc().print(stream);
return buffer;
}
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index a40bc2b3272fc..7a54ce667c6ad 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -28,6 +28,14 @@ struct GenerateRuntimeVerificationPass
} // namespace
void GenerateRuntimeVerificationPass::runOnOperation() {
+ // Check verboseLevel is in range [0, 2].
+ if (verboseLevel > 2) {
+ getOperation()->emitError(
+ "generate-runtime-verification pass: set verboseLevel to 0, 1 or 2");
+ signalPassFailure();
+ return;
+ }
+
// The implementation of the RuntimeVerifiableOpInterface may create ops that
// can be verified. We don't want to generate verification for IR that
// performs verification, so gather all runtime-verifiable ops first.
@@ -39,7 +47,8 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
OpBuilder builder(getOperation()->getContext());
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
builder.setInsertionPoint(verifiableOp);
- verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+ verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc(),
+ verboseLevel);
};
}
diff --git a/mlir/test/Dialect/Linalg/runtime-verification.mlir b/mlir/test/Dialect/Linalg/runtime-verification.mlir
index a4f29d8457e58..238169adf496e 100644
--- a/mlir/test/Dialect/Linalg/runtime-verification.mlir
+++ b/mlir/test/Dialect/Linalg/runtime-verification.mlir
@@ -1,13 +1,25 @@
// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s
+// RUN: mlir-opt %s --generate-runtime-verification="verbose-level=1" | FileCheck %s --check-prefix=VERBOSE1
+// RUN: mlir-opt %s --generate-runtime-verification="verbose-level=0" | FileCheck %s --check-prefix=VERBOSE0
// Most of the tests for linalg runtime-verification are implemented as integration tests.
#identity = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @static_dims
+// VERBOSE1-LABEL: @static_dims
+// VERBOSE0-LABEL: @static_dims
func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
// CHECK: %[[TRUE:.*]] = index.bool.constant true
// CHECK: cf.assert %[[TRUE]]
+ // VERBOSE1: %[[TRUE:.*]] = index.bool.constant true
+ // VERBOSE1: cf.assert %[[TRUE]]
+ // VERBOSE1: Operand Types: tensor<5xf32> tensor<5xf32> tensor<5xf32>
+ // VERBOSE1: Result Types
+ // VERBOSE1: Location: loc
+ // VERBOSE0-NOT: Operand Types: tensor<5xf32> tensor<5xf32> tensor<5xf32>
+ // VERBOSE0-NOT: Result Types
+ // VERBOSE0: Location: loc
%result = tensor.empty() : tensor<5xf32>
%0 = linalg.generic {
indexing_maps = [#identity, #identity, #identity],
@@ -26,9 +38,11 @@ func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5x
#map = affine_map<() -> ()>
// CHECK-LABEL: @scalars
+// VERBOSE1-LABEL: @scalars
func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
// No runtime checks are required if the operands are all scalars
// CHECK-NOT: cf.assert
+ // VERBOSE1-NOT: cf.assert
%result = tensor.empty() : tensor<f32>
%0 = linalg.generic {
indexing_maps = [#map, #map, #map],
|
@llvm/pr-subscribers-mlir-memref Author: Hanchenng Wu (HanchengWu) ChangesThe pass generate-runtime-verification generates additional runtime op verification checks. Currently, the pass is extremely expensive. For example, with a mobilenet v2 ssd network(converted to mlir), running this pass alone will take 30 minutes. The same observation has been made to other networks as small as 5 Mb. The culprit is this line "op->print(stream, flags);" in function "RuntimeVerifiableOpInterface::generateErrorMessage" in File mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp. As we are printing the op with all the names of the operands in the middle end, we are constructing a new SSANameState for each op->print(...) call. Thus, we are doing a new SSA analysis for each error message printed. Perf profiling shows that 98% percent of the time is spent in the constructor of SSANameState. This change add verbose options to generate-runtime-verification pass. verbose 2 is the current behavior and is very expensive. I still keep the default as verbose 2. When we switch from verbose 2 to verbose 0/1, we see below improvements. For mlir imported from mobileNet v2 ssd, the running time of the pass is reduced from 32 mintues to 21 seconds. Full diff: https://github.com/llvm/llvm-project/pull/160331.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
index 6fd0df59d9d2e..e5c9336c8d8dc 100644
--- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
+++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
@@ -32,14 +32,16 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
/*retTy=*/"void",
/*methodName=*/"generateRuntimeVerification",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
- "::mlir::Location":$loc)
+ "::mlir::Location":$loc,
+ "unsigned":$verboseLevel)
>,
];
let extraClassDeclaration = [{
/// Generate the error message that will be printed to the user when
/// verification fails.
- static std::string generateErrorMessage(Operation *op, const std::string &msg);
+ static std::string generateErrorMessage(Operation *op, const std::string &msg,
+ unsigned verboseLevel = 0);
}];
}
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..58ba0892df113 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..3d643d8a168db 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -271,8 +271,16 @@ def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
passes that are suspected to introduce faulty IR.
}];
let constructor = "mlir::createGenerateRuntimeVerificationPass()";
+ let options = [
+ Option<"verboseLevel", "verbose-level", "unsigned", /*default=*/"2",
+ "Verbosity level for runtime verification messages: "
+ "0 = Minimum (only source location), "
+ "1 = Basic (include operation type and operand type), "
+ "2 = Detailed (include full operation details, names, types, shapes, etc.)">
+ ];
}
+
def Inliner : Pass<"inline"> {
let summary = "Inline function calls";
let constructor = "mlir::createInlinerPass()";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index b30182dc84079..608a6801af267 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -32,7 +32,7 @@ struct StructuredOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
StructuredOpInterface<T>, T> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
@@ -73,7 +73,8 @@ struct StructuredOpInterface
auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
linalgOp, "unexpected negative result on dimension #" +
std::to_string(dim) + " of input/output operand #" +
- std::to_string(opOperand.getOperandNumber()));
+ std::to_string(opOperand.getOperandNumber()),
+ verboseLevel);
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
// Generate:
@@ -104,7 +105,8 @@ struct StructuredOpInterface
linalgOp, "dimension #" + std::to_string(dim) +
" of input/output operand #" +
std::to_string(opOperand.getOperandNumber()) +
- " is incompatible with inferred dimension size");
+ " is incompatible with inferred dimension size",
+ verboseLevel);
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index cd92026562da9..d8a7a89a3fbe7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -39,7 +39,7 @@ struct AssumeAlignmentOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto assumeOp = cast<AssumeAlignmentOp>(op);
Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
loc, assumeOp.getMemref());
@@ -53,7 +53,8 @@ struct AssumeAlignmentOpInterface
loc, isAligned,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "memref is not aligned to " +
- std::to_string(assumeOp.getAlignment())));
+ std::to_string(assumeOp.getAlignment()),
+ verboseLevel));
}
};
@@ -61,7 +62,7 @@ struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto castOp = cast<CastOp>(op);
auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
@@ -79,8 +80,8 @@ struct CastOpInterface
loc, arith::CmpIPredicate::eq, srcRank, resultRank);
builder.create<cf::AssertOp>(
loc, isSameRank,
- RuntimeVerifiableOpInterface::generateErrorMessage(op,
- "rank mismatch"));
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "rank mismatch", verboseLevel));
}
// Get source offset and strides. We do not have an op to get offsets and
@@ -119,7 +120,8 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameSz,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "size mismatch of dim " + std::to_string(it.index())));
+ op, "size mismatch of dim " + std::to_string(it.index()),
+ verboseLevel));
}
// Get result offset and strides.
@@ -139,7 +141,7 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameOffset,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "offset mismatch"));
+ op, "offset mismatch", verboseLevel));
}
// Check strides.
@@ -157,7 +159,8 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameStride,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "stride mismatch of dim " + std::to_string(it.index())));
+ op, "stride mismatch of dim " + std::to_string(it.index()),
+ verboseLevel));
}
}
};
@@ -166,7 +169,7 @@ struct CopyOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
CopyOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto copyOp = cast<CopyOp>(op);
BaseMemRefType sourceType = copyOp.getSource().getType();
BaseMemRefType targetType = copyOp.getTarget().getType();
@@ -201,7 +204,7 @@ struct CopyOpInterface
loc, sameDimSize,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size of " + std::to_string(i) +
- "-th source/target dim does not match"));
+ "-th source/target dim does not match", verboseLevel));
}
}
};
@@ -210,14 +213,14 @@ struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto dimOp = cast<DimOp>(op);
Value rank = builder.create<RankOp>(loc, dimOp.getSource());
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
builder.create<cf::AssertOp>(
loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "index is out of bounds"));
+ op, "index is out of bounds", verboseLevel));
}
};
@@ -228,7 +231,7 @@ struct LoadStoreOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto loadStoreOp = cast<LoadStoreOp>(op);
auto memref = loadStoreOp.getMemref();
@@ -251,7 +254,7 @@ struct LoadStoreOpInterface
builder.create<cf::AssertOp>(
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "out-of-bounds access"));
+ op, "out-of-bounds access", verboseLevel));
}
};
@@ -295,7 +298,7 @@ struct ReinterpretCastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ReinterpretCastOpInterface, ReinterpretCastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto reinterpretCast = cast<ReinterpretCastOp>(op);
auto baseMemref = reinterpretCast.getSource();
auto resultMemref =
@@ -323,7 +326,8 @@ struct ReinterpretCastOpInterface
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
op,
- "result of reinterpret_cast is out-of-bounds of the base memref"));
+ "result of reinterpret_cast is out-of-bounds of the base memref",
+ verboseLevel));
}
};
@@ -331,7 +335,7 @@ struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto subView = cast<SubViewOp>(op);
MemRefType sourceType = subView.getSource().getType();
@@ -357,7 +361,7 @@ struct SubViewOpInterface
builder.create<cf::AssertOp>(
loc, offsetInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "offset " + std::to_string(i) + " is out-of-bounds"));
+ op, "offset " + std::to_string(i) + " is out-of-bounds", verboseLevel));
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
@@ -371,7 +375,7 @@ struct SubViewOpInterface
loc, lastPosInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "subview runs out-of-bounds along dimension " +
- std::to_string(i)));
+ std::to_string(i), verboseLevel));
}
}
};
@@ -380,7 +384,7 @@ struct ExpandShapeOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
ExpandShapeOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto expandShapeOp = cast<ExpandShapeOp>(op);
// Verify that the expanded dim sizes are a product of the collapsed dim
@@ -414,7 +418,7 @@ struct ExpandShapeOpInterface
loc, isModZero,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "static result dims in reassoc group do not "
- "divide src dim evenly"));
+ "divide src dim evenly", verboseLevel));
}
}
};
diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
index 8aa194befb420..8b54ed1dc3780 100644
--- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
+++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
@@ -15,7 +15,7 @@ class OpBuilder;
/// Generate an error message string for the given op and the specified error.
std::string
RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
- const std::string &msg) {
+ const std::string &msg, unsigned verboseLevel) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
@@ -26,9 +26,25 @@ RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
flags.skipRegions();
flags.useLocalScope();
stream << "ERROR: Runtime op verification failed\n";
- op->print(stream, flags);
- stream << "\n^ " << msg;
- stream << "\nLocation: ";
+ if (verboseLevel == 2){
+ // print full op including operand names, very expensive
+ op->print(stream, flags);
+ stream << "\n " << msg;
+ }else if (verboseLevel == 1){
+ // print op name and operand types
+ stream << "Op: " << op->getName().getStringRef() << "\n";
+ stream << "Operand Types:";
+ for (const auto &operand : op->getOpOperands()) {
+ stream << " " << operand.get().getType();
+ }
+ stream << "\n" << msg;
+ stream << "Result Types:";
+ for (const auto &result : op->getResults()) {
+ stream << " " << result.getType();
+ }
+ stream << "\n" << msg;
+ }
+ stream << "^\nLocation: ";
op->getLoc().print(stream);
return buffer;
}
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index a40bc2b3272fc..7a54ce667c6ad 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -28,6 +28,14 @@ struct GenerateRuntimeVerificationPass
} // namespace
void GenerateRuntimeVerificationPass::runOnOperation() {
+ // Check verboseLevel is in range [0, 2].
+ if (verboseLevel > 2) {
+ getOperation()->emitError(
+ "generate-runtime-verification pass: set verboseLevel to 0, 1 or 2");
+ signalPassFailure();
+ return;
+ }
+
// The implementation of the RuntimeVerifiableOpInterface may create ops that
// can be verified. We don't want to generate verification for IR that
// performs verification, so gather all runtime-verifiable ops first.
@@ -39,7 +47,8 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
OpBuilder builder(getOperation()->getContext());
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
builder.setInsertionPoint(verifiableOp);
- verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+ verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc(),
+ verboseLevel);
};
}
diff --git a/mlir/test/Dialect/Linalg/runtime-verification.mlir b/mlir/test/Dialect/Linalg/runtime-verification.mlir
index a4f29d8457e58..238169adf496e 100644
--- a/mlir/test/Dialect/Linalg/runtime-verification.mlir
+++ b/mlir/test/Dialect/Linalg/runtime-verification.mlir
@@ -1,13 +1,25 @@
// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s
+// RUN: mlir-opt %s --generate-runtime-verification="verbose-level=1" | FileCheck %s --check-prefix=VERBOSE1
+// RUN: mlir-opt %s --generate-runtime-verification="verbose-level=0" | FileCheck %s --check-prefix=VERBOSE0
// Most of the tests for linalg runtime-verification are implemented as integration tests.
#identity = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @static_dims
+// VERBOSE1-LABEL: @static_dims
+// VERBOSE0-LABEL: @static_dims
func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
// CHECK: %[[TRUE:.*]] = index.bool.constant true
// CHECK: cf.assert %[[TRUE]]
+ // VERBOSE1: %[[TRUE:.*]] = index.bool.constant true
+ // VERBOSE1: cf.assert %[[TRUE]]
+ // VERBOSE1: Operand Types: tensor<5xf32> tensor<5xf32> tensor<5xf32>
+ // VERBOSE1: Result Types
+ // VERBOSE1: Location: loc
+ // VERBOSE0-NOT: Operand Types: tensor<5xf32> tensor<5xf32> tensor<5xf32>
+ // VERBOSE0-NOT: Result Types
+ // VERBOSE0: Location: loc
%result = tensor.empty() : tensor<5xf32>
%0 = linalg.generic {
indexing_maps = [#identity, #identity, #identity],
@@ -26,9 +38,11 @@ func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5x
#map = affine_map<() -> ()>
// CHECK-LABEL: @scalars
+// VERBOSE1-LABEL: @scalars
func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
// No runtime checks are required if the operands are all scalars
// CHECK-NOT: cf.assert
+ // VERBOSE1-NOT: cf.assert
%result = tensor.empty() : tensor<f32>
%0 = linalg.generic {
indexing_maps = [#map, #map, #map],
|
That seems to point to a caching issue to me, we probably should start there first? |
/*args=*/(ins "::mlir::OpBuilder &":$builder, | ||
"::mlir::Location":$loc) | ||
"::mlir::Location":$loc, | ||
"unsigned":$verboseLevel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"unsigned":$verboseLevel) | |
"function_ref<std::string(Operation *, StringRef msg>":$generateErrorMessage) |
Can we just use injection here?
That will actually allow the customer to control this all however they want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Mehdi,
Thanks for your feedback. In the latest commit, I've updated the implementation so that each operation implementing RuntimeVerifiableOpInterface receives a function_ref callback through the generateRuntimeVerification method. The function_ref is default-initialized based on the verbosity options, but clients can override it with custom implementations as needed.
The changes maintain backward compatibility while providing the flexibility you suggested for different verification strategies.
5eac44b
to
3b8f7dd
Compare
@@ -36,10 +44,46 @@ void GenerateRuntimeVerificationPass::runOnOperation() { | |||
ops.push_back(verifiableOp); | |||
}); | |||
|
|||
// Create error message generator based on verboseLevel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we extract an AsmState here, set the flags, and re-use it?
I'm curious how it would help the performance issue you reported.
The rest of the change LG I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Mehdi,
Thanks for your guidance. I'm relatively new to LLVM/MLIR and have a question about AsmState usage in op->print().
Looking at the current implementation, the pass first collects all RuntimeVerifiableOpInterface ops to avoid generating verification for IR that itself performs verification. Then it iterates through these ops, generating verification code for each.
My concern/puzzle is about SSA name state consistency: since each call to generateRuntimeVerification() modifies the IR by injecting verification code, the SSA name mappings may become stale between errorMsgGenerator invocations. My understanding is that AsmState/SSANameState needs to be reconstructed after IR modifications to maintain accurate SSA name printing.
Again, I could be wrong, and there might be ways to reuse ArmState. Let me know what you think, and I can give it a try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an excellent point, but isn't it an issue of the current pass that we're reconstructing the SSA numbering between each application of the instrumentation?
I would think the user would likely want to match the message to the IR as printed before the pass instead of some state that only exists in the middle of the application of the pass?
That said we can definitely land this PR as-is, to not affect the behavior too much, and then try to land an update where we construct an AsmState and cache it for the duration of the pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, Mehdi
I agree with your observation on the current behavior. See my below answers.
-isn't it an issue of the current pass that we're reconstructing the SSA numbering between each application of the instrumentation?
I think my previous understanding of how AsmState works is wrong. Currently, each op-print() will create a new AsmState for each op's assert/error message, and I think the SSA name created by this AsmState will be stored to the op and used by the emitter later for the op. But it seems that this is not the case. In the dumped mlir file after the pass, I see the operand names dumped in the assert/error message does seem to be some names that only exist in the middle of the application of the pass.
-I would think the user would likely want to match the message to the IR as printed before the pass instead of some state that only exists in the middle of the application of the pass?
Similar to above - the operand names dumped in the assert/error message do seem to be some names that only exist in the middle of the application of the pass. However, the location that gets printed as part of the error message does point to the original location. That said, if the input is an text mlir file, the error message will be able to direct the user to the original line of instruction for which the error is thrown.
-That said we can definitely land this PR as-is, to not affect the behavior too much, and then try to land an update where we construct an AsmState and cache it for the duration of the pass?
Thanks for your thoughts on this. I have submitted a new version that's rebased to tip of llvm. Can you help move this forward?
In addition, I took a look at AsmState definition. It says "The IR should not be mutated in-between invocations using this state, and the IR being printed must not be an parent of the IR originally used to initialize this state. " Since we keep addding new code for verification purposes throughtout the pass, reusing the same AsmState for the duration of the pass seems not feasible. I will look into this and update with you later on this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the SSA name created by this AsmState will be stored to the op and used by the emitter later for the op. But it seems that this is not the case.
I don't follow what you mean here?
Especially because it seems to contradict what you said before:
Similar to above - the operand names dumped in the assert/error message do seem to be some names that only exist in the middle of the application of the pass.
In addition, I took a look at AsmState definition. It says "The IR should not be mutated in-between invocations using this state
I need to look exactly at what kind of issues this is trying to protect against, I suspect it is a concern of state invalidation and stale pointer, which may be safe in the context of what we're doing here.
|
Apologize that I accidentally clicked the close button (screen's bit lagging in vncviewer), reopened it. Please see my previous answer above. |
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions h,cpp -- mlir/include/mlir/Transforms/Passes.h mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp mlir/lib/Transforms/GenerateRuntimeVerification.cpp
View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 6c45bccb2..15eb51a6d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -31,8 +31,10 @@ template <typename T>
struct StructuredOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
StructuredOpInterface<T>, T> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 66d4df230..291da1f76 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -37,8 +37,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
struct AssumeAlignmentOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto assumeOp = cast<AssumeAlignmentOp>(op);
Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
assumeOp.getMemref());
@@ -48,9 +50,9 @@ struct AssumeAlignmentOpInterface
Value isAligned =
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
arith::ConstantIndexOp::create(builder, loc, 0));
- cf::AssertOp::create(builder, loc, isAligned,
- generateErrorMessage(
- op, "memref is not aligned to " +
+ cf::AssertOp::create(
+ builder, loc, isAligned,
+ generateErrorMessage(op, "memref is not aligned to " +
std::to_string(assumeOp.getAlignment())));
}
};
@@ -58,8 +60,10 @@ struct AssumeAlignmentOpInterface
struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto castOp = cast<CastOp>(op);
auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
@@ -115,8 +119,8 @@ struct CastOpInterface
builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
cf::AssertOp::create(
builder, loc, isSameSz,
- generateErrorMessage(
- op, "size mismatch of dim " + std::to_string(it.index())));
+ generateErrorMessage(op, "size mismatch of dim " +
+ std::to_string(it.index())));
}
// Get result offset and strides.
@@ -151,8 +155,8 @@ struct CastOpInterface
builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
cf::AssertOp::create(
builder, loc, isSameStride,
- generateErrorMessage(
- op, "stride mismatch of dim " + std::to_string(it.index())));
+ generateErrorMessage(op, "stride mismatch of dim " +
+ std::to_string(it.index())));
}
}
};
@@ -160,8 +164,10 @@ struct CastOpInterface
struct CopyOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
CopyOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto copyOp = cast<CopyOp>(op);
BaseMemRefType sourceType = copyOp.getSource().getType();
BaseMemRefType targetType = copyOp.getTarget().getType();
@@ -191,9 +197,9 @@ struct CopyOpInterface
Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
Value sameDimSize = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
- cf::AssertOp::create(builder, loc, sameDimSize,
- generateErrorMessage(
- op, "size of " + std::to_string(i) +
+ cf::AssertOp::create(
+ builder, loc, sameDimSize,
+ generateErrorMessage(op, "size of " + std::to_string(i) +
"-th source/target dim does not match"));
}
}
@@ -202,8 +208,10 @@ struct CopyOpInterface
struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto dimOp = cast<DimOp>(op);
Value rank = RankOp::create(builder, loc, dimOp.getSource());
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
@@ -220,8 +228,10 @@ template <typename LoadStoreOp>
struct LoadStoreOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto loadStoreOp = cast<LoadStoreOp>(op);
auto memref = loadStoreOp.getMemref();
@@ -249,8 +259,10 @@ struct LoadStoreOpInterface
struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto subView = cast<SubViewOp>(op);
MemRefType sourceType = subView.getSource().getType();
@@ -273,10 +285,10 @@ struct SubViewOpInterface
Value dimSize = metadataOp.getSizes()[i];
Value offsetInBounds =
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
- cf::AssertOp::create(
- builder, loc, offsetInBounds,
- generateErrorMessage(
- op, "offset " + std::to_string(i) + " is out-of-bounds"));
+ cf::AssertOp::create(builder, loc, offsetInBounds,
+ generateErrorMessage(op, "offset " +
+ std::to_string(i) +
+ " is out-of-bounds"));
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
@@ -288,9 +300,9 @@ struct SubViewOpInterface
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
cf::AssertOp::create(
builder, loc, lastPosInBounds,
- generateErrorMessage(
- op, "subview runs out-of-bounds along dimension " +
- std::to_string(i)));
+ generateErrorMessage(op,
+ "subview runs out-of-bounds along dimension " +
+ std::to_string(i)));
}
}
};
@@ -298,8 +310,10 @@ struct SubViewOpInterface
struct ExpandShapeOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
ExpandShapeOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto expandShapeOp = cast<ExpandShapeOp>(op);
// Verify that the expanded dim sizes are a product of the collapsed dim
@@ -329,9 +343,9 @@ struct ExpandShapeOpInterface
Value isModZero = arith::CmpIOp::create(
builder, loc, arith::CmpIPredicate::eq, mod,
arith::ConstantIndexOp::create(builder, loc, 0));
- cf::AssertOp::create(builder, loc, isModZero,
- generateErrorMessage(
- op, "static result dims in reassoc group do not "
+ cf::AssertOp::create(
+ builder, loc, isModZero,
+ generateErrorMessage(op, "static result dims in reassoc group do not "
"divide src dim evenly"));
}
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
index fd51f5010..c03111860 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
@@ -35,8 +35,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto castOp = cast<CastOp>(op);
auto srcType = cast<TensorType>(castOp.getSource().getType());
@@ -75,8 +77,8 @@ struct CastOpInterface
builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
cf::AssertOp::create(
builder, loc, isSameSz,
- generateErrorMessage(
- op, "size mismatch of dim " + std::to_string(it.index())));
+ generateErrorMessage(op, "size mismatch of dim " +
+ std::to_string(it.index())));
}
}
};
@@ -84,8 +86,10 @@ struct CastOpInterface
struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto dimOp = cast<DimOp>(op);
Value rank = RankOp::create(builder, loc, dimOp.getSource());
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
@@ -102,8 +106,10 @@ template <typename OpTy>
struct ExtractInsertOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ExtractInsertOpInterface<OpTy>, OpTy> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto extractInsertOp = cast<OpTy>(op);
Value tensor;
@@ -140,8 +146,10 @@ struct ExtractInsertOpInterface
struct ExtractSliceOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ExtractSliceOpInterface, ExtractSliceOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
+ void
+ generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
+ function_ref<std::string(Operation *, StringRef)>
+ generateErrorMessage) const {
auto extractSliceOp = cast<ExtractSliceOp>(op);
RankedTensorType sourceType = extractSliceOp.getSource().getType();
@@ -163,10 +171,10 @@ struct ExtractSliceOpInterface
loc, extractSliceOp.getSource(), i);
Value offsetInBounds =
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
- cf::AssertOp::create(
- builder, loc, offsetInBounds,
- generateErrorMessage(
- op, "offset " + std::to_string(i) + " is out-of-bounds"));
+ cf::AssertOp::create(builder, loc, offsetInBounds,
+ generateErrorMessage(op, "offset " +
+ std::to_string(i) +
+ " is out-of-bounds"));
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index e3098a4c6..cfe531385 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -31,7 +31,7 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
// Check verboseLevel is in range [0, 2].
if (verboseLevel > 2) {
getOperation()->emitError(
- "generate-runtime-verification pass: set verboseLevel to 0, 1 or 2");
+ "generate-runtime-verification pass: set verboseLevel to 0, 1 or 2");
signalPassFailure();
return;
}
@@ -45,7 +45,8 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
});
// Create error message generator based on verboseLevel
- auto errorMsgGenerator = [vLevel = verboseLevel.getValue()](Operation *op, StringRef msg) -> std::string {
+ auto errorMsgGenerator = [vLevel = verboseLevel.getValue()](
+ Operation *op, StringRef msg) -> std::string {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
@@ -83,7 +84,7 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
builder.setInsertionPoint(verifiableOp);
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc(),
- errorMsgGenerator);
+ errorMsgGenerator);
};
}
|
3b8f7dd
to
c111294
Compare
This LGTM, but can we get an upstream repro for:
I'd like to be able to benchmark this upstream to check how to improve this pass. |
Operation *op, StringRef msg) -> std::string { | ||
std::string buffer; | ||
llvm::raw_string_ostream stream(buffer); | ||
OpPrintingFlags flags; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flags are used only for vLevel = 2
, move into the if
check.
stream << "\n " << msg; | ||
} else if (vLevel == 1) { | ||
// print op name and operand types | ||
stream << "Op: " << op->getName().getStringRef() << "\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we're generating a lot of output, I think it's worth to pick a shorter syntax. How about this (fits into one line, looks almost like generic op syntax and can be grepped easily):
"dialect.op_name"(...) : (operand types) -> (result_types) loc(location)
Something along the lines of:
stream << op->getName().getStringRef() << "(...) : (";
llvm::interleaveComma(TypeRange(op->getOperands()), stream);
stream << ") -> ("
llvm::interleaveComma(TypeRange(op->getResultTypes()), stream);
stream << ") loc(";
op->getLoc().print(stream);
stream << ")";
The pass generate-runtime-verification generates additional runtime op verification checks.
Currently, the pass is extremely expensive. For example, with a mobilenet v2 ssd network(converted to mlir), running this pass alone will take 30 minutes. The same observation has been made to other networks as small as 5 Mb.
The culprit is this line "op->print(stream, flags);" in function "RuntimeVerifiableOpInterface::generateErrorMessage" in File mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp.
As we are printing the op with all the names of the operands in the middle end, we are constructing a new SSANameState for each op->print(...) call. Thus, we are doing a new SSA analysis for each error message printed.
Perf profiling shows that 98% percent of the time is spent in the constructor of SSANameState.
This change add verbose options to generate-runtime-verification pass.
verbose 0: print only source location with error message.
verbose 1: print source location and operation name and operand types with error message.
verbose 2: print the full op, including the name of the operands.
Both verbose 0 and 1 avoid the expensive "op->print(...)" call that invokes the SSANameState analysis.
verbose 2 is the current behavior.
When we switch from verbose 2 to verbose 0/1, we see below improvements.
For mlir imported from mobileNet v2 ssd, the running time of the pass is reduced from 32 mintues to 21 seconds.
For another small network (only 5MB size), the running time of the pass is reduced from 15 minutes to 4 seconds.