1212#include " mlir/Dialect/Utils/StaticValueUtils.h"
1313#include " mlir/IR/DialectImplementation.h"
1414#include " llvm/ADT/TypeSwitch.h"
15+ #include " llvm/Support/MathExtras.h"
1516
1617using namespace mlir ;
1718using namespace xevm ;
@@ -32,24 +33,26 @@ template <typename Op> LogicalResult verifyMatrixInput(Op op) {
3233 return op->emitOpError (
3334 " 4th operand (base pitch) should be >= 2nd operand (base width)" );
3435
35- if (op. getElemSizeInBits () != 8 && op.getElemSizeInBits () != 16 &&
36- op. getElemSizeInBits () != 32 )
36+ uint32_t elemSize = op.getElemSizeInBits ();
37+ if (elemSize < 8 || ! llvm::isPowerOf2_32 (elemSize) || elemSize > 32 )
3738 return op->emitOpError (" expecting 'elem_size_in_bits' to be 8, 16, or 32" );
3839
3940 uint32_t tileHeight = op.getTileHeight ();
40- if (tileHeight != 1 && tileHeight != 2 && tileHeight != 4 &&
41- tileHeight != 8 && tileHeight != 16 && tileHeight != 32 )
41+ if (tileHeight > 32 || !llvm::isPowerOf2_32 (tileHeight))
4242 return op->emitOpError (" expecting tile_height to be 1, 2, 4, 8, 16, or 32" );
4343
4444 uint32_t vBlocks = op.getVBlocks ();
45- if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 && vBlocks != 8 )
45+ if (vBlocks > 8 || ! llvm::isPowerOf2_32 ( vBlocks) )
4646 return op->emitOpError (" expecting v_blocks to be 1, 2, 4, or 8" );
4747
4848 return success ();
4949}
5050
5151LogicalResult verify2DBlockLoadHWRestriction (BlockLoad2dOp op) {
5252 VectorType resTy = op.getRes ().getType ();
53+ if (!resTy.getElementType ().isIntOrFloat ())
54+ return op.emitOpError ()
55+ << " expecting result element type to be int or float" ;
5356 unsigned resElemTySize = resTy.getElementType ().getIntOrFloatBitWidth ();
5457 unsigned resSize = resTy.getNumElements () * resElemTySize;
5558 unsigned expectedSize = op.getElemSizeInBits () * op.getTileHeight () *
@@ -225,6 +228,8 @@ LogicalResult BlockLoad2dOp::verify() {
225228 return failure ();
226229
227230 VectorType resTy = getRes ().getType ();
231+ if (!resTy.getElementType ().isIntOrFloat ())
232+ return emitOpError () << " expecting result element type to be int of float" ;
228233 unsigned resElemTySize = resTy.getElementType ().getIntOrFloatBitWidth ();
229234 if (getElemSizeInBits () == 32 || getVnniTransform ()) {
230235 if (resElemTySize != 32 )
0 commit comments