-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][VectorToLLVM] Add support for unrolling and lowering multi-dim… #160405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ensional vector.scatter operations
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: None (tyb0807) Changes…ensional vector.scatter operations This PR adds comprehensive support for lowering multi-dimensional Modify Full diff: https://github.com/llvm/llvm-project/pull/160405.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7bd96c8a6d1a1..83c08d4103177 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -239,6 +239,14 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollScatter]
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension.
+void populateVectorScatterLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [UnrollGather]
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index a57aadcdcc5b0..c6a1c62afc92f 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -254,7 +254,8 @@ using UnrollVectorOpFn =
function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
- UnrollVectorOpFn unrollFn);
+ UnrollVectorOpFn unrollFn,
+ VectorType vectorTy = nullptr);
/// Generic utility for unrolling values of type vector<NxAxBx...>
/// to N values of type vector<AxBx...> using vector.extract. If the input
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index f958edf2746e9..1750e0430fd2b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ populateVectorScatterLoweringPatterns(patterns);
populateVectorFromElementsUnrollPatterns(patterns);
populateVectorToElementsUnrollPatterns(patterns);
if (armI8MM) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 4e0f07af95984..cd7d4f5d1c69c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorMask.cpp
LowerVectorMultiReduction.cpp
LowerVectorScan.cpp
+ LowerVectorScatter.cpp
LowerVectorShapeCast.cpp
LowerVectorShuffle.cpp
LowerVectorStep.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
new file mode 100644
index 0000000000000..d236c2d23b3b9
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
@@ -0,0 +1,101 @@
+//===- LowerVectorScatter.cpp - Lower 'vector.scatter' operation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.scatter' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#define DEBUG_TYPE "vector-scatter-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// vector.scatter %base[%c0][%idx], %mask, %value :
+/// memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+///
+/// ==>
+///
+/// %v0 = vector.extract %value[0] : vector<3xf32> from vector<2x3xf32>
+/// %m0 = vector.extract %mask[0] : vector<3xi1> from vector<2x3xi1>
+/// %i0 = vector.extract %idx[0] : vector<3xi32> from vector<2x3xi32>
+/// vector.scatter %base[%c0][%i0], %m0, %v0 :
+/// memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+///
+/// %v1 = vector.extract %value[1] : vector<3xf32> from vector<2x3xf32>
+/// %m1 = vector.extract %mask[1] : vector<3xi1> from vector<2x3xi1>
+/// %i1 = vector.extract %idx[1] : vector<3xi32> from vector<2x3xi32>
+/// vector.scatter %base[%c0][%i1], %m1, %v1 :
+/// memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d scatter ops.
+///
+/// Supports vector types with a fixed leading dimension.
+struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ScatterOp op,
+ PatternRewriter &rewriter) const override {
+ Value indexVec = op.getIndices();
+ Value maskVec = op.getMask();
+ Value valueVec = op.getValueToStore();
+
+ // Get the vector type from one of the vector operands
+ VectorType vectorTy = dyn_cast<VectorType>(indexVec.getType());
+ if (!vectorTy)
+ return failure();
+
+ auto unrollScatterFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ int64_t thisIdx[1] = {index};
+
+ Value indexSubVec =
+ vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
+ Value maskSubVec =
+ vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
+ Value valueSubVec =
+ vector::ExtractOp::create(rewriter, loc, valueVec, thisIdx);
+
+ rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getOffsets(),
+ indexSubVec, maskSubVec, valueSubVec,
+ op.getAlignmentAttr());
+
+ // Return a dummy value since unrollVectorOp expects a Value
+ return rewriter.create<ub::PoisonOp>(loc, subTy);
+ };
+
+ return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy);
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorScatterLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollScatter>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 025ee9a04a1de..53ac3d50e1d21 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -431,27 +431,51 @@ vector::unrollVectorValue(TypedValue<VectorType> vector,
}
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
- vector::UnrollVectorOpFn unrollFn) {
- assert(op->getNumResults() == 1 && "expected single result");
- assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
- VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
- if (resultTy.getRank() < 2)
+ vector::UnrollVectorOpFn unrollFn,
+ VectorType vectorTy) {
+ // If vector type is not provided, get it from the result
+ if (!vectorTy) {
+ if (op->getNumResults() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "expected single result when vector type not provided");
+
+ vectorTy = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!vectorTy)
+ return rewriter.notifyMatchFailure(op, "expected vector type");
+ }
+
+ if (vectorTy.getRank() < 2)
return rewriter.notifyMatchFailure(op, "already 1-D");
// Unrolling doesn't take vscale into account. Pattern is disabled for
// vectors with leading scalable dim(s).
- if (resultTy.getScalableDims().front())
+ if (vectorTy.getScalableDims().front())
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
Location loc = op->getLoc();
- Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
- VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
- for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+ // Only create result value if the operation produces results
+ Value result;
+ if (op->getNumResults() > 0) {
+ result = ub::PoisonOp::create(rewriter, loc, vectorTy);
+ }
+
+ VectorType subTy = VectorType::Builder(vectorTy).dropDim(0);
+
+ for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
Value subVector = unrollFn(rewriter, loc, subTy, i);
- result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+
+ // Only insert if we have a result to build
+ if (op->getNumResults() > 0) {
+ result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+ }
+ }
+
+ if (op->getNumResults() > 0) {
+ rewriter.replaceOp(op, result);
+ } else {
+ rewriter.eraseOp(op);
}
- rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2d33888854ea7..6ba37bd56083f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1643,9 +1643,6 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
// vector.scatter
//===----------------------------------------------------------------------===//
-// Multi-Dimensional scatters are not supported yet. Check that we do not lower
-// them.
-
func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
%0 = arith.constant 0: index
%1 = vector.constant_mask [2, 2] : vector<2x3xi1>
@@ -1654,7 +1651,11 @@ func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
}
// CHECK-LABEL: func @scatter_with_mask
-// CHECK: vector.scatter
+// CHECK: llvm.extractvalue {{.*}}[0]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// CHECK: llvm.extractvalue {{.*}}[1]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// CHECK-NOT: vector.scatter
// -----
@@ -1669,7 +1670,11 @@ func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]x
}
// CHECK-LABEL: func @scatter_with_mask_scalable
-// CHECK: vector.scatter
+// CHECK: llvm.extractvalue {{.*}}[0]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
+// CHECK: llvm.extractvalue {{.*}}[1]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
+// CHECK-NOT: vector.scatter
// -----
|
vectorTy = dyn_cast<VectorType>(op->getResult(0).getType()); | ||
if (!vectorTy) | ||
return rewriter.notifyMatchFailure(op, "expected vector type"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we assert that, conversely, when vectorTy
is provided, there is exactly one result?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, when vectorTy
is provided, it is because there's no result so unrollVectorOp
cannot infer the vector type. Let me add this to the function docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then when vectorTy
is provided we should probably assert there's no result?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may have misspoken, but the main point is to prefer assertions in code to verbiage in documentation.
void mlir::vector::populateVectorScatterLoweringPatterns( | ||
RewritePatternSet &patterns, PatternBenefit benefit) { | ||
patterns.add<UnrollScatter>(patterns.getContext(), benefit); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not keen on the proliferation of populate*Patterns
functions. Is there any scenario where one would want scatter unrolling while not simultaneously wanting gather patterns? If not, I'd suggest to merge the two in populateVectorGatherScatterUnrollPatterns
, which also has the benefit of being more clear about it being unrolling patterns and not all lowering.
There is already an approved PR for this: #132227 I think there are some comments that need to be addressed on tests there, but the implementation is how it should be. I can just land it for you today if you like. |
Oh I didn't know about that PR. Yeah please land it, I'll take care of the refactoring once it's landed. Thank you! |
Could we please land either of these? |
Thanks! As pointed out by Kunwar, this is a duplicate of #132227, which contains a request for more testing. I don't mind which PR lands, but lets make sure that there is more tests for this. IMO, |
…ensional vector.scatter operations
This PR adds comprehensive support for lowering multi-dimensional
vector.scatter
operations to LLVM, matching the existing functionality forvector.gather
.Modify
unrollVectorOp
to support unrolling of ops that do not have results.