diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index b896506f29eef..7bd96c8a6d1a1 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -306,20 +306,6 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); void populateVectorToFromElementsToShuffleTreePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Populate the pattern set with the following patterns: -/// -/// [UnrollFromElements] -/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the -/// outermost dimension. -void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// Populate the pattern set with the following patterns: -/// -/// [UnrollToElements] -void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - /// Populate the pattern set with the following patterns: /// /// [ContractionOpToMatmulOpLowering] diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 08f439222a9a0..69438011d2287 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -322,6 +322,16 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit = 1); +/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the +/// outermost dimension of the operand. +void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. +void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Collect a set of leading one dimension removal patterns. /// /// These patterns insert vector.shape_cast to remove leading one dimensions diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index e516118f75207..5994b64f3d9a5 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -534,7 +534,7 @@ void GpuToLLVMConversionPass::runOnOperation() { /*maxTransferRank=*/1); // Transform N-D vector.from_elements to 1-D vector.from_elements before // conversion. - vector::populateVectorFromElementsLoweringPatterns(patterns); + vector::populateVectorFromElementsUnrollPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 76a7e0f3831a2..a95263bb55f69 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -372,7 +372,7 @@ struct LowerGpuOpsToNVVMOpsPass final populateGpuRewritePatterns(patterns); // Transform N-D vector.from_elements to 1-D vector.from_elements before // conversion. - vector::populateVectorFromElementsLoweringPatterns(patterns); + vector::populateVectorFromElementsUnrollPatterns(patterns); if (failed(applyPatternsGreedily(m, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0b44ca7ceee42..cae490e5f03e7 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -94,8 +94,8 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); - populateVectorFromElementsLoweringPatterns(patterns); - populateVectorToElementsLoweringPatterns(patterns); + populateVectorFromElementsUnrollPatterns(patterns); + populateVectorToElementsUnrollPatterns(patterns); if (armI8MM) { if (armNeon) arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 18f105ef62e38..7faa222a9e574 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -146,12 +146,12 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns( void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorFromElementsLoweringPatterns(patterns); + vector::populateVectorFromElementsUnrollPatterns(patterns); } void transform::ApplyUnrollToElementsPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorToElementsLoweringPatterns(patterns); + vector::populateVectorToElementsUnrollPatterns(patterns); } void transform::ApplyLowerScanPatternsOp::populatePatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index fecf445720173..4e0f07af95984 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorBitCast.cpp LowerVectorBroadcast.cpp LowerVectorContract.cpp - LowerVectorFromElements.cpp LowerVectorGather.cpp LowerVectorInterleave.cpp LowerVectorMask.cpp @@ -12,7 +11,6 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorShapeCast.cpp LowerVectorShuffle.cpp LowerVectorStep.cpp - LowerVectorToElements.cpp LowerVectorToFromElementsToShuffleTree.cpp LowerVectorTransfer.cpp LowerVectorTranspose.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp deleted file mode 100644 index c22fd54cef46b..0000000000000 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp +++ /dev/null @@ -1,65 +0,0 @@ -//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements target-independent rewrites and utilities to lower the -// 'vector.from_elements' operation. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" - -#define DEBUG_TYPE "lower-vector-from-elements" - -using namespace mlir; - -namespace { - -/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the -/// outermost dimension. For example: -/// ``` -/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> -/// -/// ==> -/// -/// %0 = ub.poison : vector<2x3xf32> -/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> -/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> -/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> -/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> -/// ``` -/// -/// When applied exhaustively, this will produce a sequence of 1-d from_elements -/// ops. -struct UnrollFromElements : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::FromElementsOp op, - PatternRewriter &rewriter) const override { - ValueRange allElements = op.getElements(); - - auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, - VectorType subTy, int64_t index) { - size_t subTyNumElements = subTy.getNumElements(); - assert((index + 1) * subTyNumElements <= allElements.size() && - "out of bounds"); - ValueRange subElements = - allElements.slice(index * subTyNumElements, subTyNumElements); - return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); - }; - - return unrollVectorOp(op, rewriter, unrollFromElementsFn); - } -}; - -} // namespace - -void mlir::vector::populateVectorFromElementsLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); -} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp deleted file mode 100644 index a53a183ec31bc..0000000000000 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ /dev/null @@ -1,53 +0,0 @@ -//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements target-independent rewrites and utilities to lower the -// 'vector.to_elements' operation. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" - -#define DEBUG_TYPE "lower-vector-to-elements" - -using namespace mlir; - -namespace { - -struct UnrollToElements final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ToElementsOp op, - PatternRewriter &rewriter) const override { - - TypedValue source = op.getSource(); - FailureOr> result = - vector::unrollVectorValue(source, rewriter); - if (failed(result)) { - return failure(); - } - SmallVector vectors = *result; - - SmallVector results; - for (const Value &vector : vectors) { - auto subElements = - vector::ToElementsOp::create(rewriter, op.getLoc(), vector); - llvm::append_range(results, subElements.getResults()); - } - rewriter.replaceOp(op, results); - return success(); - } -}; - -} // namespace - -void mlir::vector::populateVectorToElementsLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); -} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 79786f33a2d46..14639c5f1cdd3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/ADT/MapVector.h" @@ -809,6 +810,55 @@ struct UnrollBroadcastPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the +/// outermost dimension of the operand. For example: +/// +/// ``` +/// %0:4 = vector.to_elements %v : vector<2x2xf32> +/// +/// ==> +/// +/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32> +/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32> +/// %0:4 = vector.to_elements %v0 : vector<2x2xf32> +/// %1:4 = vector.to_elements %v1 : vector<2x2xf32> +/// ``` +/// +/// When this pattern is applied until a fixed-point is reached, +/// this will produce a sequence of 1-d from_elements +/// ops. +struct UnrollToElements final : public OpRewritePattern { + UnrollToElements(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::ToElementsOp op, + PatternRewriter &rewriter) const override { + + TypedValue source = op.getSource(); + FailureOr> result = + vector::unrollVectorValue(source, rewriter); + if (failed(result)) { + return failure(); + } + SmallVector vectors = *result; + + SmallVector results; + for (Value vector : vectors) { + auto subElements = + vector::ToElementsOp::create(rewriter, op.getLoc(), vector); + llvm::append_range(results, subElements.getResults()); + } + rewriter.replaceOp(op, results); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + /// This pattern unrolls `vector.step` operations according to the provided /// target unroll shape. It decomposes a large step vector into smaller step /// vectors (segments) and assembles the result by inserting each computed @@ -884,6 +934,51 @@ struct UnrollStepPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> +/// +/// ==> +/// +/// %0 = ub.poison : vector<2x3xf32> +/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> +/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> +/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> +/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> +/// ``` +/// +/// When this pattern is applied until a fixed-point is reached, +/// this will produce a sequence of 1-d from_elements +/// ops. +struct UnrollFromElements : OpRewritePattern { + UnrollFromElements(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::FromElementsOp op, + PatternRewriter &rewriter) const override { + ValueRange allElements = op.getElements(); + + auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + size_t subTyNumElements = subTy.getNumElements(); + assert((index + 1) * subTyNumElements <= allElements.size() && + "out of bounds"); + ValueRange subElements = + allElements.slice(index * subTyNumElements, subTyNumElements); + return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); + }; + + return unrollVectorOp(op, rewriter, unrollFromElementsFn); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -893,6 +988,19 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollContractionPattern, UnrollElementwisePattern, UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, - UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>( - patterns.getContext(), options, benefit); + UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, + UnrollToElements, UnrollStepPattern>(patterns.getContext(), + options, benefit); +} + +void mlir::vector::populateVectorToElementsUnrollPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), UnrollVectorOptions(), + benefit); +} + +void mlir::vector::populateVectorFromElementsUnrollPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), UnrollVectorOptions(), + benefit); } diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir index c85f4334ff2e5..0957f67690b97 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir @@ -96,3 +96,47 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) { %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32> return %0 : vector<3x2xi32> } + +// ----- + +// In order to verify that the pattern is applied, +// we need to make sure that the the 2d vector does not +// come from the parameters. Otherwise, the pattern +// in unrollVectorsInSignatures which splits the 2d vector +// parameter will take precedent. Similarly, let's avoid +// returning a vector as another pattern would take precendence. + +// CHECK-LABEL: @unroll_to_elements_2d +func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) { + %1 = "test.op"() : () -> (vector<2x2xf32>) + // CHECK: %[[VEC2D:.+]] = "test.op" + // CHECK: %[[VEC0:.+]] = vector.extract %[[VEC2D]][0] : vector<2xf32> from vector<2x2xf32> + // CHECK: %[[VEC1:.+]] = vector.extract %[[VEC2D]][1] : vector<2xf32> from vector<2x2xf32> + // CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] + // CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] + %2:4 = vector.to_elements %1 : vector<2x2xf32> + return %2#0, %2#1, %2#2, %2#3 : f32, f32, f32, f32 +} + +// ----- + +// In order to verify that the pattern is applied, +// we need to make sure that the the 2d vector is used +// by an operation and that extracts are not folded away. +// In other words we can't use "test.op" nor return the +// value `%0 = vector.from_elements` + +// CHECK-LABEL: @unroll_from_elements_2d +// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32) +func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> (vector<2x2xf32>) { + // CHECK: %[[VEC0:.+]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32> + // CHECK: %[[VEC1:.+]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32> + %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32> + + // CHECK: %[[RES0:.+]] = arith.addf %[[VEC0]], %[[VEC0]] + // CHECK: %[[RES1:.+]] = arith.addf %[[VEC1]], %[[VEC1]] + %1 = arith.addf %0, %0 : vector<2x2xf32> + + // return %[[RES0]], %%[[RES1]] : vector<2xf32>, vector<2xf32> + return %1 : vector<2x2xf32> +}