Skip to content

Commit c111294

Browse files
committed
Add options to generate-runtime-verification to enable a faster pass
1 parent 5843ffb commit c111294

File tree

9 files changed

+155
-100
lines changed

9 files changed

+155
-100
lines changed

mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,11 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
3232
/*retTy=*/"void",
3333
/*methodName=*/"generateRuntimeVerification",
3434
/*args=*/(ins "::mlir::OpBuilder &":$builder,
35-
"::mlir::Location":$loc)
35+
"::mlir::Location":$loc,
36+
"function_ref<std::string(Operation *, StringRef)>":$generateErrorMessage)
3637
>,
3738
];
3839

39-
let extraClassDeclaration = [{
40-
/// Generate the error message that will be printed to the user when
41-
/// verification fails.
42-
static std::string generateErrorMessage(Operation *op, const std::string &msg);
43-
}];
4440
}
4541

4642
#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE

mlir/include/mlir/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class GreedyRewriteConfig;
4747
#define GEN_PASS_DECL_TOPOLOGICALSORT
4848
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
4949
#define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS
50+
#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
5051
#include "mlir/Transforms/Passes.h.inc"
5152

5253
/// Creates an instance of the Canonicalizer pass, configured with default

mlir/include/mlir/Transforms/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,16 @@ def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
270270
passes that are suspected to introduce faulty IR.
271271
}];
272272
let constructor = "mlir::createGenerateRuntimeVerificationPass()";
273+
let options = [
274+
Option<"verboseLevel", "verbose-level", "unsigned", /*default=*/"2",
275+
"Verbosity level for runtime verification messages: "
276+
"0 = Minimum (only source location), "
277+
"1 = Basic (include operation type and operand type), "
278+
"2 = Detailed (include full operation details, names, types, shapes, etc.)">
279+
];
273280
}
274281

282+
275283
def Inliner : Pass<"inline"> {
276284
let summary = "Inline function calls";
277285
let constructor = "mlir::createInlinerPass()";

mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ template <typename T>
3131
struct StructuredOpInterface
3232
: public RuntimeVerifiableOpInterface::ExternalModel<
3333
StructuredOpInterface<T>, T> {
34-
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
35-
Location loc) const {
34+
void
35+
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
36+
function_ref<std::string(Operation *, StringRef)>
37+
generateErrorMessage) const {
3638
auto linalgOp = llvm::cast<LinalgOp>(op);
3739

3840
SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
@@ -70,7 +72,7 @@ struct StructuredOpInterface
7072
builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
7173
auto cmpOp = builder.createOrFold<index::CmpOp>(
7274
loc, index::IndexCmpPredicate::SGE, min, zero);
73-
auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
75+
auto msg = generateErrorMessage(
7476
linalgOp, "unexpected negative result on dimension #" +
7577
std::to_string(dim) + " of input/output operand #" +
7678
std::to_string(opOperand.getOperandNumber()));
@@ -100,7 +102,7 @@ struct StructuredOpInterface
100102

101103
cmpOp = builder.createOrFold<index::CmpOp>(
102104
loc, predicate, inferredDimSize, actualDimSize);
103-
msg = RuntimeVerifiableOpInterface::generateErrorMessage(
105+
msg = generateErrorMessage(
104106
linalgOp, "dimension #" + std::to_string(dim) +
105107
" of input/output operand #" +
106108
std::to_string(opOperand.getOperandNumber()) +

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
3737
struct AssumeAlignmentOpInterface
3838
: public RuntimeVerifiableOpInterface::ExternalModel<
3939
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
40-
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
41-
Location loc) const {
40+
void
41+
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
42+
function_ref<std::string(Operation *, StringRef)>
43+
generateErrorMessage) const {
4244
auto assumeOp = cast<AssumeAlignmentOp>(op);
4345
Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc,
4446
assumeOp.getMemref());
@@ -48,18 +50,20 @@ struct AssumeAlignmentOpInterface
4850
Value isAligned =
4951
arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest,
5052
arith::ConstantIndexOp::create(builder, loc, 0));
51-
cf::AssertOp::create(builder, loc, isAligned,
52-
RuntimeVerifiableOpInterface::generateErrorMessage(
53-
op, "memref is not aligned to " +
53+
cf::AssertOp::create(
54+
builder, loc, isAligned,
55+
generateErrorMessage(op, "memref is not aligned to " +
5456
std::to_string(assumeOp.getAlignment())));
5557
}
5658
};
5759

5860
struct CastOpInterface
5961
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
6062
CastOp> {
61-
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
62-
Location loc) const {
63+
void
64+
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
65+
function_ref<std::string(Operation *, StringRef)>
66+
generateErrorMessage) const {
6367
auto castOp = cast<CastOp>(op);
6468
auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
6569

@@ -76,8 +80,7 @@ struct CastOpInterface
7680
Value isSameRank = arith::CmpIOp::create(
7781
builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
7882
cf::AssertOp::create(builder, loc, isSameRank,
79-
RuntimeVerifiableOpInterface::generateErrorMessage(
80-
op, "rank mismatch"));
83+
generateErrorMessage(op, "rank mismatch"));
8184
}
8285

8386
// Get source offset and strides. We do not have an op to get offsets and
@@ -116,8 +119,8 @@ struct CastOpInterface
116119
builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
117120
cf::AssertOp::create(
118121
builder, loc, isSameSz,
119-
RuntimeVerifiableOpInterface::generateErrorMessage(
120-
op, "size mismatch of dim " + std::to_string(it.index())));
122+
generateErrorMessage(op, "size mismatch of dim " +
123+
std::to_string(it.index())));
121124
}
122125

123126
// Get result offset and strides.
@@ -135,8 +138,7 @@ struct CastOpInterface
135138
Value isSameOffset = arith::CmpIOp::create(
136139
builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
137140
cf::AssertOp::create(builder, loc, isSameOffset,
138-
RuntimeVerifiableOpInterface::generateErrorMessage(
139-
op, "offset mismatch"));
141+
generateErrorMessage(op, "offset mismatch"));
140142
}
141143

142144
// Check strides.
@@ -153,17 +155,19 @@ struct CastOpInterface
153155
builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
154156
cf::AssertOp::create(
155157
builder, loc, isSameStride,
156-
RuntimeVerifiableOpInterface::generateErrorMessage(
157-
op, "stride mismatch of dim " + std::to_string(it.index())));
158+
generateErrorMessage(op, "stride mismatch of dim " +
159+
std::to_string(it.index())));
158160
}
159161
}
160162
};
161163

162164
struct CopyOpInterface
163165
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
164166
CopyOp> {
165-
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
166-
Location loc) const {
167+
void
168+
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
169+
function_ref<std::string(Operation *, StringRef)>
170+
generateErrorMessage) const {
167171
auto copyOp = cast<CopyOp>(op);
168172
BaseMemRefType sourceType = copyOp.getSource().getType();
169173
BaseMemRefType targetType = copyOp.getTarget().getType();
@@ -193,9 +197,9 @@ struct CopyOpInterface
193197
Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i);
194198
Value sameDimSize = arith::CmpIOp::create(
195199
builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
196-
cf::AssertOp::create(builder, loc, sameDimSize,
197-
RuntimeVerifiableOpInterface::generateErrorMessage(
198-
op, "size of " + std::to_string(i) +
200+
cf::AssertOp::create(
201+
builder, loc, sameDimSize,
202+
generateErrorMessage(op, "size of " + std::to_string(i) +
199203
"-th source/target dim does not match"));
200204
}
201205
}
@@ -204,16 +208,17 @@ struct CopyOpInterface
204208
struct DimOpInterface
205209
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
206210
DimOp> {
207-
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
208-
Location loc) const {
211+
void
212+
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
213+
function_ref<std::string(Operation *, StringRef)>
214+
generateErrorMessage) const {
209215
auto dimOp = cast<DimOp>(op);
210216
Value rank = RankOp::create(builder, loc, dimOp.getSource());
211217
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
212218
cf::AssertOp::create(
213219
builder, loc,
214220
generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
215-
RuntimeVerifiableOpInterface::generateErrorMessage(
216-
op, "index is out of bounds"));
221+
generateErrorMessage(op, "index is out of bounds"));
217222
}
218223
};
219224

@@ -223,8 +228,10 @@ template <typename LoadStoreOp>
223228
struct LoadStoreOpInterface
224229
: public RuntimeVerifiableOpInterface::ExternalModel<
225230
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
226-
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
227-
Location loc) const {
231+
void
232+
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
233+
function_ref<std::string(Operation *, StringRef)>
234+
generateErrorMessage) const {
228235
auto loadStoreOp = cast<LoadStoreOp>(op);
229236

230237
auto memref = loadStoreOp.getMemref();
@@ -245,16 +252,17 @@ struct LoadStoreOpInterface
245252
: inBounds;
246253
}
247254
cf::AssertOp::create(builder, loc, assertCond,
248-
RuntimeVerifiableOpInterface::generateErrorMessage(
249-
op, "out-of-bounds access"));
255+
generateErrorMessage(op, "out-of-bounds access"));
250256
}
251257
};
252258

253259
struct SubViewOpInterface
254260
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
255261
SubViewOp> {
256-
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
257-
Location loc) const {
262+
void
263+
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
264+
function_ref<std::string(Operation *, StringRef)>
265+
generateErrorMessage) const {
258266
auto subView = cast<SubViewOp>(op);
259267
MemRefType sourceType = subView.getSource().getType();
260268

@@ -277,10 +285,10 @@ struct SubViewOpInterface
277285
Value dimSize = metadataOp.getSizes()[i];
278286
Value offsetInBounds =
279287
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
280-
cf::AssertOp::create(
281-
builder, loc, offsetInBounds,
282-
RuntimeVerifiableOpInterface::generateErrorMessage(
283-
op, "offset " + std::to_string(i) + " is out-of-bounds"));
288+
cf::AssertOp::create(builder, loc, offsetInBounds,
289+
generateErrorMessage(op, "offset " +
290+
std::to_string(i) +
291+
" is out-of-bounds"));
284292

285293
// Verify that slice does not run out-of-bounds.
286294
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
@@ -292,18 +300,20 @@ struct SubViewOpInterface
292300
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
293301
cf::AssertOp::create(
294302
builder, loc, lastPosInBounds,
295-
RuntimeVerifiableOpInterface::generateErrorMessage(
296-
op, "subview runs out-of-bounds along dimension " +
297-
std::to_string(i)));
303+
generateErrorMessage(op,
304+
"subview runs out-of-bounds along dimension " +
305+
std::to_string(i)));
298306
}
299307
}
300308
};
301309

302310
struct ExpandShapeOpInterface
303311
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
304312
ExpandShapeOp> {
305-
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
306-
Location loc) const {
313+
void
314+
generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc,
315+
function_ref<std::string(Operation *, StringRef)>
316+
generateErrorMessage) const {
307317
auto expandShapeOp = cast<ExpandShapeOp>(op);
308318

309319
// Verify that the expanded dim sizes are a product of the collapsed dim
@@ -333,9 +343,9 @@ struct ExpandShapeOpInterface
333343
Value isModZero = arith::CmpIOp::create(
334344
builder, loc, arith::CmpIPredicate::eq, mod,
335345
arith::ConstantIndexOp::create(builder, loc, 0));
336-
cf::AssertOp::create(builder, loc, isModZero,
337-
RuntimeVerifiableOpInterface::generateErrorMessage(
338-
op, "static result dims in reassoc group do not "
346+
cf::AssertOp::create(
347+
builder, loc, isModZero,
348+
generateErrorMessage(op, "static result dims in reassoc group do not "
339349
"divide src dim evenly"));
340350
}
341351
}

0 commit comments

Comments
 (0)