diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td index 4f7a8421c07b9..2dd612139fa2d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td @@ -190,8 +190,9 @@ def XeVM_StoreCacheControlAttr def XeVM_BlockLoadOp : XeVM_Op<"blockload">, - Results<( - outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>, + Results<(outs AnyTypeOf< + [XeVM_1DBlockElemType, + FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>]>:$res)>, Arguments<(ins Arg:$ptr, OptionalAttr:$cache_control)> { let summary = "subgroup block load"; @@ -228,7 +229,9 @@ def XeVM_BlockLoadOp def XeVM_BlockStoreOp : XeVM_Op<"blockstore">, Arguments<(ins Arg:$ptr, - FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val, + AnyTypeOf<[XeVM_1DBlockElemType, + FixedVectorOfRankAndType<[1], + [XeVM_1DBlockElemType]>]>:$val, OptionalAttr:$cache_control)> { let summary = "subgroup block store"; let description = [{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp index 8295492ad73a8..04e8836c00359 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp @@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() { template ::value>> LogicalResult verify1DBlockArg(OpType op) { - VectorType vTy; + Type srcOrDstTy; if constexpr (std::is_same_v) - vTy = op.getResult().getType(); + srcOrDstTy = op.getResult().getType(); else - vTy = op.getVal().getType(); + srcOrDstTy = op.getVal().getType(); + VectorType vTy = dyn_cast(srcOrDstTy); + // scalar case is always valid + if (!vTy) + return success(); int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8; if (elemTySize == 1) { - llvm::SmallSet validSizes{1, 2, 4, 8, 16}; + llvm::SmallSet validSizes{2, 4, 8, 16}; if (validSizes.contains(vTy.getNumElements())) return success(); else return op.emitOpError( - "vector size must be 1, 2, 4, 8 or 16 for 8-bit element type"); + "vector size must be 2, 4, 8 or 16 for 8-bit element type"); } else { - llvm::SmallSet validSizes{1, 2, 4, 8}; + llvm::SmallSet validSizes{2, 4, 8}; if (validSizes.contains(vTy.getNumElements())) return success(); else return op.emitOpError( - "vector size must be 1, 2, 4 or 8 for element type > 8 bits"); + "vector size must be 2, 4 or 8 for element type > 8 bits"); } } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 627abd0665d8c..7ef56b52f1d83 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1943,14 +1943,14 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) { // ----- llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) { - // expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}} + // expected-error@+1 {{op vector size must be 2, 4 or 8 for element type > 8 bits}} %0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16> llvm.return } // ----- llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) { - // expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}} + // expected-error@+1 {{op vector size must be 2, 4, 8 or 16 for 8-bit element type}} xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>) llvm.return }