diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 07a4117a37b2c..72a69a056c46e 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -265,6 +265,17 @@ def ApplyUnrollFromElementsPatternsOp : Op]> { + let description = [{ + Indicates that vector to_elements operations should be unrolled + along the outermost dimension. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyLowerScanPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 47f96112a9433..f56124cb4fb95 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns( void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +/// [UnrollToElements] +void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate the pattern set with the following patterns: /// /// [ContractionOpToMatmulOpLowering] diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index ace26990601c8..97163c4532378 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -255,6 +255,12 @@ using UnrollVectorOpFn = LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn); +/// Generic utility for unrolling values of type vector +/// to N values of type vector using vector.extract. If the input +/// is rank-1 or has leading scalable dimension, failure is returned. +FailureOr> unrollVectorValue(TypedValue, + RewriterBase &); + } // namespace vector /// Constructs a permutation map of invariant memref indices to vector diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 9852df6970fdc..0b44ca7ceee42 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); populateVectorFromElementsLoweringPatterns(patterns); + populateVectorToElementsLoweringPatterns(patterns); if (armI8MM) { if (armNeon) arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index fe066dc04ad55..6bb390aa09d3e 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -144,6 +144,11 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( vector::populateVectorFromElementsLoweringPatterns(patterns); } +void transform::ApplyUnrollToElementsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorToElementsLoweringPatterns(patterns); +} + void transform::ApplyLowerScanPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorScanLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index acbf2b746037b..d74007f13a95b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorScan.cpp LowerVectorShapeCast.cpp LowerVectorStep.cpp + LowerVectorToElements.cpp LowerVectorToFromElementsToShuffleTree.cpp LowerVectorTransfer.cpp LowerVectorTranspose.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp new file mode 100644 index 0000000000000..a53a183ec31bc --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -0,0 +1,53 @@ +//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.to_elements' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" + +#define DEBUG_TYPE "lower-vector-to-elements" + +using namespace mlir; + +namespace { + +struct UnrollToElements final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ToElementsOp op, + PatternRewriter &rewriter) const override { + + TypedValue source = op.getSource(); + FailureOr> result = + vector::unrollVectorValue(source, rewriter); + if (failed(result)) { + return failure(); + } + SmallVector vectors = *result; + + SmallVector results; + for (const Value &vector : vectors) { + auto subElements = + vector::ToElementsOp::create(rewriter, op.getLoc(), vector); + llvm::append_range(results, subElements.getResults()); + } + rewriter.replaceOp(op, results); + return success(); + } +}; + +} // namespace + +void mlir::vector::populateVectorToElementsLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 841e1384e03b3..39dc7a4f284a6 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -393,6 +393,41 @@ vector::isValidMaskedInputVector(ArrayRef shape, return success(); } +/// Takes a 2+ dimensional vector as an input +/// returns n vector values produced by n vector.extract operations. +/// I.e. calling unrollVectorValue([[%v]], rewriter) such that +/// +/// %v : vector +/// +/// will produce the following IR changes +/// +/// %v0 = vector.extract %v[0] : vector from vector +/// %v1 = vector.extract %v[1] : vector from vector +/// ... +/// %vnminusone = vector.extract %v[n-1] : vector from ... +/// +/// and returns SmallVector r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]} +FailureOr> +vector::unrollVectorValue(TypedValue vector, + RewriterBase &rewriter) { + SmallVector subvectors; + VectorType ty = cast(vector.getType()); + Location loc = vector.getLoc(); + if (ty.getRank() < 2) + return rewriter.notifyMatchFailure(loc, "already 1-D"); + + // Unrolling doesn't take vscale into account. Pattern is disabled for + // vectors with leading scalable dim(s). + if (ty.getScalableDims().front()) + return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim"); + + for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) { + subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i)); + } + + return subvectors; +} + LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, vector::UnrollVectorOpFn unrollFn) { assert(op->getNumResults() == 1 && "expected single result"); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 07d335117de01..2d33888854ea7 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> return %0 : vector<2x1x2xf32> } + +// ----- + +//===----------------------------------------------------------------------===// +// vector.to_elements +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @to_elements_1d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> +// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32> +// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32> +// CHECK: return %[[V0]], %[[V1]] +func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { + %0:2 = vector.to_elements %arg0 : vector<2xf32> + return %0#0, %0#1 : f32, f32 +} + +// ----- + +// NOTE: We unroll multi-dimensional to_elements ops with pattern +// `UnrollToElements` and then convert the 1-D to_elements ops to llvm. + +// CHECK-LABEL: func @to_elements_2d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> +// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>> +// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>> +// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>> +// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32> +// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32> +// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32> +// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32> +// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]] +func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %arg0 : vector<2x2xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg new file mode 100644 index 0000000000000..3e9e8f8497624 --- /dev/null +++ b/mlir/test/Dialect/Vector/lit.local.cfg @@ -0,0 +1,2 @@ +# Skip the directory with input TD sequences. +config.excludes = ["td"] diff --git a/mlir/test/Dialect/Vector/td/unroll-elements.mlir b/mlir/test/Dialect/Vector/td/unroll-elements.mlir new file mode 100644 index 0000000000000..40a90a33b0ac4 --- /dev/null +++ b/mlir/test/Dialect/Vector/td/unroll-elements.mlir @@ -0,0 +1,11 @@ +module attributes {transform.with_named_sequence} { + transform.named_sequence @unroll_to_elements(%module_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f { + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.unroll_to_elements + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir new file mode 100644 index 0000000000000..9ec0d76599c41 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s +// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \ +// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s + +// CHECK-LABEL: func.func @to_elements_1d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> +// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1 +func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { + %0:2 = vector.to_elements %arg0 : vector<2xf32> + return %0#0, %0#1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: func.func @to_elements_2d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> +// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32> +// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32> +// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1 +func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %arg0 : vector<2x2xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index bb1598ee3efe5..d6596cd341df7 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements } }; +struct TestUnrollVectorToElements + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements) + + StringRef getArgument() const final { + return "test-unroll-vector-to-elements"; + } + StringRef getDescription() const final { + return "Test unrolling patterns for to_elements ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorToElementsLoweringPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestFoldArithExtensionIntoVectorContractPatterns : public PassWrapper> { @@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() { PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration(); diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index 5a648fe073315..28902b012f7cb 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -48,6 +48,8 @@ def non_configurable_patterns(): vector.ApplyLowerGatherPatternsOp() # CHECK: transform.apply_patterns.vector.unroll_from_elements vector.ApplyUnrollFromElementsPatternsOp() + # CHECK: transform.apply_patterns.vector.unroll_to_elements + vector.ApplyUnrollToElementsPatternsOp() # CHECK: transform.apply_patterns.vector.lower_scan vector.ApplyLowerScanPatternsOp() # CHECK: transform.apply_patterns.vector.lower_shape_cast