Skip to content

Commit 3b8f7dd

Browse files
committed
Add options to generate-runtime-verification to enable a faster pass
1 parent 69194be commit 3b8f7dd

File tree

9 files changed

+101
-71
lines changed

9 files changed

+101
-71
lines changed

mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,11 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
3232
/*retTy=*/"void",
3333
/*methodName=*/"generateRuntimeVerification",
3434
/*args=*/(ins "::mlir::OpBuilder &":$builder,
35-
"::mlir::Location":$loc)
35+
"::mlir::Location":$loc,
36+
"function_ref<std::string(Operation *, StringRef)>":$generateErrorMessage)
3637
>,
3738
];
3839

39-
let extraClassDeclaration = [{
40-
/// Generate the error message that will be printed to the user when
41-
/// verification fails.
42-
static std::string generateErrorMessage(Operation *op, const std::string &msg);
43-
}];
4440
}
4541

4642
#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE

mlir/include/mlir/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class GreedyRewriteConfig;
4747
#define GEN_PASS_DECL_TOPOLOGICALSORT
4848
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
4949
#define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS
50+
#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
5051
#include "mlir/Transforms/Passes.h.inc"
5152

5253
/// Creates an instance of the Canonicalizer pass, configured with default

mlir/include/mlir/Transforms/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,16 @@ def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
270270
passes that are suspected to introduce faulty IR.
271271
}];
272272
let constructor = "mlir::createGenerateRuntimeVerificationPass()";
273+
let options = [
274+
Option<"verboseLevel", "verbose-level", "unsigned", /*default=*/"2",
275+
"Verbosity level for runtime verification messages: "
276+
"0 = Minimum (only source location), "
277+
"1 = Basic (include operation type and operand type), "
278+
"2 = Detailed (include full operation details, names, types, shapes, etc.)">
279+
];
273280
}
274281

282+
275283
def Inliner : Pass<"inline"> {
276284
let summary = "Inline function calls";
277285
let constructor = "mlir::createInlinerPass()";

mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct StructuredOpInterface
3232
: public RuntimeVerifiableOpInterface::ExternalModel<
3333
StructuredOpInterface<T>, T> {
3434
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
35-
Location loc) const {
35+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
3636
auto linalgOp = llvm::cast<LinalgOp>(op);
3737

3838
SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
@@ -70,7 +70,7 @@ struct StructuredOpInterface
7070
builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
7171
auto cmpOp = builder.createOrFold<index::CmpOp>(
7272
loc, index::IndexCmpPredicate::SGE, min, zero);
73-
auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
73+
auto msg = generateErrorMessage(
7474
linalgOp, "unexpected negative result on dimension #" +
7575
std::to_string(dim) + " of input/output operand #" +
7676
std::to_string(opOperand.getOperandNumber()));
@@ -100,7 +100,7 @@ struct StructuredOpInterface
100100

101101
cmpOp = builder.createOrFold<index::CmpOp>(
102102
loc, predicate, inferredDimSize, actualDimSize);
103-
msg = RuntimeVerifiableOpInterface::generateErrorMessage(
103+
msg = generateErrorMessage(
104104
linalgOp, "dimension #" + std::to_string(dim) +
105105
" of input/output operand #" +
106106
std::to_string(opOperand.getOperandNumber()) +

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct AssumeAlignmentOpInterface
3838
: public RuntimeVerifiableOpInterface::ExternalModel<
3939
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
4040
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
41-
Location loc) const {
41+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
4242
auto assumeOp = cast<AssumeAlignmentOp>(op);
4343
Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
4444
assumeOp.getMemref());
@@ -49,7 +49,7 @@ struct AssumeAlignmentOpInterface
4949
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
5050
arith::ConstantIndexOp::create(builder, loc, 0));
5151
cf::AssertOp::create(builder, loc, isAligned,
52-
RuntimeVerifiableOpInterface::generateErrorMessage(
52+
generateErrorMessage(
5353
op, "memref is not aligned to " +
5454
std::to_string(assumeOp.getAlignment())));
5555
}
@@ -59,7 +59,7 @@ struct CastOpInterface
5959
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
6060
CastOp> {
6161
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
62-
Location loc) const {
62+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
6363
auto castOp = cast<CastOp>(op);
6464
auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
6565

@@ -76,8 +76,7 @@ struct CastOpInterface
7676
Value isSameRank = arith::CmpIOp::create(
7777
builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
7878
cf::AssertOp::create(builder, loc, isSameRank,
79-
RuntimeVerifiableOpInterface::generateErrorMessage(
80-
op, "rank mismatch"));
79+
generateErrorMessage(op, "rank mismatch"));
8180
}
8281

8382
// Get source offset and strides. We do not have an op to get offsets and
@@ -116,7 +115,7 @@ struct CastOpInterface
116115
builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
117116
cf::AssertOp::create(
118117
builder, loc, isSameSz,
119-
RuntimeVerifiableOpInterface::generateErrorMessage(
118+
generateErrorMessage(
120119
op, "size mismatch of dim " + std::to_string(it.index())));
121120
}
122121

@@ -135,8 +134,7 @@ struct CastOpInterface
135134
Value isSameOffset = arith::CmpIOp::create(
136135
builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
137136
cf::AssertOp::create(builder, loc, isSameOffset,
138-
RuntimeVerifiableOpInterface::generateErrorMessage(
139-
op, "offset mismatch"));
137+
generateErrorMessage(op, "offset mismatch"));
140138
}
141139

142140
// Check strides.
@@ -153,7 +151,7 @@ struct CastOpInterface
153151
builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
154152
cf::AssertOp::create(
155153
builder, loc, isSameStride,
156-
RuntimeVerifiableOpInterface::generateErrorMessage(
154+
generateErrorMessage(
157155
op, "stride mismatch of dim " + std::to_string(it.index())));
158156
}
159157
}
@@ -163,7 +161,7 @@ struct CopyOpInterface
163161
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
164162
CopyOp> {
165163
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
166-
Location loc) const {
164+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
167165
auto copyOp = cast<CopyOp>(op);
168166
BaseMemRefType sourceType = copyOp.getSource().getType();
169167
BaseMemRefType targetType = copyOp.getTarget().getType();
@@ -194,7 +192,7 @@ struct CopyOpInterface
194192
Value sameDimSize = arith::CmpIOp::create(
195193
builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
196194
cf::AssertOp::create(builder, loc, sameDimSize,
197-
RuntimeVerifiableOpInterface::generateErrorMessage(
195+
generateErrorMessage(
198196
op, "size of " + std::to_string(i) +
199197
"-th source/target dim does not match"));
200198
}
@@ -205,15 +203,14 @@ struct DimOpInterface
205203
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
206204
DimOp> {
207205
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
208-
Location loc) const {
206+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
209207
auto dimOp = cast<DimOp>(op);
210208
Value rank = RankOp::create(builder, loc, dimOp.getSource());
211209
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
212210
cf::AssertOp::create(
213211
builder, loc,
214212
generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
215-
RuntimeVerifiableOpInterface::generateErrorMessage(
216-
op, "index is out of bounds"));
213+
generateErrorMessage(op, "index is out of bounds"));
217214
}
218215
};
219216

@@ -224,7 +221,7 @@ struct LoadStoreOpInterface
224221
: public RuntimeVerifiableOpInterface::ExternalModel<
225222
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
226223
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
227-
Location loc) const {
224+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
228225
auto loadStoreOp = cast<LoadStoreOp>(op);
229226

230227
auto memref = loadStoreOp.getMemref();
@@ -245,16 +242,15 @@ struct LoadStoreOpInterface
245242
: inBounds;
246243
}
247244
cf::AssertOp::create(builder, loc, assertCond,
248-
RuntimeVerifiableOpInterface::generateErrorMessage(
249-
op, "out-of-bounds access"));
245+
generateErrorMessage(op, "out-of-bounds access"));
250246
}
251247
};
252248

253249
struct SubViewOpInterface
254250
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
255251
SubViewOp> {
256252
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
257-
Location loc) const {
253+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
258254
auto subView = cast<SubViewOp>(op);
259255
MemRefType sourceType = subView.getSource().getType();
260256

@@ -279,7 +275,7 @@ struct SubViewOpInterface
279275
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
280276
cf::AssertOp::create(
281277
builder, loc, offsetInBounds,
282-
RuntimeVerifiableOpInterface::generateErrorMessage(
278+
generateErrorMessage(
283279
op, "offset " + std::to_string(i) + " is out-of-bounds"));
284280

285281
// Verify that slice does not run out-of-bounds.
@@ -292,7 +288,7 @@ struct SubViewOpInterface
292288
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
293289
cf::AssertOp::create(
294290
builder, loc, lastPosInBounds,
295-
RuntimeVerifiableOpInterface::generateErrorMessage(
291+
generateErrorMessage(
296292
op, "subview runs out-of-bounds along dimension " +
297293
std::to_string(i)));
298294
}
@@ -303,7 +299,7 @@ struct ExpandShapeOpInterface
303299
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
304300
ExpandShapeOp> {
305301
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
306-
Location loc) const {
302+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
307303
auto expandShapeOp = cast<ExpandShapeOp>(op);
308304

309305
// Verify that the expanded dim sizes are a product of the collapsed dim
@@ -334,7 +330,7 @@ struct ExpandShapeOpInterface
334330
builder, loc, arith::CmpIPredicate::eq, mod,
335331
arith::ConstantIndexOp::create(builder, loc, 0));
336332
cf::AssertOp::create(builder, loc, isModZero,
337-
RuntimeVerifiableOpInterface::generateErrorMessage(
333+
generateErrorMessage(
338334
op, "static result dims in reassoc group do not "
339335
"divide src dim evenly"));
340336
}

mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct CastOpInterface
3636
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
3737
CastOp> {
3838
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
39-
Location loc) const {
39+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
4040
auto castOp = cast<CastOp>(op);
4141
auto srcType = cast<TensorType>(castOp.getSource().getType());
4242

@@ -53,8 +53,7 @@ struct CastOpInterface
5353
Value isSameRank = arith::CmpIOp::create(
5454
builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
5555
cf::AssertOp::create(builder, loc, isSameRank,
56-
RuntimeVerifiableOpInterface::generateErrorMessage(
57-
op, "rank mismatch"));
56+
generateErrorMessage(op, "rank mismatch"));
5857
}
5958

6059
// Check dimension sizes.
@@ -76,7 +75,7 @@ struct CastOpInterface
7675
builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
7776
cf::AssertOp::create(
7877
builder, loc, isSameSz,
79-
RuntimeVerifiableOpInterface::generateErrorMessage(
78+
generateErrorMessage(
8079
op, "size mismatch of dim " + std::to_string(it.index())));
8180
}
8281
}
@@ -86,15 +85,14 @@ struct DimOpInterface
8685
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
8786
DimOp> {
8887
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
89-
Location loc) const {
88+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
9089
auto dimOp = cast<DimOp>(op);
9190
Value rank = RankOp::create(builder, loc, dimOp.getSource());
9291
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
9392
cf::AssertOp::create(
9493
builder, loc,
9594
generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
96-
RuntimeVerifiableOpInterface::generateErrorMessage(
97-
op, "index is out of bounds"));
95+
generateErrorMessage(op, "index is out of bounds"));
9896
}
9997
};
10098

@@ -105,7 +103,7 @@ struct ExtractInsertOpInterface
105103
: public RuntimeVerifiableOpInterface::ExternalModel<
106104
ExtractInsertOpInterface<OpTy>, OpTy> {
107105
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
108-
Location loc) const {
106+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
109107
auto extractInsertOp = cast<OpTy>(op);
110108

111109
Value tensor;
@@ -135,16 +133,15 @@ struct ExtractInsertOpInterface
135133
: inBounds;
136134
}
137135
cf::AssertOp::create(builder, loc, assertCond,
138-
RuntimeVerifiableOpInterface::generateErrorMessage(
139-
op, "out-of-bounds access"));
136+
generateErrorMessage(op, "out-of-bounds access"));
140137
}
141138
};
142139

143140
struct ExtractSliceOpInterface
144141
: public RuntimeVerifiableOpInterface::ExternalModel<
145142
ExtractSliceOpInterface, ExtractSliceOp> {
146143
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
147-
Location loc) const {
144+
Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage) const {
148145
auto extractSliceOp = cast<ExtractSliceOp>(op);
149146
RankedTensorType sourceType = extractSliceOp.getSource().getType();
150147

@@ -168,7 +165,7 @@ struct ExtractSliceOpInterface
168165
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
169166
cf::AssertOp::create(
170167
builder, loc, offsetInBounds,
171-
RuntimeVerifiableOpInterface::generateErrorMessage(
168+
generateErrorMessage(
172169
op, "offset " + std::to_string(i) + " is out-of-bounds"));
173170

174171
// Verify that slice does not run out-of-bounds.
@@ -181,7 +178,7 @@ struct ExtractSliceOpInterface
181178
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
182179
cf::AssertOp::create(
183180
builder, loc, lastPosInBounds,
184-
RuntimeVerifiableOpInterface::generateErrorMessage(
181+
generateErrorMessage(
185182
op, "extract_slice runs out-of-bounds along dimension " +
186183
std::to_string(i)));
187184
}

mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,5 @@
88

99
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
1010

11-
namespace mlir {
12-
class Location;
13-
class OpBuilder;
14-
15-
/// Generate an error message string for the given op and the specified error.
16-
std::string
17-
RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
18-
const std::string &msg) {
19-
std::string buffer;
20-
llvm::raw_string_ostream stream(buffer);
21-
OpPrintingFlags flags;
22-
// We may generate a lot of error messages and so we need to ensure the
23-
// printing is fast.
24-
flags.elideLargeElementsAttrs();
25-
flags.printGenericOpForm();
26-
flags.skipRegions();
27-
flags.useLocalScope();
28-
stream << "ERROR: Runtime op verification failed\n";
29-
op->print(stream, flags);
30-
stream << "\n^ " << msg;
31-
stream << "\nLocation: ";
32-
op->getLoc().print(stream);
33-
return buffer;
34-
}
35-
} // namespace mlir
36-
3711
/// Include the definitions of the interface.
3812
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc"

0 commit comments

Comments
 (0)