[mlir][Linalg] Promote lhs/rhs when vectorizing conv1D as outerproduct#179883
Conversation
-- vector.outerproduct requires lhs/rhs to have same element type as the result. -- This commit adds a fix to promote lhs/rhs to have result's element type when vectorizing conv1D slice to vector.outerproduct. -- This is along the similar lines of what happens when we are vectorizing conv1D slice to vector.contract - the corresponding CHECK line was incorrect and this commit fixes that too. Signed-off-by: Abhishek Varma <abhvarma@amd.com>
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Abhishek Varma (Abhishek-Varma) Changes-- vector.outerproduct requires lhs/rhs to have same element type as the Signed-off-by: Abhishek Varma <abhvarma@amd.com> Full diff: https://github.com/llvm/llvm-project/pull/179883.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 24579f6aa0217..7ac759f635f87 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3848,8 +3848,12 @@ struct Conv1DGenerator
const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
- const Type dstType =
- cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
+ // Handle both shaped as well as scalar types.
+ Type dstType;
+ if (auto shapedType = dyn_cast<ShapedType>(val.getType()))
+ dstType = shapedType.cloneWith(std::nullopt, dstElementType);
+ else
+ dstType = dstElementType;
if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
return arith::SIToFPOp::create(rewriter, loc, dstType, val);
@@ -3888,6 +3892,8 @@ struct Conv1DGenerator
// convolution.
Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
+ lhs = promote(rewriter, loc, lhs, res.getType());
+ rhs = promote(rewriter, loc, rhs, res.getType());
return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
rhs, res, vector::CombiningKind::ADD);
}
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
index f8781ff5452d9..97b27befd44e2 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
@@ -678,6 +678,59 @@ module attributes {transform.with_named_sequence} {
// -----
+// Test for mixed precision hanlding of 1D non-channeled convolution.
+func.func @conv1d_mixed_precision_bf16_f32(%input: tensor<5xbf16>, %filter: tensor<2xbf16>, %output: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = linalg.conv_1d ins(%input, %filter : tensor<5xbf16>, tensor<2xbf16>)
+ outs(%output : tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// CHECK: func @conv1d_mixed_precision_bf16_f32
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<5xbf16>, %[[FILTER:.+]]: tensor<2xbf16>, %[[OUTPUT:.+]]: tensor<4xf32>)
+
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[BF0:.+]] = arith.constant 0.000000e+00 : bf16
+
+/// Read the whole data in one shot.
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]]], %[[BF0]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]]], %[[BF0]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]]], %[[F0]]
+
+// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0], sizes = [4], strides = [1]} : vector<5xbf16> to vector<4xbf16>
+// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<5xbf16> to vector<4xbf16>
+
+// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : bf16 from vector<2xbf16>
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : bf16 from vector<2xbf16>
+
+/// Extend input and filter to f32 and then perform outerproduct.
+/// kw == 0
+// CHECK: %[[V_INPUT_0_F32:.+]] = arith.extf %[[V_INPUT_0]] : vector<4xbf16> to vector<4xf32>
+// CHECK: %[[V_FILTER_0_F32:.+]] = arith.extf %[[V_FILTER_0]] : bf16 to f32
+// CHECK: %[[RES_0:.+]] = vector.outerproduct %[[V_INPUT_0_F32]], %[[V_FILTER_0_F32]], %[[V_OUTPUT_R]] {kind = #vector.kind<add>}
+// CHECK-SAME: : vector<4xf32>, f32
+/// kw == 1
+// CHECK: %[[V_INPUT_1_F32:.+]] = arith.extf %[[V_INPUT_1]] : vector<4xbf16> to vector<4xf32>
+// CHECK: %[[V_FILTER_1_F32:.+]] = arith.extf %[[V_FILTER_1]] : bf16 to f32
+// CHECK: %[[RES_1:.+]] = vector.outerproduct %[[V_INPUT_1_F32]], %[[V_FILTER_1_F32]], %[[RES_0]] {kind = #vector.kind<add>}
+// CHECK-SAME: : vector<4xf32>, f32
+
+// Write the result back in one shot.
+// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_1d", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
@@ -801,8 +854,10 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter:
// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x2xf16> from vector<1x3x2xf16>
-// CHECK: %[[CONT:.*]] = vector.contract
-// {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
+// CHECK: %[[V_INPUT_F32:.+]] = arith.extf %[[V_INPUT_R]] : vector<1x2x3xf16> to vector<1x2x3xf32>
+// CHECK: %[[V_FILTER_F32:.+]] = arith.extf %[[V_FILTER_1]] : vector<3x2xf16> to vector<3x2xf32>
+// CHECK: %[[CONT:.+]] = vector.contract
+// CHECK-SAME: %[[V_INPUT_F32]], %[[V_FILTER_F32]], %[[V_OUTPUT_R]] : vector<1x2x3xf32>, vector<3x2xf32> into vector<1x2x2xf32>
// CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
module attributes {transform.with_named_sequence} {
|
| if (auto shapedType = dyn_cast<ShapedType>(val.getType())) | ||
| dstType = shapedType.cloneWith(std::nullopt, dstElementType); | ||
| else | ||
| dstType = dstElementType; |
There was a problem hiding this comment.
Why do we need this? Is there a test case for it?
There was a problem hiding this comment.
For filter we use extractConvFilterSlices - this extracts a scalar type for a non-chanelled conv op.
We didn't require this earlier because none of the non-chanelled lit tests were of mixed precision. Example: @conv1d_8_tensor which is already checked in demonstrates the scalar case :
but since the conv example doesn't have mismatch in the element types of input/filter and output, the need never arose.
This PR is therefore trying to add the case of mixed precision with the proposed fix (and the lit test to demonstrate the same).
| if (auto shapedType = dyn_cast<ShapedType>(val.getType())) | ||
| dstType = shapedType.cloneWith(std::nullopt, dstElementType); | ||
| else | ||
| dstType = dstElementType; |
llvm#179883) -- vector.outerproduct requires lhs/rhs to have same element type as the result. -- This commit adds a fix to promote lhs/rhs to have result's element type when vectorizing conv1D slice to vector.outerproduct. -- This is along the similar lines of what happens when we are vectorizing conv1D slice to vector.contract - the corresponding CHECK line was incorrect and this commit fixes that too. Signed-off-by: Abhishek Varma <abhvarma@amd.com>
…er dim ops (#23294)" (#23383) (#23455) #23294 was reverted by #23383 to resolve #23382 But now with llvm/llvm-project#179883 merged we can reapply #23294 Refer to [this thread](#23382 (comment)) for more detail. Signed-off-by: Abhishek Varma <abhvarma@amd.com>
…er dim ops (iree-org#23294)" (iree-org#23383) (iree-org#23455) iree-org#23294 was reverted by iree-org#23383 to resolve iree-org#23382 But now with llvm/llvm-project#179883 merged we can reapply iree-org#23294 Refer to [this thread](iree-org#23382 (comment)) for more detail. Signed-off-by: Abhishek Varma <abhvarma@amd.com>
-- vector.outerproduct requires lhs/rhs to have same element type as the
result.
-- This commit adds a fix to promote lhs/rhs to have result's element
type when vectorizing conv1D slice to vector.outerproduct.
-- This is along the similar lines of what happens when we are
vectorizing conv1D slice to vector.contract - the corresponding
CHECK line was incorrect and this commit fixes that too.
Signed-off-by: Abhishek Varma abhvarma@amd.com