@@ -38,7 +38,7 @@ struct AssumeAlignmentOpInterface
38
38
: public RuntimeVerifiableOpInterface::ExternalModel<
39
39
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
40
40
void generateRuntimeVerification (Operation *op, OpBuilder &builder,
41
- Location loc) const {
41
+ Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage ) const {
42
42
auto assumeOp = cast<AssumeAlignmentOp>(op);
43
43
Value ptr = ExtractAlignedPointerAsIndexOp::create (builder, loc,
44
44
assumeOp.getMemref ());
@@ -49,7 +49,7 @@ struct AssumeAlignmentOpInterface
49
49
arith::CmpIOp::create (builder, loc, arith::CmpIPredicate::eq, rest,
50
50
arith::ConstantIndexOp::create (builder, loc, 0 ));
51
51
cf::AssertOp::create (builder, loc, isAligned,
52
- RuntimeVerifiableOpInterface:: generateErrorMessage (
52
+ generateErrorMessage (
53
53
op, " memref is not aligned to " +
54
54
std::to_string (assumeOp.getAlignment ())));
55
55
}
@@ -59,7 +59,7 @@ struct CastOpInterface
59
59
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
60
60
CastOp> {
61
61
void generateRuntimeVerification (Operation *op, OpBuilder &builder,
62
- Location loc) const {
62
+ Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage ) const {
63
63
auto castOp = cast<CastOp>(op);
64
64
auto srcType = cast<BaseMemRefType>(castOp.getSource ().getType ());
65
65
@@ -76,8 +76,7 @@ struct CastOpInterface
76
76
Value isSameRank = arith::CmpIOp::create (
77
77
builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
78
78
cf::AssertOp::create (builder, loc, isSameRank,
79
- RuntimeVerifiableOpInterface::generateErrorMessage (
80
- op, " rank mismatch" ));
79
+ generateErrorMessage (op, " rank mismatch" ));
81
80
}
82
81
83
82
// Get source offset and strides. We do not have an op to get offsets and
@@ -116,7 +115,7 @@ struct CastOpInterface
116
115
builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
117
116
cf::AssertOp::create (
118
117
builder, loc, isSameSz,
119
- RuntimeVerifiableOpInterface:: generateErrorMessage (
118
+ generateErrorMessage (
120
119
op, " size mismatch of dim " + std::to_string (it.index ())));
121
120
}
122
121
@@ -135,8 +134,7 @@ struct CastOpInterface
135
134
Value isSameOffset = arith::CmpIOp::create (
136
135
builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
137
136
cf::AssertOp::create (builder, loc, isSameOffset,
138
- RuntimeVerifiableOpInterface::generateErrorMessage (
139
- op, " offset mismatch" ));
137
+ generateErrorMessage (op, " offset mismatch" ));
140
138
}
141
139
142
140
// Check strides.
@@ -153,7 +151,7 @@ struct CastOpInterface
153
151
builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
154
152
cf::AssertOp::create (
155
153
builder, loc, isSameStride,
156
- RuntimeVerifiableOpInterface:: generateErrorMessage (
154
+ generateErrorMessage (
157
155
op, " stride mismatch of dim " + std::to_string (it.index ())));
158
156
}
159
157
}
@@ -163,7 +161,7 @@ struct CopyOpInterface
163
161
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
164
162
CopyOp> {
165
163
void generateRuntimeVerification (Operation *op, OpBuilder &builder,
166
- Location loc) const {
164
+ Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage ) const {
167
165
auto copyOp = cast<CopyOp>(op);
168
166
BaseMemRefType sourceType = copyOp.getSource ().getType ();
169
167
BaseMemRefType targetType = copyOp.getTarget ().getType ();
@@ -194,7 +192,7 @@ struct CopyOpInterface
194
192
Value sameDimSize = arith::CmpIOp::create (
195
193
builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
196
194
cf::AssertOp::create (builder, loc, sameDimSize,
197
- RuntimeVerifiableOpInterface:: generateErrorMessage (
195
+ generateErrorMessage (
198
196
op, " size of " + std::to_string (i) +
199
197
" -th source/target dim does not match" ));
200
198
}
@@ -205,15 +203,14 @@ struct DimOpInterface
205
203
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
206
204
DimOp> {
207
205
void generateRuntimeVerification (Operation *op, OpBuilder &builder,
208
- Location loc) const {
206
+ Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage ) const {
209
207
auto dimOp = cast<DimOp>(op);
210
208
Value rank = RankOp::create (builder, loc, dimOp.getSource ());
211
209
Value zero = arith::ConstantIndexOp::create (builder, loc, 0 );
212
210
cf::AssertOp::create (
213
211
builder, loc,
214
212
generateInBoundsCheck (builder, loc, dimOp.getIndex (), zero, rank),
215
- RuntimeVerifiableOpInterface::generateErrorMessage (
216
- op, " index is out of bounds" ));
213
+ generateErrorMessage (op, " index is out of bounds" ));
217
214
}
218
215
};
219
216
@@ -224,7 +221,7 @@ struct LoadStoreOpInterface
224
221
: public RuntimeVerifiableOpInterface::ExternalModel<
225
222
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
226
223
void generateRuntimeVerification (Operation *op, OpBuilder &builder,
227
- Location loc) const {
224
+ Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage ) const {
228
225
auto loadStoreOp = cast<LoadStoreOp>(op);
229
226
230
227
auto memref = loadStoreOp.getMemref ();
@@ -245,16 +242,15 @@ struct LoadStoreOpInterface
245
242
: inBounds;
246
243
}
247
244
cf::AssertOp::create (builder, loc, assertCond,
248
- RuntimeVerifiableOpInterface::generateErrorMessage (
249
- op, " out-of-bounds access" ));
245
+ generateErrorMessage (op, " out-of-bounds access" ));
250
246
}
251
247
};
252
248
253
249
struct SubViewOpInterface
254
250
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
255
251
SubViewOp> {
256
252
void generateRuntimeVerification (Operation *op, OpBuilder &builder,
257
- Location loc) const {
253
+ Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage ) const {
258
254
auto subView = cast<SubViewOp>(op);
259
255
MemRefType sourceType = subView.getSource ().getType ();
260
256
@@ -279,7 +275,7 @@ struct SubViewOpInterface
279
275
generateInBoundsCheck (builder, loc, offset, zero, dimSize);
280
276
cf::AssertOp::create (
281
277
builder, loc, offsetInBounds,
282
- RuntimeVerifiableOpInterface:: generateErrorMessage (
278
+ generateErrorMessage (
283
279
op, " offset " + std::to_string (i) + " is out-of-bounds" ));
284
280
285
281
// Verify that slice does not run out-of-bounds.
@@ -292,7 +288,7 @@ struct SubViewOpInterface
292
288
generateInBoundsCheck (builder, loc, lastPos, zero, dimSize);
293
289
cf::AssertOp::create (
294
290
builder, loc, lastPosInBounds,
295
- RuntimeVerifiableOpInterface:: generateErrorMessage (
291
+ generateErrorMessage (
296
292
op, " subview runs out-of-bounds along dimension " +
297
293
std::to_string (i)));
298
294
}
@@ -303,7 +299,7 @@ struct ExpandShapeOpInterface
303
299
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
304
300
ExpandShapeOp> {
305
301
void generateRuntimeVerification (Operation *op, OpBuilder &builder,
306
- Location loc) const {
302
+ Location loc, function_ref<std::string(Operation *, StringRef)> generateErrorMessage ) const {
307
303
auto expandShapeOp = cast<ExpandShapeOp>(op);
308
304
309
305
// Verify that the expanded dim sizes are a product of the collapsed dim
@@ -334,7 +330,7 @@ struct ExpandShapeOpInterface
334
330
builder, loc, arith::CmpIPredicate::eq, mod,
335
331
arith::ConstantIndexOp::create (builder, loc, 0 ));
336
332
cf::AssertOp::create (builder, loc, isModZero,
337
- RuntimeVerifiableOpInterface:: generateErrorMessage (
333
+ generateErrorMessage (
338
334
op, " static result dims in reassoc group do not "
339
335
" divide src dim evenly" ));
340
336
}
0 commit comments