Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Nov 12, 2025

This PR adds pattern for unrolling shape_cast given a targetShape. This PR is a follow up of #164010 which was very general and was using inserts and extracts on each element (which is also LowerVectorShapeCast.cpp is doing).
After doing some more research on use cases, we (me and @Jianhui-Li ) realized that the previous version in #164010 is unnecessarily generic and doesn't fit our performance needs.

Our use case requires that targetShape is contiguous in both source and result vector.

This pattern only applies when contiguous slices can be extracted from the source vector and inserted into the result vector such that each slice remains in vector form with targetShape (and not decompose to scalars). In these cases, the unrolling proceeds as:

vector.extract_strided_slice -> vector.shape_cast (on the slice unrolled) -> vector.insert_strided_slice

@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2025

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

Changes

This PR adds pattern for unrolling shape_cast given a targetShape. This PR is a follow up of #164010 which was very general and was using inserts and extracts on each element (which is also LowerVectorShapeCast.cpp is doing).
After doing some more research on use cases, we (me and @Jianhui-Li ) realized that the previous version in #164010 is unnecessarily generic and doesn't fit our performance needs.

Our use case requires that targetShape is contiguous in both source and result vector.

This pattern only applies when contiguous slices can be extracted from the source vector and inserted into the result vector such that each slice remains in vector form with targetShape (and not decompose to scalars). In these cases, the unrolling proceeds as:

vector.extract_strided_slice -> vector.shape_cast (on the slice unrolled) -> vector.insert_strided_slice


Full diff: https://github.com/llvm/llvm-project/pull/167738.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+168-2)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+34)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43172ff2082df..6ad179349f90f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2427,6 +2427,7 @@ def Vector_CompressStoreOp :
 
 def Vector_ShapeCastOp :
   Vector_Op<"shape_cast", [Pure,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
   ]>,
     Arguments<(ins AnyVectorOfAnyRank:$source)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index daef0ba02100a..4cac137478fab 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6241,6 +6241,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
+std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
 LogicalResult ShapeCastOp::verify() {
 
   VectorType sourceType = getSourceVectorType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..a4830809aaac8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,172 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
+                                ArrayRef<int64_t> resultShape) {
+  if (targetShape.size() > resultShape.size()) {
+    return false;
+  }
+
+  size_t rankDiff = resultShape.size() - targetShape.size();
+  // Inner dimensions must match exactly & total resultElements should be
+  // evenly divisible by targetElements.
+  for (size_t i = 1; i < targetShape.size(); ++i) {
+    if (targetShape[i] != resultShape[rankDiff + i]) {
+      return false;
+    }
+  }
+
+  int64_t targetElements = ShapedType::getNumElements(targetShape);
+  int64_t resultElements = ShapedType::getNumElements(resultShape);
+  if (resultElements % targetElements != 0) {
+    return false;
+  }
+  return true;
+}
+
+// Calculate the shape to extract from source
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+                            int64_t targetElements) {
+  SmallVector<int64_t> extractShape;
+  int64_t remainingElements = targetElements;
+
+  // Build extract shape from innermost dimension outward to ensure contiguity
+  for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+    int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+    extractShape.insert(extractShape.begin(), takeFromDim);
+
+    if (remainingElements % takeFromDim != 0) {
+      return std::nullopt; // Not evenly divisible
+    }
+    remainingElements /= takeFromDim;
+  }
+
+  // Fill remaining dimensions with 1
+  while (extractShape.size() < sourceShape.size()) {
+    extractShape.insert(extractShape.begin(), 1);
+  }
+
+  if (ShapedType::getNumElements(extractShape) != targetElements) {
+    return std::nullopt;
+  }
+
+  return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+                       ArrayRef<int64_t> sourceStrides,
+                       ArrayRef<int64_t> resultStrides) {
+  // Convert result offsets to linear position
+  int64_t linearIndex = linearize(resultOffsets, resultStrides);
+  // Convert linear position to source offsets
+  SmallVector<int64_t> sourceOffsets = delinearize(linearIndex, sourceStrides);
+  return sourceOffsets;
+}
+
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It unrolls a large shape cast into smaller
+/// shape casts by extracting contiguous slices from the source vector, casting
+/// each slice to the target shape, and assembling the result by inserting each
+/// computed segment into the appropriate offset of the result vector.
+///
+/// This pattern only applies when contiguous slices can be extracted from the
+/// source vector and inserted into the result vector such that each slice
+/// remains a valid vector (and not decompose to scalars). In these cases, the
+/// unrolling proceeds as:
+/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
+/// vector.insert_strided_slice
+///
+/// Example:
+///   Given a shape cast operation:
+///     %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
+///
+///   and a target unroll shape of <2x4>, the pattern produces:
+///
+///     %zero = arith.constant dense<0.0> : vector<4x4xf32>
+///     %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
+///       : vector<8x2xf32> to vector<4x2xf32>
+///     %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
+///     %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
+///       : vector<2x4xf32> into vector<4x4xf32>
+///     %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
+///       : vector<8x2xf32> to vector<4x2xf32>
+///     %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
+///     %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
+///       : vector<2x4xf32> into vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+  UnrollShapeCastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, shapeCastOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType sourceType = shapeCastOp.getSourceVectorType();
+    VectorType resultType = shapeCastOp.getResultVectorType();
+    ArrayRef<int64_t> sourceShape = sourceType.getShape();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+
+    if (!isContiguousExtract(*targetShape, resultShape)) {
+      return rewriter.notifyMatchFailure(shapeCastOp,
+                                         "Only supports cases where contiguous "
+                                         "extraction is possible");
+    }
+
+    int64_t targetElements = ShapedType::getNumElements(*targetShape);
+
+    // Calculate the shape to extract from source
+    auto extractShape =
+        calculateSourceExtractShape(sourceShape, targetElements);
+    if (!extractShape) {
+      return rewriter.notifyMatchFailure(
+          shapeCastOp,
+          "cannot extract target number of elements contiguously from source");
+    }
+
+    Location loc = shapeCastOp.getLoc();
+
+    // Create result vector initialized to zero
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+
+    VectorType targetType =
+        VectorType::get(*targetShape, sourceType.getElementType());
+
+    SmallVector<int64_t> extractStrides(extractShape->size(), 1);
+    SmallVector<int64_t> insertStrides(targetShape->size(), 1);
+    SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
+    SmallVector<int64_t> resultStrides = computeStrides(resultShape);
+
+    for (SmallVector<int64_t> resultOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+      SmallVector<int64_t> sourceOffsets =
+          calculateSourceOffsets(resultOffsets, sourceStrides, resultStrides);
+      Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
+          extractStrides);
+      Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, targetType, sourceChunk);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, targetChunk, result, resultOffsets, insertStrides);
+    }
+
+    rewriter.replaceOp(shapeCastOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1179,8 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
                UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
-               UnrollToElements, UnrollStepPattern>(patterns.getContext(),
-                                                    options, benefit);
+               UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
+      patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f33..c94a502fa3654 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,37 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
 // CHECK-COUNT-4:   arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
 // CHECK-NOT: arith.addf
 // CHECK: return
+
+
+func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
+  %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
+  return %0 : vector<2x2x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_1D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK:   %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
+// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
+// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK:   return %[[I1]] : vector<2x2x4xf32>
+
+
+func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
+  %0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_2D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK:   %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK:   return %[[I1]] : vector<4x4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda71..0ab4e451d544d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
                                      .setFilterConstraint([](Operation *op) {
                                        return success(isa<vector::StepOp>(op));
                                      }));
+    populateVectorUnrollPatterns(
+        patterns, UnrollVectorOptions()
+                      .setNativeShape(ArrayRef<int64_t>{2, 4})
+                      .setFilterConstraint([](Operation *op) {
+                        return success(isa<vector::ShapeCastOp>(op));
+                      }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2025

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR adds pattern for unrolling shape_cast given a targetShape. This PR is a follow up of #164010 which was very general and was using inserts and extracts on each element (which is also LowerVectorShapeCast.cpp is doing).
After doing some more research on use cases, we (me and @Jianhui-Li ) realized that the previous version in #164010 is unnecessarily generic and doesn't fit our performance needs.

Our use case requires that targetShape is contiguous in both source and result vector.

This pattern only applies when contiguous slices can be extracted from the source vector and inserted into the result vector such that each slice remains in vector form with targetShape (and not decompose to scalars). In these cases, the unrolling proceeds as:

vector.extract_strided_slice -> vector.shape_cast (on the slice unrolled) -> vector.insert_strided_slice


Full diff: https://github.com/llvm/llvm-project/pull/167738.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+168-2)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+34)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43172ff2082df..6ad179349f90f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2427,6 +2427,7 @@ def Vector_CompressStoreOp :
 
 def Vector_ShapeCastOp :
   Vector_Op<"shape_cast", [Pure,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
   ]>,
     Arguments<(ins AnyVectorOfAnyRank:$source)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index daef0ba02100a..4cac137478fab 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6241,6 +6241,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
+std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
 LogicalResult ShapeCastOp::verify() {
 
   VectorType sourceType = getSourceVectorType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..a4830809aaac8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,172 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
+                                ArrayRef<int64_t> resultShape) {
+  if (targetShape.size() > resultShape.size()) {
+    return false;
+  }
+
+  size_t rankDiff = resultShape.size() - targetShape.size();
+  // Inner dimensions must match exactly & total resultElements should be
+  // evenly divisible by targetElements.
+  for (size_t i = 1; i < targetShape.size(); ++i) {
+    if (targetShape[i] != resultShape[rankDiff + i]) {
+      return false;
+    }
+  }
+
+  int64_t targetElements = ShapedType::getNumElements(targetShape);
+  int64_t resultElements = ShapedType::getNumElements(resultShape);
+  if (resultElements % targetElements != 0) {
+    return false;
+  }
+  return true;
+}
+
+// Calculate the shape to extract from source
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+                            int64_t targetElements) {
+  SmallVector<int64_t> extractShape;
+  int64_t remainingElements = targetElements;
+
+  // Build extract shape from innermost dimension outward to ensure contiguity
+  for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+    int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+    extractShape.insert(extractShape.begin(), takeFromDim);
+
+    if (remainingElements % takeFromDim != 0) {
+      return std::nullopt; // Not evenly divisible
+    }
+    remainingElements /= takeFromDim;
+  }
+
+  // Fill remaining dimensions with 1
+  while (extractShape.size() < sourceShape.size()) {
+    extractShape.insert(extractShape.begin(), 1);
+  }
+
+  if (ShapedType::getNumElements(extractShape) != targetElements) {
+    return std::nullopt;
+  }
+
+  return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+                       ArrayRef<int64_t> sourceStrides,
+                       ArrayRef<int64_t> resultStrides) {
+  // Convert result offsets to linear position
+  int64_t linearIndex = linearize(resultOffsets, resultStrides);
+  // Convert linear position to source offsets
+  SmallVector<int64_t> sourceOffsets = delinearize(linearIndex, sourceStrides);
+  return sourceOffsets;
+}
+
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It unrolls a large shape cast into smaller
+/// shape casts by extracting contiguous slices from the source vector, casting
+/// each slice to the target shape, and assembling the result by inserting each
+/// computed segment into the appropriate offset of the result vector.
+///
+/// This pattern only applies when contiguous slices can be extracted from the
+/// source vector and inserted into the result vector such that each slice
+/// remains a valid vector (and not decompose to scalars). In these cases, the
+/// unrolling proceeds as:
+/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
+/// vector.insert_strided_slice
+///
+/// Example:
+///   Given a shape cast operation:
+///     %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
+///
+///   and a target unroll shape of <2x4>, the pattern produces:
+///
+///     %zero = arith.constant dense<0.0> : vector<4x4xf32>
+///     %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
+///       : vector<8x2xf32> to vector<4x2xf32>
+///     %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
+///     %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
+///       : vector<2x4xf32> into vector<4x4xf32>
+///     %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
+///       : vector<8x2xf32> to vector<4x2xf32>
+///     %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
+///     %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
+///       : vector<2x4xf32> into vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+  UnrollShapeCastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, shapeCastOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType sourceType = shapeCastOp.getSourceVectorType();
+    VectorType resultType = shapeCastOp.getResultVectorType();
+    ArrayRef<int64_t> sourceShape = sourceType.getShape();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+
+    if (!isContiguousExtract(*targetShape, resultShape)) {
+      return rewriter.notifyMatchFailure(shapeCastOp,
+                                         "Only supports cases where contiguous "
+                                         "extraction is possible");
+    }
+
+    int64_t targetElements = ShapedType::getNumElements(*targetShape);
+
+    // Calculate the shape to extract from source
+    auto extractShape =
+        calculateSourceExtractShape(sourceShape, targetElements);
+    if (!extractShape) {
+      return rewriter.notifyMatchFailure(
+          shapeCastOp,
+          "cannot extract target number of elements contiguously from source");
+    }
+
+    Location loc = shapeCastOp.getLoc();
+
+    // Create result vector initialized to zero
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+
+    VectorType targetType =
+        VectorType::get(*targetShape, sourceType.getElementType());
+
+    SmallVector<int64_t> extractStrides(extractShape->size(), 1);
+    SmallVector<int64_t> insertStrides(targetShape->size(), 1);
+    SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
+    SmallVector<int64_t> resultStrides = computeStrides(resultShape);
+
+    for (SmallVector<int64_t> resultOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+      SmallVector<int64_t> sourceOffsets =
+          calculateSourceOffsets(resultOffsets, sourceStrides, resultStrides);
+      Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
+          extractStrides);
+      Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, targetType, sourceChunk);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, targetChunk, result, resultOffsets, insertStrides);
+    }
+
+    rewriter.replaceOp(shapeCastOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1179,8 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
                UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
-               UnrollToElements, UnrollStepPattern>(patterns.getContext(),
-                                                    options, benefit);
+               UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
+      patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f33..c94a502fa3654 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,37 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
 // CHECK-COUNT-4:   arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
 // CHECK-NOT: arith.addf
 // CHECK: return
+
+
+func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
+  %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
+  return %0 : vector<2x2x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_1D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK:   %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
+// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
+// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK:   return %[[I1]] : vector<2x2x4xf32>
+
+
+func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
+  %0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_2D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK:   %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK:   return %[[I1]] : vector<4x4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda71..0ab4e451d544d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
                                      .setFilterConstraint([](Operation *op) {
                                        return success(isa<vector::StepOp>(op));
                                      }));
+    populateVectorUnrollPatterns(
+        patterns, UnrollVectorOptions()
+                      .setNativeShape(ArrayRef<int64_t>{2, 4})
+                      .setFilterConstraint([](Operation *op) {
+                        return success(isa<vector::ShapeCastOp>(op));
+                      }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

@nbpatel
Copy link
Contributor Author

nbpatel commented Nov 12, 2025

@Jianhui-Li @silee2 please take a look as well

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments.

@nbpatel
Copy link
Contributor Author

nbpatel commented Nov 18, 2025

@newling can you take a look as well?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Compared to the previous PR, this immediately feels like "unrolling", nice! Given how complex this is, there should be more tests. In particular, negative tests that demonstrate when the new pattern should bail out.

@nbpatel
Copy link
Contributor Author

nbpatel commented Nov 18, 2025

@Jianhui-Li @charithaintc @banach-space @kuhar Thanks for the feedback. I addressed them all, please take a look if it is good

@github-actions
Copy link

github-actions bot commented Nov 18, 2025

🐧 Linux x64 Test Results

  • 7101 tests passed
  • 594 tests skipped

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nits, thank you!

Comment on lines +1015 to +1016
/// Removes leading unit dimensions to handle cases like:
/// isContiguous([1, 16], [1, 32]) == true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite a corner case, please add a test.

Comment on lines +1086 to +1089
static SmallVector<int64_t>
calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this assert that the number of elements in sourceShape and resultShape are identical?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't shape_cast op verifier take care of that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't shape_cast op verifier take care of that?

Of making sure that no invalid inputs are ever passed to this method? I doubt that ;-)

Copy link
Contributor Author

@nbpatel nbpatel Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant if the shapeCast is not maintaining the semantics of NumElements(src) == NumElements(dst) how is it even a valid instruction? I tried it locally and it does fail in the verifier here

https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L6258

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant if the shapeCast is not maintaining the semantics of NumElements(src) == NumElements(dst) how is it even a valid instruction?

shapeCast verifier will indeed maintain that, but only for shapeCast Ops. However, how do you make sure that the inputs used in this method always come from shapeCast? Perhaps I am missing something, but what is stopping anyone/anything from using this method with some random arrays that don't come from shapeCast?

@nbpatel nbpatel merged commit af73aea into llvm:main Nov 20, 2025
10 checks passed
aadeshps-mcw pushed a commit to aadeshps-mcw/llvm-project that referenced this pull request Nov 26, 2025
This PR adds pattern for unrolling shape_cast given a targetShape. This
PR is a follow up of llvm#164010 which was very general and was using
inserts and extracts on each element (which is also
LowerVectorShapeCast.cpp is doing).
After doing some more research on use cases, we (me and @Jianhui-Li )
realized that the previous version in llvm#164010 is unnecessarily generic
and doesn't fit our performance needs.

Our use case requires that targetShape is contiguous in both source and
result vector.

This pattern only applies when contiguous slices can be extracted from
the source vector and inserted into the result vector such that each
slice remains in vector form with targetShape (and not decompose to
scalars). In these cases, the unrolling proceeds as:

vector.extract_strided_slice -> vector.shape_cast (on the slice
unrolled) -> vector.insert_strided_slice
Priyanshu3820 pushed a commit to Priyanshu3820/llvm-project that referenced this pull request Nov 26, 2025
This PR adds pattern for unrolling shape_cast given a targetShape. This
PR is a follow up of llvm#164010 which was very general and was using
inserts and extracts on each element (which is also
LowerVectorShapeCast.cpp is doing).
After doing some more research on use cases, we (me and @Jianhui-Li )
realized that the previous version in llvm#164010 is unnecessarily generic
and doesn't fit our performance needs.

Our use case requires that targetShape is contiguous in both source and
result vector.

This pattern only applies when contiguous slices can be extracted from
the source vector and inserted into the result vector such that each
slice remains in vector form with targetShape (and not decompose to
scalars). In these cases, the unrolling proceeds as:

vector.extract_strided_slice -> vector.shape_cast (on the slice
unrolled) -> vector.insert_strided_slice
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants