From a9a794fc8103d5546d9abf2e446293b9c447d4c5 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 5 Feb 2026 08:19:21 +0000 Subject: [PATCH] [mlir][Linalg] Promote lhs/rhs when vectorizing conv1D as outerproduct -- vector.outerproduct requires lhs/rhs to have same element type as the result. -- This commit adds a fix to promote lhs/rhs to have result's element type when vectorizing conv1D slice to vector.outerproduct. -- This is along the similar lines of what happens when we are vectorizing conv1D slice to vector.contract - the corresponding CHECK line was incorrect and this commit fixes that too. Signed-off-by: Abhishek Varma --- .../Linalg/Transforms/Vectorization.cpp | 10 +++- .../convolution-with-patterns.mlir | 59 ++++++++++++++++++- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 24579f6aa0217..7ac759f635f87 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -3848,8 +3848,12 @@ struct Conv1DGenerator const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth(); const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth(); - const Type dstType = - cast(val.getType()).cloneWith(std::nullopt, dstElementType); + // Handle both shaped as well as scalar types. + Type dstType; + if (auto shapedType = dyn_cast(val.getType())) + dstType = shapedType.cloneWith(std::nullopt, dstElementType); + else + dstType = dstElementType; if (isa(srcElementType) && isa(dstElementType)) { return arith::SIToFPOp::create(rewriter, loc, dstType, val); @@ -3888,6 +3892,8 @@ struct Conv1DGenerator // convolution. Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc, Value lhs, Value rhs, Value res) { + lhs = promote(rewriter, loc, lhs, res.getType()); + rhs = promote(rewriter, loc, rhs, res.getType()); return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD); } diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir index f8781ff5452d9..97b27befd44e2 100644 --- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir @@ -678,6 +678,59 @@ module attributes {transform.with_named_sequence} { // ----- +// Test for mixed precision hanlding of 1D non-channeled convolution. +func.func @conv1d_mixed_precision_bf16_f32(%input: tensor<5xbf16>, %filter: tensor<2xbf16>, %output: tensor<4xf32>) -> tensor<4xf32> { + %0 = linalg.conv_1d ins(%input, %filter : tensor<5xbf16>, tensor<2xbf16>) + outs(%output : tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK: func @conv1d_mixed_precision_bf16_f32 +// CHECK-SAME: (%[[INPUT:.+]]: tensor<5xbf16>, %[[FILTER:.+]]: tensor<2xbf16>, %[[OUTPUT:.+]]: tensor<4xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[BF0:.+]] = arith.constant 0.000000e+00 : bf16 + +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]]], %[[BF0]] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]]], %[[BF0]] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]]], %[[F0]] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0], sizes = [4], strides = [1]} : vector<5xbf16> to vector<4xbf16> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<5xbf16> to vector<4xbf16> + +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : bf16 from vector<2xbf16> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : bf16 from vector<2xbf16> + +/// Extend input and filter to f32 and then perform outerproduct. +/// kw == 0 +// CHECK: %[[V_INPUT_0_F32:.+]] = arith.extf %[[V_INPUT_0]] : vector<4xbf16> to vector<4xf32> +// CHECK: %[[V_FILTER_0_F32:.+]] = arith.extf %[[V_FILTER_0]] : bf16 to f32 +// CHECK: %[[RES_0:.+]] = vector.outerproduct %[[V_INPUT_0_F32]], %[[V_FILTER_0_F32]], %[[V_OUTPUT_R]] {kind = #vector.kind} +// CHECK-SAME: : vector<4xf32>, f32 +/// kw == 1 +// CHECK: %[[V_INPUT_1_F32:.+]] = arith.extf %[[V_INPUT_1]] : vector<4xbf16> to vector<4xf32> +// CHECK: %[[V_FILTER_1_F32:.+]] = arith.extf %[[V_FILTER_1]] : bf16 to f32 +// CHECK: %[[RES_1:.+]] = vector.outerproduct %[[V_INPUT_1_F32]], %[[V_FILTER_1_F32]], %[[RES_0]] {kind = #vector.kind} +// CHECK-SAME: : vector<4xf32>, f32 + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_1d", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) { linalg.depthwise_conv_1d_nwc_wc {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} @@ -801,8 +854,10 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: // CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]] // CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] // CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x2xf16> from vector<1x3x2xf16> -// CHECK: %[[CONT:.*]] = vector.contract -// {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32> +// CHECK: %[[V_INPUT_F32:.+]] = arith.extf %[[V_INPUT_R]] : vector<1x2x3xf16> to vector<1x2x3xf32> +// CHECK: %[[V_FILTER_F32:.+]] = arith.extf %[[V_FILTER_1]] : vector<3x2xf16> to vector<3x2xf32> +// CHECK: %[[CONT:.+]] = vector.contract +// CHECK-SAME: %[[V_INPUT_F32]], %[[V_FILTER_F32]], %[[V_OUTPUT_R]] : vector<1x2x3xf32>, vector<3x2xf32> into vector<1x2x2xf32> // CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] module attributes {transform.with_named_sequence} {