Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3018,6 +3018,7 @@ def Vector_ScanOp :

def Vector_StepOp : Vector_Op<"step", [
Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]> {
let summary = "A linear sequence of values from 0 to N";
Expand Down
77 changes: 76 additions & 1 deletion mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,81 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
vector::UnrollVectorOptions options;
};

/// This pattern unrolls `vector.step` operations according to the provided
/// target unroll shape. It decomposes a large step vector into smaller step
/// vectors (segments) and assembles the result by inserting each computed
/// segment into the appropriate offset of the original vector.
///
/// The pattern does not support scalable vectors and will fail to match them.
///
/// For each segment, it adds the base step vector and the segment's offset,
/// then inserts the result into the output vector at the corresponding
/// position.
///
/// Example:
/// Given a step operation:
/// %0 = vector.step : vector<8xindex>
///
/// and a target unroll shape of <4>, the pattern produces:
///
/// %base = vector.step : vector<4xindex>
/// %zero = arith.constant dense<0> : vector<8xindex>
/// %result0 = vector.insert_strided_slice %base, %zero
/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
/// %offset = arith.constant dense<4> : vector<4xindex>
/// %segment1 = arith.addi %base, %offset : vector<4xindex>
/// %result1 = vector.insert_strided_slice %segment1, %result0
/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
///
struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
UnrollStepPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}

LogicalResult matchAndRewrite(vector::StepOp stepOp,
PatternRewriter &rewriter) const override {
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, stepOp);
if (!targetShape)
return failure();

VectorType vecType = stepOp.getType();
if (vecType.isScalable()) {
// Scalable vectors are not supported by this pattern.
return failure();
}
int64_t originalSize = vecType.getShape()[0];
Location loc = stepOp.getLoc();
SmallVector<int64_t> strides(1, 1);

Value result = arith::ConstantOp::create(rewriter, loc, vecType,
rewriter.getZeroAttr(vecType));

auto targetVecType =
VectorType::get(*targetShape, vecType.getElementType());
Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
for (const SmallVector<int64_t> &offsets :
StaticTileOffsetRange({originalSize}, *targetShape)) {
Value bcastOffset = arith::ConstantOp::create(
rewriter, loc, targetVecType,
DenseElementsAttr::get(
targetVecType,
IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
Value tileStep =
arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);

result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tileStep, result, offsets, strides);
}
rewriter.replaceOp(stepOp, result);
return success();
}

private:
vector::UnrollVectorOptions options;
};

} // namespace

void mlir::vector::populateVectorUnrollPatterns(
Expand All @@ -818,6 +893,6 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern>(
UnrollStorePattern, UnrollBroadcastPattern, UnrollStepPattern>(
patterns.getContext(), options, benefit);
}
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,23 @@ func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>


func.func @vector_step() -> vector<32xindex> {
%0 = vector.step : vector<32xindex>
return %0 : vector<32xindex>
}
// CHECK-LABEL: func @vector_step
// CHECK: %[[CST:.*]] = arith.constant dense<24> : vector<8xindex>
// CHECK: %[[CST0:.*]] = arith.constant dense<16> : vector<8xindex>
// CHECK: %[[CST1:.*]] = arith.constant dense<8> : vector<8xindex>
// CHECK: %[[CST2:.*]] = arith.constant dense<0> : vector<32xindex>
// CHECK: %[[STEP:.*]] = vector.step : vector<8xindex>
// CHECK: %[[INS0:.*]] = vector.insert_strided_slice %[[STEP]], %[[CST2]] {offsets = [0], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: %[[ADD1:.*]] = arith.addi %[[STEP]], %[[CST1]] : vector<8xindex>
// CHECK: %[[INS1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[INS0]] {offsets = [8], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: %[[ADD2:.*]] = arith.addi %[[STEP]], %[[CST0]] : vector<8xindex>
// CHECK: %[[INS2:.*]] = vector.insert_strided_slice %[[ADD2]], %[[INS1]] {offsets = [16], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: return %[[INS3]] : vector<32xindex>
6 changes: 6 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
populateVectorUnrollPatterns(patterns,
UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{8})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
Expand Down