Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def XeVM_StoreCacheControlAttr

def XeVM_BlockLoadOp
: XeVM_Op<"blockload">,
Results<(
outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>,
Results<(outs AnyTypeOf<
[XeVM_1DBlockElemType,
FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>]>:$res)>,
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
let summary = "subgroup block load";
Expand Down Expand Up @@ -228,7 +229,9 @@ def XeVM_BlockLoadOp
def XeVM_BlockStoreOp
: XeVM_Op<"blockstore">,
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val,
AnyTypeOf<[XeVM_1DBlockElemType,
FixedVectorOfRankAndType<[1],
[XeVM_1DBlockElemType]>]>:$val,
OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
let summary = "subgroup block store";
let description = [{
Expand Down
18 changes: 11 additions & 7 deletions mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() {
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, BlockLoadOp, BlockStoreOp>::value>>
LogicalResult verify1DBlockArg(OpType op) {
VectorType vTy;
Type srcOrDstTy;
if constexpr (std::is_same_v<OpType, BlockLoadOp>)
vTy = op.getResult().getType();
srcOrDstTy = op.getResult().getType();
else
vTy = op.getVal().getType();
srcOrDstTy = op.getVal().getType();
VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
// scalar case is always valid
if (!vTy)
return success();
int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
if (elemTySize == 1) {
llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16};
llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
"vector size must be 1, 2, 4, 8 or 16 for 8-bit element type");
"vector size must be 2, 4, 8 or 16 for 8-bit element type");
} else {
llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8};
llvm::SmallSet<int, 3> validSizes{2, 4, 8};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
"vector size must be 1, 2, 4 or 8 for element type > 8 bits");
"vector size must be 2, 4 or 8 for element type > 8 bits");
}
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1943,14 +1943,14 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {

// -----
llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) {
// expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}}
// expected-error@+1 {{op vector size must be 2, 4 or 8 for element type > 8 bits}}
%0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16>
llvm.return
}

// -----
llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) {
// expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}}
// expected-error@+1 {{op vector size must be 2, 4, 8 or 16 for 8-bit element type}}
xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>)
llvm.return
}
Expand Down