From 610c8c60549c48258e3df4cc62f20c625375f313 Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Tue, 23 Sep 2025 23:48:50 +0200 Subject: [PATCH 1/2] [mlir][VectorToLLVM] Add support for unrolling and lowering multi-dimensional vector.scatter operations --- .../Vector/Transforms/LoweringPatterns.h | 8 ++ .../mlir/Dialect/Vector/Utils/VectorUtils.h | 3 +- .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 1 + .../Dialect/Vector/Transforms/CMakeLists.txt | 1 + .../Vector/Transforms/LowerVectorScatter.cpp | 101 ++++++++++++++++++ mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 46 ++++++-- .../VectorToLLVM/vector-to-llvm.mlir | 15 ++- 7 files changed, 158 insertions(+), 17 deletions(-) create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 7bd96c8a6d1a1..83c08d4103177 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -239,6 +239,14 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +/// [UnrollScatter] +/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the +/// outermost dimension. +void populateVectorScatterLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate the pattern set with the following patterns: /// /// [UnrollGather] diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index a57aadcdcc5b0..c6a1c62afc92f 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -254,7 +254,8 @@ using UnrollVectorOpFn = function_ref; LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, - UnrollVectorOpFn unrollFn); + UnrollVectorOpFn unrollFn, + VectorType vectorTy = nullptr); /// Generic utility for unrolling values of type vector /// to N values of type vector using vector.extract. If the input diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index f958edf2746e9..1750e0430fd2b 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); + populateVectorScatterLoweringPatterns(patterns); populateVectorFromElementsUnrollPatterns(patterns); populateVectorToElementsUnrollPatterns(patterns); if (armI8MM) { diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 4e0f07af95984..cd7d4f5d1c69c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorMask.cpp LowerVectorMultiReduction.cpp LowerVectorScan.cpp + LowerVectorScatter.cpp LowerVectorShapeCast.cpp LowerVectorShuffle.cpp LowerVectorStep.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp new file mode 100644 index 0000000000000..d236c2d23b3b9 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp @@ -0,0 +1,101 @@ +//===- LowerVectorScatter.cpp - Lower 'vector.scatter' operation ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.scatter' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" + +#define DEBUG_TYPE "vector-scatter-lowering" + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// vector.scatter %base[%c0][%idx], %mask, %value : +/// memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> +/// +/// ==> +/// +/// %v0 = vector.extract %value[0] : vector<3xf32> from vector<2x3xf32> +/// %m0 = vector.extract %mask[0] : vector<3xi1> from vector<2x3xi1> +/// %i0 = vector.extract %idx[0] : vector<3xi32> from vector<2x3xi32> +/// vector.scatter %base[%c0][%i0], %m0, %v0 : +/// memref, vector<3xi32>, vector<3xi1>, vector<3xf32> +/// +/// %v1 = vector.extract %value[1] : vector<3xf32> from vector<2x3xf32> +/// %m1 = vector.extract %mask[1] : vector<3xi1> from vector<2x3xi1> +/// %i1 = vector.extract %idx[1] : vector<3xi32> from vector<2x3xi32> +/// vector.scatter %base[%c0][%i1], %m1, %v1 : +/// memref, vector<3xi32>, vector<3xi1>, vector<3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d scatter ops. +/// +/// Supports vector types with a fixed leading dimension. +struct UnrollScatter : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ScatterOp op, + PatternRewriter &rewriter) const override { + Value indexVec = op.getIndices(); + Value maskVec = op.getMask(); + Value valueVec = op.getValueToStore(); + + // Get the vector type from one of the vector operands + VectorType vectorTy = dyn_cast(indexVec.getType()); + if (!vectorTy) + return failure(); + + auto unrollScatterFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + int64_t thisIdx[1] = {index}; + + Value indexSubVec = + vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); + Value maskSubVec = + vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx); + Value valueSubVec = + vector::ExtractOp::create(rewriter, loc, valueVec, thisIdx); + + rewriter.create(loc, op.getBase(), op.getOffsets(), + indexSubVec, maskSubVec, valueSubVec, + op.getAlignmentAttr()); + + // Return a dummy value since unrollVectorOp expects a Value + return rewriter.create(loc, subTy); + }; + + return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy); + } +}; + +} // namespace + +void mlir::vector::populateVectorScatterLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 025ee9a04a1de..53ac3d50e1d21 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -431,27 +431,51 @@ vector::unrollVectorValue(TypedValue vector, } LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, - vector::UnrollVectorOpFn unrollFn) { - assert(op->getNumResults() == 1 && "expected single result"); - assert(isa(op->getResult(0).getType()) && "expected vector type"); - VectorType resultTy = cast(op->getResult(0).getType()); - if (resultTy.getRank() < 2) + vector::UnrollVectorOpFn unrollFn, + VectorType vectorTy) { + // If vector type is not provided, get it from the result + if (!vectorTy) { + if (op->getNumResults() != 1) + return rewriter.notifyMatchFailure( + op, "expected single result when vector type not provided"); + + vectorTy = dyn_cast(op->getResult(0).getType()); + if (!vectorTy) + return rewriter.notifyMatchFailure(op, "expected vector type"); + } + + if (vectorTy.getRank() < 2) return rewriter.notifyMatchFailure(op, "already 1-D"); // Unrolling doesn't take vscale into account. Pattern is disabled for // vectors with leading scalable dim(s). - if (resultTy.getScalableDims().front()) + if (vectorTy.getScalableDims().front()) return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); Location loc = op->getLoc(); - Value result = ub::PoisonOp::create(rewriter, loc, resultTy); - VectorType subTy = VectorType::Builder(resultTy).dropDim(0); - for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { + // Only create result value if the operation produces results + Value result; + if (op->getNumResults() > 0) { + result = ub::PoisonOp::create(rewriter, loc, vectorTy); + } + + VectorType subTy = VectorType::Builder(vectorTy).dropDim(0); + + for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) { Value subVector = unrollFn(rewriter, loc, subTy, i); - result = vector::InsertOp::create(rewriter, loc, subVector, result, i); + + // Only insert if we have a result to build + if (op->getNumResults() > 0) { + result = vector::InsertOp::create(rewriter, loc, subVector, result, i); + } + } + + if (op->getNumResults() > 0) { + rewriter.replaceOp(op, result); + } else { + rewriter.eraseOp(op); } - rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 2d33888854ea7..6ba37bd56083f 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1643,9 +1643,6 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref, %arg1: vector<2x // vector.scatter //===----------------------------------------------------------------------===// -// Multi-Dimensional scatters are not supported yet. Check that we do not lower -// them. - func.func @scatter_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) { %0 = arith.constant 0: index %1 = vector.constant_mask [2, 2] : vector<2x3xi1> @@ -1654,7 +1651,11 @@ func.func @scatter_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2 } // CHECK-LABEL: func @scatter_with_mask -// CHECK: vector.scatter +// CHECK: llvm.extractvalue {{.*}}[0] +// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> +// CHECK: llvm.extractvalue {{.*}}[1] +// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> +// CHECK-NOT: vector.scatter // ----- @@ -1669,7 +1670,11 @@ func.func @scatter_with_mask_scalable(%arg0: memref, %arg1: vector<2x[3]x } // CHECK-LABEL: func @scatter_with_mask_scalable -// CHECK: vector.scatter +// CHECK: llvm.extractvalue {{.*}}[0] +// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr> +// CHECK: llvm.extractvalue {{.*}}[1] +// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr> +// CHECK-NOT: vector.scatter // ----- From f72fb10ff5d13bc496b4f247a0cb830eea387ac9 Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Wed, 24 Sep 2025 12:12:32 +0200 Subject: [PATCH 2/2] Address comments --- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 4 ++++ .../Vector/Transforms/LowerVectorScatter.cpp | 6 +++--- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 15 ++++++--------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index c6a1c62afc92f..8f609acd2fdb7 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -250,6 +250,10 @@ LogicalResult isValidMaskedInputVector(ArrayRef shape, /// create sub vectors. /// 5. Insert the sub vectors back into the final vector. /// 6. Replace the original op with the new result. +/// +/// Expects the operation to be unrolled to have at most 1 result. When there's +/// no result, expects the caller to pass in the `vectorTy` to be able to get +/// the unroll factor. using UnrollVectorOpFn = function_ref; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp index d236c2d23b3b9..af17136f21da0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp @@ -65,7 +65,7 @@ struct UnrollScatter : OpRewritePattern { Value maskVec = op.getMask(); Value valueVec = op.getValueToStore(); - // Get the vector type from one of the vector operands + // Get the vector type from one of the vector operands. VectorType vectorTy = dyn_cast(indexVec.getType()); if (!vectorTy) return failure(); @@ -85,8 +85,8 @@ struct UnrollScatter : OpRewritePattern { indexSubVec, maskSubVec, valueSubVec, op.getAlignmentAttr()); - // Return a dummy value since unrollVectorOp expects a Value - return rewriter.create(loc, subTy); + // Return a dummy value since unrollVectorOp expects a Value. + return Value(); }; return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy); diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 53ac3d50e1d21..02f9382c760be 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -433,15 +433,12 @@ vector::unrollVectorValue(TypedValue vector, LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, vector::UnrollVectorOpFn unrollFn, VectorType vectorTy) { - // If vector type is not provided, get it from the result + // If vector type is not provided, get it from the result. if (!vectorTy) { - if (op->getNumResults() != 1) - return rewriter.notifyMatchFailure( - op, "expected single result when vector type not provided"); - + assert(op->getNumResults() == 1 && + "expected single result when vector type not provided"); vectorTy = dyn_cast(op->getResult(0).getType()); - if (!vectorTy) - return rewriter.notifyMatchFailure(op, "expected vector type"); + assert(vectorTy && "expected result to have vector type"); } if (vectorTy.getRank() < 2) @@ -454,7 +451,7 @@ LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, Location loc = op->getLoc(); - // Only create result value if the operation produces results + // Only create result value if the operation produces results. Value result; if (op->getNumResults() > 0) { result = ub::PoisonOp::create(rewriter, loc, vectorTy); @@ -465,7 +462,7 @@ LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) { Value subVector = unrollFn(rewriter, loc, subTy, i); - // Only insert if we have a result to build + // Only insert if we have a result to build. if (op->getNumResults() > 0) { result = vector::InsertOp::create(rewriter, loc, subVector, result, i); }