diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index e215be49b74ef..e7226f4a6ac0f 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -128,19 +128,6 @@ struct UnrollVectorOptions { }; return *this; } - - /// Function that returns the traversal order (in terms of "for loop order", - /// i.e. slowest varying dimension to fastest varying dimension) that shoudl - /// be used when unrolling the given operation into units of the native vector - /// size. - using UnrollTraversalOrderFnType = - std::function>(Operation *op)>; - UnrollTraversalOrderFnType traversalOrderCallback = nullptr; - UnrollVectorOptions & - setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) { - traversalOrderCallback = std::move(traversalOrderFn); - return *this; - } }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp index dbca588f87f81..7f00788d888b2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp @@ -15,11 +15,8 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Interfaces/VectorInterfaces.h" -#include "mlir/Support/MathExtras.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" -#include #define DEBUG_TYPE "vector-unrolling" @@ -39,78 +36,20 @@ static SmallVector getVectorOffset(ArrayRef originalShape, return elementOffsets; } -/// A functor that accomplishes the same thing as `getVectorOffset` but allows -/// for reordering the traversal of the dimensions. The order of traversal is -/// given in "for loop order" (outer to inner). -namespace { -class DecomposeShapeIterator { -private: - SmallVector vectorShape; - SmallVector loopOrder; - SmallVector sliceStrides; - int64_t maxIndexVal{1}; - -public: - DecomposeShapeIterator(ArrayRef originalShape, - ArrayRef targetShape, - ArrayRef loopOrder) - : vectorShape(targetShape.begin(), targetShape.end()), - loopOrder(loopOrder.begin(), loopOrder.end()), - sliceStrides(originalShape.size()) { - // Compute the count for each dimension. - SmallVector sliceDimCounts(originalShape.size()); - for (unsigned r = 0; r < originalShape.size(); ++r) { - sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]); - maxIndexVal *= sliceDimCounts[r]; - } - - // Reversing "loop order" gives dimensions from fastest varying to slowest - // varying (smallest stride to largest stride). - int64_t accum = 1; - for (auto idx : llvm::reverse(loopOrder)) { - sliceStrides[idx] = accum; - accum *= sliceDimCounts[idx]; - } - } - - // Turn the linear index into a d-tuple based on units of vectors of size - // `vectorShape`. The linear index is assumed to represent traversal of the - // dimensions based on `order`. - SmallVector delinearize(int64_t index) const { - // Traverse in for loop order (largest stride to smallest stride). - SmallVector vectorOffsets(sliceStrides.size()); - for (auto idx : loopOrder) { - vectorOffsets[idx] = index / sliceStrides[idx]; - index %= sliceStrides[idx]; - } - return vectorOffsets; - } - - int64_t maxIndex() const { return maxIndexVal; } - - /// Return the offset within d-tuple based on the ordering given by - /// `loopOrder`. - SmallVector getVectorOffset(int64_t index) const { - SmallVector vectorOffsets = delinearize(index); - SmallVector elementOffsets = - computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets); - return elementOffsets; - } -}; -} // namespace - /// Compute the indices of the slice `index` for a tranfer op. -static SmallVector sliceTransferIndices(ArrayRef elementOffsets, - ArrayRef indices, - AffineMap permutationMap, - Location loc, - OpBuilder &builder) { +static SmallVector +sliceTransferIndices(int64_t index, ArrayRef originalShape, + ArrayRef targetShape, ArrayRef indices, + AffineMap permutationMap, Location loc, + OpBuilder &builder) { MLIRContext *ctx = builder.getContext(); auto isBroadcast = [](AffineExpr expr) { if (auto constExpr = expr.dyn_cast()) return constExpr.getValue() == 0; return false; }; + SmallVector elementOffsets = + getVectorOffset(originalShape, targetShape, index); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. SmallVector slicedIndices(indices.begin(), indices.end()); for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { @@ -160,20 +99,6 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { return targetShape; } -static SmallVector -getUnrollOrder(unsigned numLoops, Operation *op, - const vector::UnrollVectorOptions &options) { - SmallVector loopOrder = - llvm::to_vector(llvm::seq(0, static_cast(numLoops))); - if (options.traversalOrderCallback != nullptr) { - Optional> order = options.traversalOrderCallback(op); - if (order.hasValue()) { - loopOrder = std::move(*order); - } - } - return loopOrder; -} - namespace { struct UnrollTransferReadPattern @@ -197,7 +122,8 @@ struct UnrollTransferReadPattern Location loc = readOp.getLoc(); ArrayRef originalSize = readOp.getVectorType().getShape(); SmallVector ratio = *shapeRatio(originalSize, *targetShape); - + // Compute shape ratio of 'shape' and 'sizes'. + int64_t sliceCount = computeMaxLinearIndex(ratio); // Prepare the result vector; Value result = rewriter.create( loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); @@ -205,22 +131,17 @@ struct UnrollTransferReadPattern VectorType::get(*targetShape, sourceVectorType.getElementType()); SmallVector originalIndices(readOp.getIndices().begin(), readOp.getIndices().end()); - - SmallVector loopOrder = - getUnrollOrder(ratio.size(), readOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); - for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { - SmallVector elementOffsets = - indexToOffsets.getVectorOffset(i); + for (int64_t i = 0; i < sliceCount; i++) { SmallVector indices = - sliceTransferIndices(elementOffsets, originalIndices, + sliceTransferIndices(i, originalSize, *targetShape, originalIndices, readOp.getPermutationMap(), loc, rewriter); auto slicedRead = rewriter.create( loc, targetType, readOp.getSource(), indices, readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr()); + SmallVector elementOffsets = + getVectorOffset(originalSize, *targetShape, i); result = rewriter.create( loc, slicedRead, result, elementOffsets, strides); } @@ -253,21 +174,20 @@ struct UnrollTransferWritePattern SmallVector strides(targetShape->size(), 1); Location loc = writeOp.getLoc(); ArrayRef originalSize = sourceVectorType.getShape(); + SmallVector ratio = *shapeRatio(originalSize, *targetShape); + // Compute shape ratio of 'shape' and 'sizes'. + int64_t sliceCount = computeMaxLinearIndex(ratio); SmallVector originalIndices(writeOp.getIndices().begin(), writeOp.getIndices().end()); - - SmallVector loopOrder = - getUnrollOrder(originalIndices.size(), writeOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); Value resultTensor; - for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) { + for (int64_t i = 0; i < sliceCount; i++) { SmallVector elementOffsets = - indexToOffsets.getVectorOffset(i); + getVectorOffset(originalSize, *targetShape, i); Value slicedVector = rewriter.create( loc, writeOp.getVector(), elementOffsets, *targetShape, strides); + SmallVector indices = - sliceTransferIndices(elementOffsets, originalIndices, + sliceTransferIndices(i, originalSize, *targetShape, originalIndices, writeOp.getPermutationMap(), loc, rewriter); Operation *slicedWrite = rewriter.create( loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), @@ -318,6 +238,8 @@ struct UnrollContractionPattern SmallVector originalSize = *contractOp.getShapeForUnroll(); SmallVector ratio = *shapeRatio(originalSize, *targetShape); + // Compute shape ratio of 'shape' and 'sizes'. + int64_t sliceCount = computeMaxLinearIndex(ratio); Location loc = contractOp.getLoc(); unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex]; @@ -325,14 +247,9 @@ struct UnrollContractionPattern SmallVector, Value, llvm::DenseMap, unsigned, OffsetMapInfo>> accCache; - - SmallVector loopOrder = getUnrollOrder( - contractOp.getIndexingMaps().size(), contractOp, options); - DecomposeShapeIterator indexToOffsets(originalSize, *targetShape, - loopOrder); - const int64_t sliceCount = indexToOffsets.maxIndex(); for (int64_t i = 0; i < sliceCount; i++) { - SmallVector offsets = indexToOffsets.getVectorOffset(i); + SmallVector offsets = + getVectorOffset(originalSize, *targetShape, i); SmallVector slicesOperands(contractOp.getNumOperands()); // Helper to coompute the new shape of each operand and extract the slice. diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir index f6f218b6e39eb..3d0affd2a4be0 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -1,5 +1,4 @@ // RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s -// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns=reverse-unroll-order --split-input-file | FileCheck %s --check-prefix=ORDER // CHECK-LABEL: func @transfer_read_unroll // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -14,19 +13,6 @@ // CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> // CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32> -// ORDER-LABEL: func @transfer_read_unroll -// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index -// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index -// ORDER: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> -// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// ORDER-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> -// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// ORDER-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> -// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// ORDER-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> -// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// ORDER-NEXT: return %[[VEC3]] : vector<4x4xf32> - func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 @@ -47,19 +33,6 @@ func.func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> { // CHECK-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> // CHECK-NEXT: return -// ORDER-LABEL: func @transfer_write_unroll -// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index -// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index -// ORDER: %[[S0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// ORDER-NEXT: vector.transfer_write %[[S0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> -// ORDER-NEXT: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// ORDER-NEXT: vector.transfer_write %[[S1]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> -// ORDER-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// ORDER-NEXT: vector.transfer_write %[[S2]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> -// ORDER-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// ORDER-NEXT: vector.transfer_write %[[S3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> -// ORDER-NEXT: return - func.func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> @@ -249,25 +222,6 @@ func.func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) -> // CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref, vector<2x2xf32> // CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> // CHECK-NEXT: return %[[VEC5]] : vector<6x4xf32> - -// ORDER-LABEL: func @transfer_read_unroll_different_rank -// ORDER-DAG: %[[C4:.*]] = arith.constant 4 : index -// ORDER-DAG: %[[C2:.*]] = arith.constant 2 : index -// ORDER-DAG: %[[C0:.*]] = arith.constant 0 : index -// ORDER: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref, vector<2x2xf32> -// ORDER-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// ORDER-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref, vector<2x2xf32> -// ORDER-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// ORDER-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref, vector<2x2xf32> -// ORDER-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// ORDER-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref, vector<2x2xf32> -// ORDER-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// ORDER-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref, vector<2x2xf32> -// ORDER-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// ORDER-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref, vector<2x2xf32> -// ORDER-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// ORDER-NEXT: return %[[VEC5]] : vector<6x4xf32> - #map0 = affine_map<(d0, d1, d2) -> (d2, d0)> func.func @transfer_read_unroll_different_rank(%arg0 : memref) -> vector<6x4xf32> { %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 272c04a34e8ea..3b0aeb48665f4 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -1,156 +1,50 @@ // RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s -// RUN: mlir-opt %s -test-vector-unrolling-patterns="unroll-based-on-type unroll-order=2,0,1" --split-input-file | FileCheck %s --check-prefix=ORDER -func.func @vector_contract_f32(%lhs : vector<8x4xf32>, %rhs : vector<8x4xf32>, +func.func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>, %init : vector<8x8xf32>) -> vector<8x8xf32> { %0 = vector.contract {indexing_maps = [affine_map<(i, j, k) -> (i, k)>, affine_map<(i, j, k) -> (j, k)>, affine_map<(i, j, k) -> (i, j)>], iterator_types = ["parallel", "parallel", "reduction"]} - %lhs, %rhs, %init : vector<8x4xf32>, vector<8x4xf32> into vector<8x8xf32> + %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32> return %0 : vector<8x8xf32> } // CHECK-LABEL: func @vector_contract_f32 -// CHECK-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32> - -// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// CHECK-SAME: offsets = [0, 0] -// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// CHECK-SAME: offsets = [0, 0] -// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] -// CHECK-SAME: offsets = [0, 0] -// CHECK: [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// CHECK: vector.contract { // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// CHECK-SAME: offsets = [0, 2] -// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// CHECK-SAME: offsets = [0, 2] -// CHECK: [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]] +// CHECK: vector.contract { // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// CHECK-SAME: offsets = [0, 0] -// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// CHECK-SAME: offsets = [4, 0] -// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] -// CHECK-SAME: offsets = [0, 4] -// CHECK: [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// CHECK: vector.contract { // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// CHECK-SAME: offsets = [0, 2] -// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// CHECK-SAME: offsets = [4, 2] -// CHECK: [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]] +// CHECK: vector.contract { // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// CHECK-SAME: offsets = [4, 0] -// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// CHECK-SAME: offsets = [0, 0] -// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] -// CHECK-SAME: offsets = [4, 0] -// CHECK: [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// CHECK: vector.contract { // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// CHECK-SAME: offsets = [4, 2] -// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// CHECK-SAME: offsets = [0, 2] -// CHECK: [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum5]] +// CHECK: vector.contract { // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// CHECK-SAME: offsets = [4, 0] -// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// CHECK-SAME: offsets = [4, 0] -// CHECK: [[c:%.+]] = vector.extract_strided_slice [[arg2]] -// CHECK-SAME: offsets = [4, 4] -// CHECK: [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] +// CHECK: vector.contract { // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// CHECK: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// CHECK-SAME: offsets = [4, 2] -// CHECK: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// CHECK-SAME: offsets = [4, 2] -// CHECK: [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum7]] +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { // CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - // CHECK: return -// ORDER-LABEL: func @vector_contract_f32 -// ORDER-SAME: [[arg0:%.+]]: vector<8x4xf32>, [[arg1:%.+]]: vector<8x4xf32>, [[arg2:%.+]]: vector<8x8xf32> - -// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// ORDER-SAME: offsets = [0, 0] -// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// ORDER-SAME: offsets = [0, 0] -// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] -// ORDER-SAME: offsets = [0, 0] -// ORDER: [[accum1:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] -// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// ORDER-SAME: offsets = [0, 0] -// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// ORDER-SAME: offsets = [4, 0] -// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] -// ORDER-SAME: offsets = [0, 4] -// ORDER: [[accum2:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] -// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// ORDER-SAME: offsets = [4, 0] -// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// ORDER-SAME: offsets = [0, 0] -// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] -// ORDER-SAME: offsets = [4, 0] -// ORDER: [[accum3:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] -// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// ORDER-SAME: offsets = [4, 0] -// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// ORDER-SAME: offsets = [4, 0] -// ORDER: [[c:%.+]] = vector.extract_strided_slice [[arg2]] -// ORDER-SAME: offsets = [4, 4] -// ORDER: [[accum4:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[c]] -// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// ORDER-SAME: offsets = [0, 2] -// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// ORDER-SAME: offsets = [0, 2] -// ORDER: [[accum5:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum1]] -// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// ORDER-SAME: offsets = [0, 2] -// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// ORDER-SAME: offsets = [4, 2] -// ORDER: [[accum6:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum2]] -// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// ORDER-SAME: offsets = [4, 2] -// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// ORDER-SAME: offsets = [0, 2] -// ORDER: [[accum7:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum3]] -// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// ORDER: [[a:%.+]] = vector.extract_strided_slice [[arg0]] -// ORDER-SAME: offsets = [4, 2] -// ORDER: [[b:%.+]] = vector.extract_strided_slice [[arg1]] -// ORDER-SAME: offsets = [4, 2] -// ORDER: [[accum8:%.+]] = vector.contract {{{.*}}} [[a]], [[b]], [[accum4]] -// ORDER-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> - -// ORDER: return - - - func.func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>, %init : vector<8x8xf16>) -> vector<8x8xf16> { %0 = vector.contract @@ -264,4 +158,3 @@ func.func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> { // CHECK: %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32> // CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32> // CHECK: return %[[V7]] : vector<2x3x8x4xf32> - diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index cbec267734eba..a81aa536df4ad 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -18,7 +18,6 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Pass/Pass.h" @@ -323,18 +322,12 @@ struct TestVectorUnrollingPatterns } return nativeShape; }; - - UnrollVectorOptions opts; - opts.setNativeShapeFn(nativeShapeFn) - .setFilterConstraint( - [](Operation *op) { return success(isa(op)); }); - if (!unrollOrder.empty()) { - opts.setUnrollTraversalOrderFn([this](Operation *op) - -> Optional> { - return SmallVector{unrollOrder.begin(), unrollOrder.end()}; - }); - } - populateVectorUnrollPatterns(patterns, opts); + populateVectorUnrollPatterns(patterns, + UnrollVectorOptions() + .setNativeShapeFn(nativeShapeFn) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); } else { populateVectorUnrollPatterns( patterns, UnrollVectorOptions() @@ -347,10 +340,6 @@ struct TestVectorUnrollingPatterns (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } - ListOption unrollOrder{*this, "unroll-order", - llvm::cl::desc("set the unroll order"), - llvm::cl::ZeroOrMore}; - Option unrollBasedOnType{ *this, "unroll-based-on-type", llvm::cl::desc("Set the unroll factor based on type of the operation"), @@ -483,11 +472,6 @@ struct TestVectorTransferUnrollingPatterns MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestVectorTransferUnrollingPatterns) - TestVectorTransferUnrollingPatterns() = default; - TestVectorTransferUnrollingPatterns( - const TestVectorTransferUnrollingPatterns &pass) - : PassWrapper(pass) {} - void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -501,36 +485,17 @@ struct TestVectorTransferUnrollingPatterns void runOnOperation() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - UnrollVectorOptions opts; - opts.setNativeShape(ArrayRef{2, 2}) - .setFilterConstraint([](Operation *op) { - return success( - isa(op)); - }); - if (reverseUnrollOrder.getValue()) { - opts.setUnrollTraversalOrderFn( - [](Operation *op) -> Optional> { - int64_t numLoops = 0; - if (auto readOp = dyn_cast(op)) - numLoops = readOp.getVectorType().getRank(); - else if (auto writeOp = dyn_cast(op)) - numLoops = writeOp.getVectorType().getRank(); - else - return None; - auto order = llvm::reverse(llvm::seq(0, numLoops)); - return llvm::to_vector(order); - }); - } - populateVectorUnrollPatterns(patterns, opts); + populateVectorUnrollPatterns( + patterns, + UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2}) + .setFilterConstraint([](Operation *op) { + return success( + isa(op)); + })); populateVectorToVectorCanonicalizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } - - Option reverseUnrollOrder{ - *this, "reverse-unroll-order", - llvm::cl::desc( - "reverse the order of unrolling of vector transfer operations"), - llvm::cl::init(false)}; }; struct TestVectorTransferFullPartialSplitPatterns