Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation error for iree_linalg_ext.scan #17441

Closed
vivekkhandelwal1 opened this issue May 20, 2024 · 7 comments
Closed

Compilation error for iree_linalg_ext.scan #17441

vivekkhandelwal1 opened this issue May 20, 2024 · 7 comments
Assignees
Labels
bug 🐞 Something isn't working

Comments

@vivekkhandelwal1
Copy link
Member

vivekkhandelwal1 commented May 20, 2024

What happened?

While compiling an IR, getting the below compilation failure. Although, the failure happens to be because of some dynamic dims but the input IR have all the static dims the dynamic dims are introduced by TileAndDistributeToWorkgroupsPass and then it results in failure.

Error:

../torch-mlir-vivek/opt-125M-linalg-elided-ir.mlir:294:13: error: 'iree_linalg_ext.scan' op expected type of operand #1 ('tensor<1x8xi64>') to match type of corresponding result ('tensor<1x?xi64>')
    %17:2 = tm_tensor.scan dimension(1) inclusive(true) ins(%cst : tensor<1x8xi64>) outs(%14, %16 : tensor<1x8xi64>, tensor<1xi64>) {
            ^
../torch-mlir-vivek/opt-125M-linalg-elided-ir.mlir:294:13: note: called from
    %17:2 = tm_tensor.scan dimension(1) inclusive(true) ins(%cst : tensor<1x8xi64>) outs(%14, %16 : tensor<1x8xi64>, tensor<1xi64>) {
            ^
../torch-mlir-vivek/opt-125M-linalg-elided-ir.mlir:24:3: note: called from
  func.func @main_graph(%arg0: tensor<1x8xi64>) -> (tensor<1x8x50272xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>) {
  ^
../torch-mlir-vivek/opt-125M-linalg-elided-ir.mlir:294:13: note: see current operation: 
%12:2 = "iree_linalg_ext.scan"(%7, %9, %11) <{dimension = 1 : i64, inclusive = true, operandSegmentSizes = array<i32: 1, 2>}> ({
^bb0(%arg0: i64, %arg1: i64):
  %13 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
  "iree_linalg_ext.yield"(%13) : (i64) -> ()
}) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0]]>} : (tensor<1x?xi64>, tensor<1x8xi64>, tensor<1xi64>) -> (tensor<1x?xi64>, tensor<1xi64>)
    %17:2 = tm_tensor.scan dimension(1) inclusive(true) ins(%cst : tensor<1x8xi64>) outs(%14, %16 : tensor<1x8xi64>, tensor<1xi64>) {
            ^
../torch-mlir-vivek/opt-125M-linalg-elided-ir.mlir:294:13: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "broadwell", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,-xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,-xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,+rtm,+adx,+avx2,-hreset,-movdiri,-serialize,-vpclmulqdq,-avx512vl,-uintr,-cf,-clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,-clzero,-mwaitx,-lwp,+lzcnt,-sha,-movdir64b,-ppx,-wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,-rdpru,-clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,-rdpid,-fma4,-avx512vbmi,-shstk,-vaes,-waitpkg,-sgx,+fxsr,-avx512dq,-sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
    %17:2 = tm_tensor.scan dimension(1) inclusive(true) ins(%cst : tensor<1x8xi64>) outs(%14, %16 : tensor<1x8xi64>, tensor<1xi64>) {
            ^
../torch-mlir-vivek/opt-125M-linalg-elided-ir.mlir:294:13: note: called from
    %17:2 = tm_tensor.scan dimension(1) inclusive(true) ins(%cst : tensor<1x8xi64>) outs(%14, %16 : tensor<1x8xi64>, tensor<1xi64>) {
            ^
../torch-mlir-vivek/opt-125M-linalg-elided-ir.mlir:24:3: note: called from
  func.func @main_graph(%arg0: tensor<1x8xi64>) -> (tensor<1x8x50272xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>, tensor<1x12x8x64xf32>) {
  ^
../torch-mlir-vivek/opt-125M-linalg-elided-ir.mlir:294:13: note: see current operation: 
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg6: !hal.device):
    %14 = "arith.constant"() <{value = 1 : index}> : () -> index
    "hal.return"(%14, %14, %14) : (index, index, index) -> ()
  }) {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>], layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>, <2, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "jit_eval_2_dispatch_0_scan_1x8xi64"} : () -> ()
  "builtin.module"() ({
    "func.func"() <{function_type = () -> (), sym_name = "jit_eval_2_dispatch_0_scan_1x8xi64"}> ({
      %0 = "arith.constant"() <{value = 8 : index}> : () -> index
      %1 = "arith.constant"() <{value = 0 : i64}> : () -> i64
      %2 = "arith.constant"() <{value = 0 : index}> : () -> index
      %3 = "arith.constant"() <{value = 64 : index}> : () -> index
      %4 = "hal.interface.binding.subspan"(%2) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<readonly:tensor<1x8xi64>>
      %5 = "hal.interface.binding.subspan"(%2) {alignment = 64 : index, binding = 1 : index, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<writeonly:tensor<1x8xi64>>
      %6 = "hal.interface.binding.subspan"(%3) {alignment = 64 : index, binding = 2 : index, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<writeonly:tensor<1xi64>>
      %7 = "flow.dispatch.tensor.load"(%4, %2, %0) <{operandSegmentSizes = array<i32: 1, 0, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (!flow.dispatch.tensor<readonly:tensor<1x8xi64>>, index, index) -> tensor<1x?xi64>
      %8 = "tensor.empty"() : () -> tensor<1x8xi64>
      %9 = "linalg.fill"(%1, %8) <{operandSegmentSizes = array<i32: 1, 1>}> ({
      ^bb0(%arg4: i64, %arg5: i64):
        "linalg.yield"(%arg4) : (i64) -> ()
      }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0], [0, 0], [0, 0], [0, 0]]>} : (i64, tensor<1x8xi64>) -> tensor<1x8xi64>
      %10 = "tensor.empty"() : () -> tensor<1xi64>
      %11 = "linalg.fill"(%1, %10) <{operandSegmentSizes = array<i32: 1, 1>}> ({
      ^bb0(%arg2: i64, %arg3: i64):
        "linalg.yield"(%arg2) : (i64) -> ()
      }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1], [0], [0], [0]]>} : (i64, tensor<1xi64>) -> tensor<1xi64>
      %12:2 = "iree_linalg_ext.scan"(%7, %9, %11) <{dimension = 1 : i64, inclusive = true, operandSegmentSizes = array<i32: 1, 2>}> ({
      ^bb0(%arg0: i64, %arg1: i64):
        %13 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
        "iree_linalg_ext.yield"(%13) : (i64) -> ()
      }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0]]>} : (tensor<1x?xi64>, tensor<1x8xi64>, tensor<1xi64>) -> (tensor<1x?xi64>, tensor<1xi64>)
      "flow.dispatch.tensor.store"(%12#0, %5, %2, %0) <{operandSegmentSizes = array<i32: 1, 1, 0, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (tensor<1x?xi64>, !flow.dispatch.tensor<writeonly:tensor<1x8xi64>>, index, index) -> ()
      "flow.dispatch.tensor.store"(%12#1, %6, %2) <{operandSegmentSizes = array<i32: 1, 1, 0, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (tensor<1xi64>, !flow.dispatch.tensor<writeonly:tensor<1xi64>>, index) -> ()
      "func.return"() : () -> ()
    }) {translation_info = #iree_codegen.translation_info<CPUDefault>} : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "embedded_elf_x86_64", target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "broadwell", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,-xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,-xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,+rtm,+adx,+avx2,-hreset,-movdiri,-serialize,-vpclmulqdq,-avx512vl,-uintr,-cf,-clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,-clzero,-mwaitx,-lwp,+lzcnt,-sha,-movdir64b,-ppx,-wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,-rdpru,-clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,-rdpid,-fma4,-avx512vbmi,-shstk,-vaes,-waitpkg,-sgx,+fxsr,-avx512dq,-sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>} : () -> ()
    %17:2 = tm_tensor.scan dimension(1) inclusive(true) ins(%cst : tensor<1x8xi64>) outs(%14, %16 : tensor<1x8xi64>, tensor<1xi64>) {

Steps to reproduce your issue

Download the IR from https://gist.github.com/vivekkhandelwal1/5b07ce3c403b99dfda8ed64f5174595b

And run:

iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu ir.mlir

What component(s) does this issue relate to?

MLIR, Compiler

Version information

No response

Additional context

No response

@vivekkhandelwal1
Copy link
Member Author

@AmosLewis Please assign this issue to concerned person.

@AmosLewis
Copy link
Contributor

@AmosLewis
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar self-assigned this May 21, 2024
@AmosLewis
Copy link
Contributor

#17500

@AmosLewis
Copy link
Contributor

AmosLewis commented Jun 25, 2024

Can be fixed by @pashu123 's patch https://github.com/pashu123/iree/tree/scan_tile_issue, need to be merged

@AmosLewis
Copy link
Contributor

Can be fixed by @pashu123 's patch https://github.com/pashu123/iree/tree/scan_tile_issue, need to be merged

PR #17761

pashu123 added a commit that referenced this issue Jun 28, 2024
Update the tile sizes to contain i64Attrs instead of arith.constant.
Somehow it's giving dynamic shapes in tensor.extract_slice since the
arith.constant op isn't folded or seen as a constant.

To fix Issue: #17441
@ScottTodd
Copy link
Member

Are the tests in https://github.com/iree-org/iree/blob/main/tests/e2e/linalg_ext_ops/scan.mlir representative enough?

LLITCHEV pushed a commit to LLITCHEV/iree that referenced this issue Jul 30, 2024
)

Update the tile sizes to contain i64Attrs instead of arith.constant.
Somehow it's giving dynamic shapes in tensor.extract_slice since the
arith.constant op isn't folded or seen as a constant.

To fix Issue: iree-org#17441

Signed-off-by: Lubo Litchev <lubol@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants