From 08655a9b64e846a5dcd1d77ec0481292392032e3 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 16 Sep 2025 09:00:46 -0700 Subject: [PATCH 01/11] [mlir] Add vector.{to_elements,from_elements} unrolling to VectorToSPIRV --- .../SPIRV/Transforms/SPIRVConversion.cpp | 2 + .../ConvertToSPIRV/vector-unroll.mlir | 44 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 49f4ce8de7c76..98e294b40456f 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1495,6 +1495,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { RewritePatternSet patterns(context); auto options = vector::UnrollVectorOptions().setNativeShapeFn( [](auto op) { return mlir::spirv::getNativeVectorShape(op); }); + vector::populateVectorFromElementsLoweringPatterns(patterns); + vector::populateVectorToElementsLoweringPatterns(patterns); populateVectorUnrollPatterns(patterns, options); if (failed(applyPatternsGreedily(op, std::move(patterns)))) return failure(); diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir index c85f4334ff2e5..0957f67690b97 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir @@ -96,3 +96,47 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) { %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32> return %0 : vector<3x2xi32> } + +// ----- + +// In order to verify that the pattern is applied, +// we need to make sure that the the 2d vector does not +// come from the parameters. Otherwise, the pattern +// in unrollVectorsInSignatures which splits the 2d vector +// parameter will take precedent. Similarly, let's avoid +// returning a vector as another pattern would take precendence. + +// CHECK-LABEL: @unroll_to_elements_2d +func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) { + %1 = "test.op"() : () -> (vector<2x2xf32>) + // CHECK: %[[VEC2D:.+]] = "test.op" + // CHECK: %[[VEC0:.+]] = vector.extract %[[VEC2D]][0] : vector<2xf32> from vector<2x2xf32> + // CHECK: %[[VEC1:.+]] = vector.extract %[[VEC2D]][1] : vector<2xf32> from vector<2x2xf32> + // CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] + // CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] + %2:4 = vector.to_elements %1 : vector<2x2xf32> + return %2#0, %2#1, %2#2, %2#3 : f32, f32, f32, f32 +} + +// ----- + +// In order to verify that the pattern is applied, +// we need to make sure that the the 2d vector is used +// by an operation and that extracts are not folded away. +// In other words we can't use "test.op" nor return the +// value `%0 = vector.from_elements` + +// CHECK-LABEL: @unroll_from_elements_2d +// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32) +func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> (vector<2x2xf32>) { + // CHECK: %[[VEC0:.+]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32> + // CHECK: %[[VEC1:.+]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32> + %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32> + + // CHECK: %[[RES0:.+]] = arith.addf %[[VEC0]], %[[VEC0]] + // CHECK: %[[RES1:.+]] = arith.addf %[[VEC1]], %[[VEC1]] + %1 = arith.addf %0, %0 : vector<2x2xf32> + + // return %[[RES0]], %%[[RES1]] : vector<2xf32>, vector<2xf32> + return %1 : vector<2x2xf32> +} From c891a27212669a5d54f383a11ac3defb16a92f8f Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 16 Sep 2025 11:25:02 -0700 Subject: [PATCH 02/11] populate patterns inside populateVectorUnrollPatterns --- mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 2 -- mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 98e294b40456f..49f4ce8de7c76 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1495,8 +1495,6 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { RewritePatternSet patterns(context); auto options = vector::UnrollVectorOptions().setNativeShapeFn( [](auto op) { return mlir::spirv::getNativeVectorShape(op); }); - vector::populateVectorFromElementsLoweringPatterns(patterns); - vector::populateVectorToElementsLoweringPatterns(patterns); populateVectorUnrollPatterns(patterns, options); if (failed(applyPatternsGreedily(op, std::move(patterns)))) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index e8ecb0c0be846..a75e680afe1fb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/ADT/MapVector.h" @@ -814,6 +815,8 @@ struct UnrollBroadcastPattern : public OpRewritePattern { void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit) { + populateVectorToElementsLoweringPatterns(patterns); + populateVectorFromElementsLoweringPatterns(patterns); patterns.add Date: Tue, 16 Sep 2025 11:43:01 -0700 Subject: [PATCH 03/11] Copy over UnrollToElements --- .../Vector/Transforms/VectorRewritePatterns.h | 6 ++++ .../Vector/Transforms/VectorUnroll.cpp | 32 ++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 0138f477cadea..32fcb948b9cf7 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -322,6 +322,12 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +/// [UnrollToElements] +void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Collect a set of leading one dimension removal patterns. /// /// These patterns insert vector.shape_cast to remove leading one dimensions diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index a75e680afe1fb..bcfaa843a306f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -810,13 +810,38 @@ struct UnrollBroadcastPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +struct UnrollToElements final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ToElementsOp op, + PatternRewriter &rewriter) const override { + + TypedValue source = op.getSource(); + FailureOr> result = + vector::unrollVectorValue(source, rewriter); + if (failed(result)) { + return failure(); + } + SmallVector vectors = *result; + + SmallVector results; + for (const Value &vector : vectors) { + auto subElements = + vector::ToElementsOp::create(rewriter, op.getLoc(), vector); + llvm::append_range(results, subElements.getResults()); + } + rewriter.replaceOp(op, results); + return success(); + } +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit) { - populateVectorToElementsLoweringPatterns(patterns); populateVectorFromElementsLoweringPatterns(patterns); + patterns.add(patterns.getContext(), benefit); patterns.add( patterns.getContext(), options, benefit); } + +void mlir::vector::populateVectorToElementsUnrollPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} From 5c3e7d5d60c1fdd4e3e09399c68d04a78ee8748a Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 16 Sep 2025 11:55:05 -0700 Subject: [PATCH 04/11] Removes reference to previous UnrollToElements pattern --- .../Vector/Transforms/LoweringPatterns.h | 6 --- .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 2 +- .../TransformOps/VectorTransformOps.cpp | 2 +- .../Dialect/Vector/Transforms/CMakeLists.txt | 1 - .../Transforms/LowerVectorToElements.cpp | 53 ------------------- 5 files changed, 2 insertions(+), 62 deletions(-) delete mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index f56124cb4fb95..47f96112a9433 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -311,12 +311,6 @@ void populateVectorToFromElementsToShuffleTreePatterns( void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Populate the pattern set with the following patterns: -/// -/// [UnrollToElements] -void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - /// Populate the pattern set with the following patterns: /// /// [ContractionOpToMatmulOpLowering] diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 0b44ca7ceee42..9cdfeea2b81bf 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -95,7 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); populateVectorFromElementsLoweringPatterns(patterns); - populateVectorToElementsLoweringPatterns(patterns); + populateVectorToElementsUnrollPatterns(patterns); if (armI8MM) { if (armNeon) arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 18f105ef62e38..a3350a3332862 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -151,7 +151,7 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( void transform::ApplyUnrollToElementsPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorToElementsLoweringPatterns(patterns); + vector::populateVectorToElementsUnrollPatterns(patterns); } void transform::ApplyLowerScanPatternsOp::populatePatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index d74007f13a95b..acbf2b746037b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorScan.cpp LowerVectorShapeCast.cpp LowerVectorStep.cpp - LowerVectorToElements.cpp LowerVectorToFromElementsToShuffleTree.cpp LowerVectorTransfer.cpp LowerVectorTranspose.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp deleted file mode 100644 index a53a183ec31bc..0000000000000 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp +++ /dev/null @@ -1,53 +0,0 @@ -//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements target-independent rewrites and utilities to lower the -// 'vector.to_elements' operation. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" - -#define DEBUG_TYPE "lower-vector-to-elements" - -using namespace mlir; - -namespace { - -struct UnrollToElements final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ToElementsOp op, - PatternRewriter &rewriter) const override { - - TypedValue source = op.getSource(); - FailureOr> result = - vector::unrollVectorValue(source, rewriter); - if (failed(result)) { - return failure(); - } - SmallVector vectors = *result; - - SmallVector results; - for (const Value &vector : vectors) { - auto subElements = - vector::ToElementsOp::create(rewriter, op.getLoc(), vector); - llvm::append_range(results, subElements.getResults()); - } - rewriter.replaceOp(op, results); - return success(); - } -}; - -} // namespace - -void mlir::vector::populateVectorToElementsLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); -} From 04431c5f5755c7ed2b71f831014dc8d6f8563af7 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 16 Sep 2025 12:05:05 -0700 Subject: [PATCH 05/11] Copy over UnrollFromElements --- .../Vector/Transforms/VectorRewritePatterns.h | 8 ++++ .../Vector/Transforms/VectorUnroll.cpp | 46 ++++++++++++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 32fcb948b9cf7..c42b8748f60de 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -328,6 +328,14 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns, void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +/// [UnrollFromElements] +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. +void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Collect a set of leading one dimension removal patterns. /// /// These patterns insert vector.shape_cast to remove leading one dimensions diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index bcfaa843a306f..ba82aa766180f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -835,13 +835,50 @@ struct UnrollToElements final : public OpRewritePattern { } }; +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> +/// +/// ==> +/// +/// %0 = ub.poison : vector<2x3xf32> +/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> +/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> +/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> +/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d from_elements +/// ops. +struct UnrollFromElements : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FromElementsOp op, + PatternRewriter &rewriter) const override { + ValueRange allElements = op.getElements(); + + auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + size_t subTyNumElements = subTy.getNumElements(); + assert((index + 1) * subTyNumElements <= allElements.size() && + "out of bounds"); + ValueRange subElements = + allElements.slice(index * subTyNumElements, subTyNumElements); + return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); + }; + + return unrollVectorOp(op, rewriter, unrollFromElementsFn); + } +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit) { - populateVectorFromElementsLoweringPatterns(patterns); - patterns.add(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), + benefit); patterns.add(patterns.getContext(), benefit); } + +void mlir::vector::populateVectorFromElementsUnrollPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} From 70a667e035a6de981e7b21e759337f8bb47728bb Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 16 Sep 2025 12:10:48 -0700 Subject: [PATCH 06/11] Removes reference to previous UnrollFromElements pattern --- .../Vector/Transforms/LoweringPatterns.h | 8 --- .../GPUCommon/GPUToLLVMConversion.cpp | 2 +- .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +- .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 2 +- .../TransformOps/VectorTransformOps.cpp | 2 +- .../Dialect/Vector/Transforms/CMakeLists.txt | 1 - .../Transforms/LowerVectorFromElements.cpp | 65 ------------------- 7 files changed, 4 insertions(+), 78 deletions(-) delete mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 47f96112a9433..e03f0dabece52 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -303,14 +303,6 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); void populateVectorToFromElementsToShuffleTreePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Populate the pattern set with the following patterns: -/// -/// [UnrollFromElements] -/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the -/// outermost dimension. -void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - /// Populate the pattern set with the following patterns: /// /// [ContractionOpToMatmulOpLowering] diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index e516118f75207..5994b64f3d9a5 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -534,7 +534,7 @@ void GpuToLLVMConversionPass::runOnOperation() { /*maxTransferRank=*/1); // Transform N-D vector.from_elements to 1-D vector.from_elements before // conversion. - vector::populateVectorFromElementsLoweringPatterns(patterns); + vector::populateVectorFromElementsUnrollPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 76a7e0f3831a2..a95263bb55f69 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -372,7 +372,7 @@ struct LowerGpuOpsToNVVMOpsPass final populateGpuRewritePatterns(patterns); // Transform N-D vector.from_elements to 1-D vector.from_elements before // conversion. - vector::populateVectorFromElementsLoweringPatterns(patterns); + vector::populateVectorFromElementsUnrollPatterns(patterns); if (failed(applyPatternsGreedily(m, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 9cdfeea2b81bf..cae490e5f03e7 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -94,7 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); - populateVectorFromElementsLoweringPatterns(patterns); + populateVectorFromElementsUnrollPatterns(patterns); populateVectorToElementsUnrollPatterns(patterns); if (armI8MM) { if (armNeon) diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index a3350a3332862..7faa222a9e574 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -146,7 +146,7 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns( void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorFromElementsLoweringPatterns(patterns); + vector::populateVectorFromElementsUnrollPatterns(patterns); } void transform::ApplyUnrollToElementsPatternsOp::populatePatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index acbf2b746037b..9e287fc109990 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorBitCast.cpp LowerVectorBroadcast.cpp LowerVectorContract.cpp - LowerVectorFromElements.cpp LowerVectorGather.cpp LowerVectorInterleave.cpp LowerVectorMask.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp deleted file mode 100644 index c22fd54cef46b..0000000000000 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp +++ /dev/null @@ -1,65 +0,0 @@ -//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements target-independent rewrites and utilities to lower the -// 'vector.from_elements' operation. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" - -#define DEBUG_TYPE "lower-vector-from-elements" - -using namespace mlir; - -namespace { - -/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the -/// outermost dimension. For example: -/// ``` -/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> -/// -/// ==> -/// -/// %0 = ub.poison : vector<2x3xf32> -/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> -/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> -/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> -/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> -/// ``` -/// -/// When applied exhaustively, this will produce a sequence of 1-d from_elements -/// ops. -struct UnrollFromElements : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::FromElementsOp op, - PatternRewriter &rewriter) const override { - ValueRange allElements = op.getElements(); - - auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, - VectorType subTy, int64_t index) { - size_t subTyNumElements = subTy.getNumElements(); - assert((index + 1) * subTyNumElements <= allElements.size() && - "out of bounds"); - ValueRange subElements = - allElements.slice(index * subTyNumElements, subTyNumElements); - return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); - }; - - return unrollVectorOp(op, rewriter, unrollFromElementsFn); - } -}; - -} // namespace - -void mlir::vector::populateVectorFromElementsLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); -} From c054f16b16be2d8e3cee1234c6b1443b81ba70d3 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 16 Sep 2025 12:26:11 -0700 Subject: [PATCH 07/11] Adds UnrollVectorOptions to ToElements and ForElements patterns --- .../Vector/Transforms/VectorUnroll.cpp | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index ba82aa766180f..6f8c667d58f48 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -811,7 +811,11 @@ struct UnrollBroadcastPattern : public OpRewritePattern { }; struct UnrollToElements final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + UnrollToElements(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} LogicalResult matchAndRewrite(vector::ToElementsOp op, PatternRewriter &rewriter) const override { @@ -833,6 +837,9 @@ struct UnrollToElements final : public OpRewritePattern { rewriter.replaceOp(op, results); return success(); } + +private: + vector::UnrollVectorOptions options; }; /// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the @@ -852,7 +859,11 @@ struct UnrollToElements final : public OpRewritePattern { /// When applied exhaustively, this will produce a sequence of 1-d from_elements /// ops. struct UnrollFromElements : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + UnrollFromElements(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} LogicalResult matchAndRewrite(vector::FromElementsOp op, PatternRewriter &rewriter) const override { @@ -870,6 +881,9 @@ struct UnrollFromElements : OpRewritePattern { return unrollVectorOp(op, rewriter, unrollFromElementsFn); } + +private: + vector::UnrollVectorOptions options; }; } // namespace @@ -877,22 +891,22 @@ struct UnrollFromElements : OpRewritePattern { void mlir::vector::populateVectorUnrollPatterns( RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit) { - patterns.add(patterns.getContext(), - benefit); patterns.add( - patterns.getContext(), options, benefit); + UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, + UnrollToElements>(patterns.getContext(), options, benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), UnrollVectorOptions(), + benefit); } void mlir::vector::populateVectorFromElementsUnrollPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), UnrollVectorOptions(), + benefit); } From 5a266fbbfc6977b8f8bf21234133f2a4b30c45c5 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Wed, 17 Sep 2025 13:09:50 -0400 Subject: [PATCH 08/11] Address review comments --- mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 0af1882a66d30..1c80c6b6a39aa 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -829,7 +829,7 @@ struct UnrollToElements final : public OpRewritePattern { SmallVector vectors = *result; SmallVector results; - for (const Value &vector : vectors) { + for (Value vector : vectors) { auto subElements = vector::ToElementsOp::create(rewriter, op.getLoc(), vector); llvm::append_range(results, subElements.getResults()); From f0283ca5f4ed806fe5cb88cf768b7b21d3049710 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 18 Sep 2025 08:43:13 -0400 Subject: [PATCH 09/11] Update documentation --- .../Dialect/Vector/Transforms/VectorRewritePatterns.h | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index b1effc8642383..69438011d2287 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -322,15 +322,11 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit = 1); -/// Populate the pattern set with the following patterns: -/// -/// [UnrollToElements] +/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the +/// outermost dimension of the operand. void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Populate the pattern set with the following patterns: -/// -/// [UnrollFromElements] /// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the /// outermost dimension. void populateVectorFromElementsUnrollPatterns(RewritePatternSet &patterns, From 5d7afb1efd4a2b05918044c1bf0b46773b41619c Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 18 Sep 2025 08:44:46 -0400 Subject: [PATCH 10/11] Reword documentation 'exhaustive' -> 'fixed-point' --- mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 1c80c6b6a39aa..e7dd5958cf72a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -931,7 +931,8 @@ struct UnrollStepPattern : public OpRewritePattern { /// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> /// ``` /// -/// When applied exhaustively, this will produce a sequence of 1-d from_elements +/// When this pattern is applied until a fixed-point is reached, +/// this will produce a sequence of 1-d from_elements /// ops. struct UnrollFromElements : OpRewritePattern { UnrollFromElements(MLIRContext *context, From 44145b5c525a179ef117c17d89c4fa6bd37481a8 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Thu, 18 Sep 2025 08:48:06 -0400 Subject: [PATCH 11/11] Add documentation to UnrollToElements --- .../Dialect/Vector/Transforms/VectorUnroll.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index e7dd5958cf72a..14639c5f1cdd3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -810,6 +810,23 @@ struct UnrollBroadcastPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the +/// outermost dimension of the operand. For example: +/// +/// ``` +/// %0:4 = vector.to_elements %v : vector<2x2xf32> +/// +/// ==> +/// +/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32> +/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32> +/// %0:4 = vector.to_elements %v0 : vector<2x2xf32> +/// %1:4 = vector.to_elements %v1 : vector<2x2xf32> +/// ``` +/// +/// When this pattern is applied until a fixed-point is reached, +/// this will produce a sequence of 1-d from_elements +/// ops. struct UnrollToElements final : public OpRewritePattern { UnrollToElements(MLIRContext *context, const vector::UnrollVectorOptions &options,