@@ -98,127 +98,179 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
98
98
return os.str ();
99
99
}
100
100
101
- template <bool isLoad, typename OpType>
102
- int32_t getL1CacheControl (OpType op) {
101
+ static int32_t getL1CacheControl (LoadCacheControl cc) {
103
102
int32_t control = 0 ;
104
- if constexpr (isLoad) {
105
- switch (*op.getCacheControl ()) {
106
- case LoadCacheControl::L1UC_L2UC_L3UC:
107
- case LoadCacheControl::L1UC_L2UC_L3C:
108
- case LoadCacheControl::L1UC_L2C_L3UC:
109
- case LoadCacheControl::L1UC_L2C_L3C:
110
- control = 1 ;
111
- break ;
112
- case LoadCacheControl::L1C_L2UC_L3UC:
113
- case LoadCacheControl::L1C_L2UC_L3C:
114
- case LoadCacheControl::L1C_L2C_L3UC:
115
- case LoadCacheControl::L1C_L2C_L3C:
116
- control = 2 ;
117
- break ;
118
- case LoadCacheControl::L1S_L2UC_L3UC:
119
- case LoadCacheControl::L1S_L2UC_L3C:
120
- case LoadCacheControl::L1S_L2C_L3UC:
121
- case LoadCacheControl::L1S_L2C_L3C:
122
- control = 3 ;
123
- break ;
124
- case LoadCacheControl::INVALIDATE_READ:
125
- control = 4 ;
126
- break ;
127
- }
128
- } else {
129
- switch (*op.getCacheControl ()) {
130
- case StoreCacheControl::L1UC_L2UC_L3UC:
131
- case StoreCacheControl::L1UC_L2UC_L3WB:
132
- case StoreCacheControl::L1UC_L2WB_L3UC:
133
- case StoreCacheControl::L1UC_L2WB_L3WB:
134
- control = 1 ;
135
- break ;
136
- case StoreCacheControl::L1WT_L2UC_L3UC:
137
- case StoreCacheControl::L1WT_L2UC_L3WB:
138
- case StoreCacheControl::L1WT_L2WB_L3UC:
139
- case StoreCacheControl::L1WT_L2WB_L3WB:
140
- control = 2 ;
141
- break ;
142
- case StoreCacheControl::L1S_L2UC_L3UC:
143
- case StoreCacheControl::L1S_L2UC_L3WB:
144
- case StoreCacheControl::L1S_L2WB_L3UC:
145
- case StoreCacheControl::L1S_L2WB_L3WB:
146
- control = 3 ;
147
- break ;
148
- case StoreCacheControl::L1WB_L2UC_L3UC:
149
- case StoreCacheControl::L1WB_L2WB_L3UC:
150
- case StoreCacheControl::L1WB_L2UC_L3WB:
151
- control = 4 ;
152
- break ;
153
- }
103
+ switch (cc) {
104
+ case LoadCacheControl::L1UC_L2UC_L3UC:
105
+ case LoadCacheControl::L1UC_L2UC_L3C:
106
+ case LoadCacheControl::L1UC_L2C_L3UC:
107
+ case LoadCacheControl::L1UC_L2C_L3C:
108
+ control = 1 ;
109
+ break ;
110
+ case LoadCacheControl::L1C_L2UC_L3UC:
111
+ case LoadCacheControl::L1C_L2UC_L3C:
112
+ case LoadCacheControl::L1C_L2C_L3UC:
113
+ case LoadCacheControl::L1C_L2C_L3C:
114
+ control = 2 ;
115
+ break ;
116
+ case LoadCacheControl::L1S_L2UC_L3UC:
117
+ case LoadCacheControl::L1S_L2UC_L3C:
118
+ case LoadCacheControl::L1S_L2C_L3UC:
119
+ case LoadCacheControl::L1S_L2C_L3C:
120
+ control = 3 ;
121
+ break ;
122
+ case LoadCacheControl::INVALIDATE_READ:
123
+ control = 4 ;
124
+ break ;
154
125
}
155
126
return control;
156
127
}
157
128
158
- template <bool isLoad, typename OpType>
159
- int32_t getL3CacheControl (OpType op) {
129
+ static int32_t getL1CacheControl (StoreCacheControl cc) {
160
130
int32_t control = 0 ;
161
- if constexpr (isLoad) {
162
- switch (*op.getCacheControl ()) {
163
- case LoadCacheControl::L1UC_L2UC_L3UC:
164
- case LoadCacheControl::L1UC_L2C_L3UC:
165
- case LoadCacheControl::L1C_L2UC_L3UC:
166
- case LoadCacheControl::L1C_L2C_L3UC:
167
- case LoadCacheControl::L1S_L2UC_L3UC:
168
- case LoadCacheControl::L1S_L2C_L3UC:
169
- control = 1 ;
170
- break ;
171
- case LoadCacheControl::L1UC_L2UC_L3C:
172
- case LoadCacheControl::L1UC_L2C_L3C:
173
- case LoadCacheControl::L1C_L2UC_L3C:
174
- case LoadCacheControl::L1C_L2C_L3C:
175
- case LoadCacheControl::L1S_L2UC_L3C:
176
- case LoadCacheControl::L1S_L2C_L3C:
177
- control = 2 ;
178
- break ;
179
- case LoadCacheControl::INVALIDATE_READ:
180
- control = 4 ;
181
- break ;
182
- }
183
- } else {
184
- switch (*op.getCacheControl ()) {
185
- case StoreCacheControl::L1UC_L2UC_L3UC:
186
- case StoreCacheControl::L1UC_L2WB_L3UC:
187
- case StoreCacheControl::L1WT_L2UC_L3UC:
188
- case StoreCacheControl::L1WT_L2WB_L3UC:
189
- case StoreCacheControl::L1S_L2UC_L3UC:
190
- case StoreCacheControl::L1S_L2WB_L3UC:
191
- case StoreCacheControl::L1WB_L2UC_L3UC:
192
- case StoreCacheControl::L1WB_L2WB_L3UC:
193
- control = 1 ;
194
- break ;
195
- case StoreCacheControl::L1UC_L2UC_L3WB:
196
- case StoreCacheControl::L1UC_L2WB_L3WB:
197
- case StoreCacheControl::L1WT_L2UC_L3WB:
198
- case StoreCacheControl::L1WT_L2WB_L3WB:
199
- case StoreCacheControl::L1S_L2UC_L3WB:
200
- case StoreCacheControl::L1S_L2WB_L3WB:
201
- case StoreCacheControl::L1WB_L2UC_L3WB:
202
- control = 2 ;
203
- break ;
204
- }
131
+ switch (cc) {
132
+ case StoreCacheControl::L1UC_L2UC_L3UC:
133
+ case StoreCacheControl::L1UC_L2UC_L3WB:
134
+ case StoreCacheControl::L1UC_L2WB_L3UC:
135
+ case StoreCacheControl::L1UC_L2WB_L3WB:
136
+ control = 1 ;
137
+ break ;
138
+ case StoreCacheControl::L1WT_L2UC_L3UC:
139
+ case StoreCacheControl::L1WT_L2UC_L3WB:
140
+ case StoreCacheControl::L1WT_L2WB_L3UC:
141
+ case StoreCacheControl::L1WT_L2WB_L3WB:
142
+ control = 2 ;
143
+ break ;
144
+ case StoreCacheControl::L1S_L2UC_L3UC:
145
+ case StoreCacheControl::L1S_L2UC_L3WB:
146
+ case StoreCacheControl::L1S_L2WB_L3UC:
147
+ case StoreCacheControl::L1S_L2WB_L3WB:
148
+ control = 3 ;
149
+ break ;
150
+ case StoreCacheControl::L1WB_L2UC_L3UC:
151
+ case StoreCacheControl::L1WB_L2WB_L3UC:
152
+ case StoreCacheControl::L1WB_L2UC_L3WB:
153
+ control = 4 ;
154
+ break ;
205
155
}
206
156
return control;
207
157
}
208
158
209
- template <bool isLoad, typename OpType>
159
+ static int32_t getL3CacheControl (LoadCacheControl cc) {
160
+ int32_t control = 0 ;
161
+ switch (cc) {
162
+ case LoadCacheControl::L1UC_L2UC_L3UC:
163
+ case LoadCacheControl::L1UC_L2C_L3UC:
164
+ case LoadCacheControl::L1C_L2UC_L3UC:
165
+ case LoadCacheControl::L1C_L2C_L3UC:
166
+ case LoadCacheControl::L1S_L2UC_L3UC:
167
+ case LoadCacheControl::L1S_L2C_L3UC:
168
+ control = 1 ;
169
+ break ;
170
+ case LoadCacheControl::L1UC_L2UC_L3C:
171
+ case LoadCacheControl::L1UC_L2C_L3C:
172
+ case LoadCacheControl::L1C_L2UC_L3C:
173
+ case LoadCacheControl::L1C_L2C_L3C:
174
+ case LoadCacheControl::L1S_L2UC_L3C:
175
+ case LoadCacheControl::L1S_L2C_L3C:
176
+ control = 2 ;
177
+ break ;
178
+ case LoadCacheControl::INVALIDATE_READ:
179
+ control = 4 ;
180
+ break ;
181
+ }
182
+ return control;
183
+ }
184
+
185
+ static int32_t getL3CacheControl (StoreCacheControl cc) {
186
+ int32_t control = 0 ;
187
+ switch (cc) {
188
+ case StoreCacheControl::L1UC_L2UC_L3UC:
189
+ case StoreCacheControl::L1UC_L2WB_L3UC:
190
+ case StoreCacheControl::L1WT_L2UC_L3UC:
191
+ case StoreCacheControl::L1WT_L2WB_L3UC:
192
+ case StoreCacheControl::L1S_L2UC_L3UC:
193
+ case StoreCacheControl::L1S_L2WB_L3UC:
194
+ case StoreCacheControl::L1WB_L2UC_L3UC:
195
+ case StoreCacheControl::L1WB_L2WB_L3UC:
196
+ control = 1 ;
197
+ break ;
198
+ case StoreCacheControl::L1UC_L2UC_L3WB:
199
+ case StoreCacheControl::L1UC_L2WB_L3WB:
200
+ case StoreCacheControl::L1WT_L2UC_L3WB:
201
+ case StoreCacheControl::L1WT_L2WB_L3WB:
202
+ case StoreCacheControl::L1S_L2UC_L3WB:
203
+ case StoreCacheControl::L1S_L2WB_L3WB:
204
+ case StoreCacheControl::L1WB_L2UC_L3WB:
205
+ control = 2 ;
206
+ break ;
207
+ }
208
+ return control;
209
+ }
210
+
211
+ static std::optional<LoadCacheControl> getCacheControl (PrefetchOp op) {
212
+ return op.getCacheControl ();
213
+ }
214
+
215
+ static std::optional<LoadCacheControl> getCacheControl (BlockLoad2dOp op) {
216
+ return op.getCacheControl ();
217
+ }
218
+
219
+ static std::optional<LoadCacheControl> getCacheControl (BlockPrefetch2dOp op) {
220
+ return op.getCacheControl ();
221
+ }
222
+
223
+ static std::optional<StoreCacheControl> getCacheControl (BlockStore2dOp op) {
224
+ return op.getCacheControl ();
225
+ }
226
+
227
+ static std::optional<LoadCacheControl> getCacheControl (LLVM::LoadOp op) {
228
+ if (op->hasAttr (" cache_control" )) {
229
+ auto attr = op->getAttrOfType <xevm::LoadCacheControlAttr>(" cache_control" );
230
+ if (!attr)
231
+ return std::nullopt ;
232
+ return std::optional<LoadCacheControl>(attr.getValue ());
233
+ }
234
+ return std::nullopt ;
235
+ }
236
+
237
+ static std::optional<StoreCacheControl> getCacheControl (LLVM::StoreOp op) {
238
+ if (op->hasAttr (" cache_control" )) {
239
+ auto attr = op->getAttrOfType <xevm::StoreCacheControlAttr>(" cache_control" );
240
+ if (!attr)
241
+ return std::nullopt ;
242
+ return std::optional<StoreCacheControl>(attr.getValue ());
243
+ }
244
+ return std::nullopt ;
245
+ }
246
+
247
+ template <typename OpType>
248
+ int32_t getL1CacheControl (OpType op) {
249
+ return getL1CacheControl (*getCacheControl (op));
250
+ }
251
+
252
+ template <typename OpType>
253
+ int32_t getL3CacheControl (OpType op) {
254
+ return getL3CacheControl (*getCacheControl (op));
255
+ }
256
+
257
+ template <typename OpType>
210
258
static std::optional<ArrayAttr>
211
259
getCacheControlMetadata (ConversionPatternRewriter &rewriter, OpType op) {
212
- if (!op. getCacheControl ())
260
+ if (!getCacheControl (op ))
213
261
return {};
214
262
constexpr int32_t decorationCacheControlArity{4 };
215
263
constexpr int32_t loadCacheControlKey{6442 };
216
264
constexpr int32_t storeCacheControlKey{6443 };
265
+ constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
266
+ std::is_same_v<OpType, BlockPrefetch2dOp> ||
267
+ std::is_same_v<OpType, LLVM::LoadOp> ||
268
+ std::is_same_v<OpType, PrefetchOp>;
217
269
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
218
270
SmallVector<int32_t , decorationCacheControlArity> decorationsL1{
219
- controlKey, 0 , getL1CacheControl<isLoad, OpType>(op), 0 };
271
+ controlKey, 0 , getL1CacheControl<OpType>(op), 0 };
220
272
SmallVector<int32_t , decorationCacheControlArity> decorationsL3{
221
- controlKey, 1 , getL3CacheControl<isLoad, OpType>(op), 0 };
273
+ controlKey, 1 , getL3CacheControl<OpType>(op), 0 };
222
274
auto arrayAttrL1 = rewriter.getI32ArrayAttr (decorationsL1);
223
275
auto arrayAttrL3 = rewriter.getI32ArrayAttr (decorationsL3);
224
276
@@ -398,7 +450,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
398
450
rewriter, fnName, LLVM::LLVMVoidType::get (rewriter.getContext ()),
399
451
argTypes, args, {}, funcAttr, op.getOperation ());
400
452
if (std::optional<ArrayAttr> optCacheControls =
401
- getCacheControlMetadata< true > (rewriter, op))
453
+ getCacheControlMetadata (rewriter, op))
402
454
call->setAttr (XeVMDialect::getCacheControlsAttrName (), *optCacheControls);
403
455
rewriter.eraseOp (op);
404
456
return success ();
@@ -557,7 +609,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
557
609
rewriter, funcName, LLVM::LLVMVoidType::get (rewriter.getContext ()),
558
610
argTypes, args, paramAttrs, funcAttr, op.getOperation ());
559
611
if (std::optional<ArrayAttr> optCacheControls =
560
- getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
612
+ getCacheControlMetadata (rewriter, op)) {
561
613
call->setAttr (XeVMDialect::getCacheControlsAttrName (), *optCacheControls);
562
614
}
563
615
if constexpr (isLoad)
@@ -568,6 +620,21 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
568
620
return success ();
569
621
}
570
622
};
623
+ template <typename OpType>
624
+ class LLVMLoadStoreToOCLPattern : public OpConversionPattern <OpType> {
625
+ using OpConversionPattern<OpType>::OpConversionPattern;
626
+ LogicalResult
627
+ matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
628
+ ConversionPatternRewriter &rewriter) const override {
629
+ if (!op->hasAttr (" cache_control" ))
630
+ return failure ();
631
+ std::optional<ArrayAttr> optCacheControls =
632
+ getCacheControlMetadata (rewriter, op);
633
+ op->setAttr (XeVMDialect::getCacheControlsAttrName (), *optCacheControls);
634
+ op->removeAttr (" cache_control" );
635
+ return success ();
636
+ }
637
+ };
571
638
572
639
// ===----------------------------------------------------------------------===//
573
640
// Pass Definition
@@ -583,10 +650,8 @@ struct ConvertXeVMToLLVMPass
583
650
584
651
void runOnOperation () override {
585
652
ConversionTarget target (getContext ());
586
- target.addLegalDialect <LLVM::LLVMDialect>();
587
- target.addIllegalDialect <XeVMDialect>();
588
653
RewritePatternSet patterns (&getContext ());
589
- populateXeVMToLLVMConversionPatterns (patterns);
654
+ populateXeVMToLLVMConversionPatterns (target, patterns);
590
655
if (failed (applyPartialConversion (getOperation (), target,
591
656
std::move (patterns))))
592
657
signalPassFailure ();
@@ -611,7 +676,7 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
611
676
void populateConvertToLLVMConversionPatterns (
612
677
ConversionTarget &target, LLVMTypeConverter &typeConverter,
613
678
RewritePatternSet &patterns) const final {
614
- populateXeVMToLLVMConversionPatterns (patterns);
679
+ populateXeVMToLLVMConversionPatterns (target, patterns);
615
680
}
616
681
};
617
682
} // namespace
@@ -620,12 +685,17 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
620
685
// Pattern Population
621
686
// ===----------------------------------------------------------------------===//
622
687
623
- void ::mlir::populateXeVMToLLVMConversionPatterns (RewritePatternSet &patterns) {
688
+ void ::mlir::populateXeVMToLLVMConversionPatterns (ConversionTarget &target,
689
+ RewritePatternSet &patterns) {
690
+ target.addDynamicallyLegalDialect <LLVM::LLVMDialect>(
691
+ [](Operation *op) { return !op->hasAttr (" cache_control" ); });
692
+ target.addIllegalDialect <XeVMDialect>();
624
693
patterns.add <LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
625
694
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
626
695
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
627
- MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
628
- patterns.getContext ());
696
+ MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
697
+ LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
698
+ LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext ());
629
699
}
630
700
631
701
void ::mlir::registerConvertXeVMToLLVMInterface (DialectRegistry ®istry) {
0 commit comments