Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.aten.cumsum to linalg #684

Closed
AmosLewis opened this issue May 15, 2024 · 4 comments
Closed

torch.aten.cumsum to linalg #684

AmosLewis opened this issue May 15, 2024 · 4 comments
Assignees

Comments

@AmosLewis
Copy link
Contributor

AmosLewis commented May 15, 2024

opt-125M model and opt-350M
python ./run.py --torchmlirbuild ../../torch-mlir/build --tolerance 0.001 0.001 --cachedir ./huggingface_cache --ireebuild ../../iree-build -f pytorch -g models --mode onnx --tests pytorch/models/opt-125M

failed to translate executables
opt-125M.default.pytorch.torch.mlir:239:12: error: 'iree_linalg_ext.scan' op expected type of operand #1 ('tensor<1x8xi64>') to match type of corresponding result ('tensor<1x?xi64>')
    %220 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^
opt-125M.default.pytorch.torch.mlir:239:12: note: called from
    %220 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^
opt-125M.default.pytorch.torch.mlir:239:12: note: see current operation: 
%12:2 = "iree_linalg_ext.scan"(%7, %9, %11) <{dimension = 1 : i64, inclusive = true, operandSegmentSizes = array<i32: 1, 2>}> ({
^bb0(%arg0: i64, %arg1: i64):
  %13 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
  "iree_linalg_ext.yield"(%13) : (i64) -> ()
}) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0]]>} : (tensor<1x?xi64>, tensor<1x8xi64>, tensor<1xi64>) -> (tensor<1x?xi64>, tensor<1xi64>)
    %220 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^
opt-125M.default.pytorch.torch.mlir:239:12: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver3", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,-rtm,+adx,+avx2,-hreset,-movdiri,-serialize,+vpclmulqdq,-avx512vl,-uintr,-cf,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,+clzero,-mwaitx,-lwp,+lzcnt,+sha,-movdir64b,-ppx,-wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,+sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
    %220 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^
opt-125M.default.pytorch.torch.mlir:239:12: note: called from
    %220 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^
opt-125M.default.pytorch.torch.mlir:239:12: note: see current operation: 
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg6: !hal.device):
    %14 = "arith.constant"() <{value = 1 : index}> : () -> index
    "hal.return"(%14, %14, %14) : (index, index, index) -> ()
  }) {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>], layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>, <2, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "jit_eval_2_dispatch_0_scan_1x8xi64"} : () -> ()
  "builtin.module"() ({
    "func.func"() <{function_type = () -> (), sym_name = "jit_eval_2_dispatch_0_scan_1x8xi64"}> ({
      %0 = "arith.constant"() <{value = 8 : index}> : () -> index
      %1 = "arith.constant"() <{value = 0 : i64}> : () -> i64
      %2 = "arith.constant"() <{value = 0 : index}> : () -> index
      %3 = "arith.constant"() <{value = 64 : index}> : () -> index
      %4 = "hal.interface.binding.subspan"(%2) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<readonly:tensor<1x8xi64>>
      %5 = "hal.interface.binding.subspan"(%2) {alignment = 64 : index, binding = 1 : index, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<writeonly:tensor<1x8xi64>>
      %6 = "hal.interface.binding.subspan"(%3) {alignment = 64 : index, binding = 2 : index, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<writeonly:tensor<1xi64>>
      %7 = "flow.dispatch.tensor.load"(%4, %2, %0) <{operandSegmentSizes = array<i32: 1, 0, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (!flow.dispatch.tensor<readonly:tensor<1x8xi64>>, index, index) -> tensor<1x?xi64>
      %8 = "tensor.empty"() : () -> tensor<1x8xi64>
      %9 = "linalg.fill"(%1, %8) <{operandSegmentSizes = array<i32: 1, 1>}> ({
      ^bb0(%arg4: i64, %arg5: i64):
        "linalg.yield"(%arg4) : (i64) -> ()
      }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0], [0, 0], [0, 0], [0, 0]]>} : (i64, tensor<1x8xi64>) -> tensor<1x8xi64>
      %10 = "tensor.empty"() : () -> tensor<1xi64>
      %11 = "linalg.fill"(%1, %10) <{operandSegmentSizes = array<i32: 1, 1>}> ({
      ^bb0(%arg2: i64, %arg3: i64):
        "linalg.yield"(%arg2) : (i64) -> ()
      }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1], [0], [0], [0]]>} : (i64, tensor<1xi64>) -> tensor<1xi64>
      %12:2 = "iree_linalg_ext.scan"(%7, %9, %11) <{dimension = 1 : i64, inclusive = true, operandSegmentSizes = array<i32: 1, 2>}> ({
      ^bb0(%arg0: i64, %arg1: i64):
        %13 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
        "iree_linalg_ext.yield"(%13) : (i64) -> ()
      }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0]]>} : (tensor<1x?xi64>, tensor<1x8xi64>, tensor<1xi64>) -> (tensor<1x?xi64>, tensor<1xi64>)
      "flow.dispatch.tensor.store"(%12#0, %5, %2, %0) <{operandSegmentSizes = array<i32: 1, 1, 0, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (tensor<1x?xi64>, !flow.dispatch.tensor<writeonly:tensor<1x8xi64>>, index, index) -> ()
      "flow.dispatch.tensor.store"(%12#1, %6, %2) <{operandSegmentSizes = array<i32: 1, 1, 0, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (tensor<1xi64>, !flow.dispatch.tensor<writeonly:tensor<1xi64>>, index) -> ()
      "func.return"() : () -> ()
    }) {translation_info = #iree_codegen.translation_info<CPUDefault>} : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "embedded_elf_x86_64", target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver3", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,-rtm,+adx,+avx2,-hreset,-movdiri,-serialize,+vpclmulqdq,-avx512vl,-uintr,-cf,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,+clzero,-mwaitx,-lwp,+lzcnt,+sha,-movdir64b,-ppx,-wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,+sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>} : () -> ()
    %220 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^

python ./run.py --torchmlirbuild ../../torch-mlir/build --tolerance 0.001 0.001 --cachedir ./huggingface_cache --ireebuild ../../iree-build -f pytorch -g models --mode onnx --tests pytorch/models/opt-350M

failed to translate executables
opt-350m.default.pytorch.torch.mlir:385:12: error: 'iree_linalg_ext.scan' op expected type of operand #1 ('tensor<1x8xi64>') to match type of corresponding result ('tensor<1x?xi64>')
    %365 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^
opt-350m.default.pytorch.torch.mlir:385:12: note: called from
    %365 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^
opt-350m.default.pytorch.torch.mlir:385:12: note: see current operation: 
%12:2 = "iree_linalg_ext.scan"(%7, %9, %11) <{dimension = 1 : i64, inclusive = true, operandSegmentSizes = array<i32: 1, 2>}> ({
^bb0(%arg0: i64, %arg1: i64):
  %13 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
  "iree_linalg_ext.yield"(%13) : (i64) -> ()
}) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0]]>} : (tensor<1x?xi64>, tensor<1x8xi64>, tensor<1xi64>) -> (tensor<1x?xi64>, tensor<1xi64>)
    %365 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^
opt-350m.default.pytorch.torch.mlir:385:12: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver3", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,-rtm,+adx,+avx2,-hreset,-movdiri,-serialize,+vpclmulqdq,-avx512vl,-uintr,-cf,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,+clzero,-mwaitx,-lwp,+lzcnt,+sha,-movdir64b,-ppx,-wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,+sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>
    %365 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^
opt-350m.default.pytorch.torch.mlir:385:12: note: called from
    %365 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^
opt-350m.default.pytorch.torch.mlir:385:12: note: see current operation: 
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg6: !hal.device):
    %14 = "arith.constant"() <{value = 1 : index}> : () -> index
    "hal.return"(%14, %14, %14) : (index, index, index) -> ()
  }) {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>], layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>, <2, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "jit_eval_2_dispatch_0_scan_1x8xi64"} : () -> ()
  "builtin.module"() ({
    "func.func"() <{function_type = () -> (), sym_name = "jit_eval_2_dispatch_0_scan_1x8xi64"}> ({
      %0 = "arith.constant"() <{value = 8 : index}> : () -> index
      %1 = "arith.constant"() <{value = 0 : i64}> : () -> i64
      %2 = "arith.constant"() <{value = 0 : index}> : () -> index
      %3 = "arith.constant"() <{value = 64 : index}> : () -> index
      %4 = "hal.interface.binding.subspan"(%2) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<readonly:tensor<1x8xi64>>
      %5 = "hal.interface.binding.subspan"(%2) {alignment = 64 : index, binding = 1 : index, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<writeonly:tensor<1x8xi64>>
      %6 = "hal.interface.binding.subspan"(%3) {alignment = 64 : index, binding = 2 : index, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<writeonly:tensor<1xi64>>
      %7 = "flow.dispatch.tensor.load"(%4, %2, %0) <{operandSegmentSizes = array<i32: 1, 0, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (!flow.dispatch.tensor<readonly:tensor<1x8xi64>>, index, index) -> tensor<1x?xi64>
      %8 = "tensor.empty"() : () -> tensor<1x8xi64>
      %9 = "linalg.fill"(%1, %8) <{operandSegmentSizes = array<i32: 1, 1>}> ({
      ^bb0(%arg4: i64, %arg5: i64):
        "linalg.yield"(%arg4) : (i64) -> ()
      }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0], [0, 0], [0, 0], [0, 0]]>} : (i64, tensor<1x8xi64>) -> tensor<1x8xi64>
      %10 = "tensor.empty"() : () -> tensor<1xi64>
      %11 = "linalg.fill"(%1, %10) <{operandSegmentSizes = array<i32: 1, 1>}> ({
      ^bb0(%arg2: i64, %arg3: i64):
        "linalg.yield"(%arg2) : (i64) -> ()
      }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1], [0], [0], [0]]>} : (i64, tensor<1xi64>) -> tensor<1xi64>
      %12:2 = "iree_linalg_ext.scan"(%7, %9, %11) <{dimension = 1 : i64, inclusive = true, operandSegmentSizes = array<i32: 1, 2>}> ({
      ^bb0(%arg0: i64, %arg1: i64):
        %13 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
        "iree_linalg_ext.yield"(%13) : (i64) -> ()
      }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0]]>} : (tensor<1x?xi64>, tensor<1x8xi64>, tensor<1xi64>) -> (tensor<1x?xi64>, tensor<1xi64>)
      "flow.dispatch.tensor.store"(%12#0, %5, %2, %0) <{operandSegmentSizes = array<i32: 1, 1, 0, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (tensor<1x?xi64>, !flow.dispatch.tensor<writeonly:tensor<1x8xi64>>, index, index) -> ()
      "flow.dispatch.tensor.store"(%12#1, %6, %2) <{operandSegmentSizes = array<i32: 1, 1, 0, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (tensor<1xi64>, !flow.dispatch.tensor<writeonly:tensor<1xi64>>, index) -> ()
      "func.return"() : () -> ()
    }) {translation_info = #iree_codegen.translation_info<CPUDefault>} : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "embedded_elf_x86_64", target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver3", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,-rtm,+adx,+avx2,-hreset,-movdiri,-serialize,+vpclmulqdq,-avx512vl,-uintr,-cf,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,+clzero,-mwaitx,-lwp,+lzcnt,+sha,-movdir64b,-ppx,-wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,+sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>} : () -> ()
    %365 = torch.aten.cumsum %4, %int1, %none : !torch.vtensor<[1,8],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,8],si64>
           ^

@vivekkhandelwal1
Copy link
Contributor

Hi @AmosLewis, after spending some time debugging this issue it seems that it's not a Torch-MLIR issue. It's an IREE issue. Because the Linalg lowering through Torch-MLIR doesn't consist of dynamic dim but it's introduced during the IREE compilation.

@vivekkhandelwal1
Copy link
Contributor

IR comparison across 2 different passes of IREE:

// -----// IR Dump After LowerExecutableUsingTransformDialect (iree-codegen-lower-executable-using-transform-dialect) //----- //
module {
  func.func @jit_eval_2_dispatch_0_scan_1x8xi64() attributes {translation_info = #iree_codegen.translation_info<CPUDefault>} {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x8xi64>>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x8xi64>>
    %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c64) : !flow.dispatch.tensor<writeonly:tensor<1xi64>>
    %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 8], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x8xi64>> -> tensor<1x8xi64>
    %4 = tensor.empty() : tensor<1xi64>
    %5 = tensor.empty() : tensor<1x8xi64>
    %6 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0], [0, 0], [0, 0], [0, 0]]>} ins(%c0_i64 : i64) outs(%5 : tensor<1x8xi64>) -> tensor<1x8xi64>
    %7 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1], [0], [0], [0]]>} ins(%c0_i64 : i64) outs(%4 : tensor<1xi64>) -> tensor<1xi64>
    %8:2 = iree_linalg_ext.scan {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0]]>} dimension(1) inclusive(true) ins(%3 : tensor<1x8xi64>) outs(%6, %7 : tensor<1x8xi64>, tensor<1xi64>) {
    ^bb0(%arg0: i64, %arg1: i64):
      %9 = arith.addi %arg0, %arg1 : i64
      iree_linalg_ext.yield %9 : i64
    } -> tensor<1x8xi64>, tensor<1xi64>
    flow.dispatch.tensor.store %8#0, %1, offsets = [0, 0], sizes = [1, 8], strides = [1, 1] : tensor<1x8xi64> -> !flow.dispatch.tensor<writeonly:tensor<1x8xi64>>
    flow.dispatch.tensor.store %8#1, %2, offsets = [0], sizes = [1], strides = [1] : tensor<1xi64> -> !flow.dispatch.tensor<writeonly:tensor<1xi64>>
    return
  }
}
// -----// IR Dump After TileAndDistributeToWorkgroups Failed (iree-codegen-tile-and-distribute-to-workgroups) //----- //
"func.func"() <{function_type = () -> (), sym_name = "jit_eval_2_dispatch_0_scan_1x8xi64"}> ({
  %0 = "arith.constant"() <{value = 8 : index}> : () -> index
  %1 = "arith.constant"() <{value = 0 : i64}> : () -> i64
  %2 = "arith.constant"() <{value = 0 : index}> : () -> index
  %3 = "arith.constant"() <{value = 64 : index}> : () -> index
  %4 = "hal.interface.binding.subspan"(%2) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<readonly:tensor<1x8xi64>>
  %5 = "hal.interface.binding.subspan"(%2) {alignment = 64 : index, binding = 1 : index, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<writeonly:tensor<1x8xi64>>
  %6 = "hal.interface.binding.subspan"(%3) {alignment = 64 : index, binding = 2 : index, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 0>, set = 0 : index} : (index) -> !flow.dispatch.tensor<writeonly:tensor<1xi64>>
  %7 = "flow.dispatch.tensor.load"(%4, %2, %0) <{operandSegmentSizes = array<i32: 1, 0, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (!flow.dispatch.tensor<readonly:tensor<1x8xi64>>, index, index) -> tensor<1x?xi64>
  %8 = "tensor.empty"() : () -> tensor<1x8xi64>
  %9 = "linalg.fill"(%1, %8) <{operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg4: i64, %arg5: i64):
    "linalg.yield"(%arg4) : (i64) -> ()
  }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0], [0, 0], [0, 0], [0, 0]]>} : (i64, tensor<1x8xi64>) -> tensor<1x8xi64>
  %10 = "tensor.empty"() : () -> tensor<1xi64>
  %11 = "linalg.fill"(%1, %10) <{operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg2: i64, %arg3: i64):
    "linalg.yield"(%arg2) : (i64) -> ()
  }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1], [0], [0], [0]]>} : (i64, tensor<1xi64>) -> tensor<1xi64>
  %12:2 = "iree_linalg_ext.scan"(%7, %9, %11) <{dimension = 1 : i64, inclusive = true, operandSegmentSizes = array<i32: 1, 2>}> ({
  ^bb0(%arg0: i64, %arg1: i64):
    %13 = "arith.addi"(%arg0, %arg1) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
    "iree_linalg_ext.yield"(%13) : (i64) -> ()
  }) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0]]>} : (tensor<1x?xi64>, tensor<1x8xi64>, tensor<1xi64>) -> (tensor<1x?xi64>, tensor<1xi64>)
  "flow.dispatch.tensor.store"(%12#0, %5, %2, %0) <{operandSegmentSizes = array<i32: 1, 1, 0, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: 1, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (tensor<1x?xi64>, !flow.dispatch.tensor<writeonly:tensor<1x8xi64>>, index, index) -> ()
  "flow.dispatch.tensor.store"(%12#1, %6, %2) <{operandSegmentSizes = array<i32: 1, 1, 0, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (tensor<1xi64>, !flow.dispatch.tensor<writeonly:tensor<1xi64>>, index) -> ()
  "func.return"() : () -> ()
}) {translation_info = #iree_codegen.translation_info<CPUDefault>} : () -> ()

@AmosLewis
Copy link
Contributor Author

Hi @AmosLewis, after spending some time debugging this issue it seems that it's not a Torch-MLIR issue. It's an IREE issue. Because the Linalg lowering through Torch-MLIR doesn't consist of dynamic dim but it's introduced during the IREE compilation.

Could you file an iree issue and send it in iree pytorch channel?

@vivekkhandelwal1
Copy link
Contributor

vivekkhandelwal1 commented May 20, 2024

Filed an issue here: iree-org/iree#17441

CC: @AmosLewis

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants