Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation error for SHARK-TestSuite (onnx/models/RAFT_vaiq_int8) #17455

Open
IanWood1 opened this issue May 21, 2024 · 1 comment
Open

Compilation error for SHARK-TestSuite (onnx/models/RAFT_vaiq_int8) #17455

IanWood1 opened this issue May 21, 2024 · 1 comment
Assignees
Labels
bug 🐞 Something isn't working integrations/pytorch PyTorch integration work

Comments

@IanWood1
Copy link
Contributor

IanWood1 commented May 21, 2024

EDIT (also added to reproduction steps):
The problem occurs during LLVMCPUVectorTransferLowering during canonicalization and can be reproduced with https://gist.github.com/IanWood1/59153bb58858c69b0569a6a6f39e3289 and running it with:

What happened?

Compilation fails due to excessive stack allocations

SHARK-TestSuite/e2eshark/test-run/onnx/models/RAFT_vaiq_int8/RAFT_vaiq_int8.default.onnx.linalg.mlir:6453:13: error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 401408 bytes
    %1024 = linalg.generic {indexing_maps = [#map20, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_440 : tensor<1024x7x7x1xf32>) outs(%1020 : tensor<1024x7x7x1xi8>) {
            ^

A bunch of consecutive vector.extract and vector.store ops are generated

Verbose output
    "vector.store"(%151559, %1030, %0, %1015, %1019, %1025) <{nontemporal = false}> : (vector<2xf32>, memref<1024x7x7x2xf32>, index, index, index, index) -> ()
    %151560 = "vector.extract"(%101387) <{static_position = array<i64: 1023, 6, 3>}> : (vector<1024x7x7x2xf32>) -> vector<2xf32>
    "vector.store"(%151560, %1030, %0, %1015, %1018, %1025) <{nontemporal = false}> : (vector<2xf32>, memref<1024x7x7x2xf32>, index, index, index, index) -> ()
    %151561 = "vector.extract"(%101387) <{static_position = array<i64: 1023, 6, 4>}> : (vector<1024x7x7x2xf32>) -> vector<2xf32>
    "vector.store"(%151561, %1030, %0, %1015, %1017, %1025) <{nontemporal = false}> : (vector<2xf32>, memref<1024x7x7x2xf32>, index, index, index, index) -> ()
    %151562 = "vector.extract"(%101387) <{static_position = array<i64: 1023, 6, 5>}> : (vector<1024x7x7x2xf32>) -> vector<2xf32>
    "vector.store"(%151562, %1030, %0, %1015, %1016, %1025) <{nontemporal = false}> : (vector<2xf32>, memref<1024x7x7x2xf32>, index, index, index, index) -> ()
    %151563 = "vector.extract"(%101387) <{static_position = array<i64: 1023, 6, 6>}> : (vector<1024x7x7x2xf32>) -> vector<2xf32>
    "vector.store"(%151563, %1030, %0, %1015, %1015, %1025) <{nontemporal = false}> : (vector<2xf32>, memref<1024x7x7x2xf32>, index, index, index, index) -> ()
    %151564 = "memref.subview"(%1030) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 1>, static_sizes = array<i64: 1024, 7, 7, 1>, static_strides = array<i64: 1, 1, 1, 1>}> : (memref<1024x7x7x2xf32>) -> memref<1024
x7x7xf32, strided<[98, 14, 2], offset: 1>>
    %151565 = "hal.interface.workgroup.id"() {dimension = 0 : index} : () -> index
    %151566 = "affine.apply"(%151565) <{map = affine_map<()[s0] -> (s0 * 128)>}> : (index) -> index
    %151567 = "memref.subview"(%1032, %151566) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, 0>, static_sizes = array<i64: 128, 7, 7>, static_strides = array<i64: 1, 1, 1>}> : (memref<1024x7x7xi8>,
index) -> memref<128x7x7xi8, strided<[49, 7, 1], offset: ?>>
    "scf.for"(%1025, %1026, %1027) ({
    ^bb0(%arg0: index):
      "scf.for"(%1025, %1028, %1027) ({
      ^bb0(%arg1: index):
        %151568 = "arith.addi"(%arg0, %151566) <{overflowFlags = #arith.overflow<none>}> : (index, index) -> index
        %151569 = "scf.for"(%1025, %1028, %1027, %1023) ({
        ^bb0(%arg2: index, %arg3: vector<7xf32>):
          %151578 = "memref.load"(%151564, %151568, %arg1, %arg2) <{nontemporal = false}> : (memref<1024x7x7xf32, strided<[98, 14, 2], offset: 1>>, index, index, index) -> f32
          %151579 = "vector.insertelement"(%151578, %arg3, %arg2) : (f32, vector<7xf32>, index) -> vector<7xf32>
          "scf.yield"(%151579) : (vector<7xf32>) -> ()
        }) : (index, index, index, vector<7xf32>) -> vector<7xf32>
        %151570 = "arith.divf"(%151569, %1024) <{fastmath = #arith.fastmath<none>}> : (vector<7xf32>, vector<7xf32>) -> vector<7xf32>
        %151571 = "math.round"(%151570) <{fastmath = #arith.fastmath<none>}> : (vector<7xf32>) -> vector<7xf32>
        %151572 = "arith.addf"(%151571, %1023) <{fastmath = #arith.fastmath<none>}> : (vector<7xf32>, vector<7xf32>) -> vector<7xf32>
        %151573 = "arith.cmpf"(%151572, %1022) <{fastmath = #arith.fastmath<none>, predicate = 11 : i64}> : (vector<7xf32>, vector<7xf32>) -> vector<7xi1>
        %151574 = "arith.cmpf"(%151572, %1021) <{fastmath = #arith.fastmath<none>, predicate = 9 : i64}> : (vector<7xf32>, vector<7xf32>) -> vector<7xi1>
        %151575 = "arith.select"(%151573, %1022, %151572) : (vector<7xi1>, vector<7xf32>, vector<7xf32>) -> vector<7xf32>
        %151576 = "arith.select"(%151574, %1021, %151575) : (vector<7xi1>, vector<7xf32>, vector<7xf32>) -> vector<7xf32>
        %151577 = "arith.fptosi"(%151576) : (vector<7xf32>) -> vector<7xi8>
        "vector.store"(%151577, %151567, %arg0, %arg1, %1025) <{nontemporal = false}> : (vector<7xi8>, memref<128x7x7xi8, strided<[49, 7, 1], offset: ?>>, index, index, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "scf.yield"() : () -> ()
    }) : (index, index, index) -> ()
    "func.return"() : () -> ()
  }) {translation_info = #iree_codegen.translation_info<CPUDoubleTilingExpert>} : () -> ()
}) : () -> ()
"hal.executable.variant_end"() : () -> ()
}) {sym_name = "embedded_elf_x86_64", target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver3", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+x
save,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,-rtm,+adx,+avx2,-hreset,-movd
iri,-serialize,+vpclmulqdq,-avx512vl,-uintr,-cf,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,+clzero,-mwai
tx,-lwp,+lzcnt,+sha,-movdir64b,-ppx,-wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shs
tk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,+sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>} : () -> ()
  %1081 = linalg.generic {indexing_maps = [#map20, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_458 : tensor<1024x7x7x1xf32>) outs(%1020 : tensor<1024x7x7x1xi8>) {
          ^

Steps to reproduce your issue

Noting that this issue also occurs with some other models. In the SHARK-TestSuite, the onnx/models/RAFT_vaiq_int8 also encounters a similar issue. To reproduce, set up the test suite, and run

python run.py --cachedir=/path/to/.cache/ -t onnx/models/RAFT_vaiq_int8/ -m onnx -c /path/to/torch-mlir/build/ -i /path/to/iree-build/ --torchtolinalg

with an up-to-date torch-mlir and iree build.

Originally posted by @zjgarvey in #17226 (comment)

Minimal repro

#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map20 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
func.func public @jit_eval_174(%1104 :  tensor<1024x7x7x2xi8>) -> tensor<1024x7x7x2xi8> {
    %cst_9 = arith.constant 2.00 : f32
    %cst_4 = arith.constant 4.00 : f32
    %cst_0 = arith.constant 0.00 : f32
    %cst_1 = arith.constant 1.00 : f32

    %1015 = tensor.empty() : tensor<1024x7x7x2xf32>
    %1105 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1104 : tensor<1024x7x7x2xi8>) outs(%1015 : tensor<1024x7x7x2xf32>) {
    ^bb0(%in: i8, %out: f32):
      %3555 = arith.extsi %in : i8 to i32
      %3556 = arith.sitofp %3555 : i32 to f32
      %3557 = arith.mulf %3556, %cst_9 : f32
      linalg.yield %3557 : f32
    } -> tensor<1024x7x7x2xf32>

    %cst_218 = arith.constant dense<1.000000e+00> : tensor<f32>

    %1020 = tensor.empty() : tensor<1024x7x7x1xi8>
    %1022 = tensor.empty() : tensor<1024x7x7x1xf32>
    %extracted_slice_466 = tensor.extract_slice %1105[0, 0, 0, 0] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7x1xf32>
    %extracted_slice_467 = tensor.extract_slice %1105[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7x1xf32>

     %1106 = linalg.generic {indexing_maps = [#map20, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_466 : tensor<1024x7x7x1xf32>) outs(%1020 : tensor<1024x7x7x1xi8>) {
    ^bb0(%in: f32, %out: i8):
      %3555 = arith.divf %in, %cst_9 : f32
      %3556 = math.round %3555 : f32
      %3557 = arith.addf %3556, %cst_4 : f32
      %3558 = arith.cmpf ult, %3557, %cst_0 : f32
      %3559 = arith.cmpf ugt, %3557, %cst_1 : f32
      %3560 = arith.select %3558, %cst_0, %3557 : f32
      %3561 = arith.select %3559, %cst_1, %3560 : f32
      %3562 = arith.fptosi %3561 : f32 to i8
      linalg.yield %3562 : i8
    } -> tensor<1024x7x7x1xi8>

    %1108 = linalg.generic {indexing_maps = [#map20, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_467 : tensor<1024x7x7x1xf32>) outs(%1020 : tensor<1024x7x7x1xi8>) {
    ^bb0(%in: f32, %out: i8):
      %3555 = arith.divf %in, %cst_9 : f32
      %3556 = math.round %3555 : f32
      %3557 = arith.addf %3556, %cst_4 : f32
      %3558 = arith.cmpf ult, %3557, %cst_0 : f32
      %3559 = arith.cmpf ugt, %3557, %cst_1 : f32
      %3560 = arith.select %3558, %cst_0, %3557 : f32
      %3561 = arith.select %3559, %cst_1, %3560 : f32
      %3562 = arith.fptosi %3561 : f32 to i8
      linalg.yield %3562 : i8
    } -> tensor<1024x7x7x1xi8>

 
    %concat_468 = tensor.concat dim(3) %1108, %1106 : (tensor<1024x7x7x1xi8>, tensor<1024x7x7x1xi8>) -> tensor<1024x7x7x2xi8>
    return %concat_468: tensor<1024x7x7x2xi8>
} 

run with:

iree-compile --iree-hal-target-backends=llvm-cpu --iree-input-demote-i64-to-i3 path/to/mlir.mlir

Additional context

#17226
#17341

@IanWood1
Copy link
Contributor Author

IanWood1 commented Jun 3, 2024

Here is a second (more concise) example and corresponding iree-compile logs
MLIR
logs

also, here is the logs from the original repro https://gist.github.com/IanWood1/cc3e732c49796b4ce9e0300824b57b3e

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working integrations/pytorch PyTorch integration work
Projects
None yet
Development

No branches or pull requests

2 participants