Skip to content

Commit 3b6bd49

Browse files
authored
[MLIR][XeVM] Add lowering for llvm load store ops with XeVM cache control (#156768)
Add lowering support for LLVM load / store ops with XeVM cache control attributes.
1 parent 53bcfe2 commit 3b6bd49

File tree

5 files changed

+235
-116
lines changed

5 files changed

+235
-116
lines changed

mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <memory>
1212

1313
namespace mlir {
14+
class ConversionTarget;
1415
class DialectRegistry;
1516
class LLVMTypeConverter;
1617
class RewritePatternSet;
@@ -19,7 +20,8 @@ class Pass;
1920
#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
2021
#include "mlir/Conversion/Passes.h.inc"
2122

22-
void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
23+
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
24+
RewritePatternSet &patterns);
2325

2426
void registerConvertXeVMToLLVMInterface(DialectRegistry &registry);
2527
} // namespace mlir

mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp

Lines changed: 181 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -98,127 +98,179 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
9898
return os.str();
9999
}
100100

101-
template <bool isLoad, typename OpType>
102-
int32_t getL1CacheControl(OpType op) {
101+
static int32_t getL1CacheControl(LoadCacheControl cc) {
103102
int32_t control = 0;
104-
if constexpr (isLoad) {
105-
switch (*op.getCacheControl()) {
106-
case LoadCacheControl::L1UC_L2UC_L3UC:
107-
case LoadCacheControl::L1UC_L2UC_L3C:
108-
case LoadCacheControl::L1UC_L2C_L3UC:
109-
case LoadCacheControl::L1UC_L2C_L3C:
110-
control = 1;
111-
break;
112-
case LoadCacheControl::L1C_L2UC_L3UC:
113-
case LoadCacheControl::L1C_L2UC_L3C:
114-
case LoadCacheControl::L1C_L2C_L3UC:
115-
case LoadCacheControl::L1C_L2C_L3C:
116-
control = 2;
117-
break;
118-
case LoadCacheControl::L1S_L2UC_L3UC:
119-
case LoadCacheControl::L1S_L2UC_L3C:
120-
case LoadCacheControl::L1S_L2C_L3UC:
121-
case LoadCacheControl::L1S_L2C_L3C:
122-
control = 3;
123-
break;
124-
case LoadCacheControl::INVALIDATE_READ:
125-
control = 4;
126-
break;
127-
}
128-
} else {
129-
switch (*op.getCacheControl()) {
130-
case StoreCacheControl::L1UC_L2UC_L3UC:
131-
case StoreCacheControl::L1UC_L2UC_L3WB:
132-
case StoreCacheControl::L1UC_L2WB_L3UC:
133-
case StoreCacheControl::L1UC_L2WB_L3WB:
134-
control = 1;
135-
break;
136-
case StoreCacheControl::L1WT_L2UC_L3UC:
137-
case StoreCacheControl::L1WT_L2UC_L3WB:
138-
case StoreCacheControl::L1WT_L2WB_L3UC:
139-
case StoreCacheControl::L1WT_L2WB_L3WB:
140-
control = 2;
141-
break;
142-
case StoreCacheControl::L1S_L2UC_L3UC:
143-
case StoreCacheControl::L1S_L2UC_L3WB:
144-
case StoreCacheControl::L1S_L2WB_L3UC:
145-
case StoreCacheControl::L1S_L2WB_L3WB:
146-
control = 3;
147-
break;
148-
case StoreCacheControl::L1WB_L2UC_L3UC:
149-
case StoreCacheControl::L1WB_L2WB_L3UC:
150-
case StoreCacheControl::L1WB_L2UC_L3WB:
151-
control = 4;
152-
break;
153-
}
103+
switch (cc) {
104+
case LoadCacheControl::L1UC_L2UC_L3UC:
105+
case LoadCacheControl::L1UC_L2UC_L3C:
106+
case LoadCacheControl::L1UC_L2C_L3UC:
107+
case LoadCacheControl::L1UC_L2C_L3C:
108+
control = 1;
109+
break;
110+
case LoadCacheControl::L1C_L2UC_L3UC:
111+
case LoadCacheControl::L1C_L2UC_L3C:
112+
case LoadCacheControl::L1C_L2C_L3UC:
113+
case LoadCacheControl::L1C_L2C_L3C:
114+
control = 2;
115+
break;
116+
case LoadCacheControl::L1S_L2UC_L3UC:
117+
case LoadCacheControl::L1S_L2UC_L3C:
118+
case LoadCacheControl::L1S_L2C_L3UC:
119+
case LoadCacheControl::L1S_L2C_L3C:
120+
control = 3;
121+
break;
122+
case LoadCacheControl::INVALIDATE_READ:
123+
control = 4;
124+
break;
154125
}
155126
return control;
156127
}
157128

158-
template <bool isLoad, typename OpType>
159-
int32_t getL3CacheControl(OpType op) {
129+
static int32_t getL1CacheControl(StoreCacheControl cc) {
160130
int32_t control = 0;
161-
if constexpr (isLoad) {
162-
switch (*op.getCacheControl()) {
163-
case LoadCacheControl::L1UC_L2UC_L3UC:
164-
case LoadCacheControl::L1UC_L2C_L3UC:
165-
case LoadCacheControl::L1C_L2UC_L3UC:
166-
case LoadCacheControl::L1C_L2C_L3UC:
167-
case LoadCacheControl::L1S_L2UC_L3UC:
168-
case LoadCacheControl::L1S_L2C_L3UC:
169-
control = 1;
170-
break;
171-
case LoadCacheControl::L1UC_L2UC_L3C:
172-
case LoadCacheControl::L1UC_L2C_L3C:
173-
case LoadCacheControl::L1C_L2UC_L3C:
174-
case LoadCacheControl::L1C_L2C_L3C:
175-
case LoadCacheControl::L1S_L2UC_L3C:
176-
case LoadCacheControl::L1S_L2C_L3C:
177-
control = 2;
178-
break;
179-
case LoadCacheControl::INVALIDATE_READ:
180-
control = 4;
181-
break;
182-
}
183-
} else {
184-
switch (*op.getCacheControl()) {
185-
case StoreCacheControl::L1UC_L2UC_L3UC:
186-
case StoreCacheControl::L1UC_L2WB_L3UC:
187-
case StoreCacheControl::L1WT_L2UC_L3UC:
188-
case StoreCacheControl::L1WT_L2WB_L3UC:
189-
case StoreCacheControl::L1S_L2UC_L3UC:
190-
case StoreCacheControl::L1S_L2WB_L3UC:
191-
case StoreCacheControl::L1WB_L2UC_L3UC:
192-
case StoreCacheControl::L1WB_L2WB_L3UC:
193-
control = 1;
194-
break;
195-
case StoreCacheControl::L1UC_L2UC_L3WB:
196-
case StoreCacheControl::L1UC_L2WB_L3WB:
197-
case StoreCacheControl::L1WT_L2UC_L3WB:
198-
case StoreCacheControl::L1WT_L2WB_L3WB:
199-
case StoreCacheControl::L1S_L2UC_L3WB:
200-
case StoreCacheControl::L1S_L2WB_L3WB:
201-
case StoreCacheControl::L1WB_L2UC_L3WB:
202-
control = 2;
203-
break;
204-
}
131+
switch (cc) {
132+
case StoreCacheControl::L1UC_L2UC_L3UC:
133+
case StoreCacheControl::L1UC_L2UC_L3WB:
134+
case StoreCacheControl::L1UC_L2WB_L3UC:
135+
case StoreCacheControl::L1UC_L2WB_L3WB:
136+
control = 1;
137+
break;
138+
case StoreCacheControl::L1WT_L2UC_L3UC:
139+
case StoreCacheControl::L1WT_L2UC_L3WB:
140+
case StoreCacheControl::L1WT_L2WB_L3UC:
141+
case StoreCacheControl::L1WT_L2WB_L3WB:
142+
control = 2;
143+
break;
144+
case StoreCacheControl::L1S_L2UC_L3UC:
145+
case StoreCacheControl::L1S_L2UC_L3WB:
146+
case StoreCacheControl::L1S_L2WB_L3UC:
147+
case StoreCacheControl::L1S_L2WB_L3WB:
148+
control = 3;
149+
break;
150+
case StoreCacheControl::L1WB_L2UC_L3UC:
151+
case StoreCacheControl::L1WB_L2WB_L3UC:
152+
case StoreCacheControl::L1WB_L2UC_L3WB:
153+
control = 4;
154+
break;
205155
}
206156
return control;
207157
}
208158

209-
template <bool isLoad, typename OpType>
159+
static int32_t getL3CacheControl(LoadCacheControl cc) {
160+
int32_t control = 0;
161+
switch (cc) {
162+
case LoadCacheControl::L1UC_L2UC_L3UC:
163+
case LoadCacheControl::L1UC_L2C_L3UC:
164+
case LoadCacheControl::L1C_L2UC_L3UC:
165+
case LoadCacheControl::L1C_L2C_L3UC:
166+
case LoadCacheControl::L1S_L2UC_L3UC:
167+
case LoadCacheControl::L1S_L2C_L3UC:
168+
control = 1;
169+
break;
170+
case LoadCacheControl::L1UC_L2UC_L3C:
171+
case LoadCacheControl::L1UC_L2C_L3C:
172+
case LoadCacheControl::L1C_L2UC_L3C:
173+
case LoadCacheControl::L1C_L2C_L3C:
174+
case LoadCacheControl::L1S_L2UC_L3C:
175+
case LoadCacheControl::L1S_L2C_L3C:
176+
control = 2;
177+
break;
178+
case LoadCacheControl::INVALIDATE_READ:
179+
control = 4;
180+
break;
181+
}
182+
return control;
183+
}
184+
185+
static int32_t getL3CacheControl(StoreCacheControl cc) {
186+
int32_t control = 0;
187+
switch (cc) {
188+
case StoreCacheControl::L1UC_L2UC_L3UC:
189+
case StoreCacheControl::L1UC_L2WB_L3UC:
190+
case StoreCacheControl::L1WT_L2UC_L3UC:
191+
case StoreCacheControl::L1WT_L2WB_L3UC:
192+
case StoreCacheControl::L1S_L2UC_L3UC:
193+
case StoreCacheControl::L1S_L2WB_L3UC:
194+
case StoreCacheControl::L1WB_L2UC_L3UC:
195+
case StoreCacheControl::L1WB_L2WB_L3UC:
196+
control = 1;
197+
break;
198+
case StoreCacheControl::L1UC_L2UC_L3WB:
199+
case StoreCacheControl::L1UC_L2WB_L3WB:
200+
case StoreCacheControl::L1WT_L2UC_L3WB:
201+
case StoreCacheControl::L1WT_L2WB_L3WB:
202+
case StoreCacheControl::L1S_L2UC_L3WB:
203+
case StoreCacheControl::L1S_L2WB_L3WB:
204+
case StoreCacheControl::L1WB_L2UC_L3WB:
205+
control = 2;
206+
break;
207+
}
208+
return control;
209+
}
210+
211+
static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
212+
return op.getCacheControl();
213+
}
214+
215+
static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
216+
return op.getCacheControl();
217+
}
218+
219+
static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
220+
return op.getCacheControl();
221+
}
222+
223+
static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
224+
return op.getCacheControl();
225+
}
226+
227+
static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
228+
if (op->hasAttr("cache_control")) {
229+
auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
230+
if (!attr)
231+
return std::nullopt;
232+
return std::optional<LoadCacheControl>(attr.getValue());
233+
}
234+
return std::nullopt;
235+
}
236+
237+
static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
238+
if (op->hasAttr("cache_control")) {
239+
auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
240+
if (!attr)
241+
return std::nullopt;
242+
return std::optional<StoreCacheControl>(attr.getValue());
243+
}
244+
return std::nullopt;
245+
}
246+
247+
template <typename OpType>
248+
int32_t getL1CacheControl(OpType op) {
249+
return getL1CacheControl(*getCacheControl(op));
250+
}
251+
252+
template <typename OpType>
253+
int32_t getL3CacheControl(OpType op) {
254+
return getL3CacheControl(*getCacheControl(op));
255+
}
256+
257+
template <typename OpType>
210258
static std::optional<ArrayAttr>
211259
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
212-
if (!op.getCacheControl())
260+
if (!getCacheControl(op))
213261
return {};
214262
constexpr int32_t decorationCacheControlArity{4};
215263
constexpr int32_t loadCacheControlKey{6442};
216264
constexpr int32_t storeCacheControlKey{6443};
265+
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
266+
std::is_same_v<OpType, BlockPrefetch2dOp> ||
267+
std::is_same_v<OpType, LLVM::LoadOp> ||
268+
std::is_same_v<OpType, PrefetchOp>;
217269
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
218270
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
219-
controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
271+
controlKey, 0, getL1CacheControl<OpType>(op), 0};
220272
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
221-
controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
273+
controlKey, 1, getL3CacheControl<OpType>(op), 0};
222274
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
223275
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
224276

@@ -398,7 +450,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
398450
rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
399451
argTypes, args, {}, funcAttr, op.getOperation());
400452
if (std::optional<ArrayAttr> optCacheControls =
401-
getCacheControlMetadata<true>(rewriter, op))
453+
getCacheControlMetadata(rewriter, op))
402454
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
403455
rewriter.eraseOp(op);
404456
return success();
@@ -557,7 +609,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
557609
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
558610
argTypes, args, paramAttrs, funcAttr, op.getOperation());
559611
if (std::optional<ArrayAttr> optCacheControls =
560-
getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
612+
getCacheControlMetadata(rewriter, op)) {
561613
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
562614
}
563615
if constexpr (isLoad)
@@ -568,6 +620,21 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
568620
return success();
569621
}
570622
};
623+
template <typename OpType>
624+
class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
625+
using OpConversionPattern<OpType>::OpConversionPattern;
626+
LogicalResult
627+
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
628+
ConversionPatternRewriter &rewriter) const override {
629+
if (!op->hasAttr("cache_control"))
630+
return failure();
631+
std::optional<ArrayAttr> optCacheControls =
632+
getCacheControlMetadata(rewriter, op);
633+
op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
634+
op->removeAttr("cache_control");
635+
return success();
636+
}
637+
};
571638

572639
//===----------------------------------------------------------------------===//
573640
// Pass Definition
@@ -583,10 +650,8 @@ struct ConvertXeVMToLLVMPass
583650

584651
void runOnOperation() override {
585652
ConversionTarget target(getContext());
586-
target.addLegalDialect<LLVM::LLVMDialect>();
587-
target.addIllegalDialect<XeVMDialect>();
588653
RewritePatternSet patterns(&getContext());
589-
populateXeVMToLLVMConversionPatterns(patterns);
654+
populateXeVMToLLVMConversionPatterns(target, patterns);
590655
if (failed(applyPartialConversion(getOperation(), target,
591656
std::move(patterns))))
592657
signalPassFailure();
@@ -611,7 +676,7 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
611676
void populateConvertToLLVMConversionPatterns(
612677
ConversionTarget &target, LLVMTypeConverter &typeConverter,
613678
RewritePatternSet &patterns) const final {
614-
populateXeVMToLLVMConversionPatterns(patterns);
679+
populateXeVMToLLVMConversionPatterns(target, patterns);
615680
}
616681
};
617682
} // namespace
@@ -620,12 +685,17 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
620685
// Pattern Population
621686
//===----------------------------------------------------------------------===//
622687

623-
void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
688+
void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
689+
RewritePatternSet &patterns) {
690+
target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
691+
[](Operation *op) { return !op->hasAttr("cache_control"); });
692+
target.addIllegalDialect<XeVMDialect>();
624693
patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
625694
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
626695
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
627-
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
628-
patterns.getContext());
696+
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
697+
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
698+
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext());
629699
}
630700

631701
void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {

mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,6 @@ class XeVMDialectLLVMIRTranslationInterface
5555
return handleDecorationCacheControl(instructions.front(),
5656
cacheControlsArray.getValue());
5757
}
58-
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
59-
if (!func)
60-
return failure();
61-
6258
return success();
6359
}
6460

0 commit comments

Comments
 (0)