Skip to content

Conversation

amd-eochoalo
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2025

@llvm/pr-subscribers-mlir-vector

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/157740.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+45-2)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+23)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7dde6311fa809..54eb182a9680f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,6 +798,49 @@ struct LinearizeVectorFromElements final
   }
 };
 
+/// This pattern linearizes the operand in `vector.to_elements` operations
+/// by converting the result type to a 1-D vector while preserving all element
+/// values. The transformation creates a linearized `vector.shape_cast`
+/// followed by a `vector.to_elements`.
+///
+/// Example:
+///
+///     %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// is converted to:
+///
+///     %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
+///     %0:4 = vector.to_elements %vector_cast : vector<4xf32>
+///
+struct LinearizeVectorToElements final
+    : public OpConversionPattern<vector::ToElementsOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorToElements(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType vecType = toElementsOp.getSource().getType();
+    if (vecType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(
+          toElementsOp, "the rank is already less than or equal to 1");
+
+    assert(vecType.getNumScalableDims() == 0 &&
+           "scalable vector is not yet supported");
+    auto vec1DType =
+        VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+    Value shapeCast = vector::ShapeCastOp::create(
+        rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
+    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(
+        toElementsOp, toElementsOp.getResultTypes(), shapeCast);
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -890,8 +933,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
            LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
-           LinearizeVectorStore, LinearizeVectorFromElements>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorStore, LinearizeVectorFromElements,
+           LinearizeVectorToElements>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 5e8bfd0698b33..fe697c8b9c057 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
   %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
   return %1 : vector<2x2xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}

@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2025

@llvm/pr-subscribers-mlir

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/157740.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+45-2)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+23)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7dde6311fa809..54eb182a9680f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,6 +798,49 @@ struct LinearizeVectorFromElements final
   }
 };
 
+/// This pattern linearizes the operand in `vector.to_elements` operations
+/// by converting the result type to a 1-D vector while preserving all element
+/// values. The transformation creates a linearized `vector.shape_cast`
+/// followed by a `vector.to_elements`.
+///
+/// Example:
+///
+///     %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// is converted to:
+///
+///     %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
+///     %0:4 = vector.to_elements %vector_cast : vector<4xf32>
+///
+struct LinearizeVectorToElements final
+    : public OpConversionPattern<vector::ToElementsOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorToElements(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType vecType = toElementsOp.getSource().getType();
+    if (vecType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(
+          toElementsOp, "the rank is already less than or equal to 1");
+
+    assert(vecType.getNumScalableDims() == 0 &&
+           "scalable vector is not yet supported");
+    auto vec1DType =
+        VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+    Value shapeCast = vector::ShapeCastOp::create(
+        rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
+    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(
+        toElementsOp, toElementsOp.getResultTypes(), shapeCast);
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -890,8 +933,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
            LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
-           LinearizeVectorStore, LinearizeVectorFromElements>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorStore, LinearizeVectorFromElements,
+           LinearizeVectorToElements>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 5e8bfd0698b33..fe697c8b9c057 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
   %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
   return %1 : vector<2x2xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}

: OpConversionPattern(typeConverter, context, benefit) {}

LogicalResult
matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to use the adaptor, but I noticed that in the pattern above adaptor is not used. To continue the style of the pattern above I also did not use it here.

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only minor comments. LGTM; but please give others some time to review

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@amd-eochoalo amd-eochoalo merged commit b812e3d into llvm:main Sep 11, 2025
9 checks passed
amd-eochoalo added a commit to amd-eochoalo/iree that referenced this pull request Sep 12, 2025
After:

* llvm/llvm-project#157740
* llvm/llvm-project#157142

the linearization of vector.to_elements pattern
can be changed to either the one now upstream
or to the unrolling version.

This commit changes the strategy from linearizing
to unrolling.
amd-eochoalo added a commit to amd-eochoalo/iree that referenced this pull request Sep 12, 2025
After:

* llvm/llvm-project#157740
* llvm/llvm-project#157142

the linearization of vector.to_elements pattern
can be changed to either the one now upstream
or to the unrolling version.

This commit changes the strategy from linearizing
to unrolling.

Signed-off-by: Erick Ochoa <erick.ochoalopez@amd.com>
amd-eochoalo added a commit to amd-eochoalo/iree that referenced this pull request Sep 12, 2025
After:

* llvm/llvm-project#157740
* llvm/llvm-project#157142

the linearization of vector.to_elements pattern
can be changed to either the one now upstream
or to the unrolling version.

This commit changes the strategy from linearizing
to unrolling.

Signed-off-by: Erick Ochoa <erick.ochoalopez@amd.com>
kuhar pushed a commit to iree-org/iree that referenced this pull request Sep 12, 2025
After:

* llvm/llvm-project#157740
* llvm/llvm-project#157142

the linearization of vector.to_elements pattern
can be changed to either the one now upstream
or to the unrolling version.

This commit changes the strategy from linearizing
to unrolling.

Signed-off-by: Erick Ochoa <erick.ochoalopez@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants