diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 77e26cca1607f..c60fa3b85b396 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -3018,6 +3018,7 @@ def Vector_ScanOp : def Vector_StepOp : Vector_Op<"step", [ Pure, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]> { let summary = "A linear sequence of values from 0 to N"; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index e8ecb0c0be846..79786f33a2d46 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -809,6 +809,81 @@ struct UnrollBroadcastPattern : public OpRewritePattern { vector::UnrollVectorOptions options; }; +/// This pattern unrolls `vector.step` operations according to the provided +/// target unroll shape. It decomposes a large step vector into smaller step +/// vectors (segments) and assembles the result by inserting each computed +/// segment into the appropriate offset of the original vector. +/// +/// The pattern does not support scalable vectors and will fail to match them. +/// +/// For each segment, it adds the base step vector and the segment's offset, +/// then inserts the result into the output vector at the corresponding +/// position. +/// +/// Example: +/// Given a step operation: +/// %0 = vector.step : vector<8xindex> +/// +/// and a target unroll shape of <4>, the pattern produces: +/// +/// %base = vector.step : vector<4xindex> +/// %zero = arith.constant dense<0> : vector<8xindex> +/// %result0 = vector.insert_strided_slice %base, %zero +/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex> +/// %offset = arith.constant dense<4> : vector<4xindex> +/// %segment1 = arith.addi %base, %offset : vector<4xindex> +/// %result1 = vector.insert_strided_slice %segment1, %result0 +/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex> +/// +struct UnrollStepPattern : public OpRewritePattern { + UnrollStepPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(vector::StepOp stepOp, + PatternRewriter &rewriter) const override { + std::optional> targetShape = + getTargetShape(options, stepOp); + if (!targetShape) + return failure(); + + VectorType vecType = stepOp.getType(); + if (vecType.isScalable()) { + // Scalable vectors are not supported by this pattern. + return failure(); + } + int64_t originalSize = vecType.getShape()[0]; + Location loc = stepOp.getLoc(); + SmallVector strides(1, 1); + + Value result = arith::ConstantOp::create(rewriter, loc, vecType, + rewriter.getZeroAttr(vecType)); + + auto targetVecType = + VectorType::get(*targetShape, vecType.getElementType()); + Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType); + for (const SmallVector &offsets : + StaticTileOffsetRange({originalSize}, *targetShape)) { + Value bcastOffset = arith::ConstantOp::create( + rewriter, loc, targetVecType, + DenseElementsAttr::get( + targetVecType, + IntegerAttr::get(targetVecType.getElementType(), offsets[0]))); + Value tileStep = + arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset); + + result = rewriter.createOrFold( + loc, tileStep, result, offsets, strides); + } + rewriter.replaceOp(stepOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -818,6 +893,6 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollContractionPattern, UnrollElementwisePattern, UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, - UnrollStorePattern, UnrollBroadcastPattern>( + UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>( patterns.getContext(), options, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e129cd5c40b9c..35db14e0f7f1d 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -420,3 +420,23 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) { // CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16> // CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16> // CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16> + + +func.func @vector_step() -> vector<32xindex> { + %0 = vector.step : vector<32xindex> + return %0 : vector<32xindex> +} +// CHECK-LABEL: func @vector_step +// CHECK: %[[CST:.*]] = arith.constant dense<24> : vector<8xindex> +// CHECK: %[[CST0:.*]] = arith.constant dense<16> : vector<8xindex> +// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex> +// CHECK: %[[CST2:.*]] = arith.constant dense<0> : vector<32xindex> +// CHECK: %[[STEP:.*]] = vector.step : vector<8xindex> +// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP]], %[[CST2]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex> +// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP]], %[[CST1]] : vector<8xindex> +// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex> +// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP]], %[[CST0]] : vector<8xindex> +// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex> +// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex> +// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex> +// CHECK: return %[[INS3]] : vector<32xindex> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index bb1598ee3efe5..1cd092cec2b81 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success(isa(op)); })); + populateVectorUnrollPatterns(patterns, + UnrollVectorOptions() + .setNativeShape(ArrayRef{8}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{1, 3, 4, 2})