Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy regression on INT8 Models #8712

Closed
mariecwhite opened this issue Mar 31, 2022 · 23 comments
Closed

Accuracy regression on INT8 Models #8712

mariecwhite opened this issue Mar 31, 2022 · 23 comments
Assignees
Labels
bug 🐞 Something isn't working

Comments

@mariecwhite
Copy link
Contributor

Seeing numerical differences of 5-17 in several INT8 models that were previously passing in iree-samples. Ran on x86 with dylib.

lit -a tflitehub/mobilenet_v2_int8_test.py
lit -a tflitehub/mobilenet_v35_int8_test.py
lit -a tflitehub/resnet_50_int8_test.py
@mariecwhite mariecwhite added the bug 🐞 Something isn't working label Mar 31, 2022
@mariecwhite
Copy link
Contributor Author

MLIR files attached.

mobilenet_v2_int8_test.zip
mobilenet_v35_int8_test.zip

@hanhanW
Copy link
Contributor

hanhanW commented Mar 31, 2022

Do you know when does it start? Also, can you attach the log and flagfile?

@mariecwhite
Copy link
Contributor Author

It started failing at this release: https://github.com/google/iree/releases/tag/candidate-20220325.87

@hanhanW
Copy link
Contributor

hanhanW commented Mar 31, 2022

Thank you! I'll try to see if there are suspicious commits between candidate-20220324.86 and candidate-20220325.87.

@mariecwhite
Copy link
Contributor Author

mariecwhite commented Mar 31, 2022

Actually, candidate-20220324.86 didn't have any python wheels that I could pip install. candidate-20220323.85 passes. So might want to expand the search to between 85 and 87.

@hanhanW
Copy link
Contributor

hanhanW commented Mar 31, 2022

Okay, good to know this, thank you!

@hanhanW
Copy link
Contributor

hanhanW commented Apr 2, 2022

This file includes MLIR and inputs.

repro.zip

@hanhanW
Copy link
Contributor

hanhanW commented Apr 2, 2022

note: the second number of the output should be -0.8531494, not -5.81693.

I'm going to bisect commits...

@hanhanW
Copy link
Contributor

hanhanW commented Apr 2, 2022

Checking out to f10b758, I'm able to get -0.77559

@hanhanW
Copy link
Contributor

hanhanW commented Apr 2, 2022

It starts failing at this integration... 9240f1c

I'll try to narrow it down to dispatch levels next week

@hanhanW
Copy link
Contributor

hanhanW commented Apr 4, 2022

Got a smaller repro. It is a matmul+generic case.

To repro:

$ iree-translate -iree-mlir-to-vm-bytecode-module --iree-hal-target-backends=dylib-llvm-aot ~/repro.mlir -o /tmp/a.vmfb
$ iree-run-module --module_file=/tmp/a.vmfb --flagfile=$HOME/flagfile

Expected results:

1x1001xf32=[0 -0.77559 -1.08583 -0.698031 -0.85315 ...

Actual output:

1x1001xf32=[0 -5.81693 -5.04134 -4.49842 -0.232677 ...

@hanhanW
Copy link
Contributor

hanhanW commented Apr 4, 2022

Looking into IRs, it looks like there is an overflow in IR. 550829555712 x 1092190289 = 601610691642830880768 which is ~32x of 2^64.

      %10 = arith.select %9, %c550829555712_i64, %c548682072064_i64 : i64
      %11 = arith.extsi %8 : i32 to i64
      %12 = arith.muli %11, %c1092190289_i64 : i64

@rsuderman I highly suspect that this is come from some apply_scale ops. I remember that you have some changes recently. Would you like to verify if this issue will get addressed with your recent TOSA changes?

#map0 = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> ()>
module {
  func @main(%arg0: tensor<?x1280xi8>, %arg1: tensor<1280x1001xi8>, %arg2: tensor<1001xi32>, %arg3: tensor<1001xi32>) -> tensor<?x1001xf32> {
    %cst = arith.constant 0.0775590464 : f32
    %cst_0 = arith.constant -5.300000e+01 : f32
    %c127_i32 = arith.constant 127 : i32
    %c-128_i32 = arith.constant -128 : i32
    %c-53_i32 = arith.constant -53 : i32
    %c40_i64 = arith.constant 40 : i64
    %c1092190289_i64 = arith.constant 1092190289 : i64
    %c548682072064_i64 = arith.constant 548682072064 : i64
    %c550829555712_i64 = arith.constant 550829555712 : i64
    %c0_i32 = arith.constant 0 : i32
    %c0 = arith.constant 0 : index
    %0 = tensor.dim %arg0, %c0 : tensor<?x1280xi8>
    %1 = linalg.init_tensor [%0, 1001] : tensor<?x1001xf32>
    %2 = linalg.init_tensor [%0, 1001] : tensor<?x1001xi32>
    %3 = linalg.fill ins(%c0_i32 : i32) outs(%2 : tensor<?x1001xi32>) -> tensor<?x1001xi32>
    %4 = linalg.matmul ins(%arg0, %arg1 : tensor<?x1280xi8>, tensor<1280x1001xi8>) outs(%3 : tensor<?x1001xi32>) -> tensor<?x1001xi32>
    %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg2, %4, %arg3, %c-128_i32 : tensor<1001xi32>, tensor<?x1001xi32>, tensor<1001xi32>, i32) outs(%1 : tensor<?x1001xf32>) {
    ^bb0(%arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: f32):
      %6 = arith.muli %arg6, %arg7 : i32
      %7 = arith.subi %arg5, %6 : i32
      %8 = arith.addi %arg4, %7 : i32
      %9 = arith.cmpi sge, %8, %c0_i32 : i32
      %10 = arith.select %9, %c550829555712_i64, %c548682072064_i64 : i64
      %11 = arith.extsi %8 : i32 to i64
      %12 = arith.muli %11, %c1092190289_i64 : i64
      %13 = arith.addi %12, %10 : i64
      %14 = arith.shrsi %13, %c40_i64 : i64
      %15 = arith.trunci %14 : i64 to i32
      %16 = arith.addi %15, %c-53_i32 : i32
      %17 = arith.cmpi slt, %16, %c-128_i32 : i32
      %18 = arith.select %17, %c-128_i32, %16 : i32
      %19 = arith.cmpi slt, %c127_i32, %16 : i32
      %20 = arith.select %19, %c127_i32, %18 : i32
      %21 = arith.trunci %20 : i32 to i8
      %22 = arith.sitofp %21 : i8 to f32
      %23 = arith.subf %22, %cst_0 : f32
      %24 = arith.mulf %23, %cst : f32
      linalg.yield %24 : f32
    } -> tensor<?x1001xf32>
    return %5 : tensor<?x1001xf32>
  }
}

@hanhanW
Copy link
Contributor

hanhanW commented Apr 4, 2022

Read wrong IRs, I'll take the overflow issue back and read it again...

@hanhanW
Copy link
Contributor

hanhanW commented Apr 4, 2022

550829555712 is greater than 2^32. I feel that the i64 constants would be an issue in IREE.

Reading from #8731 (comment), I think this issue can be addressed in the next integration?

We may want to add a verify pass to detect if unsupported types are involved though. My guess is that the i64 constant is demoted to i32 with a weird value in this case.

@benvanik
Copy link
Collaborator

benvanik commented Apr 4, 2022

Yep - if you want to run on vulkan (or small systems) you'll need to ensure there are no i64 values that actually need 64-bits.

@hanhanW
Copy link
Contributor

hanhanW commented Apr 7, 2022

Still have correctness issues after bumping the MLIR. The final dispatch changes. There is a apply_rescale op in the dispatch. Here is the new repro: repro.zip

#map0 = affine_map<(d0, d1) -> (d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1) -> ()>
module {
  func @main(%arg0: tensor<?x1280xi8>, %arg1: tensor<1280x1001xi8>, %arg2: tensor<1001xi32>, %arg3: tensor<1001xi32>) -> tensor<?x1001xf32> {
    %c0_i32 = arith.constant 0 : i32
    %c1092190289_i32 = arith.constant 1092190289 : i32
    %c40_i8 = arith.constant 40 : i8
    %c-53_i32 = arith.constant -53 : i32
    %c-128_i32 = arith.constant -128 : i32
    %c127_i32 = arith.constant 127 : i32
    %cst = arith.constant -5.300000e+01 : f32
    %cst_0 = arith.constant 0.0775590464 : f32
    %c0 = arith.constant 0 : index
    %0 = tensor.dim %arg0, %c0 : tensor<?x1280xi8>
    %1 = linalg.init_tensor [%0, 1001] : tensor<?x1001xf32>
    %2 = linalg.init_tensor [%0, 1001] : tensor<?x1001xi32>
    %3 = linalg.fill ins(%c0_i32 : i32) outs(%2 : tensor<?x1001xi32>) -> tensor<?x1001xi32>
    %4 = linalg.matmul ins(%arg0, %arg1 : tensor<?x1280xi8>, tensor<1280x1001xi8>) outs(%3 : tensor<?x1001xi32>) -> tensor<?x1001xi32>
    %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map2, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg2, %4, %arg3, %c-128_i32 : tensor<1001xi32>, tensor<?x1001xi32>, tensor<1001xi32>, i32) outs(%1 : tensor<?x1001xf32>) {
    ^bb0(%arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: f32):
      %6 = arith.muli %arg6, %arg7 : i32
      %7 = arith.subi %arg5, %6 : i32
      %8 = arith.addi %arg4, %7 : i32
      %9 = "tosa.apply_scale"(%8, %c1092190289_i32, %c40_i8) {double_round = true} : (i32, i32, i8) -> i32
      %10 = arith.addi %9, %c-53_i32 : i32
      %11 = arith.cmpi slt, %10, %c-128_i32 : i32
      %12 = arith.select %11, %c-128_i32, %10 : i32
      %13 = arith.cmpi slt, %c127_i32, %10 : i32
      %14 = arith.select %13, %c127_i32, %12 : i32
      %15 = arith.trunci %14 : i32 to i8
      %16 = arith.sitofp %15 : i8 to f32
     %17 = arith.subf %16, %cst : f32
      %18 = arith.mulf %17, %cst_0 : f32
      linalg.yield %18 : f32
    } -> tensor<?x1001xf32>
    return %5 : tensor<?x1001xf32>
  }
}

It looks like there are correctness issue in tosa.apply_scale. If I remove the op (and use %8 in %10 op), the results are identical between 5e5cbd4 and main.

The results are different if there are apply_scale ops. @rsuderman could you help take a look at what's happening?

@hanhanW hanhanW assigned rsuderman and unassigned rsuderman Apr 7, 2022
@hanhanW
Copy link
Contributor

hanhanW commented Apr 13, 2022

I took a quick look, and found that there are stack buffer. The stack buffer is needed because it is quantized matmul and it's not vectorized.

I have a local "always-vectorize" patch which fixes the issue. This makes me suspect the issues happened in bufferization, because there won't be stack buffer in "always-vectorize". However, the IR looks good to me. I don't know the root cause yet.

@hanhanW
Copy link
Contributor

hanhanW commented Apr 13, 2022

I know the issue now. It looks like an upstream bug. The CSE pass (after tileAndFuse) hoists linalg.fill op outside the loop. This changes the initialization behavior.

The left hand side is correct, and the right hand side is incorrect:

image

which results in

image

The initialization does not happen right before matmul.

@gysit @matthias-springer any ideas how to fix it?

(I'll send out "always-vectorize" patch which would fix correctness issue, but the upstream bug still needs to be fixed.)

@hanhanW
Copy link
Contributor

hanhanW commented Apr 13, 2022

@gysit
Copy link
Contributor

gysit commented Apr 13, 2022

That is an interesting case!

Let me share my initial understanding. It looks like there is more stuff running than just CSE. I would guess the change is a combination of canonicalization and hoisting. Additionally, it seems that the change is actually correct but it triggers a problem in bufferization afterwards.

What canonicalization does is folding:

%19 = linalg.init_tensor [%16, 13] : tensor<?x13xi32>
// ...
%34 = tensor.extract_slice %19[%arg2, %arg4] [%33, %c1] [1, 1] : tensor<?x13xi32> to tensor<?x?xi32>
%35 = linalg.fill {__internal_linalg_transform__ = "1"} ins(%c0_i32 : i32) outs(%34 : tensor<?x?xi32>) -> tensor<?x?xi32>

into

%23 = linalg.init_tensor [%22, 1] : tensor<?x1xi32>
%24 = linalg.fill {__internal_linalg_transform__ = "1"} ins(%c0_i32 : i32) outs(%23 : tensor<?x1xi32>) -> tensor<?x1xi32>

After folding the extract_slice with the init_tensor into a smaller init_tensor the code snipped is not dependent on the inner loop anymore and the fill hoists out of the inner loop. That seems correct?

It seems though that bufferization afterwards bufferizes the matmul output / fill result in-place (@matthias-springer and me discussed this and @matthias-springer already has a smaller repro). Fixing this requires extending the bufferization analysis and likely requires a larger change.

@matthias-springer
Copy link
Contributor

I'm working on a fix for the bufferization. Have some idea on how to fix it. Will keep you posted.

@matthias-springer
Copy link
Contributor

I have a potential fix: https://reviews.llvm.org/D123791

@hanhanW
Copy link
Contributor

hanhanW commented Apr 15, 2022

Verified that it's fixed in #8888.

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
I0415 12:17:54.249106 140119498839296 test_util.py:82] Setting up iree runtime
I0415 12:17:54.284276 140119498839296 test_util.py:131] Setting up test inputs
I0415 12:17:54.488303 140119498839296 test_util.py:134] Invoking TFLite
I0415 12:17:55.265557 140119498839296 test_util.py:96] Invocation time: 0.7769 seconds
I0415 12:17:55.265695 140119498839296 test_util.py:137] Invoke IREE
I0415 12:17:55.326046 140119498839296 test_util.py:111] Invocation time: 0.0602 seconds
I0415 12:17:55.326839 140119498839296 test_util.py:72] Max error (0): 0.465354
[       OK ] MobilenetV2Int8Test.test_compile_tflite
I0415 12:18:58.260412 140671293609216 test_util.py:82] Setting up iree runtime
I0415 12:18:58.325455 140671293609216 test_util.py:131] Setting up test inputs
I0415 12:18:58.554241 140671293609216 test_util.py:134] Invoking TFLite
I0415 12:18:59.649587 140671293609216 test_util.py:96] Invocation time: 1.0946 seconds
I0415 12:18:59.649970 140671293609216 test_util.py:137] Invoke IREE
I0415 12:18:59.728529 140671293609216 test_util.py:111] Invocation time: 0.0783 seconds
I0415 12:18:59.729651 140671293609216 test_util.py:72] Max error (0): 0.407153
[       OK ] MobilenetV35Int8Test.test_compile_tflite
I0415 12:19:37.313293 140295578084608 test_util.py:82] Setting up iree runtime
I0415 12:19:37.387068 140295578084608 test_util.py:131] Setting up test inputs
I0415 12:19:37.387548 140295578084608 test_util.py:57]  [  1 224 224   3], float32
I0415 12:19:37.387740 140295578084608 test_util.py:134] Invoking TFLite
I0415 12:19:47.684582 140295578084608 test_util.py:96] Invocation time: 10.2962 seconds
I0415 12:19:47.685064 140295578084608 test_util.py:137] Invoke IREE
I0415 12:19:48.200058 140295578084608 test_util.py:111] Invocation time: 0.5147 seconds
I0415 12:19:48.201171 140295578084608 test_util.py:72] Max error (0): 0.438422
[       OK ] ResNet50Int8Test.test_compile_tflite

@hanhanW hanhanW closed this as completed Apr 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants