Skip to content

Conversation

silee2
Copy link
Contributor

@silee2 silee2 commented Oct 2, 2025

instead of single element vectors.

@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

Changes

instead of single element vectors.


Full diff: https://github.com/llvm/llvm-project/pull/161708.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td (+6-3)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp (+11-7)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 4f7a8421c07b9..2dd612139fa2d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -190,8 +190,9 @@ def XeVM_StoreCacheControlAttr
 
 def XeVM_BlockLoadOp
     : XeVM_Op<"blockload">,
-      Results<(
-          outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>,
+      Results<(outs AnyTypeOf<
+          [XeVM_1DBlockElemType,
+           FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>]>:$res)>,
       Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
           OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
   let summary = "subgroup block load";
@@ -228,7 +229,9 @@ def XeVM_BlockLoadOp
 def XeVM_BlockStoreOp
     : XeVM_Op<"blockstore">,
       Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
-          FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val,
+          AnyTypeOf<[XeVM_1DBlockElemType,
+                     FixedVectorOfRankAndType<[1],
+                                              [XeVM_1DBlockElemType]>]>:$val,
           OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
   let summary = "subgroup block store";
   let description = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 8295492ad73a8..04e8836c00359 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() {
 template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
                                OpType, BlockLoadOp, BlockStoreOp>::value>>
 LogicalResult verify1DBlockArg(OpType op) {
-  VectorType vTy;
+  Type srcOrDstTy;
   if constexpr (std::is_same_v<OpType, BlockLoadOp>)
-    vTy = op.getResult().getType();
+    srcOrDstTy = op.getResult().getType();
   else
-    vTy = op.getVal().getType();
+    srcOrDstTy = op.getVal().getType();
+  VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
+  // scalar case is always valid
+  if (!vTy)
+    return success();
   int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
   if (elemTySize == 1) {
-    llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16};
+    llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
     if (validSizes.contains(vTy.getNumElements()))
       return success();
     else
       return op.emitOpError(
-          "vector size must be 1, 2, 4, 8 or 16 for 8-bit element type");
+          "vector size must be 2, 4, 8 or 16 for 8-bit element type");
   } else {
-    llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8};
+    llvm::SmallSet<int, 3> validSizes{2, 4, 8};
     if (validSizes.contains(vTy.getNumElements()))
       return success();
     else
       return op.emitOpError(
-          "vector size must be 1, 2, 4 or 8 for element type > 8 bits");
+          "vector size must be 2, 4 or 8 for element type > 8 bits");
   }
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 627abd0665d8c..7ef56b52f1d83 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1943,14 +1943,14 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {
 
 // -----
 llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) {
-  // expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}}
+  // expected-error@+1 {{op vector size must be 2, 4 or 8 for element type > 8 bits}}
   %0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16>
   llvm.return
 }
 
 // -----
 llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) {
-  // expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}}
+  // expected-error@+1 {{op vector size must be 2, 4, 8 or 16 for 8-bit element type}}
   xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>)
   llvm.return
 }

@silee2 silee2 requested review from Jianhui-Li and nbpatel October 2, 2025 17:37
Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@silee2 silee2 merged commit 0d758de into llvm:main Oct 7, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants