diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index f56124cb4fb95..b896506f29eef 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -293,6 +293,9 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank = 1, PatternBenefit benefit = 1); +void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where /// n > 1. void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index d74007f13a95b..8e36ead6993a8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorMultiReduction.cpp LowerVectorScan.cpp LowerVectorShapeCast.cpp + LowerVectorShuffle.cpp LowerVectorStep.cpp LowerVectorToElements.cpp LowerVectorToFromElementsToShuffleTree.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp new file mode 100644 index 0000000000000..78102f7325b9f --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp @@ -0,0 +1,110 @@ +//===- LowerVectorShuffle.cpp - Lower 'vector.shuffle' operation ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lowering of complex `vector.shuffle` operation to a +// set of simpler operations supported by LLVM/SPIR-V. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/PatternMatch.h" + +#define DEBUG_TYPE "vector-shuffle-lowering" + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +/// Lowers a `vector.shuffle` operation with mixed-size inputs to a new +/// `vector.shuffle` which promotes the smaller input to the larger vector size +/// and an updated version of the original `vector.shuffle`. +/// +/// Example: +/// +/// %0 = vector.shuffle %v1, %v2 [0, 2, 1, 3] : vector<2xf32>, vector<4xf32> +/// +/// is lowered to: +/// +/// %0 = vector.shuffle %v1, %v1 [0, 1, -1, -1] : +/// vector<2xf32>, vector<2xf32> +/// %1 = vector.shuffle %0, %v2 [0, 4, 1, 5] : +/// vector<4xf32>, vector<4xf32> +/// +/// Note: This transformation helps legalize vector.shuffle ops when lowering +/// to SPIR-V/LLVM, which don't support shuffle operations with mixed-size +/// inputs. +/// +struct MixedSizeInputShuffleOpRewrite final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, + PatternRewriter &rewriter) const override { + auto v1Type = shuffleOp.getV1VectorType(); + auto v2Type = shuffleOp.getV2VectorType(); + + // Only support 1-D shuffle for now. + if (v1Type.getRank() != 1 || v2Type.getRank() != 1) + return failure(); + + // Bail out if inputs don't have mixed sizes. + int64_t v1OrigNumElems = v1Type.getNumElements(); + int64_t v2OrigNumElems = v2Type.getNumElements(); + if (v1OrigNumElems == v2OrigNumElems) + return failure(); + + // Determine which input needs promotion. + bool promoteV1 = v1OrigNumElems < v2OrigNumElems; + Value inputToPromote = promoteV1 ? shuffleOp.getV1() : shuffleOp.getV2(); + VectorType promotedType = promoteV1 ? v2Type : v1Type; + int64_t origNumElems = promoteV1 ? v1OrigNumElems : v2OrigNumElems; + int64_t promotedNumElems = promoteV1 ? v2OrigNumElems : v1OrigNumElems; + + // Create a shuffle with a mask that preserves existing elements and fills + // up with poison. + SmallVector promoteMask(promotedNumElems, ShuffleOp::kPoisonIndex); + for (int64_t i = 0; i < origNumElems; ++i) + promoteMask[i] = i; + + Value promotedInput = rewriter.create( + shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote, + promoteMask); + + // Create the final shuffle with the promoted inputs. + Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1(); + Value promotedV2 = promoteV1 ? shuffleOp.getV2() : promotedInput; + + SmallVector newMask; + if (!promoteV1) { + newMask = to_vector(shuffleOp.getMask()); + } else { + // Adjust V2 indices to account for the new V1 size. + for (auto idx : shuffleOp.getMask()) { + int64_t newIdx = idx; + if (idx >= v1OrigNumElems) { + newIdx += promotedNumElems - v1OrigNumElems; + } + newMask.push_back(newIdx); + } + } + + rewriter.replaceOpWithNewOp( + shuffleOp, shuffleOp.getResultVectorType(), promotedV1, promotedV2, + newMask); + return success(); + } +}; +} // namespace + +void mlir::vector::populateVectorShuffleLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir b/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir new file mode 100644 index 0000000000000..a137811fa367c --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-shuffle-lowering.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt %s --test-vector-shuffle-lowering --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @shuffle_smaller_lhs_arbitrary +// CHECK-SAME: %[[LHS:.*]]: vector<2xf32>, %[[RHS:.*]]: vector<4xf32> +func.func @shuffle_smaller_lhs_arbitrary(%lhs: vector<2xf32>, %rhs: vector<4xf32>) -> vector<5xf32> { + // CHECK: %[[PROMOTE_LHS:.*]] = vector.shuffle %[[LHS]], %[[LHS]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32> + // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_LHS]], %[[RHS]] [1, 5, 0, 6, 7] : vector<4xf32>, vector<4xf32> + // CHECK: return %[[RESULT]] : vector<5xf32> + %0 = vector.shuffle %lhs, %rhs [1, 3, 0, 4, 5] : vector<2xf32>, vector<4xf32> + return %0 : vector<5xf32> +} + +// ----- + +// CHECK-LABEL: func.func @shuffle_smaller_rhs_arbitrary +// CHECK-SAME: %[[LHS:.*]]: vector<4xi32>, %[[RHS:.*]]: vector<2xi32> +func.func @shuffle_smaller_rhs_arbitrary(%lhs: vector<4xi32>, %rhs: vector<2xi32>) -> vector<6xi32> { + // CHECK: %[[PROMOTE_RHS:.*]] = vector.shuffle %[[RHS]], %[[RHS]] [0, 1, -1, -1] : vector<2xi32>, vector<2xi32> + // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[PROMOTE_RHS]] [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<4xi32> + // CHECK: return %[[RESULT]] : vector<6xi32> + %0 = vector.shuffle %lhs, %rhs [3, 5, 1, 4, 0, 2] : vector<4xi32>, vector<2xi32> + return %0 : vector<6xi32> +} + +// ----- + +// CHECK-LABEL: func.func @shuffle_smaller_lhs_concat +// CHECK-SAME: %[[LHS:.*]]: vector<3xf64>, %[[RHS:.*]]: vector<5xf64> +func.func @shuffle_smaller_lhs_concat(%lhs: vector<3xf64>, %rhs: vector<5xf64>) -> vector<8xf64> { + // CHECK: %[[PROMOTE_LHS:.*]] = vector.shuffle %[[LHS]], %[[LHS]] [0, 1, 2, -1, -1] : vector<3xf64>, vector<3xf64> + // CHECK: %[[RESULT:.*]] = vector.shuffle %[[PROMOTE_LHS]], %[[RHS]] [0, 1, 2, 5, 6, 7, 8, 9] : vector<5xf64>, vector<5xf64> + // CHECK: return %[[RESULT]] : vector<8xf64> + %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3, 4, 5, 6, 7] : vector<3xf64>, vector<5xf64> + return %0 : vector<8xf64> +} + +// ----- + +// CHECK-LABEL: func.func @shuffle_smaller_rhs_concat +// CHECK-SAME: %[[LHS:.*]]: vector<4xi16>, %[[RHS:.*]]: vector<2xi16> +func.func @shuffle_smaller_rhs_concat(%lhs: vector<4xi16>, %rhs: vector<2xi16>) -> vector<6xi16> { + // CHECK: %[[PROMOTE_RHS:.*]] = vector.shuffle %[[RHS]], %[[RHS]] [0, 1, -1, -1] : vector<2xi16>, vector<2xi16> + // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[PROMOTE_RHS]] [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<4xi16> + // CHECK: return %[[RESULT]] : vector<6xi16> + %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3, 4, 5] : vector<4xi16>, vector<2xi16> + return %0 : vector<6xi16> +} + +// ----- + +// Test that shuffles with same size inputs are not modified. + +// CHECK-LABEL: func.func @negative_shuffle_same_input_sizes +// CHECK-SAME: %[[LHS:.*]]: vector<4xf32>, %[[RHS:.*]]: vector<4xf32> +func.func @negative_shuffle_same_input_sizes(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<6xf32> { + // CHECK-NOT: vector.shuffle %[[LHS]], %[[LHS]] + // CHECK-NOT: vector.shuffle %[[RHS]], %[[RHS]] + // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32> + // CHECK: return %[[RESULT]] : vector<6xf32> + %0 = vector.shuffle %lhs, %rhs [0, 1, 4, 5, 2, 6] : vector<4xf32>, vector<4xf32> + return %0 : vector<6xf32> +} + +// ----- + +// Test that multi-dimensional shuffles are not modified. + +// CHECK-LABEL: func.func @negative_shuffle_2d_vectors +// CHECK-SAME: %[[LHS:.*]]: vector<2x4xf32>, %[[RHS:.*]]: vector<3x4xf32> +func.func @negative_shuffle_2d_vectors(%lhs: vector<2x4xf32>, %rhs: vector<3x4xf32>) -> vector<4x4xf32> { + // CHECK-NOT: vector.shuffle %[[LHS]], %[[LHS]] + // CHECK-NOT: vector.shuffle %[[RHS]], %[[RHS]] + // CHECK: %[[RESULT:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32> + // CHECK: return %[[RESULT]] : vector<4x4xf32> + %0 = vector.shuffle %lhs, %rhs [0, 1, 2, 3] : vector<2x4xf32>, vector<3x4xf32> + return %0 : vector<4x4xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 72dd103b33f75..79bfc9bbcda71 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -994,6 +994,22 @@ struct TestEliminateVectorMasks VscaleRange{vscaleMin, vscaleMax}); } }; + +struct TestVectorShuffleLowering + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorShuffleLowering) + + StringRef getArgument() const final { return "test-vector-shuffle-lowering"; } + StringRef getDescription() const final { + return "Test lowering patterns for vector.shuffle with mixed-size inputs"; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorShuffleLoweringPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; } // namespace namespace mlir { @@ -1023,6 +1039,8 @@ void registerTestVectorLowerings() { PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration();