Skip to content

Conversation

tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Sep 23, 2025

…ensional vector.scatter operations

This PR adds comprehensive support for lowering multi-dimensional vector.scatter operations to LLVM, matching the existing functionality for vector.gather.

Modify unrollVectorOp to support unrolling of ops that do not have results.

@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: None (tyb0807)

Changes

…ensional vector.scatter operations

This PR adds comprehensive support for lowering multi-dimensional vector.scatter operations to LLVM, matching the existing functionality for vector.gather.

Modify unrollVectorOp to support unrolling of ops that do not have results.


Full diff: https://github.com/llvm/llvm-project/pull/160405.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+8)
  • (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+2-1)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp (+101)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+35-11)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+10-5)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7bd96c8a6d1a1..83c08d4103177 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -239,6 +239,14 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
 void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
                                         PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollScatter]
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension.
+void populateVectorScatterLoweringPatterns(RewritePatternSet &patterns,
+                                           PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [UnrollGather]
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index a57aadcdcc5b0..c6a1c62afc92f 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -254,7 +254,8 @@ using UnrollVectorOpFn =
     function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
 
 LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
-                             UnrollVectorOpFn unrollFn);
+                             UnrollVectorOpFn unrollFn,
+                             VectorType vectorTy = nullptr);
 
 /// Generic utility for unrolling values of type vector<NxAxBx...>
 /// to N values of type vector<AxBx...> using vector.extract. If the input
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index f958edf2746e9..1750e0430fd2b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
+    populateVectorScatterLoweringPatterns(patterns);
     populateVectorFromElementsUnrollPatterns(patterns);
     populateVectorToElementsUnrollPatterns(patterns);
     if (armI8MM) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 4e0f07af95984..cd7d4f5d1c69c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorMask.cpp
   LowerVectorMultiReduction.cpp
   LowerVectorScan.cpp
+  LowerVectorScatter.cpp
   LowerVectorShapeCast.cpp
   LowerVectorShuffle.cpp
   LowerVectorStep.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
new file mode 100644
index 0000000000000..d236c2d23b3b9
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
@@ -0,0 +1,101 @@
+//===- LowerVectorScatter.cpp - Lower 'vector.scatter' operation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.scatter' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#define DEBUG_TYPE "vector-scatter-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// vector.scatter %base[%c0][%idx], %mask, %value :
+///        memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+///
+/// ==>
+///
+/// %v0 = vector.extract %value[0] : vector<3xf32> from vector<2x3xf32>
+/// %m0 = vector.extract %mask[0] : vector<3xi1> from vector<2x3xi1>
+/// %i0 = vector.extract %idx[0] : vector<3xi32> from vector<2x3xi32>
+/// vector.scatter %base[%c0][%i0], %m0, %v0 :
+///     memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+///
+/// %v1 = vector.extract %value[1] : vector<3xf32> from vector<2x3xf32>
+/// %m1 = vector.extract %mask[1] : vector<3xi1> from vector<2x3xi1>
+/// %i1 = vector.extract %idx[1] : vector<3xi32> from vector<2x3xi32>
+/// vector.scatter %base[%c0][%i1], %m1, %v1 :
+///     memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d scatter ops.
+///
+/// Supports vector types with a fixed leading dimension.
+struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    Value indexVec = op.getIndices();
+    Value maskVec = op.getMask();
+    Value valueVec = op.getValueToStore();
+
+    // Get the vector type from one of the vector operands
+    VectorType vectorTy = dyn_cast<VectorType>(indexVec.getType());
+    if (!vectorTy)
+      return failure();
+
+    auto unrollScatterFn = [&](PatternRewriter &rewriter, Location loc,
+                               VectorType subTy, int64_t index) {
+      int64_t thisIdx[1] = {index};
+
+      Value indexSubVec =
+          vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
+      Value maskSubVec =
+          vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
+      Value valueSubVec =
+          vector::ExtractOp::create(rewriter, loc, valueVec, thisIdx);
+
+      rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getOffsets(),
+                                         indexSubVec, maskSubVec, valueSubVec,
+                                         op.getAlignmentAttr());
+
+      // Return a dummy value since unrollVectorOp expects a Value
+      return rewriter.create<ub::PoisonOp>(loc, subTy);
+    };
+
+    return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy);
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorScatterLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollScatter>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 025ee9a04a1de..53ac3d50e1d21 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -431,27 +431,51 @@ vector::unrollVectorValue(TypedValue<VectorType> vector,
 }
 
 LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
-                                     vector::UnrollVectorOpFn unrollFn) {
-  assert(op->getNumResults() == 1 && "expected single result");
-  assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
-  VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
-  if (resultTy.getRank() < 2)
+                                     vector::UnrollVectorOpFn unrollFn,
+                                     VectorType vectorTy) {
+  // If vector type is not provided, get it from the result
+  if (!vectorTy) {
+    if (op->getNumResults() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "expected single result when vector type not provided");
+
+    vectorTy = dyn_cast<VectorType>(op->getResult(0).getType());
+    if (!vectorTy)
+      return rewriter.notifyMatchFailure(op, "expected vector type");
+  }
+
+  if (vectorTy.getRank() < 2)
     return rewriter.notifyMatchFailure(op, "already 1-D");
 
   // Unrolling doesn't take vscale into account. Pattern is disabled for
   // vectors with leading scalable dim(s).
-  if (resultTy.getScalableDims().front())
+  if (vectorTy.getScalableDims().front())
     return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
 
   Location loc = op->getLoc();
-  Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
-  VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
 
-  for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+  // Only create result value if the operation produces results
+  Value result;
+  if (op->getNumResults() > 0) {
+    result = ub::PoisonOp::create(rewriter, loc, vectorTy);
+  }
+
+  VectorType subTy = VectorType::Builder(vectorTy).dropDim(0);
+
+  for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
     Value subVector = unrollFn(rewriter, loc, subTy, i);
-    result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+
+    // Only insert if we have a result to build
+    if (op->getNumResults() > 0) {
+      result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+    }
+  }
+
+  if (op->getNumResults() > 0) {
+    rewriter.replaceOp(op, result);
+  } else {
+    rewriter.eraseOp(op);
   }
 
-  rewriter.replaceOp(op, result);
   return success();
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2d33888854ea7..6ba37bd56083f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1643,9 +1643,6 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
 // vector.scatter
 //===----------------------------------------------------------------------===//
 
-// Multi-Dimensional scatters are not supported yet. Check that we do not lower
-// them.
-
 func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
   %0 = arith.constant 0: index
   %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
@@ -1654,7 +1651,11 @@ func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
 }
 
 // CHECK-LABEL: func @scatter_with_mask
-// CHECK: vector.scatter
+// CHECK: llvm.extractvalue {{.*}}[0]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// CHECK: llvm.extractvalue {{.*}}[1]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// CHECK-NOT: vector.scatter
 
 // -----
 
@@ -1669,7 +1670,11 @@ func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]x
 }
 
 // CHECK-LABEL: func @scatter_with_mask_scalable
-// CHECK: vector.scatter
+// CHECK: llvm.extractvalue {{.*}}[0]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
+// CHECK: llvm.extractvalue {{.*}}[1]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
+// CHECK-NOT: vector.scatter
 
 // -----
 

vectorTy = dyn_cast<VectorType>(op->getResult(0).getType());
if (!vectorTy)
return rewriter.notifyMatchFailure(op, "expected vector type");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that, conversely, when vectorTy is provided, there is exactly one result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, when vectorTy is provided, it is because there's no result so unrollVectorOp cannot infer the vector type. Let me add this to the function docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then when vectorTy is provided we should probably assert there's no result?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have misspoken, but the main point is to prefer assertions in code to verbiage in documentation.

Comment on lines +98 to +100
void mlir::vector::populateVectorScatterLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<UnrollScatter>(patterns.getContext(), benefit);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not keen on the proliferation of populate*Patterns functions. Is there any scenario where one would want scatter unrolling while not simultaneously wanting gather patterns? If not, I'd suggest to merge the two in populateVectorGatherScatterUnrollPatterns, which also has the benefit of being more clear about it being unrolling patterns and not all lowering.

@Groverkss
Copy link
Member

There is already an approved PR for this: #132227 I think there are some comments that need to be addressed on tests there, but the implementation is how it should be. I can just land it for you today if you like.

@tyb0807
Copy link
Contributor Author

tyb0807 commented Sep 24, 2025

Oh I didn't know about that PR. Yeah please land it, I'll take care of the refactoring once it's landed. Thank you!

@ftynse
Copy link
Member

ftynse commented Oct 1, 2025

Could we please land either of these?

@banach-space
Copy link
Contributor

Thanks!

As pointed out by Kunwar, this is a duplicate of #132227, which contains a request for more testing. I don't mind which PR lands, but lets make sure that there is more tests for this.

IMO, Vector patterns should be tested more rigorously (as opposed to relying on ConvertVectorToLLVM).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants