@@ -37,8 +37,10 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
37
37
struct AssumeAlignmentOpInterface
38
38
: public RuntimeVerifiableOpInterface::ExternalModel<
39
39
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
40
- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
41
- Location loc) const {
40
+ void
41
+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
42
+ function_ref<std::string(Operation *, StringRef)>
43
+ generateErrorMessage) const {
42
44
auto assumeOp = cast<AssumeAlignmentOp>(op);
43
45
Value ptr = ExtractAlignedPointerAsIndexOp::create (builder, loc,
44
46
assumeOp.getMemref ());
@@ -48,18 +50,20 @@ struct AssumeAlignmentOpInterface
48
50
Value isAligned =
49
51
arith::CmpIOp::create (builder, loc, arith::CmpIPredicate::eq, rest,
50
52
arith::ConstantIndexOp::create (builder, loc, 0 ));
51
- cf::AssertOp::create (builder, loc, isAligned,
52
- RuntimeVerifiableOpInterface::generateErrorMessage (
53
- op, " memref is not aligned to " +
53
+ cf::AssertOp::create (
54
+ builder, loc, isAligned,
55
+ generateErrorMessage ( op, " memref is not aligned to " +
54
56
std::to_string (assumeOp.getAlignment ())));
55
57
}
56
58
};
57
59
58
60
struct CastOpInterface
59
61
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
60
62
CastOp> {
61
- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
62
- Location loc) const {
63
+ void
64
+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
65
+ function_ref<std::string(Operation *, StringRef)>
66
+ generateErrorMessage) const {
63
67
auto castOp = cast<CastOp>(op);
64
68
auto srcType = cast<BaseMemRefType>(castOp.getSource ().getType ());
65
69
@@ -76,8 +80,7 @@ struct CastOpInterface
76
80
Value isSameRank = arith::CmpIOp::create (
77
81
builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank);
78
82
cf::AssertOp::create (builder, loc, isSameRank,
79
- RuntimeVerifiableOpInterface::generateErrorMessage (
80
- op, " rank mismatch" ));
83
+ generateErrorMessage (op, " rank mismatch" ));
81
84
}
82
85
83
86
// Get source offset and strides. We do not have an op to get offsets and
@@ -116,8 +119,8 @@ struct CastOpInterface
116
119
builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
117
120
cf::AssertOp::create (
118
121
builder, loc, isSameSz,
119
- RuntimeVerifiableOpInterface:: generateErrorMessage (
120
- op, " size mismatch of dim " + std::to_string (it.index ())));
122
+ generateErrorMessage (op, " size mismatch of dim " +
123
+ std::to_string (it.index ())));
121
124
}
122
125
123
126
// Get result offset and strides.
@@ -135,8 +138,7 @@ struct CastOpInterface
135
138
Value isSameOffset = arith::CmpIOp::create (
136
139
builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
137
140
cf::AssertOp::create (builder, loc, isSameOffset,
138
- RuntimeVerifiableOpInterface::generateErrorMessage (
139
- op, " offset mismatch" ));
141
+ generateErrorMessage (op, " offset mismatch" ));
140
142
}
141
143
142
144
// Check strides.
@@ -153,17 +155,19 @@ struct CastOpInterface
153
155
builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
154
156
cf::AssertOp::create (
155
157
builder, loc, isSameStride,
156
- RuntimeVerifiableOpInterface:: generateErrorMessage (
157
- op, " stride mismatch of dim " + std::to_string (it.index ())));
158
+ generateErrorMessage (op, " stride mismatch of dim " +
159
+ std::to_string (it.index ())));
158
160
}
159
161
}
160
162
};
161
163
162
164
struct CopyOpInterface
163
165
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
164
166
CopyOp> {
165
- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
166
- Location loc) const {
167
+ void
168
+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
169
+ function_ref<std::string(Operation *, StringRef)>
170
+ generateErrorMessage) const {
167
171
auto copyOp = cast<CopyOp>(op);
168
172
BaseMemRefType sourceType = copyOp.getSource ().getType ();
169
173
BaseMemRefType targetType = copyOp.getTarget ().getType ();
@@ -193,9 +197,9 @@ struct CopyOpInterface
193
197
Value targetDim = getDimSize (copyOp.getTarget (), rankedTargetType, i);
194
198
Value sameDimSize = arith::CmpIOp::create (
195
199
builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim);
196
- cf::AssertOp::create (builder, loc, sameDimSize,
197
- RuntimeVerifiableOpInterface::generateErrorMessage (
198
- op, " size of " + std::to_string (i) +
200
+ cf::AssertOp::create (
201
+ builder, loc, sameDimSize,
202
+ generateErrorMessage ( op, " size of " + std::to_string (i) +
199
203
" -th source/target dim does not match" ));
200
204
}
201
205
}
@@ -204,16 +208,17 @@ struct CopyOpInterface
204
208
struct DimOpInterface
205
209
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
206
210
DimOp> {
207
- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
208
- Location loc) const {
211
+ void
212
+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
213
+ function_ref<std::string(Operation *, StringRef)>
214
+ generateErrorMessage) const {
209
215
auto dimOp = cast<DimOp>(op);
210
216
Value rank = RankOp::create (builder, loc, dimOp.getSource ());
211
217
Value zero = arith::ConstantIndexOp::create (builder, loc, 0 );
212
218
cf::AssertOp::create (
213
219
builder, loc,
214
220
generateInBoundsCheck (builder, loc, dimOp.getIndex (), zero, rank),
215
- RuntimeVerifiableOpInterface::generateErrorMessage (
216
- op, " index is out of bounds" ));
221
+ generateErrorMessage (op, " index is out of bounds" ));
217
222
}
218
223
};
219
224
@@ -223,8 +228,10 @@ template <typename LoadStoreOp>
223
228
struct LoadStoreOpInterface
224
229
: public RuntimeVerifiableOpInterface::ExternalModel<
225
230
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
226
- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
227
- Location loc) const {
231
+ void
232
+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
233
+ function_ref<std::string(Operation *, StringRef)>
234
+ generateErrorMessage) const {
228
235
auto loadStoreOp = cast<LoadStoreOp>(op);
229
236
230
237
auto memref = loadStoreOp.getMemref ();
@@ -245,16 +252,17 @@ struct LoadStoreOpInterface
245
252
: inBounds;
246
253
}
247
254
cf::AssertOp::create (builder, loc, assertCond,
248
- RuntimeVerifiableOpInterface::generateErrorMessage (
249
- op, " out-of-bounds access" ));
255
+ generateErrorMessage (op, " out-of-bounds access" ));
250
256
}
251
257
};
252
258
253
259
struct SubViewOpInterface
254
260
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
255
261
SubViewOp> {
256
- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
257
- Location loc) const {
262
+ void
263
+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
264
+ function_ref<std::string(Operation *, StringRef)>
265
+ generateErrorMessage) const {
258
266
auto subView = cast<SubViewOp>(op);
259
267
MemRefType sourceType = subView.getSource ().getType ();
260
268
@@ -277,10 +285,10 @@ struct SubViewOpInterface
277
285
Value dimSize = metadataOp.getSizes ()[i];
278
286
Value offsetInBounds =
279
287
generateInBoundsCheck (builder, loc, offset, zero, dimSize);
280
- cf::AssertOp::create (
281
- builder, loc, offsetInBounds,
282
- RuntimeVerifiableOpInterface::generateErrorMessage (
283
- op, " offset " + std::to_string (i) + " is out-of-bounds" ));
288
+ cf::AssertOp::create (builder, loc, offsetInBounds,
289
+ generateErrorMessage (op, " offset " +
290
+ std::to_string (i) +
291
+ " is out-of-bounds" ));
284
292
285
293
// Verify that slice does not run out-of-bounds.
286
294
Value sizeMinusOne = arith::SubIOp::create (builder, loc, size, one);
@@ -292,18 +300,20 @@ struct SubViewOpInterface
292
300
generateInBoundsCheck (builder, loc, lastPos, zero, dimSize);
293
301
cf::AssertOp::create (
294
302
builder, loc, lastPosInBounds,
295
- RuntimeVerifiableOpInterface:: generateErrorMessage (
296
- op, " subview runs out-of-bounds along dimension " +
297
- std::to_string (i)));
303
+ generateErrorMessage (op,
304
+ " subview runs out-of-bounds along dimension " +
305
+ std::to_string (i)));
298
306
}
299
307
}
300
308
};
301
309
302
310
struct ExpandShapeOpInterface
303
311
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
304
312
ExpandShapeOp> {
305
- void generateRuntimeVerification (Operation *op, OpBuilder &builder,
306
- Location loc) const {
313
+ void
314
+ generateRuntimeVerification (Operation *op, OpBuilder &builder, Location loc,
315
+ function_ref<std::string(Operation *, StringRef)>
316
+ generateErrorMessage) const {
307
317
auto expandShapeOp = cast<ExpandShapeOp>(op);
308
318
309
319
// Verify that the expanded dim sizes are a product of the collapsed dim
@@ -333,9 +343,9 @@ struct ExpandShapeOpInterface
333
343
Value isModZero = arith::CmpIOp::create (
334
344
builder, loc, arith::CmpIPredicate::eq, mod,
335
345
arith::ConstantIndexOp::create (builder, loc, 0 ));
336
- cf::AssertOp::create (builder, loc, isModZero,
337
- RuntimeVerifiableOpInterface::generateErrorMessage (
338
- op, " static result dims in reassoc group do not "
346
+ cf::AssertOp::create (
347
+ builder, loc, isModZero,
348
+ generateErrorMessage ( op, " static result dims in reassoc group do not "
339
349
" divide src dim evenly" ));
340
350
}
341
351
}
0 commit comments