Skip to content

Backward convolution compilation failed due to iree-org/iree@e1f38110 #23382

@yzhang93

Description

@yzhang93

What happened?

The error shows as

error: 'vector.outerproduct' op failed to verify that lhs operand and result have same element type
    %13 = linalg.generic {indexing_maps = [#map4, #map5, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%1, %0 : tensor<1x1x32x32xbf16>, tensor<1x1x29x29xbf16>) outs(%12 : tensor<1x1x4x4xf32>) {

The issue is caused by #23294.

Before the change, we have:

%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%extracted_slice_1, %extracted_slice_0 : tensor<1x29xbf16>, tensor<1x29xbf16>) outs(%arg7 : tensor<1x1xf32>) attrs =  {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 1, 29], thread = [1, 1, 0, 0], workgroup = [8, 8, 0, 0]}>} {
  ^bb0(%in: bf16, %in_2: bf16, %out: f32):
    %17 = arith.extf %in : bf16 to f32
    %18 = arith.extf %in_2 : bf16 to f32
    %19 = arith.mulf %17, %18 : f32
    %20 = arith.addf %out, %19 : f32
    linalg.yield %20 : f32
  } -> tensor<1x1xf32>

After the change:

%16 = linalg.conv_1d ins(%extracted_slice_2, %extracted_slice_3 : tensor<29xbf16>, tensor<29xbf16>) outs(%extracted_slice_4 : tensor<1xf32>) -> tensor<1xf32>

When linalg.conv_1d with mixed-precision (bf16 inputs, f32 output) is vectorized, it incorrectly generates vector.outerproduct with bf16 LHS and f32 result, which violates the op's type constraints.

Steps to reproduce your issue

Input IR:

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2 + d5, d3 + d6)>
#map5 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d0, d5, d6)>
module @module {
  util.func public @conv_2d_bfloat16_input_weight_backward_1x1x32x32_nchw_1x1x4x4_fchw_nfhw_1x1s_0x0p_1x1d_1g$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view) attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
    %cst = arith.constant 0.000000e+00 : bf16
    %c0 = arith.constant 0 : index
    %c3 = arith.constant 3 : index
    %cst_0 = arith.constant 0.000000e+00 : f32
    %0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<1x1x29x29xbf16>
    %1 = hal.tensor.import wait(%arg3) => %arg1 : !hal.buffer_view -> tensor<1x1x32x32xbf16>
    %2 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<1x1x4x4xbf16>
    %3 = tensor.empty() : tensor<1x1x4x4xbf16>
    %4 = linalg.fill ins(%cst : bf16) outs(%3 : tensor<1x1x4x4xbf16>) -> tensor<1x1x4x4xbf16>
    %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<1x1x4x4xbf16>) outs(%4 : tensor<1x1x4x4xbf16>) {
    ^bb0(%in: bf16, %out: bf16):
      %18 = linalg.index 2 : index
      %19 = linalg.index 3 : index
      %20 = arith.subi %c3, %18 : index
      %21 = arith.subi %c3, %19 : index
      %extracted = tensor.extract %2[%c0, %c0, %20, %21] : tensor<1x1x4x4xbf16>
      linalg.yield %extracted : bf16
    } -> tensor<1x1x4x4xbf16>
    %padded = tensor.pad %0 low[0, 0, 3, 3] high[0, 0, 3, 3] {
    ^bb0(%arg5: index, %arg6: index, %arg7: index, %arg8: index):
      tensor.yield %cst : bf16
    } : tensor<1x1x29x29xbf16> to tensor<1x1x35x35xbf16>
    %6 = tensor.empty() : tensor<1x1x32x32xf32>
    %7 = linalg.fill ins(%cst_0 : f32) outs(%6 : tensor<1x1x32x32xf32>) -> tensor<1x1x32x32xf32>
    %8 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%padded, %5 : tensor<1x1x35x35xbf16>, tensor<1x1x4x4xbf16>) outs(%7 : tensor<1x1x32x32xf32>) {
    ^bb0(%in: bf16, %in_1: bf16, %out: f32):
      %18 = arith.extf %in : bf16 to f32
      %19 = arith.extf %in_1 : bf16 to f32
      %20 = arith.mulf %18, %19 : f32
      %21 = arith.addf %out, %20 : f32
      linalg.yield %21 : f32
    } -> tensor<1x1x32x32xf32>
    %9 = tensor.empty() : tensor<1x1x32x32xbf16>
    %10 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%8 : tensor<1x1x32x32xf32>) outs(%9 : tensor<1x1x32x32xbf16>) {
    ^bb0(%in: f32, %out: bf16):
      %18 = arith.truncf %in : f32 to bf16
      linalg.yield %18 : bf16
    } -> tensor<1x1x32x32xbf16>
    %11 = tensor.empty() : tensor<1x1x4x4xf32>
    %12 = linalg.fill ins(%cst_0 : f32) outs(%11 : tensor<1x1x4x4xf32>) -> tensor<1x1x4x4xf32>
    %13 = linalg.generic {indexing_maps = [#map4, #map5, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%1, %0 : tensor<1x1x32x32xbf16>, tensor<1x1x29x29xbf16>) outs(%12 : tensor<1x1x4x4xf32>) {
    ^bb0(%in: bf16, %in_1: bf16, %out: f32):
      %18 = arith.extf %in : bf16 to f32
      %19 = arith.extf %in_1 : bf16 to f32
      %20 = arith.mulf %18, %19 : f32
      %21 = arith.addf %out, %20 : f32
      linalg.yield %21 : f32
    } -> tensor<1x1x4x4xf32>
    %14 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x1x4x4xf32>) outs(%3 : tensor<1x1x4x4xbf16>) {
    ^bb0(%in: f32, %out: bf16):
      %18 = arith.truncf %in : f32 to bf16
      linalg.yield %18 : bf16
    } -> tensor<1x1x4x4xbf16>
    %15:2 = hal.tensor.barrier join(%10, %14 : tensor<1x1x32x32xbf16>, tensor<1x1x4x4xbf16>) => %arg4 : !hal.fence
    %16 = hal.tensor.export %15#0 : tensor<1x1x32x32xbf16> -> !hal.buffer_view
    %17 = hal.tensor.export %15#1 : tensor<1x1x4x4xbf16> -> !hal.buffer_view
    util.return %16, %17 : !hal.buffer_view, !hal.buffer_view
  }
  util.func public @conv_2d_bfloat16_input_weight_backward_1x1x32x32_nchw_1x1x4x4_fchw_nfhw_1x1s_0x0p_1x1d_1g(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub} {
    %0 = util.null : !hal.fence
    %c-1_i32 = arith.constant -1 : i32
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    %1:2 = util.call @conv_2d_bfloat16_input_weight_backward_1x1x32x32_nchw_1x1x4x4_fchw_nfhw_1x1s_0x0p_1x1d_1g$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> (!hal.buffer_view, !hal.buffer_view)
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
    util.return %1#0, %1#1 : !hal.buffer_view, !hal.buffer_view
  }
}

Command:

iree-compile --iree-hal-target-backends=rocm --iree-hip-target=mi300x --iree-opt-level=O3 --iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops --iree-dispatch-creation-enable-split-reduction '--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-convert-conv-filter-to-channels-last)' test.mlir

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

Metadata

Metadata

Labels

bug 🐞Something isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions