error: 'vector.outerproduct' op failed to verify that lhs operand and result have same element type
%13 = linalg.generic {indexing_maps = [#map4, #map5, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%1, %0 : tensor<1x1x32x32xbf16>, tensor<1x1x29x29xbf16>) outs(%12 : tensor<1x1x4x4xf32>) {
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%extracted_slice_1, %extracted_slice_0 : tensor<1x29xbf16>, tensor<1x29xbf16>) outs(%arg7 : tensor<1x1xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 1, 29], thread = [1, 1, 0, 0], workgroup = [8, 8, 0, 0]}>} {
^bb0(%in: bf16, %in_2: bf16, %out: f32):
%17 = arith.extf %in : bf16 to f32
%18 = arith.extf %in_2 : bf16 to f32
%19 = arith.mulf %17, %18 : f32
%20 = arith.addf %out, %19 : f32
linalg.yield %20 : f32
} -> tensor<1x1xf32>
%16 = linalg.conv_1d ins(%extracted_slice_2, %extracted_slice_3 : tensor<29xbf16>, tensor<29xbf16>) outs(%extracted_slice_4 : tensor<1xf32>) -> tensor<1xf32>
When linalg.conv_1d with mixed-precision (bf16 inputs, f32 output) is vectorized, it incorrectly generates vector.outerproduct with bf16 LHS and f32 result, which violates the op's type constraints.
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2 + d5, d3 + d6)>
#map5 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d0, d5, d6)>
module @module {
util.func public @conv_2d_bfloat16_input_weight_backward_1x1x32x32_nchw_1x1x4x4_fchw_nfhw_1x1s_0x0p_1x1d_1g$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view) attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<1x1x29x29xbf16>
%1 = hal.tensor.import wait(%arg3) => %arg1 : !hal.buffer_view -> tensor<1x1x32x32xbf16>
%2 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<1x1x4x4xbf16>
%3 = tensor.empty() : tensor<1x1x4x4xbf16>
%4 = linalg.fill ins(%cst : bf16) outs(%3 : tensor<1x1x4x4xbf16>) -> tensor<1x1x4x4xbf16>
%5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<1x1x4x4xbf16>) outs(%4 : tensor<1x1x4x4xbf16>) {
^bb0(%in: bf16, %out: bf16):
%18 = linalg.index 2 : index
%19 = linalg.index 3 : index
%20 = arith.subi %c3, %18 : index
%21 = arith.subi %c3, %19 : index
%extracted = tensor.extract %2[%c0, %c0, %20, %21] : tensor<1x1x4x4xbf16>
linalg.yield %extracted : bf16
} -> tensor<1x1x4x4xbf16>
%padded = tensor.pad %0 low[0, 0, 3, 3] high[0, 0, 3, 3] {
^bb0(%arg5: index, %arg6: index, %arg7: index, %arg8: index):
tensor.yield %cst : bf16
} : tensor<1x1x29x29xbf16> to tensor<1x1x35x35xbf16>
%6 = tensor.empty() : tensor<1x1x32x32xf32>
%7 = linalg.fill ins(%cst_0 : f32) outs(%6 : tensor<1x1x32x32xf32>) -> tensor<1x1x32x32xf32>
%8 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%padded, %5 : tensor<1x1x35x35xbf16>, tensor<1x1x4x4xbf16>) outs(%7 : tensor<1x1x32x32xf32>) {
^bb0(%in: bf16, %in_1: bf16, %out: f32):
%18 = arith.extf %in : bf16 to f32
%19 = arith.extf %in_1 : bf16 to f32
%20 = arith.mulf %18, %19 : f32
%21 = arith.addf %out, %20 : f32
linalg.yield %21 : f32
} -> tensor<1x1x32x32xf32>
%9 = tensor.empty() : tensor<1x1x32x32xbf16>
%10 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%8 : tensor<1x1x32x32xf32>) outs(%9 : tensor<1x1x32x32xbf16>) {
^bb0(%in: f32, %out: bf16):
%18 = arith.truncf %in : f32 to bf16
linalg.yield %18 : bf16
} -> tensor<1x1x32x32xbf16>
%11 = tensor.empty() : tensor<1x1x4x4xf32>
%12 = linalg.fill ins(%cst_0 : f32) outs(%11 : tensor<1x1x4x4xf32>) -> tensor<1x1x4x4xf32>
%13 = linalg.generic {indexing_maps = [#map4, #map5, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%1, %0 : tensor<1x1x32x32xbf16>, tensor<1x1x29x29xbf16>) outs(%12 : tensor<1x1x4x4xf32>) {
^bb0(%in: bf16, %in_1: bf16, %out: f32):
%18 = arith.extf %in : bf16 to f32
%19 = arith.extf %in_1 : bf16 to f32
%20 = arith.mulf %18, %19 : f32
%21 = arith.addf %out, %20 : f32
linalg.yield %21 : f32
} -> tensor<1x1x4x4xf32>
%14 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x1x4x4xf32>) outs(%3 : tensor<1x1x4x4xbf16>) {
^bb0(%in: f32, %out: bf16):
%18 = arith.truncf %in : f32 to bf16
linalg.yield %18 : bf16
} -> tensor<1x1x4x4xbf16>
%15:2 = hal.tensor.barrier join(%10, %14 : tensor<1x1x32x32xbf16>, tensor<1x1x4x4xbf16>) => %arg4 : !hal.fence
%16 = hal.tensor.export %15#0 : tensor<1x1x32x32xbf16> -> !hal.buffer_view
%17 = hal.tensor.export %15#1 : tensor<1x1x4x4xbf16> -> !hal.buffer_view
util.return %16, %17 : !hal.buffer_view, !hal.buffer_view
}
util.func public @conv_2d_bfloat16_input_weight_backward_1x1x32x32_nchw_1x1x4x4_fchw_nfhw_1x1s_0x0p_1x1d_1g(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub} {
%0 = util.null : !hal.fence
%c-1_i32 = arith.constant -1 : i32
%c0 = arith.constant 0 : index
%device_0 = hal.devices.get %c0 : !hal.device
%fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
%1:2 = util.call @conv_2d_bfloat16_input_weight_backward_1x1x32x32_nchw_1x1x4x4_fchw_nfhw_1x1s_0x0p_1x1d_1g$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> (!hal.buffer_view, !hal.buffer_view)
%status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
util.return %1#0, %1#1 : !hal.buffer_view, !hal.buffer_view
}
}
iree-compile --iree-hal-target-backends=rocm --iree-hip-target=mi300x --iree-opt-level=O3 --iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops --iree-dispatch-creation-enable-split-reduction '--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-convert-conv-filter-to-channels-last)' test.mlir
What happened?
The error shows as
The issue is caused by #23294.
Before the change, we have:
After the change:
When linalg.conv_1d with mixed-precision (bf16 inputs, f32 output) is vectorized, it incorrectly generates vector.outerproduct with bf16 LHS and f32 result, which violates the op's type constraints.
Steps to reproduce your issue
Input IR:
Command:
What component(s) does this issue relate to?
No response
Version information
No response
Additional context
No response