99
1010#include " mlir/Dialect/GPU/IR/CompilationInterfaces.h"
1111#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
12+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
1213#include " mlir/IR/DialectImplementation.h"
1314#include " llvm/ADT/TypeSwitch.h"
1415
@@ -18,9 +19,260 @@ using namespace xevm;
1819#include " gc/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
1920#include " gc/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
2021
21- // TODO
22- LogicalResult BlockLoad2dOp::verify () { return success (); }
23- LogicalResult BlockStore2dOp::verify () { return success (); }
22+ namespace {
23+ constexpr uint32_t subgroupSize = 16 ;
24+
25+ template <typename Op> LogicalResult verifyMatrixInput (Op op) {
26+ static_assert (llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp>::value,
27+ " Unexpected template parameter" );
28+
29+ std::optional<int64_t > width = getConstantIntValue (op.getBaseWidth ());
30+ std::optional<int64_t > pitch = getConstantIntValue (op.getBasePitch ());
31+ if (pitch && width && *pitch < *width)
32+ return op->emitOpError (
33+ " 4th operand (base pitch) should be >= 2nd operand (base width)" );
34+
35+ if (op.getElemSizeInBits () != 8 && op.getElemSizeInBits () != 16 &&
36+ op.getElemSizeInBits () != 32 )
37+ return op->emitOpError (" expecting 'elem_size_in_bits' to be 8, 16, or 32" );
38+
39+ uint32_t tileHeight = op.getTileHeight ();
40+ if (tileHeight != 1 && tileHeight != 2 && tileHeight != 4 &&
41+ tileHeight != 8 && tileHeight != 16 && tileHeight != 32 )
42+ return op->emitOpError (" expecting tile_height to be 1, 2, 4, 8, 16, or 32" );
43+
44+ uint32_t vBlocks = op.getVBlocks ();
45+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 && vBlocks != 8 )
46+ return op->emitOpError (" expecting v_blocks to be 1, 2, 4, or 8" );
47+
48+ return success ();
49+ }
50+
51+ LogicalResult verify2DBlockLoadHWRestriction (BlockLoad2dOp op) {
52+ VectorType resTy = op.getRes ().getType ();
53+ unsigned resElemTySize = resTy.getElementType ().getIntOrFloatBitWidth ();
54+ unsigned resSize = resTy.getNumElements () * resElemTySize;
55+ unsigned expectedSize = op.getElemSizeInBits () * op.getTileHeight () *
56+ op.getTileWidth () * op.getVBlocks () / subgroupSize;
57+ if (resSize != expectedSize)
58+ return op.emitOpError () << " result size of " << resSize
59+ << " bits does not match the expected size of "
60+ << expectedSize << " bits" ;
61+
62+ if (op.getTranspose () && op.getVnniTransform ())
63+ return op.emitOpError (
64+ " transpose and vnni_transform are mutually exclusive" );
65+
66+ if (!op.getTranspose () && !op.getVnniTransform ()) {
67+ uint32_t tileHeight = op.getTileHeight ();
68+ if (tileHeight < 1 || tileHeight > 32 )
69+ return op.emitOpError (" expecting tile_height to be between 1 and 32" );
70+
71+ uint32_t tileWidth = op.getTileWidth ();
72+ uint32_t vBlocks = op.getVBlocks ();
73+ switch (op.getElemSizeInBits ()) {
74+ case 8 :
75+ if (tileWidth < 4 || tileWidth > 64 )
76+ return op.emitOpError (" expecting tile_width to be between 4 and 64" );
77+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 )
78+ return op.emitOpError (" expecting v_blocks to be 1, 2, or 4" );
79+ if (tileWidth * vBlocks > 64 )
80+ return op.emitOpError (
81+ " tile_width * v_blocks should be less than or equal "
82+ " to 64 for 8 bit elements" );
83+ break ;
84+ case 16 :
85+ if (tileWidth < 2 || tileWidth > 32 )
86+ return op.emitOpError (" expecting tile_width to be between 2 and 32" );
87+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 )
88+ return op.emitOpError (" expecting v_blocks to be 1, 2, or 4" );
89+ if (tileWidth * vBlocks > 32 )
90+ return op.emitOpError (
91+ " tile_width * v_blocks should be less than or equal "
92+ " to 32 for 16 bit elements" );
93+ break ;
94+ case 32 :
95+ if (tileWidth < 1 || tileWidth > 16 )
96+ return op.emitOpError (" expecting tile_width to be between 1 and 16" );
97+ if (vBlocks != 1 && vBlocks != 2 )
98+ return op.emitOpError (" expecting v_blocks to be 1 or 2" );
99+ if (tileWidth * vBlocks > 16 )
100+ return op.emitOpError (
101+ " tile_width * v_blocks should be less than or equal "
102+ " to 16 for 32 bit elements" );
103+ break ;
104+ case 64 :
105+ if (tileWidth < 1 || tileWidth > 8 )
106+ return op.emitOpError (" expecting tile_width to be between 1 and 8" );
107+ if (vBlocks != 1 )
108+ return op.emitOpError (" expecting v_blocks to be 1" );
109+ break ;
110+ default :
111+ return op.emitOpError (
112+ " expecting elem_size_in_bits to be 8, 16, 32, or 64" );
113+ }
114+
115+ return success ();
116+ }
117+
118+ if (op.getTranspose ()) {
119+ assert (!op.getVnniTransform () &&
120+ " Expecting vnni_transform should be false" );
121+
122+ uint32_t vBlocks = op.getVBlocks ();
123+ if (vBlocks != 1 )
124+ return op.emitOpError (" expecting v_blocks to be 1" );
125+
126+ uint32_t tileHeight = op.getTileHeight ();
127+ uint32_t tileWidth = op.getTileWidth ();
128+ switch (op.getElemSizeInBits ()) {
129+ case 32 :
130+ if (tileHeight < 1 || tileHeight > 32 )
131+ return op.emitOpError (" expecting tile_height to be between 1 and 32" );
132+ if (tileWidth < 1 || tileWidth > 8 )
133+ return op.emitOpError (" expecting tile_width to be between 1 and 8" );
134+ break ;
135+ case 64 :
136+ if (tileHeight != 8 )
137+ return op.emitOpError (
138+ " expecting tile_height to be 8 for 64 bit elements" );
139+ if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4 )
140+ return op.emitOpError (" expecting tile_width to be 1, 2, or 4" );
141+ break ;
142+ default :
143+ return op.emitOpError (" transpose is only supported for 32 and 64 bit "
144+ " elements" );
145+ }
146+
147+ return success ();
148+ }
149+
150+ assert (op.getVnniTransform () && !op.getTranspose () &&
151+ " Expecting vnni_transform should be true and transpose should be "
152+ " false" );
153+
154+ uint32_t vBlocks = op.getVBlocks ();
155+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 )
156+ return op.emitOpError (" expecting v_blocks to be 1, 2, or 4" );
157+
158+ uint32_t tileHeight = op.getTileHeight ();
159+ uint32_t tileWidth = op.getTileWidth ();
160+ switch (op.getElemSizeInBits ()) {
161+ case 8 :
162+ if (tileHeight < 4 || tileHeight > 32 )
163+ return op.emitOpError (" expecting tile_height to be between 4 and 32" );
164+ if (tileWidth < 4 || tileWidth > 16 )
165+ return op.emitOpError (" expecting tile_width to be between 4 and 16" );
166+ break ;
167+ case 16 :
168+ if (tileHeight < 2 || tileHeight > 32 )
169+ return op.emitOpError (" expecting tile_height to be between 2 and 32" );
170+ if (tileWidth < 2 || tileWidth > 16 )
171+ return op.emitOpError (" expecting tile_width to be between 2 and 16" );
172+ if (tileWidth * vBlocks > 32 )
173+ return op.emitOpError (
174+ " tile_width * v_blocks should be less than or equal "
175+ " to 32 for 16 bit elements" );
176+ break ;
177+ default :
178+ return op.emitOpError (" vnni_transform is only supported for 8 and 16 bit "
179+ " elements" );
180+ }
181+
182+ return success ();
183+ }
184+
185+ static LogicalResult verify2DBlockStoreHWRestriction (BlockStore2dOp op) {
186+ uint32_t tileHeight = op.getTileHeight ();
187+ if (tileHeight < 1 || tileHeight > 8 )
188+ return op.emitOpError (" expecting tile_height to be between 1 and 8" );
189+
190+ uint32_t tileWidth = op.getTileWidth ();
191+ switch (op.getElemSizeInBits ()) {
192+ case 8 :
193+ if (tileWidth < 4 || tileWidth > 64 )
194+ return op.emitOpError (" expecting tile_width to be between 4 and 64" );
195+ break ;
196+ case 16 :
197+ if (tileWidth < 2 || tileWidth > 32 )
198+ return op.emitOpError (" expecting tile_width to be between 2 and 32" );
199+ break ;
200+ case 32 :
201+ if (tileWidth < 1 || tileWidth > 16 )
202+ return op.emitOpError (" expecting tile_width to be between 1 and 16" );
203+ break ;
204+ case 64 :
205+ if (tileWidth < 1 || tileWidth > 8 )
206+ return op.emitOpError (" expecting tile_width to be between 1 and 8" );
207+ break ;
208+ default :
209+ return op.emitOpError (" expecting elem_size_in_bits to be 8, 16, 32, or 64" );
210+ }
211+
212+ uint32_t vBlocks = op.getVBlocks ();
213+ if (vBlocks != 1 )
214+ return op.emitOpError (" expecting v_blocks to be 1" );
215+ return success ();
216+ }
217+
218+ } // namespace
219+
220+ LogicalResult BlockLoad2dOp::verify () {
221+ if (verify2DBlockLoadHWRestriction (*this ).failed ())
222+ return failure ();
223+
224+ if (verifyMatrixInput (*this ).failed ())
225+ return failure ();
226+
227+ VectorType resTy = getRes ().getType ();
228+ unsigned resElemTySize = resTy.getElementType ().getIntOrFloatBitWidth ();
229+ if (getElemSizeInBits () == 32 || getVnniTransform ()) {
230+ if (resElemTySize != 32 )
231+ return emitOpError () << " expecting result element type to be 32 bits" ;
232+ }
233+
234+ uint32_t tileWidth = getTileWidth ();
235+ if (getVnniTransform ()) {
236+ if (tileWidth != 16 )
237+ return emitOpError (
238+ " tile_width when vnni_transform is true should be equal "
239+ " to subgroup size (16 elements)" );
240+ return success ();
241+ }
242+
243+ return success ();
244+ }
245+
246+ LogicalResult BlockStore2dOp::verify () {
247+ if (verify2DBlockStoreHWRestriction (*this ).failed ())
248+ return failure ();
249+
250+ if (verifyMatrixInput (*this ).failed ())
251+ return failure ();
252+
253+ uint32_t tileWidth = getTileWidth ();
254+ switch (getElemSizeInBits ()) {
255+ case 8 :
256+ if (tileWidth != 16 && tileWidth != 32 )
257+ return emitOpError (" tile_width for 8 bit elements should be equal to "
258+ " 16 or 32" );
259+ break ;
260+ case 16 :
261+ if (tileWidth != 16 )
262+ return emitOpError (" tile_width for 16 bit elements should be equal "
263+ " to 16" );
264+ break ;
265+ case 32 :
266+ if (tileWidth != 16 )
267+ return emitOpError (" tile_width for 32 bit elements should be equal "
268+ " to 16" );
269+ break ;
270+ default :
271+ llvm_unreachable (" unexpected element size" );
272+ }
273+
274+ return success ();
275+ }
24276
25277void XeVMDialect::initialize () {
26278 // NOLINTBEGIN
0 commit comments