From cd8b818297287afbed0c675d9bf491bfb296f385 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 11 Nov 2025 00:14:40 +0000 Subject: [PATCH 1/7] Add unroll pattern for vector.shape_cast --- .../mlir/Dialect/Vector/IR/VectorOps.td | 1 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 + .../Vector/Transforms/VectorUnroll.cpp | 170 +++++++++++++++++- .../Dialect/Vector/vector-unroll-options.mlir | 34 ++++ .../Dialect/Vector/TestVectorTransforms.cpp | 6 + 5 files changed, 213 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 43172ff2082df..6ad179349f90f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2427,6 +2427,7 @@ def Vector_CompressStoreOp : def Vector_ShapeCastOp : Vector_Op<"shape_cast", [Pure, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, Arguments<(ins AnyVectorOfAnyRank:$source)>, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index daef0ba02100a..4cac137478fab 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6241,6 +6241,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), argRanges.front()); } +std::optional> ShapeCastOp::getShapeForUnroll() { + return llvm::to_vector<4>(getResultVectorType().getShape()); +} + LogicalResult ShapeCastOp::verify() { VectorType sourceType = getSourceVectorType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fbae0989bed26..a4830809aaac8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,6 +1003,172 @@ struct UnrollFromElements : OpRewritePattern { vector::UnrollVectorOptions options; }; +static bool isContiguousExtract(ArrayRef targetShape, + ArrayRef resultShape) { + if (targetShape.size() > resultShape.size()) { + return false; + } + + size_t rankDiff = resultShape.size() - targetShape.size(); + // Inner dimensions must match exactly & total resultElements should be + // evenly divisible by targetElements. + for (size_t i = 1; i < targetShape.size(); ++i) { + if (targetShape[i] != resultShape[rankDiff + i]) { + return false; + } + } + + int64_t targetElements = ShapedType::getNumElements(targetShape); + int64_t resultElements = ShapedType::getNumElements(resultShape); + if (resultElements % targetElements != 0) { + return false; + } + return true; +} + +// Calculate the shape to extract from source +static std::optional> +calculateSourceExtractShape(ArrayRef sourceShape, + int64_t targetElements) { + SmallVector extractShape; + int64_t remainingElements = targetElements; + + // Build extract shape from innermost dimension outward to ensure contiguity + for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) { + int64_t takeFromDim = std::min(remainingElements, sourceShape[i]); + extractShape.insert(extractShape.begin(), takeFromDim); + + if (remainingElements % takeFromDim != 0) { + return std::nullopt; // Not evenly divisible + } + remainingElements /= takeFromDim; + } + + // Fill remaining dimensions with 1 + while (extractShape.size() < sourceShape.size()) { + extractShape.insert(extractShape.begin(), 1); + } + + if (ShapedType::getNumElements(extractShape) != targetElements) { + return std::nullopt; + } + + return extractShape; +} + +// Convert result offsets to source offsets via linear position +static SmallVector +calculateSourceOffsets(ArrayRef resultOffsets, + ArrayRef sourceStrides, + ArrayRef resultStrides) { + // Convert result offsets to linear position + int64_t linearIndex = linearize(resultOffsets, resultStrides); + // Convert linear position to source offsets + SmallVector sourceOffsets = delinearize(linearIndex, sourceStrides); + return sourceOffsets; +} + +/// This pattern unrolls `vector.shape_cast` operations according to the +/// provided target unroll shape. It unrolls a large shape cast into smaller +/// shape casts by extracting contiguous slices from the source vector, casting +/// each slice to the target shape, and assembling the result by inserting each +/// computed segment into the appropriate offset of the result vector. +/// +/// This pattern only applies when contiguous slices can be extracted from the +/// source vector and inserted into the result vector such that each slice +/// remains a valid vector (and not decompose to scalars). In these cases, the +/// unrolling proceeds as: +/// vector.extract_strided_slice -> vector.shape_cast (on the slice) -> +/// vector.insert_strided_slice +/// +/// Example: +/// Given a shape cast operation: +/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32> +/// +/// and a target unroll shape of <2x4>, the pattern produces: +/// +/// %zero = arith.constant dense<0.0> : vector<4x4xf32> +/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1] +/// : vector<8x2xf32> to vector<4x2xf32> +/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32> +/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1] +/// : vector<2x4xf32> into vector<4x4xf32> +/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1] +/// : vector<8x2xf32> to vector<4x2xf32> +/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32> +/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1] +/// : vector<2x4xf32> into vector<4x4xf32> +/// +struct UnrollShapeCastPattern : public OpRewritePattern { + UnrollShapeCastPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, shapeCastOp); + if (!targetShape) + return failure(); + + VectorType sourceType = shapeCastOp.getSourceVectorType(); + VectorType resultType = shapeCastOp.getResultVectorType(); + ArrayRef sourceShape = sourceType.getShape(); + ArrayRef resultShape = resultType.getShape(); + + if (!isContiguousExtract(*targetShape, resultShape)) { + return rewriter.notifyMatchFailure(shapeCastOp, + "Only supports cases where contiguous " + "extraction is possible"); + } + + int64_t targetElements = ShapedType::getNumElements(*targetShape); + + // Calculate the shape to extract from source + auto extractShape = + calculateSourceExtractShape(sourceShape, targetElements); + if (!extractShape) { + return rewriter.notifyMatchFailure( + shapeCastOp, + "cannot extract target number of elements contiguously from source"); + } + + Location loc = shapeCastOp.getLoc(); + + // Create result vector initialized to zero + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + + VectorType targetType = + VectorType::get(*targetShape, sourceType.getElementType()); + + SmallVector extractStrides(extractShape->size(), 1); + SmallVector insertStrides(targetShape->size(), 1); + SmallVector sourceStrides = computeStrides(sourceShape); + SmallVector resultStrides = computeStrides(resultShape); + + for (SmallVector resultOffsets : + StaticTileOffsetRange(resultShape, *targetShape)) { + SmallVector sourceOffsets = + calculateSourceOffsets(resultOffsets, sourceStrides, resultStrides); + Value sourceChunk = rewriter.createOrFold( + loc, shapeCastOp.getSource(), sourceOffsets, *extractShape, + extractStrides); + Value targetChunk = rewriter.createOrFold( + loc, targetType, sourceChunk); + result = rewriter.createOrFold( + loc, targetChunk, result, resultOffsets, insertStrides); + } + + rewriter.replaceOp(shapeCastOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -1013,8 +1179,8 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, - UnrollToElements, UnrollStepPattern>(patterns.getContext(), - options, benefit); + UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>( + patterns.getContext(), options, benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e5a98b5c67f33..c94a502fa3654 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -496,3 +496,37 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3 // CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32> // CHECK-NOT: arith.addf // CHECK: return + + +func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> { + %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32> + return %0 : vector<2x2x4xf32> +} + +// CHECK-LABEL: func @shape_cast_1D +// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<2x2x4xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32> +// CHECK: return %[[I1]] : vector<2x2x4xf32> + + +func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> { + %0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// CHECK-LABEL: func @shape_cast_2D +// CHECK-SAME: (%[[ARG0:.*]]: vector<8x2xf32>) -> vector<4x4xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32> +// CHECK: return %[[I1]] : vector<4x4xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 79bfc9bbcda71..0ab4e451d544d 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success(isa(op)); })); + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 4}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{1, 3, 4, 2}) From 73512fd722ea836ea96ec31d55f55e893c6f9b14 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 12 Nov 2025 19:38:26 +0000 Subject: [PATCH 2/7] Address feedback --- .../Vector/Transforms/VectorUnroll.cpp | 59 ++++++++----------- 1 file changed, 24 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index a4830809aaac8..7afc83bb8a876 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1005,67 +1005,57 @@ struct UnrollFromElements : OpRewritePattern { static bool isContiguousExtract(ArrayRef targetShape, ArrayRef resultShape) { - if (targetShape.size() > resultShape.size()) { + if (targetShape.size() > resultShape.size()) return false; - } size_t rankDiff = resultShape.size() - targetShape.size(); // Inner dimensions must match exactly & total resultElements should be // evenly divisible by targetElements. - for (size_t i = 1; i < targetShape.size(); ++i) { - if (targetShape[i] != resultShape[rankDiff + i]) { - return false; - } - } + if (!llvm::equal(targetShape.drop_front(), + resultShape.drop_front(rankDiff + 1))) + return false; int64_t targetElements = ShapedType::getNumElements(targetShape); int64_t resultElements = ShapedType::getNumElements(resultShape); - if (resultElements % targetElements != 0) { - return false; - } - return true; + return resultElements % targetElements == 0; } -// Calculate the shape to extract from source +// Calculate the shape to extract from source. static std::optional> calculateSourceExtractShape(ArrayRef sourceShape, int64_t targetElements) { SmallVector extractShape; int64_t remainingElements = targetElements; - // Build extract shape from innermost dimension outward to ensure contiguity + // Build extract shape from innermost dimension outward to ensure contiguity. for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) { int64_t takeFromDim = std::min(remainingElements, sourceShape[i]); extractShape.insert(extractShape.begin(), takeFromDim); - if (remainingElements % takeFromDim != 0) { - return std::nullopt; // Not evenly divisible - } + if (remainingElements % takeFromDim != 0) + return std::nullopt; // Not evenly divisible. remainingElements /= takeFromDim; } - // Fill remaining dimensions with 1 - while (extractShape.size() < sourceShape.size()) { + // Fill remaining dimensions with 1. + while (extractShape.size() < sourceShape.size()) extractShape.insert(extractShape.begin(), 1); - } - if (ShapedType::getNumElements(extractShape) != targetElements) { + if (ShapedType::getNumElements(extractShape) != targetElements) return std::nullopt; - } return extractShape; } -// Convert result offsets to source offsets via linear position +// Convert result offsets to source offsets via linear position. static SmallVector calculateSourceOffsets(ArrayRef resultOffsets, ArrayRef sourceStrides, ArrayRef resultStrides) { - // Convert result offsets to linear position + // Convert result offsets to linear position. int64_t linearIndex = linearize(resultOffsets, resultStrides); - // Convert linear position to source offsets - SmallVector sourceOffsets = delinearize(linearIndex, sourceStrides); - return sourceOffsets; + // Convert linear position to source offsets. + return delinearize(linearIndex, sourceStrides); } /// This pattern unrolls `vector.shape_cast` operations according to the @@ -1079,7 +1069,7 @@ calculateSourceOffsets(ArrayRef resultOffsets, /// remains a valid vector (and not decompose to scalars). In these cases, the /// unrolling proceeds as: /// vector.extract_strided_slice -> vector.shape_cast (on the slice) -> -/// vector.insert_strided_slice +/// vector.insert_strided_slice. /// /// Example: /// Given a shape cast operation: @@ -1108,7 +1098,8 @@ struct UnrollShapeCastPattern : public OpRewritePattern { LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { - auto targetShape = getTargetShape(options, shapeCastOp); + std::optional> targetShape = + getTargetShape(options, shapeCastOp); if (!targetShape) return failure(); @@ -1117,26 +1108,24 @@ struct UnrollShapeCastPattern : public OpRewritePattern { ArrayRef sourceShape = sourceType.getShape(); ArrayRef resultShape = resultType.getShape(); - if (!isContiguousExtract(*targetShape, resultShape)) { + if (!isContiguousExtract(*targetShape, resultShape)) return rewriter.notifyMatchFailure(shapeCastOp, "Only supports cases where contiguous " "extraction is possible"); - } int64_t targetElements = ShapedType::getNumElements(*targetShape); - // Calculate the shape to extract from source - auto extractShape = + // Calculate the shape to extract from source. + std::optional> extractShape = calculateSourceExtractShape(sourceShape, targetElements); - if (!extractShape) { + if (!extractShape) return rewriter.notifyMatchFailure( shapeCastOp, "cannot extract target number of elements contiguously from source"); - } Location loc = shapeCastOp.getLoc(); - // Create result vector initialized to zero + // Create result vector initialized to zero. Value result = arith::ConstantOp::create(rewriter, loc, resultType, rewriter.getZeroAttr(resultType)); From 9b4191a1c63c033fbf8f88dc9b227c1db1a936db Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 13 Nov 2025 00:40:23 +0000 Subject: [PATCH 3/7] Fix isContiguousExtract --- .../Vector/Transforms/VectorUnroll.cpp | 50 ++++++++++++++++--- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 7afc83bb8a876..885fcf835c1a3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1008,16 +1008,50 @@ static bool isContiguousExtract(ArrayRef targetShape, if (targetShape.size() > resultShape.size()) return false; - size_t rankDiff = resultShape.size() - targetShape.size(); - // Inner dimensions must match exactly & total resultElements should be - // evenly divisible by targetElements. - if (!llvm::equal(targetShape.drop_front(), - resultShape.drop_front(rankDiff + 1))) - return false; - int64_t targetElements = ShapedType::getNumElements(targetShape); int64_t resultElements = ShapedType::getNumElements(resultShape); - return resultElements % targetElements == 0; + + // Result must be evenly divisible by target. + if (resultElements % targetElements != 0) + return false; + + // For contiguous extraction, we need to be able to + // extract targetElements contiguously from the result shape. + // This means we can "consume" dimensions from the innermost outward + // until we have exactly targetElements. + + int64_t remainingElements = targetElements; + int targetDimIdx = targetShape.size() - 1; + + // Work backwards through result dimensions. + for (int resultDimIdx = resultShape.size() - 1; + resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0; + --resultDimIdx) { + + int64_t resultDimSize = resultShape[resultDimIdx]; + int64_t targetDimSize = targetShape[targetDimIdx]; + + if (targetDimSize > resultDimSize) + return false; + + if (targetDimSize == resultDimSize) { + if (remainingElements % targetDimSize != 0) + return false; + remainingElements /= targetDimSize; + --targetDimIdx; + } else { + if (remainingElements != targetDimSize) + return false; + remainingElements = 1; + --targetDimIdx; + } + } + + // Check remaining target dimensions are all 1 and we consumed all elements + return remainingElements == 1 && + (targetDimIdx < 0 || llvm::all_of( + targetShape.take_front(targetDimIdx + 1), + [](int64_t d) { return d == 1; })); } // Calculate the shape to extract from source. From d4ea820d64c74a829225de31715be50a96045fa7 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 18 Nov 2025 16:52:35 +0000 Subject: [PATCH 4/7] Address feedback --- .../Vector/Transforms/VectorUnroll.cpp | 110 +++++++++--------- .../Dialect/Vector/vector-unroll-options.mlir | 39 ++++++- 2 files changed, 88 insertions(+), 61 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 885fcf835c1a3..0a1d86109beea 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,58 +1003,60 @@ struct UnrollFromElements : OpRewritePattern { vector::UnrollVectorOptions options; }; -static bool isContiguousExtract(ArrayRef targetShape, - ArrayRef resultShape) { - if (targetShape.size() > resultShape.size()) - return false; - - int64_t targetElements = ShapedType::getNumElements(targetShape); - int64_t resultElements = ShapedType::getNumElements(resultShape); +/// Checks whether targetShape is contiguous in resultShape. +/// For targetShape to be contiguous in resultShape: +/// 1) The inner dimensions of targetShape and resultShape must match exactly. +/// 2) The total number of elements in resultShape must be evenly divisible by +/// the total number of elements in targetShape. +/// Examples: +/// isContiguous([4, 4], [8, 4]) == true +/// isContiguous([2, 4], [8, 4]) == true +/// isContiguous([2, 2], [8, 4]) == false +/// Removes leading unit dimensions to handle cases like: +/// isContiguous([1, 16], [1, 32]) == true +static bool isContiguous(ArrayRef targetShape, + ArrayRef resultShape) { - // Result must be evenly divisible by target. - if (resultElements % targetElements != 0) + if (targetShape.size() > resultShape.size()) return false; - // For contiguous extraction, we need to be able to - // extract targetElements contiguously from the result shape. - // This means we can "consume" dimensions from the innermost outward - // until we have exactly targetElements. + while (!targetShape.empty() && targetShape.front() == 1) { + targetShape = targetShape.drop_front(); + } - int64_t remainingElements = targetElements; - int targetDimIdx = targetShape.size() - 1; - - // Work backwards through result dimensions. - for (int resultDimIdx = resultShape.size() - 1; - resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0; - --resultDimIdx) { - - int64_t resultDimSize = resultShape[resultDimIdx]; - int64_t targetDimSize = targetShape[targetDimIdx]; - - if (targetDimSize > resultDimSize) - return false; - - if (targetDimSize == resultDimSize) { - if (remainingElements % targetDimSize != 0) - return false; - remainingElements /= targetDimSize; - --targetDimIdx; - } else { - if (remainingElements != targetDimSize) - return false; - remainingElements = 1; - --targetDimIdx; - } + while (!resultShape.empty() && resultShape.front() == 1) { + resultShape = resultShape.drop_front(); } - // Check remaining target dimensions are all 1 and we consumed all elements - return remainingElements == 1 && - (targetDimIdx < 0 || llvm::all_of( - targetShape.take_front(targetDimIdx + 1), - [](int64_t d) { return d == 1; })); + size_t rankDiff = resultShape.size() - targetShape.size(); + if (!llvm::equal(targetShape.drop_front(), + resultShape.drop_front(rankDiff + 1))) + return false; + + int64_t targetElements = ShapedType::getNumElements(targetShape); + int64_t resultElements = ShapedType::getNumElements(resultShape); + return resultElements % targetElements == 0; } -// Calculate the shape to extract from source. +/// This function determines what shape to use with +/// `vector.extract_strided_slice` to extract a contiguous memory region from a +/// source vector. The extraction must be contiguous and contain exactly the +/// specified number of elements. If such an extraction shape cannot be +/// determined, the function returns std::nullopt. +/// Examples: +/// sourceShape = [16], targetElements = 8 +/// Working right-to-left: +/// - Take min(8, 16) = 8 from only dim → extractShape = [8], +/// remaining = 8/8 = 1 +/// Result: [8] +/// +/// sourceShape = [4, 4], targetElements = 8 +/// Working right-to-left: +/// - Take min(8, 4) = 4 from last dim → extractShape = [4], +/// remaining = 8/4 = 2 +/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4], +/// remaining = 2/2 = 1 +/// Result: [2, 4] static std::optional> calculateSourceExtractShape(ArrayRef sourceShape, int64_t targetElements) { @@ -1084,12 +1086,12 @@ calculateSourceExtractShape(ArrayRef sourceShape, // Convert result offsets to source offsets via linear position. static SmallVector calculateSourceOffsets(ArrayRef resultOffsets, - ArrayRef sourceStrides, - ArrayRef resultStrides) { + ArrayRef sourceShape, + ArrayRef resultShape) { // Convert result offsets to linear position. - int64_t linearIndex = linearize(resultOffsets, resultStrides); + int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape)); // Convert linear position to source offsets. - return delinearize(linearIndex, sourceStrides); + return delinearize(linearIndex, computeStrides(sourceShape)); } /// This pattern unrolls `vector.shape_cast` operations according to the @@ -1142,10 +1144,10 @@ struct UnrollShapeCastPattern : public OpRewritePattern { ArrayRef sourceShape = sourceType.getShape(); ArrayRef resultShape = resultType.getShape(); - if (!isContiguousExtract(*targetShape, resultShape)) - return rewriter.notifyMatchFailure(shapeCastOp, - "Only supports cases where contiguous " - "extraction is possible"); + if (!isContiguous(*targetShape, resultShape)) + return rewriter.notifyMatchFailure( + shapeCastOp, "Only supports cases where target shape is " + "contiguous in result vector shape"); int64_t targetElements = ShapedType::getNumElements(*targetShape); @@ -1168,13 +1170,11 @@ struct UnrollShapeCastPattern : public OpRewritePattern { SmallVector extractStrides(extractShape->size(), 1); SmallVector insertStrides(targetShape->size(), 1); - SmallVector sourceStrides = computeStrides(sourceShape); - SmallVector resultStrides = computeStrides(resultShape); for (SmallVector resultOffsets : StaticTileOffsetRange(resultShape, *targetShape)) { SmallVector sourceOffsets = - calculateSourceOffsets(resultOffsets, sourceStrides, resultStrides); + calculateSourceOffsets(resultOffsets, sourceShape, resultShape); Value sourceChunk = rewriter.createOrFold( loc, shapeCastOp.getSource(), sourceOffsets, *extractShape, extractStrides); diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index c94a502fa3654..8e2caa39696cb 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -504,12 +504,12 @@ func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> { } // CHECK-LABEL: func @shape_cast_1D -// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<2x2x4xf32> { +// CHECK-SAME: (%[[V:.*]]: vector<16xf32>) -> vector<2x2x4xf32> { // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32> -// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> // CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32> // CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32> -// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> // CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32> // CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32> // CHECK: return %[[I1]] : vector<2x2x4xf32> @@ -521,12 +521,39 @@ func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> { } // CHECK-LABEL: func @shape_cast_2D -// CHECK-SAME: (%[[ARG0:.*]]: vector<8x2xf32>) -> vector<4x4xf32> { +// CHECK-SAME: (%[[V:.*]]: vector<8x2xf32>) -> vector<4x4xf32> { // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> -// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> // CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32> // CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32> -// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> // CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32> // CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32> // CHECK: return %[[I1]] : vector<4x4xf32> + + +// This is a negative test case to ensure that such shape casts are not unrolled +// because the targetShape (2x4) is not contiguous in result vector +func.func @negative_shape_cast_target_shape_not_contiguous(%v: vector<64xf32>) -> vector<8x8xf32> { + %0 = vector.shape_cast %v : vector<64xf32> to vector<8x8xf32> + return %0 : vector<8x8xf32> +} + +// CHECK-LABEL: func @negative_shape_cast_target_shape_not_contiguous +// CHECK-SAME: (%[[V:.*]]: vector<64xf32>) -> vector<8x8xf32> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<64xf32> to vector<8x8xf32> +// CHECK: return %[[SC]] : vector<8x8xf32> + + +// This is negative test case to ensure that such shape casts are not unrolled +// because it cannot determine the extractShape from source vector (8x3) +// to extract conitguous targetShape (2x4) +func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32>) -> vector<6x4xf32> { + %0 = vector.shape_cast %v : vector<8x3xf32> to vector<6x4xf32> + return %0 : vector<6x4xf32> +} + +// CHECK-LABEL: func @negative_shape_cast_source_shape_not_determinable +// CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32> +// CHECK: return %[[SC]] : vector<6x4xf32> From edf3dd30bc9b60b88bf34001b7fa68cc4191a465 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 18 Nov 2025 21:34:14 +0000 Subject: [PATCH 5/7] Change variable name --- .../Vector/Transforms/VectorUnroll.cpp | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 0a1d86109beea..42293c88c410e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,39 +1003,38 @@ struct UnrollFromElements : OpRewritePattern { vector::UnrollVectorOptions options; }; -/// Checks whether targetShape is contiguous in resultShape. -/// For targetShape to be contiguous in resultShape: -/// 1) The inner dimensions of targetShape and resultShape must match exactly. -/// 2) The total number of elements in resultShape must be evenly divisible by -/// the total number of elements in targetShape. +/// Checks whether extractShape is contiguous in shape. +/// For extractShape to be contiguous in shape: +/// 1) The inner dimensions of extractShape and shape must match exactly. +/// 2) The total number of elements in shape must be evenly divisible by +/// the total number of elements in extractShape. /// Examples: /// isContiguous([4, 4], [8, 4]) == true /// isContiguous([2, 4], [8, 4]) == true /// isContiguous([2, 2], [8, 4]) == false /// Removes leading unit dimensions to handle cases like: /// isContiguous([1, 16], [1, 32]) == true -static bool isContiguous(ArrayRef targetShape, - ArrayRef resultShape) { +static bool isContiguous(ArrayRef extractShape, + ArrayRef shape) { - if (targetShape.size() > resultShape.size()) + if (extractShape.size() > shape.size()) return false; - while (!targetShape.empty() && targetShape.front() == 1) { - targetShape = targetShape.drop_front(); + while (!extractShape.empty() && extractShape.front() == 1) { + extractShape = extractShape.drop_front(); } - while (!resultShape.empty() && resultShape.front() == 1) { - resultShape = resultShape.drop_front(); + while (!shape.empty() && shape.front() == 1) { + shape = shape.drop_front(); } - size_t rankDiff = resultShape.size() - targetShape.size(); - if (!llvm::equal(targetShape.drop_front(), - resultShape.drop_front(rankDiff + 1))) + size_t rankDiff = shape.size() - extractShape.size(); + if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1))) return false; - int64_t targetElements = ShapedType::getNumElements(targetShape); - int64_t resultElements = ShapedType::getNumElements(resultShape); - return resultElements % targetElements == 0; + int64_t extractElements = ShapedType::getNumElements(extractShape); + int64_t shapeElements = ShapedType::getNumElements(shape); + return shapeElements % extractElements == 0; } /// This function determines what shape to use with From 1778d99859a14cc4f36043e91c71bf0f98a72cf9 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 19 Nov 2025 18:09:42 +0000 Subject: [PATCH 6/7] Update comments --- .../Vector/Transforms/VectorUnroll.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 42293c88c410e..b60f80534bfb6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,10 +1003,11 @@ struct UnrollFromElements : OpRewritePattern { vector::UnrollVectorOptions options; }; -/// Checks whether extractShape is contiguous in shape. +/// Checks whether extractShape is a contiguous slice of shape. /// For extractShape to be contiguous in shape: -/// 1) The inner dimensions of extractShape and shape must match exactly. -/// 2) The total number of elements in shape must be evenly divisible by +/// 1) All but the leading dimension of extractShape and shape must match +/// exactly. 2) The total number of elements in shape must be evenly divisible +/// by /// the total number of elements in extractShape. /// Examples: /// isContiguous([4, 4], [8, 4]) == true @@ -1037,18 +1038,18 @@ static bool isContiguous(ArrayRef extractShape, return shapeElements % extractElements == 0; } -/// This function determines what shape to use with -/// `vector.extract_strided_slice` to extract a contiguous memory region from a -/// source vector. The extraction must be contiguous and contain exactly the -/// specified number of elements. If such an extraction shape cannot be -/// determined, the function returns std::nullopt. -/// Examples: +/// Determines what shape to use with `vector.extract_strided_slice` to extract +/// a contiguous memory region from a source vector. The extraction must be +/// contiguous and contain exactly the specified number of elements. If such an +/// extraction shape cannot be determined, returns std::nullopt. +/// EXAMPLE 1: /// sourceShape = [16], targetElements = 8 /// Working right-to-left: /// - Take min(8, 16) = 8 from only dim → extractShape = [8], /// remaining = 8/8 = 1 /// Result: [8] /// +/// EXAMPLE 2: /// sourceShape = [4, 4], targetElements = 8 /// Working right-to-left: /// - Take min(8, 4) = 4 from last dim → extractShape = [4], From 18ed975e3d2df8ed7e150b68ab7a2854082cabda Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 19 Nov 2025 18:54:13 +0000 Subject: [PATCH 7/7] Add test --- .../Dialect/Vector/vector-unroll-options.mlir | 18 +++++++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 26 +++++++++++++++---- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 8e2caa39696cb..dec32e1c61a9b 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -557,3 +557,21 @@ func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32> // CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> { // CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32> // CHECK: return %[[SC]] : vector<6x4xf32> + + +// TargetShape is [1x16] +func.func @shape_cast_leading_unit_dim(%v: vector<32xf32>) -> vector<1x32xf32> { + %0 = vector.shape_cast %v : vector<32xf32> to vector<1x32xf32> + return %0 : vector<1x32xf32> +} + +// CHECK-LABEL: func @shape_cast_leading_unit_dim +// CHECK-SAME: (%[[V:.*]]: vector<32xf32>) -> vector<1x32xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<16xf32> to vector<1x16xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<16xf32> to vector<1x16xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32> +// CHECK: return %[[I1]] : vector<1x32xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 0ab4e451d544d..e8ea0cc02d7f6 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -179,11 +179,27 @@ struct TestVectorUnrollingPatterns return success(isa(op)); })); populateVectorUnrollPatterns( - patterns, UnrollVectorOptions() - .setNativeShape(ArrayRef{2, 4}) - .setFilterConstraint([](Operation *op) { - return success(isa(op)); - })); + patterns, + UnrollVectorOptions() + .setNativeShapeFn( + [](Operation *op) -> std::optional> { + auto shapeCast = dyn_cast(op); + if (!shapeCast) + return std::nullopt; + + auto resultShape = shapeCast.getResultVectorType().getShape(); + // Special case with leading unit dims and different inner dim + // for result and target shape. + if (resultShape.size() == 2 && resultShape[0] == 1 && + resultShape[1] == 32) { + return SmallVector{1, 16}; + } + // Default case: [2,4] for all tests. + return SmallVector{2, 4}; + }) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{1, 3, 4, 2})