-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][Vector] Add unroll pattern for vector.shape_cast #167738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cd8b818
73512fd
9b4191a
d4ea820
edf3dd3
1778d99
18ed975
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1003,6 +1003,195 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> { | |
| vector::UnrollVectorOptions options; | ||
| }; | ||
|
|
||
| /// Checks whether extractShape is a contiguous slice of shape. | ||
| /// For extractShape to be contiguous in shape: | ||
| /// 1) All but the leading dimension of extractShape and shape must match | ||
| /// exactly. 2) The total number of elements in shape must be evenly divisible | ||
| /// by | ||
| /// the total number of elements in extractShape. | ||
| /// Examples: | ||
| /// isContiguous([4, 4], [8, 4]) == true | ||
| /// isContiguous([2, 4], [8, 4]) == true | ||
| /// isContiguous([2, 2], [8, 4]) == false | ||
| /// Removes leading unit dimensions to handle cases like: | ||
| /// isContiguous([1, 16], [1, 32]) == true | ||
| static bool isContiguous(ArrayRef<int64_t> extractShape, | ||
| ArrayRef<int64_t> shape) { | ||
|
|
||
| if (extractShape.size() > shape.size()) | ||
| return false; | ||
|
|
||
| while (!extractShape.empty() && extractShape.front() == 1) { | ||
| extractShape = extractShape.drop_front(); | ||
| } | ||
|
|
||
| while (!shape.empty() && shape.front() == 1) { | ||
| shape = shape.drop_front(); | ||
| } | ||
|
|
||
| size_t rankDiff = shape.size() - extractShape.size(); | ||
| if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1))) | ||
| return false; | ||
|
|
||
| int64_t extractElements = ShapedType::getNumElements(extractShape); | ||
| int64_t shapeElements = ShapedType::getNumElements(shape); | ||
| return shapeElements % extractElements == 0; | ||
| } | ||
|
|
||
| /// Determines what shape to use with `vector.extract_strided_slice` to extract | ||
| /// a contiguous memory region from a source vector. The extraction must be | ||
| /// contiguous and contain exactly the specified number of elements. If such an | ||
| /// extraction shape cannot be determined, returns std::nullopt. | ||
| /// EXAMPLE 1: | ||
| /// sourceShape = [16], targetElements = 8 | ||
| /// Working right-to-left: | ||
| /// - Take min(8, 16) = 8 from only dim → extractShape = [8], | ||
| /// remaining = 8/8 = 1 | ||
| /// Result: [8] | ||
| /// | ||
| /// EXAMPLE 2: | ||
| /// sourceShape = [4, 4], targetElements = 8 | ||
| /// Working right-to-left: | ||
| /// - Take min(8, 4) = 4 from last dim → extractShape = [4], | ||
| /// remaining = 8/4 = 2 | ||
| /// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4], | ||
| /// remaining = 2/2 = 1 | ||
| /// Result: [2, 4] | ||
| static std::optional<SmallVector<int64_t>> | ||
| calculateSourceExtractShape(ArrayRef<int64_t> sourceShape, | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| int64_t targetElements) { | ||
| SmallVector<int64_t> extractShape; | ||
| int64_t remainingElements = targetElements; | ||
|
|
||
| // Build extract shape from innermost dimension outward to ensure contiguity. | ||
| for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) { | ||
| int64_t takeFromDim = std::min(remainingElements, sourceShape[i]); | ||
| extractShape.insert(extractShape.begin(), takeFromDim); | ||
|
|
||
| if (remainingElements % takeFromDim != 0) | ||
| return std::nullopt; // Not evenly divisible. | ||
| remainingElements /= takeFromDim; | ||
| } | ||
|
|
||
| // Fill remaining dimensions with 1. | ||
| while (extractShape.size() < sourceShape.size()) | ||
| extractShape.insert(extractShape.begin(), 1); | ||
|
|
||
| if (ShapedType::getNumElements(extractShape) != targetElements) | ||
| return std::nullopt; | ||
|
|
||
| return extractShape; | ||
| } | ||
|
|
||
| // Convert result offsets to source offsets via linear position. | ||
| static SmallVector<int64_t> | ||
| calculateSourceOffsets(ArrayRef<int64_t> resultOffsets, | ||
| ArrayRef<int64_t> sourceShape, | ||
| ArrayRef<int64_t> resultShape) { | ||
|
Comment on lines
+1087
to
+1090
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this assert that the number of elements in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wouldn't shape_cast op verifier take care of that?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Of making sure that no invalid inputs are ever passed to this method? I doubt that ;-)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant if the shapeCast is not maintaining the semantics of NumElements(src) == NumElements(dst) how is it even a valid instruction? I tried it locally and it does fail in the verifier here https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L6258
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
shapeCast verifier will indeed maintain that, but only for shapeCast Ops. However, how do you make sure that the inputs used in this method always come from shapeCast? Perhaps I am missing something, but what is stopping anyone/anything from using this method with some random arrays that don't come from shapeCast? |
||
| // Convert result offsets to linear position. | ||
| int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape)); | ||
| // Convert linear position to source offsets. | ||
| return delinearize(linearIndex, computeStrides(sourceShape)); | ||
| } | ||
|
|
||
| /// This pattern unrolls `vector.shape_cast` operations according to the | ||
| /// provided target unroll shape. It unrolls a large shape cast into smaller | ||
| /// shape casts by extracting contiguous slices from the source vector, casting | ||
| /// each slice to the target shape, and assembling the result by inserting each | ||
| /// computed segment into the appropriate offset of the result vector. | ||
| /// | ||
| /// This pattern only applies when contiguous slices can be extracted from the | ||
| /// source vector and inserted into the result vector such that each slice | ||
| /// remains a valid vector (and not decompose to scalars). In these cases, the | ||
| /// unrolling proceeds as: | ||
| /// vector.extract_strided_slice -> vector.shape_cast (on the slice) -> | ||
| /// vector.insert_strided_slice. | ||
| /// | ||
| /// Example: | ||
| /// Given a shape cast operation: | ||
| /// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32> | ||
| /// | ||
| /// and a target unroll shape of <2x4>, the pattern produces: | ||
| /// | ||
| /// %zero = arith.constant dense<0.0> : vector<4x4xf32> | ||
| /// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1] | ||
| /// : vector<8x2xf32> to vector<4x2xf32> | ||
| /// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32> | ||
| /// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1] | ||
| /// : vector<2x4xf32> into vector<4x4xf32> | ||
| /// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1] | ||
| /// : vector<8x2xf32> to vector<4x2xf32> | ||
| /// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32> | ||
| /// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1] | ||
| /// : vector<2x4xf32> into vector<4x4xf32> | ||
| /// | ||
| struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> { | ||
| UnrollShapeCastPattern(MLIRContext *context, | ||
| const vector::UnrollVectorOptions &options, | ||
| PatternBenefit benefit = 1) | ||
| : OpRewritePattern<vector::ShapeCastOp>(context, benefit), | ||
| options(options) {} | ||
|
|
||
| LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, | ||
| PatternRewriter &rewriter) const override { | ||
| std::optional<SmallVector<int64_t>> targetShape = | ||
| getTargetShape(options, shapeCastOp); | ||
| if (!targetShape) | ||
| return failure(); | ||
|
|
||
| VectorType sourceType = shapeCastOp.getSourceVectorType(); | ||
| VectorType resultType = shapeCastOp.getResultVectorType(); | ||
| ArrayRef<int64_t> sourceShape = sourceType.getShape(); | ||
| ArrayRef<int64_t> resultShape = resultType.getShape(); | ||
|
|
||
| if (!isContiguous(*targetShape, resultShape)) | ||
| return rewriter.notifyMatchFailure( | ||
| shapeCastOp, "Only supports cases where target shape is " | ||
| "contiguous in result vector shape"); | ||
|
|
||
| int64_t targetElements = ShapedType::getNumElements(*targetShape); | ||
|
|
||
| // Calculate the shape to extract from source. | ||
| std::optional<SmallVector<int64_t>> extractShape = | ||
| calculateSourceExtractShape(sourceShape, targetElements); | ||
| if (!extractShape) | ||
| return rewriter.notifyMatchFailure( | ||
| shapeCastOp, | ||
| "cannot extract target number of elements contiguously from source"); | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Location loc = shapeCastOp.getLoc(); | ||
|
|
||
| // Create result vector initialized to zero. | ||
| Value result = arith::ConstantOp::create(rewriter, loc, resultType, | ||
| rewriter.getZeroAttr(resultType)); | ||
|
|
||
| VectorType targetType = | ||
| VectorType::get(*targetShape, sourceType.getElementType()); | ||
|
|
||
| SmallVector<int64_t> extractStrides(extractShape->size(), 1); | ||
| SmallVector<int64_t> insertStrides(targetShape->size(), 1); | ||
|
|
||
| for (SmallVector<int64_t> resultOffsets : | ||
| StaticTileOffsetRange(resultShape, *targetShape)) { | ||
| SmallVector<int64_t> sourceOffsets = | ||
| calculateSourceOffsets(resultOffsets, sourceShape, resultShape); | ||
| Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>( | ||
| loc, shapeCastOp.getSource(), sourceOffsets, *extractShape, | ||
| extractStrides); | ||
| Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>( | ||
| loc, targetType, sourceChunk); | ||
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( | ||
| loc, targetChunk, result, resultOffsets, insertStrides); | ||
| } | ||
|
|
||
| rewriter.replaceOp(shapeCastOp, result); | ||
| return success(); | ||
| } | ||
|
|
||
| private: | ||
| vector::UnrollVectorOptions options; | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| void mlir::vector::populateVectorUnrollPatterns( | ||
|
|
@@ -1013,8 +1202,8 @@ void mlir::vector::populateVectorUnrollPatterns( | |
| UnrollReductionPattern, UnrollMultiReductionPattern, | ||
| UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, | ||
| UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, | ||
| UnrollToElements, UnrollStepPattern>(patterns.getContext(), | ||
| options, benefit); | ||
| UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>( | ||
| patterns.getContext(), options, benefit); | ||
| } | ||
|
|
||
| void mlir::vector::populateVectorToElementsUnrollPatterns( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite a corner case, please add a test.