Skip to content

Conversation

@Men-cotton
Copy link
Contributor

Fixes a crash in ReorderCastOpsOnBroadcast by ensuring the cast result is a VectorType before applying the pattern.
A regression test has been added to mlir/test/Dialect/Vector/vector-sink.mlir.

Fixes: #126371

@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Men-cotton (Men-cotton)

Changes

Fixes a crash in ReorderCastOpsOnBroadcast by ensuring the cast result is a VectorType before applying the pattern.
A regression test has been added to mlir/test/Dialect/Vector/vector-sink.mlir.

Fixes: #126371


Full diff: https://github.com/llvm/llvm-project/pull/170985.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+2)
  • (modified) mlir/test/Dialect/Vector/vector-sink.mlir (+17)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 726da1e9a3d14..ad16b80a732b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -453,6 +453,8 @@ struct ReorderCastOpsOnBroadcast
                                 PatternRewriter &rewriter) const override {
     if (op->getNumOperands() != 1)
       return failure();
+    if (!isa<VectorType>(op->getResult(0).getType()))
+      return failure();
     auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
     if (!bcastOp)
       return failure();
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index beaba52af1841..50ff63b44901a 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -850,3 +850,20 @@ func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg
   vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
   return %0 : vector<1xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_cast_non_vector_result
+// CHECK-SAME: (%[[ARG:.*]]: i64)
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG]] : i64 to vector<26x7xi64>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[BCAST]] : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>>
+// CHECK: return %[[CAST]] : !llvm.array<26 x vector<7xi64>>
+/// This test ensures that the `ReorderCastOpsOnBroadcast` pattern does not
+/// attempt to reorder a cast operation that produces a non-vector result type.
+/// Previously, this would crash because the pattern assumed the result was a
+/// vector type when creating the new inner broadcast.
+func.func @broadcast_cast_non_vector_result(%arg0: i64) -> !llvm.array<26 x vector<7xi64>> {
+  %0 = vector.broadcast %arg0 : i64 to vector<26x7xi64>
+  %1 = builtin.unrealized_conversion_cast %0 : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>>
+  return %1 : !llvm.array<26 x vector<7xi64>>
+}

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thank you! LGTM % minor comments

// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG]] : i64 to vector<26x7xi64>
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[BCAST]] : vector<26x7xi64> to !llvm.array<26 x vector<7xi64>>
// CHECK: return %[[CAST]] : !llvm.array<26 x vector<7xi64>>
/// This test ensures that the `ReorderCastOpsOnBroadcast` pattern does not
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for this pattern are located below this block comment:

// [Pattern: ReorderCastOpsOnBroadcast]

Please move it accordingly. Thanks!

Comment on lines 863 to 864
/// Previously, this would crash because the pattern assumed the result was a
/// vector type when creating the new inner broadcast.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] "Previously" is a very relative term ("previously to what?"). I would just drop this sentence, it doesn't add any new info.

/// attempt to reorder a cast operation that produces a non-vector result type.
/// Previously, this would crash because the pattern assumed the result was a
/// vector type when creating the new inner broadcast.
func.func @broadcast_cast_non_vector_result(%arg0: i64) -> !llvm.array<26 x vector<7xi64>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow our naming convention (as per https://mlir.llvm.org/getting_started/TestingGuide/#test-naming-convention)

Suggested change
func.func @broadcast_cast_non_vector_result(%arg0: i64) -> !llvm.array<26 x vector<7xi64>> {
func.func @negative_broadcast_cast_non_vector_result(%arg0: i64) -> !llvm.array<26 x vector<7xi64>> {

@Men-cotton
Copy link
Contributor Author

@banach-space
I updated the test docs based on your reviews. Thanks!
Please review again, and merge if there are no problems since I don't have write access.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@banach-space banach-space merged commit 94ebcfd into llvm:main Dec 9, 2025
10 checks passed
@Men-cotton Men-cotton deleted the users/Men-cotton/mlir/126371 branch December 9, 2025 16:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir] Crash when using --test-vector-sink-patterns

3 participants