Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2427,6 +2427,7 @@ def Vector_CompressStoreOp :

def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]>,
Arguments<(ins AnyVectorOfAnyRank:$source)>,
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6241,6 +6241,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}

std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}

LogicalResult ShapeCastOp::verify() {

VectorType sourceType = getSourceVectorType();
Expand Down
193 changes: 191 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,195 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};

/// Checks whether extractShape is a contiguous slice of shape.
/// For extractShape to be contiguous in shape:
/// 1) All but the leading dimension of extractShape and shape must match
/// exactly. 2) The total number of elements in shape must be evenly divisible
/// by
/// the total number of elements in extractShape.
/// Examples:
/// isContiguous([4, 4], [8, 4]) == true
/// isContiguous([2, 4], [8, 4]) == true
/// isContiguous([2, 2], [8, 4]) == false
/// Removes leading unit dimensions to handle cases like:
/// isContiguous([1, 16], [1, 32]) == true
Comment on lines +1016 to +1017
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite a corner case, please add a test.

static bool isContiguous(ArrayRef<int64_t> extractShape,
ArrayRef<int64_t> shape) {

if (extractShape.size() > shape.size())
return false;

while (!extractShape.empty() && extractShape.front() == 1) {
extractShape = extractShape.drop_front();
}

while (!shape.empty() && shape.front() == 1) {
shape = shape.drop_front();
}

size_t rankDiff = shape.size() - extractShape.size();
if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
return false;

int64_t extractElements = ShapedType::getNumElements(extractShape);
int64_t shapeElements = ShapedType::getNumElements(shape);
return shapeElements % extractElements == 0;
}

/// Determines what shape to use with `vector.extract_strided_slice` to extract
/// a contiguous memory region from a source vector. The extraction must be
/// contiguous and contain exactly the specified number of elements. If such an
/// extraction shape cannot be determined, returns std::nullopt.
/// EXAMPLE 1:
/// sourceShape = [16], targetElements = 8
/// Working right-to-left:
/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
/// remaining = 8/8 = 1
/// Result: [8]
///
/// EXAMPLE 2:
/// sourceShape = [4, 4], targetElements = 8
/// Working right-to-left:
/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
/// remaining = 8/4 = 2
/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
/// remaining = 2/2 = 1
/// Result: [2, 4]
static std::optional<SmallVector<int64_t>>
calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
int64_t targetElements) {
SmallVector<int64_t> extractShape;
int64_t remainingElements = targetElements;

// Build extract shape from innermost dimension outward to ensure contiguity.
for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
extractShape.insert(extractShape.begin(), takeFromDim);

if (remainingElements % takeFromDim != 0)
return std::nullopt; // Not evenly divisible.
remainingElements /= takeFromDim;
}

// Fill remaining dimensions with 1.
while (extractShape.size() < sourceShape.size())
extractShape.insert(extractShape.begin(), 1);

if (ShapedType::getNumElements(extractShape) != targetElements)
return std::nullopt;

return extractShape;
}

// Convert result offsets to source offsets via linear position.
static SmallVector<int64_t>
calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape) {
Comment on lines +1087 to +1090
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this assert that the number of elements in sourceShape and resultShape are identical?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't shape_cast op verifier take care of that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't shape_cast op verifier take care of that?

Of making sure that no invalid inputs are ever passed to this method? I doubt that ;-)

Copy link
Contributor Author

@nbpatel nbpatel Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant if the shapeCast is not maintaining the semantics of NumElements(src) == NumElements(dst) how is it even a valid instruction? I tried it locally and it does fail in the verifier here

https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L6258

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant if the shapeCast is not maintaining the semantics of NumElements(src) == NumElements(dst) how is it even a valid instruction?

shapeCast verifier will indeed maintain that, but only for shapeCast Ops. However, how do you make sure that the inputs used in this method always come from shapeCast? Perhaps I am missing something, but what is stopping anyone/anything from using this method with some random arrays that don't come from shapeCast?

// Convert result offsets to linear position.
int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
// Convert linear position to source offsets.
return delinearize(linearIndex, computeStrides(sourceShape));
}

/// This pattern unrolls `vector.shape_cast` operations according to the
/// provided target unroll shape. It unrolls a large shape cast into smaller
/// shape casts by extracting contiguous slices from the source vector, casting
/// each slice to the target shape, and assembling the result by inserting each
/// computed segment into the appropriate offset of the result vector.
///
/// This pattern only applies when contiguous slices can be extracted from the
/// source vector and inserted into the result vector such that each slice
/// remains a valid vector (and not decompose to scalars). In these cases, the
/// unrolling proceeds as:
/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
/// vector.insert_strided_slice.
///
/// Example:
/// Given a shape cast operation:
/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
///
/// and a target unroll shape of <2x4>, the pattern produces:
///
/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
/// : vector<8x2xf32> to vector<4x2xf32>
/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
/// : vector<2x4xf32> into vector<4x4xf32>
/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
/// : vector<8x2xf32> to vector<4x2xf32>
/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
/// : vector<2x4xf32> into vector<4x4xf32>
///
struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
UnrollShapeCastPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::ShapeCastOp>(context, benefit),
options(options) {}

LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
std::optional<SmallVector<int64_t>> targetShape =
getTargetShape(options, shapeCastOp);
if (!targetShape)
return failure();

VectorType sourceType = shapeCastOp.getSourceVectorType();
VectorType resultType = shapeCastOp.getResultVectorType();
ArrayRef<int64_t> sourceShape = sourceType.getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();

if (!isContiguous(*targetShape, resultShape))
return rewriter.notifyMatchFailure(
shapeCastOp, "Only supports cases where target shape is "
"contiguous in result vector shape");

int64_t targetElements = ShapedType::getNumElements(*targetShape);

// Calculate the shape to extract from source.
std::optional<SmallVector<int64_t>> extractShape =
calculateSourceExtractShape(sourceShape, targetElements);
if (!extractShape)
return rewriter.notifyMatchFailure(
shapeCastOp,
"cannot extract target number of elements contiguously from source");

Location loc = shapeCastOp.getLoc();

// Create result vector initialized to zero.
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
rewriter.getZeroAttr(resultType));

VectorType targetType =
VectorType::get(*targetShape, sourceType.getElementType());

SmallVector<int64_t> extractStrides(extractShape->size(), 1);
SmallVector<int64_t> insertStrides(targetShape->size(), 1);

for (SmallVector<int64_t> resultOffsets :
StaticTileOffsetRange(resultShape, *targetShape)) {
SmallVector<int64_t> sourceOffsets =
calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
extractStrides);
Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
loc, targetType, sourceChunk);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, targetChunk, result, resultOffsets, insertStrides);
}

rewriter.replaceOp(shapeCastOp, result);
return success();
}

private:
vector::UnrollVectorOptions options;
};

} // namespace

void mlir::vector::populateVectorUnrollPatterns(
Expand All @@ -1013,8 +1202,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
UnrollToElements, UnrollStepPattern>(patterns.getContext(),
options, benefit);
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
patterns.getContext(), options, benefit);
}

void mlir::vector::populateVectorToElementsUnrollPatterns(
Expand Down
79 changes: 79 additions & 0 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,82 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
// CHECK-NOT: arith.addf
// CHECK: return


func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
return %0 : vector<2x2x4xf32>
}

// CHECK-LABEL: func @shape_cast_1D
// CHECK-SAME: (%[[V:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
// CHECK: return %[[I1]] : vector<2x2x4xf32>


func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
%0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32>
return %0 : vector<4x4xf32>
}

// CHECK-LABEL: func @shape_cast_2D
// CHECK-SAME: (%[[V:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
// CHECK: return %[[I1]] : vector<4x4xf32>


// This is a negative test case to ensure that such shape casts are not unrolled
// because the targetShape (2x4) is not contiguous in result vector
func.func @negative_shape_cast_target_shape_not_contiguous(%v: vector<64xf32>) -> vector<8x8xf32> {
%0 = vector.shape_cast %v : vector<64xf32> to vector<8x8xf32>
return %0 : vector<8x8xf32>
}

// CHECK-LABEL: func @negative_shape_cast_target_shape_not_contiguous
// CHECK-SAME: (%[[V:.*]]: vector<64xf32>) -> vector<8x8xf32> {
// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<64xf32> to vector<8x8xf32>
// CHECK: return %[[SC]] : vector<8x8xf32>


// This is negative test case to ensure that such shape casts are not unrolled
// because it cannot determine the extractShape from source vector (8x3)
// to extract conitguous targetShape (2x4)
func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32>) -> vector<6x4xf32> {
%0 = vector.shape_cast %v : vector<8x3xf32> to vector<6x4xf32>
return %0 : vector<6x4xf32>
}

// CHECK-LABEL: func @negative_shape_cast_source_shape_not_determinable
// CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> {
// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32>
// CHECK: return %[[SC]] : vector<6x4xf32>


// TargetShape is [1x16]
func.func @shape_cast_leading_unit_dim(%v: vector<32xf32>) -> vector<1x32xf32> {
%0 = vector.shape_cast %v : vector<32xf32> to vector<1x32xf32>
return %0 : vector<1x32xf32>
}

// CHECK-LABEL: func @shape_cast_leading_unit_dim
// CHECK-SAME: (%[[V:.*]]: vector<32xf32>) -> vector<1x32xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<16xf32> to vector<1x16xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<16xf32> to vector<1x16xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
// CHECK: return %[[I1]] : vector<1x32xf32>
22 changes: 22 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,28 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions()
.setNativeShapeFn(
[](Operation *op) -> std::optional<SmallVector<int64_t>> {
auto shapeCast = dyn_cast<vector::ShapeCastOp>(op);
if (!shapeCast)
return std::nullopt;

auto resultShape = shapeCast.getResultVectorType().getShape();
// Special case with leading unit dims and different inner dim
// for result and target shape.
if (resultShape.size() == 2 && resultShape[0] == 1 &&
resultShape[1] == 32) {
return SmallVector<int64_t>{1, 16};
}
// Default case: [2,4] for all tests.
return SmallVector<int64_t>{2, 4};
})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ShapeCastOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
Expand Down