diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 7bd96c8a6d1a1..83c08d4103177 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -239,6 +239,14 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +/// [UnrollScatter] +/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the +/// outermost dimension. +void populateVectorScatterLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate the pattern set with the following patterns: /// /// [UnrollGather] diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index a57aadcdcc5b0..8f609acd2fdb7 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -250,11 +250,16 @@ LogicalResult isValidMaskedInputVector(ArrayRef shape, /// create sub vectors. /// 5. Insert the sub vectors back into the final vector. /// 6. Replace the original op with the new result. +/// +/// Expects the operation to be unrolled to have at most 1 result. When there's +/// no result, expects the caller to pass in the `vectorTy` to be able to get +/// the unroll factor. using UnrollVectorOpFn = function_ref; LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, - UnrollVectorOpFn unrollFn); + UnrollVectorOpFn unrollFn, + VectorType vectorTy = nullptr); /// Generic utility for unrolling values of type vector /// to N values of type vector using vector.extract. If the input diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index f958edf2746e9..1750e0430fd2b 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); + populateVectorScatterLoweringPatterns(patterns); populateVectorFromElementsUnrollPatterns(patterns); populateVectorToElementsUnrollPatterns(patterns); if (armI8MM) { diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 4e0f07af95984..cd7d4f5d1c69c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorMask.cpp LowerVectorMultiReduction.cpp LowerVectorScan.cpp + LowerVectorScatter.cpp LowerVectorShapeCast.cpp LowerVectorShuffle.cpp LowerVectorStep.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp new file mode 100644 index 0000000000000..af17136f21da0 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp @@ -0,0 +1,101 @@ +//===- LowerVectorScatter.cpp - Lower 'vector.scatter' operation ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.scatter' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" + +#define DEBUG_TYPE "vector-scatter-lowering" + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// vector.scatter %base[%c0][%idx], %mask, %value : +/// memref, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> +/// +/// ==> +/// +/// %v0 = vector.extract %value[0] : vector<3xf32> from vector<2x3xf32> +/// %m0 = vector.extract %mask[0] : vector<3xi1> from vector<2x3xi1> +/// %i0 = vector.extract %idx[0] : vector<3xi32> from vector<2x3xi32> +/// vector.scatter %base[%c0][%i0], %m0, %v0 : +/// memref, vector<3xi32>, vector<3xi1>, vector<3xf32> +/// +/// %v1 = vector.extract %value[1] : vector<3xf32> from vector<2x3xf32> +/// %m1 = vector.extract %mask[1] : vector<3xi1> from vector<2x3xi1> +/// %i1 = vector.extract %idx[1] : vector<3xi32> from vector<2x3xi32> +/// vector.scatter %base[%c0][%i1], %m1, %v1 : +/// memref, vector<3xi32>, vector<3xi1>, vector<3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d scatter ops. +/// +/// Supports vector types with a fixed leading dimension. +struct UnrollScatter : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ScatterOp op, + PatternRewriter &rewriter) const override { + Value indexVec = op.getIndices(); + Value maskVec = op.getMask(); + Value valueVec = op.getValueToStore(); + + // Get the vector type from one of the vector operands. + VectorType vectorTy = dyn_cast(indexVec.getType()); + if (!vectorTy) + return failure(); + + auto unrollScatterFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + int64_t thisIdx[1] = {index}; + + Value indexSubVec = + vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); + Value maskSubVec = + vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx); + Value valueSubVec = + vector::ExtractOp::create(rewriter, loc, valueVec, thisIdx); + + rewriter.create(loc, op.getBase(), op.getOffsets(), + indexSubVec, maskSubVec, valueSubVec, + op.getAlignmentAttr()); + + // Return a dummy value since unrollVectorOp expects a Value. + return Value(); + }; + + return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy); + } +}; + +} // namespace + +void mlir::vector::populateVectorScatterLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 025ee9a04a1de..02f9382c760be 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -431,27 +431,48 @@ vector::unrollVectorValue(TypedValue vector, } LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, - vector::UnrollVectorOpFn unrollFn) { - assert(op->getNumResults() == 1 && "expected single result"); - assert(isa(op->getResult(0).getType()) && "expected vector type"); - VectorType resultTy = cast(op->getResult(0).getType()); - if (resultTy.getRank() < 2) + vector::UnrollVectorOpFn unrollFn, + VectorType vectorTy) { + // If vector type is not provided, get it from the result. + if (!vectorTy) { + assert(op->getNumResults() == 1 && + "expected single result when vector type not provided"); + vectorTy = dyn_cast(op->getResult(0).getType()); + assert(vectorTy && "expected result to have vector type"); + } + + if (vectorTy.getRank() < 2) return rewriter.notifyMatchFailure(op, "already 1-D"); // Unrolling doesn't take vscale into account. Pattern is disabled for // vectors with leading scalable dim(s). - if (resultTy.getScalableDims().front()) + if (vectorTy.getScalableDims().front()) return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); Location loc = op->getLoc(); - Value result = ub::PoisonOp::create(rewriter, loc, resultTy); - VectorType subTy = VectorType::Builder(resultTy).dropDim(0); - for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { + // Only create result value if the operation produces results. + Value result; + if (op->getNumResults() > 0) { + result = ub::PoisonOp::create(rewriter, loc, vectorTy); + } + + VectorType subTy = VectorType::Builder(vectorTy).dropDim(0); + + for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) { Value subVector = unrollFn(rewriter, loc, subTy, i); - result = vector::InsertOp::create(rewriter, loc, subVector, result, i); + + // Only insert if we have a result to build. + if (op->getNumResults() > 0) { + result = vector::InsertOp::create(rewriter, loc, subVector, result, i); + } + } + + if (op->getNumResults() > 0) { + rewriter.replaceOp(op, result); + } else { + rewriter.eraseOp(op); } - rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 2d33888854ea7..6ba37bd56083f 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1643,9 +1643,6 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref, %arg1: vector<2x // vector.scatter //===----------------------------------------------------------------------===// -// Multi-Dimensional scatters are not supported yet. Check that we do not lower -// them. - func.func @scatter_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) { %0 = arith.constant 0: index %1 = vector.constant_mask [2, 2] : vector<2x3xi1> @@ -1654,7 +1651,11 @@ func.func @scatter_with_mask(%arg0: memref, %arg1: vector<2x3xi32>, %arg2 } // CHECK-LABEL: func @scatter_with_mask -// CHECK: vector.scatter +// CHECK: llvm.extractvalue {{.*}}[0] +// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> +// CHECK: llvm.extractvalue {{.*}}[1] +// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr> +// CHECK-NOT: vector.scatter // ----- @@ -1669,7 +1670,11 @@ func.func @scatter_with_mask_scalable(%arg0: memref, %arg1: vector<2x[3]x } // CHECK-LABEL: func @scatter_with_mask_scalable -// CHECK: vector.scatter +// CHECK: llvm.extractvalue {{.*}}[0] +// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr> +// CHECK: llvm.extractvalue {{.*}}[1] +// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr> +// CHECK-NOT: vector.scatter // -----