Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3848,8 +3848,12 @@ struct Conv1DGenerator

const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
const Type dstType =
cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
// Handle both shaped as well as scalar types.
Type dstType;
if (auto shapedType = dyn_cast<ShapedType>(val.getType()))
dstType = shapedType.cloneWith(std::nullopt, dstElementType);
else
dstType = dstElementType;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Is there a test case for it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For filter we use extractConvFilterSlices - this extracts a scalar type for a non-chanelled conv op.

We didn't require this earlier because none of the non-chanelled lit tests were of mixed precision. Example: @conv1d_8_tensor which is already checked in demonstrates the scalar case :

// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : f32 from vector<4xf32>

but since the conv example doesn't have mismatch in the element types of input/filter and output, the need never arose.

This PR is therefore trying to add the case of mixed precision with the proposed fix (and the lit test to demonstrate the same).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks


if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
return arith::SIToFPOp::create(rewriter, loc, dstType, val);
Expand Down Expand Up @@ -3888,6 +3892,8 @@ struct Conv1DGenerator
// convolution.
Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
lhs = promote(rewriter, loc, lhs, res.getType());
rhs = promote(rewriter, loc, rhs, res.getType());
return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
rhs, res, vector::CombiningKind::ADD);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,59 @@ module attributes {transform.with_named_sequence} {

// -----

// Test for mixed precision hanlding of 1D non-channeled convolution.
func.func @conv1d_mixed_precision_bf16_f32(%input: tensor<5xbf16>, %filter: tensor<2xbf16>, %output: tensor<4xf32>) -> tensor<4xf32> {
%0 = linalg.conv_1d ins(%input, %filter : tensor<5xbf16>, tensor<2xbf16>)
outs(%output : tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}

// CHECK: func @conv1d_mixed_precision_bf16_f32
// CHECK-SAME: (%[[INPUT:.+]]: tensor<5xbf16>, %[[FILTER:.+]]: tensor<2xbf16>, %[[OUTPUT:.+]]: tensor<4xf32>)

// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[BF0:.+]] = arith.constant 0.000000e+00 : bf16

/// Read the whole data in one shot.
// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]]], %[[BF0]]
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]]], %[[BF0]]
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]]], %[[F0]]

// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0], sizes = [4], strides = [1]} : vector<5xbf16> to vector<4xbf16>
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<5xbf16> to vector<4xbf16>

// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : bf16 from vector<2xbf16>
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : bf16 from vector<2xbf16>

/// Extend input and filter to f32 and then perform outerproduct.
/// kw == 0
// CHECK: %[[V_INPUT_0_F32:.+]] = arith.extf %[[V_INPUT_0]] : vector<4xbf16> to vector<4xf32>
// CHECK: %[[V_FILTER_0_F32:.+]] = arith.extf %[[V_FILTER_0]] : bf16 to f32
// CHECK: %[[RES_0:.+]] = vector.outerproduct %[[V_INPUT_0_F32]], %[[V_FILTER_0_F32]], %[[V_OUTPUT_R]] {kind = #vector.kind<add>}
// CHECK-SAME: : vector<4xf32>, f32
/// kw == 1
// CHECK: %[[V_INPUT_1_F32:.+]] = arith.extf %[[V_INPUT_1]] : vector<4xbf16> to vector<4xf32>
// CHECK: %[[V_FILTER_1_F32:.+]] = arith.extf %[[V_FILTER_1]] : bf16 to f32
// CHECK: %[[RES_1:.+]] = vector.outerproduct %[[V_INPUT_1_F32]], %[[V_FILTER_1_F32]], %[[RES_0]] {kind = #vector.kind<add>}
// CHECK-SAME: : vector<4xf32>, f32

// Write the result back in one shot.
// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.conv_1d", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
Expand Down Expand Up @@ -801,8 +854,10 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter:
// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x2xf16> from vector<1x3x2xf16>
// CHECK: %[[CONT:.*]] = vector.contract
// {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
// CHECK: %[[V_INPUT_F32:.+]] = arith.extf %[[V_INPUT_R]] : vector<1x2x3xf16> to vector<1x2x3xf32>
// CHECK: %[[V_FILTER_F32:.+]] = arith.extf %[[V_FILTER_1]] : vector<3x2xf16> to vector<3x2xf32>
// CHECK: %[[CONT:.+]] = vector.contract
// CHECK-SAME: %[[V_INPUT_F32]], %[[V_FILTER_F32]], %[[V_OUTPUT_R]] : vector<1x2x3xf32>, vector<3x2xf32> into vector<1x2x2xf32>
// CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]

module attributes {transform.with_named_sequence} {
Expand Down