Skip to content

[mlir][Linalg] Promote lhs/rhs when vectorizing conv1D as outerproduct#179883

Merged
Abhishek-Varma merged 1 commit into
llvm:mainfrom
Abhishek-Varma:fix_vectorization_promotion
Feb 7, 2026
Merged

[mlir][Linalg] Promote lhs/rhs when vectorizing conv1D as outerproduct#179883
Abhishek-Varma merged 1 commit into
llvm:mainfrom
Abhishek-Varma:fix_vectorization_promotion

Conversation

@Abhishek-Varma
Copy link
Copy Markdown
Contributor

-- vector.outerproduct requires lhs/rhs to have same element type as the
result.
-- This commit adds a fix to promote lhs/rhs to have result's element
type when vectorizing conv1D slice to vector.outerproduct.
-- This is along the similar lines of what happens when we are
vectorizing conv1D slice to vector.contract - the corresponding
CHECK line was incorrect and this commit fixes that too.

Signed-off-by: Abhishek Varma abhvarma@amd.com

-- vector.outerproduct requires lhs/rhs to have same element type as the
   result.
-- This commit adds a fix to promote lhs/rhs to have result's element
   type when vectorizing conv1D slice to vector.outerproduct.
-- This is along the similar lines of what happens when we are
   vectorizing conv1D slice to vector.contract - the corresponding
   CHECK line was incorrect and this commit fixes that too.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 5, 2026

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Abhishek Varma (Abhishek-Varma)

Changes

-- vector.outerproduct requires lhs/rhs to have same element type as the
result.
-- This commit adds a fix to promote lhs/rhs to have result's element
type when vectorizing conv1D slice to vector.outerproduct.
-- This is along the similar lines of what happens when we are
vectorizing conv1D slice to vector.contract - the corresponding
CHECK line was incorrect and this commit fixes that too.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>


Full diff: https://github.com/llvm/llvm-project/pull/179883.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+8-2)
  • (modified) mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir (+57-2)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 24579f6aa0217..7ac759f635f87 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3848,8 +3848,12 @@ struct Conv1DGenerator
 
     const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
     const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
-    const Type dstType =
-        cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
+    // Handle both shaped as well as scalar types.
+    Type dstType;
+    if (auto shapedType = dyn_cast<ShapedType>(val.getType()))
+      dstType = shapedType.cloneWith(std::nullopt, dstElementType);
+    else
+      dstType = dstElementType;
 
     if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
       return arith::SIToFPOp::create(rewriter, loc, dstType, val);
@@ -3888,6 +3892,8 @@ struct Conv1DGenerator
   // convolution.
   Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
                                   Value lhs, Value rhs, Value res) {
+    lhs = promote(rewriter, loc, lhs, res.getType());
+    rhs = promote(rewriter, loc, rhs, res.getType());
     return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
                                           rhs, res, vector::CombiningKind::ADD);
   }
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
index f8781ff5452d9..97b27befd44e2 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
@@ -678,6 +678,59 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Test for mixed precision hanlding of 1D non-channeled convolution.
+func.func @conv1d_mixed_precision_bf16_f32(%input: tensor<5xbf16>, %filter: tensor<2xbf16>, %output: tensor<4xf32>) -> tensor<4xf32> {
+  %0 = linalg.conv_1d ins(%input, %filter : tensor<5xbf16>, tensor<2xbf16>)
+                     outs(%output : tensor<4xf32>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+//      CHECK: func @conv1d_mixed_precision_bf16_f32
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<5xbf16>, %[[FILTER:.+]]: tensor<2xbf16>, %[[OUTPUT:.+]]: tensor<4xf32>)
+
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG:   %[[BF0:.+]] = arith.constant 0.000000e+00 : bf16
+
+/// Read the whole data in one shot.
+//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]]], %[[BF0]]
+//  CHECK-DAG:   %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]]], %[[BF0]]
+//  CHECK-DAG:   %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]]], %[[F0]]
+
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0], sizes = [4], strides = [1]} : vector<5xbf16> to vector<4xbf16>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [1], sizes = [4], strides = [1]} : vector<5xbf16> to vector<4xbf16>
+
+//      CHECK:   %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : bf16 from vector<2xbf16>
+//      CHECK:   %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : bf16 from vector<2xbf16>
+
+/// Extend input and filter to f32 and then perform outerproduct.
+/// kw == 0
+//      CHECK:   %[[V_INPUT_0_F32:.+]] = arith.extf %[[V_INPUT_0]] : vector<4xbf16> to vector<4xf32>
+//      CHECK:   %[[V_FILTER_0_F32:.+]] = arith.extf %[[V_FILTER_0]] : bf16 to f32
+//      CHECK:   %[[RES_0:.+]] = vector.outerproduct %[[V_INPUT_0_F32]], %[[V_FILTER_0_F32]], %[[V_OUTPUT_R]] {kind = #vector.kind<add>}
+// CHECK-SAME:     : vector<4xf32>, f32
+/// kw == 1
+//      CHECK:   %[[V_INPUT_1_F32:.+]] = arith.extf %[[V_INPUT_1]] : vector<4xbf16> to vector<4xf32>
+//      CHECK:   %[[V_FILTER_1_F32:.+]] = arith.extf %[[V_FILTER_1]] : bf16 to f32
+//      CHECK:   %[[RES_1:.+]] = vector.outerproduct %[[V_INPUT_1_F32]], %[[V_FILTER_1_F32]], %[[RES_0]] {kind = #vector.kind<add>}
+// CHECK-SAME:     : vector<4xf32>, f32
+
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_1d", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
   linalg.depthwise_conv_1d_nwc_wc
     {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
@@ -801,8 +854,10 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter:
 //      CHECK:   %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]]
 //      CHECK:   %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 //      CHECK:   %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x2xf16> from vector<1x3x2xf16>
-//      CHECK:   %[[CONT:.*]] = vector.contract
-//        {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
+//      CHECK:   %[[V_INPUT_F32:.+]] = arith.extf %[[V_INPUT_R]] : vector<1x2x3xf16> to vector<1x2x3xf32>
+//      CHECK:   %[[V_FILTER_F32:.+]] = arith.extf %[[V_FILTER_1]] : vector<3x2xf16> to vector<3x2xf32>
+//      CHECK:   %[[CONT:.+]] = vector.contract
+// CHECK-SAME:     %[[V_INPUT_F32]], %[[V_FILTER_F32]], %[[V_OUTPUT_R]] : vector<1x2x3xf32>, vector<3x2xf32> into vector<1x2x2xf32>
 //      CHECK:   vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
 module attributes {transform.with_named_sequence} {

if (auto shapedType = dyn_cast<ShapedType>(val.getType()))
dstType = shapedType.cloneWith(std::nullopt, dstElementType);
else
dstType = dstElementType;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Is there a test case for it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For filter we use extractConvFilterSlices - this extracts a scalar type for a non-chanelled conv op.

We didn't require this earlier because none of the non-chanelled lit tests were of mixed precision. Example: @conv1d_8_tensor which is already checked in demonstrates the scalar case :

// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : f32 from vector<4xf32>

but since the conv example doesn't have mismatch in the element types of input/filter and output, the need never arose.

This PR is therefore trying to add the case of mixed precision with the proposed fix (and the lit test to demonstrate the same).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks

if (auto shapedType = dyn_cast<ShapedType>(val.getType()))
dstType = shapedType.cloneWith(std::nullopt, dstElementType);
else
dstType = dstElementType;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks

@Abhishek-Varma Abhishek-Varma merged commit 060f325 into llvm:main Feb 7, 2026
13 checks passed
rishabhmadan19 pushed a commit to rishabhmadan19/llvm-project that referenced this pull request Feb 9, 2026
llvm#179883)

-- vector.outerproduct requires lhs/rhs to have same element type as the
   result.
-- This commit adds a fix to promote lhs/rhs to have result's element
   type when vectorizing conv1D slice to vector.outerproduct.
-- This is along the similar lines of what happens when we are
   vectorizing conv1D slice to vector.contract - the corresponding
   CHECK line was incorrect and this commit fixes that too.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
hanhanW pushed a commit to iree-org/iree that referenced this pull request Feb 11, 2026
…er dim ops (#23294)" (#23383) (#23455)

#23294 was reverted by
#23383 to resolve
#23382

But now with llvm/llvm-project#179883 merged we
can reapply #23294

Refer to [this
thread](#23382 (comment))
for more detail.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
kimm240 pushed a commit to kimm240/iree that referenced this pull request May 8, 2026
…er dim ops (iree-org#23294)" (iree-org#23383) (iree-org#23455)

iree-org#23294 was reverted by
iree-org#23383 to resolve
iree-org#23382

But now with llvm/llvm-project#179883 merged we
can reapply iree-org#23294

Refer to [this
thread](iree-org#23382 (comment))
for more detail.

Signed-off-by: Abhishek Varma <abhvarma@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants