Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populate the pattern set with the following patterns:
///
/// [UnrollScatter]
/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
/// outermost dimension.
void populateVectorScatterLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Populate the pattern set with the following patterns:
///
/// [UnrollGather]
Expand Down
7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,16 @@ LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
/// create sub vectors.
/// 5. Insert the sub vectors back into the final vector.
/// 6. Replace the original op with the new result.
///
/// Expects the operation to be unrolled to have at most 1 result. When there's
/// no result, expects the caller to pass in the `vectorTy` to be able to get
/// the unroll factor.
using UnrollVectorOpFn =
function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;

LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
UnrollVectorOpFn unrollFn);
UnrollVectorOpFn unrollFn,
VectorType vectorTy = nullptr);

/// Generic utility for unrolling values of type vector<NxAxBx...>
/// to N values of type vector<AxBx...> using vector.extract. If the input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
populateVectorScatterLoweringPatterns(patterns);
populateVectorFromElementsUnrollPatterns(patterns);
populateVectorToElementsUnrollPatterns(patterns);
if (armI8MM) {
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorMask.cpp
LowerVectorMultiReduction.cpp
LowerVectorScan.cpp
LowerVectorScatter.cpp
LowerVectorShapeCast.cpp
LowerVectorShuffle.cpp
LowerVectorStep.cpp
Expand Down
101 changes: 101 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//===- LowerVectorScatter.cpp - Lower 'vector.scatter' operation ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.scatter' operation.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"

#define DEBUG_TYPE "vector-scatter-lowering"

using namespace mlir;
using namespace mlir::vector;

namespace {

/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
/// outermost dimension. For example:
/// ```
/// vector.scatter %base[%c0][%idx], %mask, %value :
/// memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
///
/// ==>
///
/// %v0 = vector.extract %value[0] : vector<3xf32> from vector<2x3xf32>
/// %m0 = vector.extract %mask[0] : vector<3xi1> from vector<2x3xi1>
/// %i0 = vector.extract %idx[0] : vector<3xi32> from vector<2x3xi32>
/// vector.scatter %base[%c0][%i0], %m0, %v0 :
/// memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
///
/// %v1 = vector.extract %value[1] : vector<3xf32> from vector<2x3xf32>
/// %m1 = vector.extract %mask[1] : vector<3xi1> from vector<2x3xi1>
/// %i1 = vector.extract %idx[1] : vector<3xi32> from vector<2x3xi32>
/// vector.scatter %base[%c0][%i1], %m1, %v1 :
/// memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
/// ```
///
/// When applied exhaustively, this will produce a sequence of 1-d scatter ops.
///
/// Supports vector types with a fixed leading dimension.
struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ScatterOp op,
PatternRewriter &rewriter) const override {
Value indexVec = op.getIndices();
Value maskVec = op.getMask();
Value valueVec = op.getValueToStore();

// Get the vector type from one of the vector operands.
VectorType vectorTy = dyn_cast<VectorType>(indexVec.getType());
if (!vectorTy)
return failure();

auto unrollScatterFn = [&](PatternRewriter &rewriter, Location loc,
VectorType subTy, int64_t index) {
int64_t thisIdx[1] = {index};

Value indexSubVec =
vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
Value maskSubVec =
vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
Value valueSubVec =
vector::ExtractOp::create(rewriter, loc, valueVec, thisIdx);

rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getOffsets(),
indexSubVec, maskSubVec, valueSubVec,
op.getAlignmentAttr());

// Return a dummy value since unrollVectorOp expects a Value.
return Value();
};

return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy);
}
};

} // namespace

void mlir::vector::populateVectorScatterLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<UnrollScatter>(patterns.getContext(), benefit);
Comment on lines +98 to +100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not keen on the proliferation of populate*Patterns functions. Is there any scenario where one would want scatter unrolling while not simultaneously wanting gather patterns? If not, I'd suggest to merge the two in populateVectorGatherScatterUnrollPatterns, which also has the benefit of being more clear about it being unrolling patterns and not all lowering.

}
43 changes: 32 additions & 11 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,27 +431,48 @@ vector::unrollVectorValue(TypedValue<VectorType> vector,
}

LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
vector::UnrollVectorOpFn unrollFn) {
assert(op->getNumResults() == 1 && "expected single result");
assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
if (resultTy.getRank() < 2)
vector::UnrollVectorOpFn unrollFn,
VectorType vectorTy) {
// If vector type is not provided, get it from the result.
if (!vectorTy) {
assert(op->getNumResults() == 1 &&
"expected single result when vector type not provided");
vectorTy = dyn_cast<VectorType>(op->getResult(0).getType());
assert(vectorTy && "expected result to have vector type");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that, conversely, when vectorTy is provided, there is exactly one result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, when vectorTy is provided, it is because there's no result so unrollVectorOp cannot infer the vector type. Let me add this to the function docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then when vectorTy is provided we should probably assert there's no result?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have misspoken, but the main point is to prefer assertions in code to verbiage in documentation.


if (vectorTy.getRank() < 2)
return rewriter.notifyMatchFailure(op, "already 1-D");

// Unrolling doesn't take vscale into account. Pattern is disabled for
// vectors with leading scalable dim(s).
if (resultTy.getScalableDims().front())
if (vectorTy.getScalableDims().front())
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");

Location loc = op->getLoc();
Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
VectorType subTy = VectorType::Builder(resultTy).dropDim(0);

for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
// Only create result value if the operation produces results.
Value result;
if (op->getNumResults() > 0) {
result = ub::PoisonOp::create(rewriter, loc, vectorTy);
}

VectorType subTy = VectorType::Builder(vectorTy).dropDim(0);

for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
Value subVector = unrollFn(rewriter, loc, subTy, i);
result = vector::InsertOp::create(rewriter, loc, subVector, result, i);

// Only insert if we have a result to build.
if (op->getNumResults() > 0) {
result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
}
}

if (op->getNumResults() > 0) {
rewriter.replaceOp(op, result);
} else {
rewriter.eraseOp(op);
}

rewriter.replaceOp(op, result);
return success();
}
15 changes: 10 additions & 5 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1643,9 +1643,6 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
// vector.scatter
//===----------------------------------------------------------------------===//

// Multi-Dimensional scatters are not supported yet. Check that we do not lower
// them.

func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
%0 = arith.constant 0: index
%1 = vector.constant_mask [2, 2] : vector<2x3xi1>
Expand All @@ -1654,7 +1651,11 @@ func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
}

// CHECK-LABEL: func @scatter_with_mask
// CHECK: vector.scatter
// CHECK: llvm.extractvalue {{.*}}[0]
// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
// CHECK: llvm.extractvalue {{.*}}[1]
// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
// CHECK-NOT: vector.scatter

// -----

Expand All @@ -1669,7 +1670,11 @@ func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]x
}

// CHECK-LABEL: func @scatter_with_mask_scalable
// CHECK: vector.scatter
// CHECK: llvm.extractvalue {{.*}}[0]
// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
// CHECK: llvm.extractvalue {{.*}}[1]
// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
// CHECK-NOT: vector.scatter

// -----

Expand Down