Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Understand why IREE is 2x slower than TFLite on ViT INT8 on ARM64 #15399

Closed
mariecwhite opened this issue Nov 3, 2023 · 57 comments
Closed
Assignees
Labels
bug 🐞 Something isn't working p2

Comments

@mariecwhite
Copy link
Contributor

What happened?

On Pixel 8 Pro CPU, IREE latency on ViT is 236ms whereas TFLite is 118ms. Let's understand why.

Steps to reproduce your issue

Download https://storage.googleapis.com/iree-model-artifacts/tflite/tflite_models_1698315913/VIT_CLASSIFICATION_INT8_TFLITE_3X224X224XINT8/tosa.mlirbc

Build a version of IREE with #15387 patched.

Compile for Android

iree-compile tosa.mlirbc \
    --iree-hal-target-backends=llvm-cpu \
    --iree-input-type="tosa" \
    --iree-input-demote-f64-to-f32=false \
    --iree-input-demote-i64-to-i32=false \
    --iree-input-promote-bf16-to-f32=false \
    --iree-llvmcpu-debug-symbols=true \
    --iree-vm-bytecode-module-strip-source-map=true \
    --iree-vm-emit-polyglot-zip=false \
    --iree-llvmcpu-target-cpu="cortex-a715" \
    --iree-llvmcpu-target-triple="aarch64-none-linux-android33" \
    --iree-opt-data-tiling \
    --iree-llvmcpu-enable-microkernels \
    -o vit.vmfb

Run on device:

taskset 1F0 iree-benchmark-module --module=vit.vfmb --task_topology_group_count=5 --task_topology_cpu_ids=0,1,2,3,4 --device=local-task --function=main --input=1x3x224x224xi8=0

What component(s) does this issue relate to?

Compiler

Version information

d32d8ce

Additional context

No response

@mariecwhite mariecwhite added the bug 🐞 Something isn't working label Nov 3, 2023
@mariecwhite
Copy link
Contributor Author

Note that the compile command with the patch ends up using cpu features cpu_features +reserve-x18,+bf16,+crc,+dotprod,+flagm,+fp-armv8,+fullfp16,+fp16fml,+i8mm,+lse,+mte,+pauth,+perfmon,+predres,+spe,+ras,+rcpc,+rdm,+sb,+neon,+ssbs,+sve,+sve2-bitperm,+sve2

@bjacob
Copy link
Contributor

bjacob commented Nov 3, 2023

Interesting. This is a 5-thread benchmark across 2 tiers of cores (taskset 1f0)
How does it look on 1-thread on a taskset that selects 1 tier of cores (eg taskset 100 or taskset f0)?

@mariecwhite
Copy link
Contributor Author

On 1 thread with taskset 100, IREE has latency 730ms and TFLite is 196ms.

I ran at 1, 4 and 5 with tasksets 100, F0 and 1F0:

TFLite:

Model Threads Latency
VIT_CLASSIFICATION_INT8_TFLITE_3X224X224XINT8 1 196.9
VIT_CLASSIFICATION_INT8_TFLITE_3X224X224XINT8 4 156.7
VIT_CLASSIFICATION_INT8_TFLITE_3X224X224XINT8 5 118.5

IREE

Model Threads Latency
VIT_CLASSIFICATION_INT8_TFLITE_3X224X224XINT8 1 729
VIT_CLASSIFICATION_INT8_TFLITE_3X224X224XINT8 4 330
VIT_CLASSIFICATION_INT8_TFLITE_3X224X224XINT8 5 236

@qcolombet
Copy link
Contributor

Do you have a profile handy to see if we can narrow down the issue to a few dispatches without having the run/reproduce the full thing?

@qcolombet
Copy link
Contributor

What are the steps to run with TFLite?

@mariecwhite
Copy link
Contributor Author

I'll try and get Tracy up and running in the next couple of days and circle back with a profile.

@mariecwhite
Copy link
Contributor Author

Setting up Tracy ended up being straightforward. Here is the trace: https://storage.googleapis.com/iree-shared-files/mariewhite/profiles/vit.tracy. Not sure what to look for here but it looks like Softmax and 3 matmuls take the longest time.

@qcolombet
Copy link
Contributor

As far as I can tell, the 3 heaviest matmuls in the profile that you share, @mariecwhite, use µkernels.

At this point I don't know if our implementation for these kernels is inferior to what TFLite uses or if we just spent our time differently. Ditto for the softmax.

Could you share the repro steps with TFLite and/or a profile for TFLite?

@mariecwhite
Copy link
Contributor Author

Here is a nice article on performance measurement of TFLite models, including ways to profile them: https://www.tensorflow.org/lite/performance/measurement

In that article is a link to the prebuilt benchmark tool for Android ARM: https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/android_aarch64_benchmark_model

Once downloaded, adb push the tool to the device. Also download and push the TFLite flatbuffer: https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.20_1699872537/VIT_CLASSIFICATION_JAX_3X224X224XF32/model_int8.tflite

Then on the device run:

./android_aarch64_benchmark_model --graph <path to tflite model> --num_threads=1 

@mariecwhite
Copy link
Contributor Author

You can also quickly get a profile by adding --enable_op_profiling=true to the benchmark run.

@qcolombet
Copy link
Contributor

Reporting a few facts here, no much analysis yet:

  • In this case, TFLite uses XNNPack ukernels
  • GEMM kernels account for ~75% of the runtime, of which 85% comes from just 2 GEMM kernels (instead of 3 for us).
  • Softmax runtime is barely on the radar, it shows 4% of runtime, highlighting that we probably do a really bad job here.
Number of nodes executed: 203
============================== Summary by node type ==============================
	                             [Node type]	  [count]	  [avg ms]	    [avg %]	    [cdf %]	  [mem KB]	[times called]
	          Fully Connected (NC, QS8) GEMM	        5	   140.139	    74.294%	    74.294%	     0.000	       73
	                                     SUB	       33	    18.728	     9.929%	    84.223%	     0.000	       33
	                            BATCH_MATMUL	       24	     8.169	     4.331%	    88.554%	     0.000	       24
	                                 SOFTMAX	       12	     7.513	     3.983%	    92.537%	     0.000	       12
	                           Add (ND, QS8)	        5	     4.527	     2.400%	    94.937%	     0.000	       62
	                      Multiply (ND, QS8)	        6	     3.387	     1.796%	    96.732%	     0.000	      150
	           Convolution (NHWC, QS8) IGEMM	        1	     2.241	     1.188%	    97.920%	     0.000	        1
	                                     SUM	       50	     2.025	     1.074%	    98.994%	     0.000	       50
	            Transpose (ND, X8) Transpose	        5	     1.033	     0.548%	    99.541%	     0.000	       49
	                      Subtract (ND, QS8)	        3	     0.375	     0.199%	    99.740%	     0.000	       17
	                                   RSQRT	       25	     0.236	     0.125%	    99.865%	     0.000	       25
	                           Copy (NC, X8)	        6	     0.133	     0.071%	    99.936%	     0.000	      100
	                                     ADD	       25	     0.054	     0.029%	    99.964%	     0.000	       25
	                  Convert (NC, F32, QS8)	        1	     0.037	     0.020%	    99.984%	     0.000	        1
	                          Slice (ND, X8)	        1	     0.028	     0.015%	    99.999%	     0.000	        1
	                  Convert (NC, QS8, F32)	        1	     0.002	     0.001%	   100.000%	     0.000	        1

@qcolombet
Copy link
Contributor

Interestingly disabling XNNPack yields pretty much the same performance (176 ms vs. 173 ms; min of 50 runs).

============================== Summary by node type ==============================
	                             [Node type]	  [count]	  [avg ms]	    [avg %]	    [cdf %]	  [mem KB]	[times called]
	                         FULLY_CONNECTED	       73	   111.416	    63.090%	    63.090%	     0.000	       73
	                                     SUB	       50	    31.273	    17.708%	    80.798%	     0.000	       50
	                                     ADD	       87	     8.722	     4.939%	    85.737%	     0.000	       87
	                            BATCH_MATMUL	       24	     7.863	     4.452%	    90.190%	     0.000	       24
	                                 SOFTMAX	       12	     7.318	     4.144%	    94.333%	     0.000	       12
	                                     MUL	      150	     5.400	     3.058%	    97.391%	     0.000	      150
	                                     SUM	       50	     2.075	     1.175%	    98.566%	     0.000	       50
	                                 CONV_2D	        1	     1.278	     0.724%	    99.290%	     0.000	        1
	                               TRANSPOSE	       49	     0.802	     0.454%	    99.744%	     0.000	       49
	                                   RSQRT	       25	     0.202	     0.114%	    99.858%	     0.000	       25
	                                 RESHAPE	       99	     0.184	     0.104%	    99.963%	     0.000	       99
	                                QUANTIZE	        1	     0.038	     0.022%	    99.984%	     0.000	        1
	                           STRIDED_SLICE	        1	     0.018	     0.010%	    99.994%	     0.000	        1
	                           CONCATENATION	        1	     0.009	     0.005%	    99.999%	     0.000	        1
	                              DEQUANTIZE	        1	     0.001	     0.001%	   100.000%	     0.000	        1

@qcolombet
Copy link
Contributor

These numbers are with just 1 thread.

@dcaballe
Copy link
Contributor

That's interesting! I saw softmax in the profile of other models but it was too far from the top to be worth investing on it. We do know that there are some improvements to be done, though: #15210 (comment) Some of these issues seem specific to IREE, perhaps due to changes needed on the GPU side. We can work on that. Feel free to create issues for this!

Regarding matmuls, it's very likely that DT+UK is just not enough for this large matmuls. We probably need multiple levels of tiling targeting the different cache levels...

@stellaraccident
Copy link
Collaborator

If looking for a 2x, ignore the inefficient 4% thing for now?

@bjacob knows the xnnpack and ukernel story well. Unless if something has changed, I don't think this is doing anything particularly advanced and probably just needs some catch-up. Any advice to keep the analysis on the most profitable path, Benoit?

@bjacob
Copy link
Contributor

bjacob commented Nov 17, 2023

If looking for a 2x, ignore the inefficient 4% thing for now?

Yes. Since we are looking for a 2x and both the XNNPACK and the non-XNNPACK profile show respectively 74% and 63% time spent in FULLY_CONNECTED, let's focus on that.

Getting us this far was what the coarse-grained TFLite "op profiling" was good for. Now that we know we're down to profiling inside matrix multiplication implementations, we need another profiler.

You could always skip straight to simpleperf, but: since you observe the same performance levels on XNNPACK and non-XNNPACK, and the non-XNNPACK spends actually less time in FULLY_CONNECTED, implying it has a more optimized implementation of it, maybe let's focus on that --- conveniently, that more optimized implementation is going to be ruy and there is a dedicated instrumentation profiler that can give us more information here.

Rebuild TFLite (specifically that :benchmark_model binary) with this Bazel command-line flag: --define=ruy_profiler=true. Then just rerun the benchmark (again, with XNNPACK disabled). This should print to stdout a treeview profile of matmuls. This would tell us in particular the matmul shapes, the SIMD code paths taken, the number of worker threads effectively used, and the breakdown of % of time spent in packing vs the arithmetic kernel.

@bjacob
Copy link
Contributor

bjacob commented Nov 17, 2023

I have this old script to break down TOSA ops/shapes... here is what it says about this model:

   1196 tosa.rescale
    393 tosa.reshape
    186 tosa.mul
     99 tosa.add
     74 tosa.sub
     74 tosa.conv_quant
     73 tosa.fully_connected
            breakdown by output type:
                60 tensor<197x768xi32>
                12 tensor<197x3072xi32>
                 1 tensor<1x1000xi32>
     62 tosa.reduce_sum
     61 tosa.transpose
     25 tosa.table
     24 tosa.matmul_quant
     24 tosa.matmul
     24 tosa.cast
     12 tosa.reduce_max
     12 tosa.reciprocal
     12 tosa.exp
      1 tosa.slice
      1 tosa.conv2d
            breakdown by filter:
                 1 (kernel 16x16)
            breakdown by dilation:
                 1 (dilation 1x1)
      1 tosa.concat

So these are pretty normal i8 x i8 -> i32 matmul shapes.

On the IREE side:

  1. Could you compile with --iree-hal-dump-executable-intermediates-to= and share the resulting .s file ? Just to confirm this is using the right ukernel path (should see smmla instructions, since you enabled +i8mm which is supported on Pixel8 and that was part of the CPU features enumerated above).
  2. Could you get a Tracy profile, ideally with sampling and the full source mapping working (doc). It's OK to elide weights to make the source mapping lighter if needed.
    • I would like to see if the % of time we spend in packing is comparable to what Ruy gets.

@bjacob
Copy link
Contributor

bjacob commented Nov 17, 2023

Also, it's interesting that Ruy is being 2x faster than IREE even though Ruy only knows how to use the +dotprod extension, not +i8mm, implying it should actually be 2x slower.

@qcolombet
Copy link
Contributor

The assembly vit_iree_ukernels.s.txt.

For tracy, I have to rework my settings, I don't get the full source mapping.

@qcolombet
Copy link
Contributor

Here's the tracy profile
vit_feat.tracy.tgz

@qcolombet
Copy link
Contributor

@mariecwhite could you try to build the android benchmark utility with the option that @bjacob recommend (--define=ruy_profiler=true)?
(Or point out what I'm doing wrong?)

I don't know what I'm doing wrong, but bazel fails pretty fast for me:

$ bazel build -c opt  --config=android_arm64    tensorflow/lite/tools/benchmark/android:benchmark_model
[...]
ERROR: [...]/.cache/bazel/_bazel_qcolombet/cfcb28938d908500dfc3e05e7d22907a/external/hexagon_nn/BUILD.bazel:62:11: Compiling hexagon/rpcmem_stub.c failed: undeclared inclusion(s) in rule '@hexagon_nn//:rpcmem':
this rule is missing dependency declarations for the following files included by 'hexagon/rpcmem_stub.c':
  'external/androidndk/toolchains/llvm/prebuilt/linux-x86_64/lib/clang/17/include/stdint.h'
  'external/androidndk/toolchains/llvm/prebuilt/linux-x86_64/lib/clang/17/include/stddef.h'
  'external/androidndk/toolchains/llvm/prebuilt/linux-x86_64/lib/clang/17/include/__stddef_max_align_t.h'
  'external/androidndk/toolchains/llvm/prebuilt/linux-x86_64/lib/clang/17/include/stdarg.h'
Target //tensorflow/lite/tools/benchmark/android:benchmark_model failed to build
Use --verbose_failures to see the command lines of failed build steps.
INFO: Elapsed time: 0.637s, Critical Path: 0.35s
INFO: 129 processes: 129 internal.
FAILED: Build did NOT complete successfully

I configured my workspace with the following environment:

    ANDROID_BUILD_TOOLS_VERSION=34.0.0
    ANDROID_NDK_API_LEVEL=26
    ANDROID_NDK_HOME=/usr/local/google/home/qcolombet/workdir/NDK/android-ndk-r26b
    ANDROID_NDK_VERSION=26
    ANDROID_SDK_API_LEVEL=34
    ANDROID_SDK_HOME=/usr/local/google/home/qcolombet/Android/Sdk
    CLANG_COMPILER_PATH=/usr/local/google/home/qcolombet/workdir/NDK/android-ndk-r26b/toolchains/llvm/prebuilt/linux-x86_64/bin/clang-17

@mariecwhite
Copy link
Contributor Author

I ran into the same error. The configure script does warn that NDK version 26 is not supported. Using NDK version 25 and SDK version 33 works, so I suggest downloading them, running ./configure, bazel clean --expunge, then build the TFLite benchmark tool.

@qcolombet
Copy link
Contributor

Ah thanks @mariecwhite!
For some reasons, I missed the warning message.

Trying with the older versions.

@qcolombet
Copy link
Contributor

Thanks again @mariecwhite that fixed it.

@bjacob here is the output of the benchmarking tool when built with --define=ruy_profiler=true.

With 1 thread:

$ adb shell "/data/local/tmp/benchmark_model --graph=/data/local/tmp/model_int8.tflite --num_threads=1  --use_xnnpack=false
[...]
Profile (1 threads):

Thread 0 (7605 samples)

* 66.02% FullyConnectedInt8/8bit
  * 66.01% cpu_backend_gemm::Gemm
    * 65.97% Mul
      * 22.37% matmul shape: 768x768x197
        * 22.37% TrMul (Path=0x20, max_num_threads=1, is_prepacked=(0,0))
          * 22.37% TrMulImpl, general case
            * 20.97% Kernel (kNeonDotprod)
            * 1.22% Pack (kNeonDotprod)
            * 0.16% [other]
            * 0.01% MakeBlockMap
      * 21.97% matmul shape: 768x3072x197
        * 21.97% TrMul (Path=0x20, max_num_threads=1, is_prepacked=(0,0))
          * 21.97% TrMulImpl, general case
            * 19.72% Kernel (kNeonDotprod)
            * 2.12% Pack (kNeonDotprod)
            * 0.13% [other]
      * 21.63% matmul shape: 3072x768x197
        * 21.63% TrMul (Path=0x20, max_num_threads=1, is_prepacked=(0,0))
          * 21.63% TrMulImpl, general case
            * 19.91% Kernel (kNeonDotprod)
            * 1.53% Pack (kNeonDotprod)
            * 0.20% [other]
    * 0.04% cpu_backend_gemm::Gemm: CustomGemv
  * 0.01% [other]
* 16.86% BroadcastQuantSubSlow/T
* 5.21% cpu_backend_gemm::Gemm
  * 5.21% Mul
    * 3.18% matmul shape: 197x64x197
      * 3.17% TrMul (Path=0x20, max_num_threads=1, is_prepacked=(0,0))
        * 3.16% TrMulImpl, simple loop
          * 2.97% Kernel (kNeonDotprod)
          * 0.11% [other]
          * 0.04% Pack (kNeonDotprod, from row-major)
          * 0.04% Pack (kNeonDotprod)
        * 0.01% [other]
      * 0.01% [other]
    * 2.02% matmul shape: 197x197x64
      * 2.01% TrMul (Path=0x20, max_num_threads=1, is_prepacked=(0,0))
        * 2.01% TrMulImpl, general case
          * 1.68% Kernel (kNeonDotprod)
          * 0.22% Pack (kNeonDotprod)
          * 0.11% [other]
      * 0.01% [other]
* 3.75% SoftmaxInt8LUT
* 3.54% AddElementwiseInt8/8bit
* 1.81% MulInt8/8bit
  * 1.81% MulElementwiseInt8/8bit
* 1.14% BroadMulSimpleBroadcastInt8/8bit
* 1.04% AddInt8/8bit
  * 1.04% AddElementwiseInt8/8bit
* 0.49% Conv/8bit
  * 0.46% cpu_backend_gemm::Gemm
    * 0.46% Mul
      * 0.46% matmul shape: 768x768x196
        * 0.46% TrMul (Path=0x20, max_num_threads=1, is_prepacked=(0,0))
          * 0.46% TrMulImpl, general case
            * 0.42% Kernel (kNeonDotprod)
            * 0.03% Pack (kNeonDotprod)
            * 0.01% [other]
  * 0.03% Im2col
* 0.13% Sum
  * 0.08% QuantizedSum
  * 0.05% [other]
* 0.01% Quantize/Int8

And with 5 threads:

$ adb shell "/data/local/tmp/benchmark_model --graph=/data/local/tmp/model_int8.tflite --num_threads=5 --use_xnnpack=false"
[...]
Profile (5 threads):

Thread 0 (4798 samples)

* 46.52% FullyConnectedInt8/8bit
  * 46.52% cpu_backend_gemm::Gemm
    * 46.50% Mul
      * 16.69% matmul shape: 768x768x197
        * 16.69% TrMul (Path=0x20, max_num_threads=5, is_prepacked=(0,0))
          * 16.69% TrMulImpl, general case
            * 11.13% Kernel (kNeonDotprod)
            * 4.11% [other]
            * 1.42% Pack (kNeonDotprod)
            * 0.04% MakeBlockMap
      * 15.21% matmul shape: 3072x768x197
        * 15.21% TrMul (Path=0x20, max_num_threads=5, is_prepacked=(0,0))
          * 15.21% TrMulImpl, general case
            * 9.63% Kernel (kNeonDotprod)
            * 4.38% [other]
            * 1.21% Pack (kNeonDotprod)
      * 14.59% matmul shape: 768x3072x197
        * 14.59% TrMul (Path=0x20, max_num_threads=5, is_prepacked=(0,0))
          * 14.59% TrMulImpl, general case
            * 10.44% Kernel (kNeonDotprod)
            * 3.02% [other]
            * 1.13% Pack (kNeonDotprod)
    * 0.02% cpu_backend_gemm::Gemm: CustomGemv
* 27.82% BroadcastQuantSubSlow/T
* 6.77% cpu_backend_gemm::Gemm
  * 6.67% Mul
    * 3.52% matmul shape: 197x197x64
      * 3.50% TrMul (Path=0x20, max_num_threads=5, is_prepacked=(0,0))
        * 3.48% TrMulImpl, general case
          * 1.73% [other]
          * 1.35% Kernel (kNeonDotprod)
          * 0.33% Pack (kNeonDotprod)
          * 0.04% GetBlockMatrixCoords
          * 0.02% MakeBlockMap
        * 0.02% [other]
      * 0.02% [other]
    * 3.13% matmul shape: 197x64x197
      * 3.11% TrMul (Path=0x20, max_num_threads=5, is_prepacked=(0,0))
        * 3.11% TrMulImpl, general case
          * 1.83% [other]
          * 1.10% Kernel (kNeonDotprod)
          * 0.10% Pack (kNeonDotprod)
          * 0.06% Pack (kNeonDotprod, from row-major)
      * 0.02% [other]
    * 0.02% [other]
  * 0.10% [other]
* 6.25% AddElementwiseInt8/8bit
* 6.19% SoftmaxInt8LUT
* 2.79% MulInt8/8bit
  * 2.79% MulElementwiseInt8/8bit
* 1.94% AddInt8/8bit
  * 1.94% AddElementwiseInt8/8bit
* 1.04% BroadMulSimpleBroadcastInt8/8bit
* 0.40% Sum
  * 0.40% QuantizedSum
* 0.21% Conv/8bit
  * 0.21% cpu_backend_gemm::Gemm
    * 0.21% Mul
      * 0.21% matmul shape: 768x768x196
        * 0.21% TrMul (Path=0x20, max_num_threads=5, is_prepacked=(0,0))
          * 0.21% TrMulImpl, general case
            * 0.15% [other]
            * 0.06% Kernel (kNeonDotprod)
* 0.06% Quantize/Int8

Thread 1 (1850 samples)

* 96.00% Kernel (kNeonDotprod)
* 3.95% Pack (kNeonDotprod)
* 0.05% GetBlockMatrixCoords

Thread 2 (1940 samples)

* 95.00% Kernel (kNeonDotprod)
* 4.85% Pack (kNeonDotprod)
* 0.15% GetBlockByIndex

Thread 3 (1918 samples)

* 95.83% Kernel (kNeonDotprod)
* 4.01% Pack (kNeonDotprod)
* 0.10% GetBlockMatrixCoords
* 0.05% Pack (kNeonDotprod, from row-major)

Thread 4 (2037 samples)

* 95.53% Kernel (kNeonDotprod)
* 4.37% Pack (kNeonDotprod)
* 0.10% GetBlockByIndex

@qcolombet
Copy link
Contributor

@mariecwhite, I just realized that we were not using the same CPU features.
I had:
+v9a,+fullfp16,+fp-armv8,+neon,+aes,+sha2,+crc,+lse,+rdm,+complxnum,+rcpc,+sha3,+sm4,+dotprod,+fp16fml,+dit,+flagm,+ssbs,+sb,+altnzcv,+fptoint,+bf16,+i8mm,+bti

and you have:
+reserve-x18,+bf16,+crc,+dotprod,+flagm,+fp-armv8,+fullfp16,+fp16fml,+i8mm,+lse,+mte,+pauth,+perfmon,+predres,+spe,+ras,+rcpc,+rdm,+sb,+neon,+ssbs,+sve,+sve2-bitperm,+sve2

I tried to use your target features and I hit a compiler error:

<unknown>:0: error: 'arith.extsi' op requires the same shape for all operands and results
<unknown>:0: note: see current operation: %46 = "arith.extsi"(%42) : (vector<16x[32]xi8>) -> vector<16x32xi32>

Any idea what I may be missing?

@mariecwhite
Copy link
Contributor Author

We did find that using any of the sve flags resulted in compiler errors for many workloads. The flags used in this test where auto-generated from this patch, which we've found to be infeasible. There is an ongoing task to figure out exactly which feature flags to include for Pixel 8 Pro but I think the ones you're using are fine.

@bjacob
Copy link
Contributor

bjacob commented Nov 22, 2023

Yes, avoid +sve and +sve2 as they disable data-tiling ( @banach-space FYI ). No need to specify +reserve-x18, IREE does that internally.

Thanks for gathering the detailed profiles above. Now the next step is to compare that to Tracy profiles of how IREE is spending its time. I see the tracy trace and disassembly that you shared above 2 days ago, so now the next step is to look into that.

@banach-space
Copy link
Collaborator

Any idea what I may be missing?

@mariecwhite Could you check whether you have this change:

(see the update in LowerVectorContract.cpp).

We did find that using any of the sve flags resulted in compiler errors for many workloads.

sve and "scalable vectors" support is new and still incomplete. We are slowly patching various bits that make invalid assumptions about "scalability", but it will take a few more months. ATM I wouldn't use sve for anything beyond a plain Matmul.

Having said that, bug reports and repros are very welcome :)

@qcolombet
Copy link
Contributor

The input TFLite model has the batch matmul and so has the converted tosa IR.
This is probably just a matter of how the numbers are reported.

Anyhow, looking at the TFLite graph I see only one quantize and one dequantize node for the full graph whereas in the tosa input, we have such nodes (tosa.rescale) between every single operations. In particular we do this i32 -> i8 -> i32 dance between every single computation and I wonder if TFLite has to deal with that too.
(And this is not something we can technically optimize away since these steps could affect the precision).

@qcolombet
Copy link
Contributor

Nah, matmul vs batch_matmul doesn't make a performance difference in itself.

However I have just checked out the tracy profile now, and this is the top dispatches:

image

The quantized_batch_matmul here are suspiscious. We have an early rewrite of quantized_matmul to matmul which is key to letting quantized_matmul ride our optimized matmul paths.

https://github.com/openxla/iree/blob/1f6c162fb000c4cd983637eb68eb6cda7a9ec75b/compiler/src/iree/compiler/InputConversion/Common/QuantizedMatmulToMatmul.cpp#L37

It only handles quantized_matmul, not quantized_batch_matmul.

Ah thanks!
I'll take a look.

@bjacob
Copy link
Contributor

bjacob commented Nov 24, 2023

Anyhow, looking at the TFLite graph I see only one quantize and one dequantize node for the full graph whereas in the tosa input, we have such nodes (tosa.rescale) between every single operations. In particular we do this i32 -> i8 -> i32 dance between every single computation and I wonder if TFLite has to deal with that too.

TFLite does avoid that -- it keeps i32 accumulators an internal detail of its fused quantized-matmul-and-rescale kernels (with both ruy and xnnpack backends), in registers. So that's one way in which TFLite is going to be fundamentally more efficient here. However, that won't explain a 2x difference by itself. I think that the bulk of the 2x gap has to be explained by somethind else, such as quantized_batch_matmul not being picked up by the above mentioned rewrite and thus staying on a default naive lowering path.

@bjacob
Copy link
Contributor

bjacob commented Nov 24, 2023

Yep, it seems like a nice generally useful PR to write to let QuantizedMatmulToMatmul.cpp handle QuantizedBatchMatmul. I'll let your team handle it and return to other things, but let me know if there's any question. There's a comment at the top linking to the paper explaining the math that's going on here and the idea of that rewrite.

@qcolombet
Copy link
Contributor

However, that won't explain a 2x difference by itself. I think that the bulk of the 2x gap has to be explained by somethind else, such as quantized_batch_matmul not being picked up by the above mentioned rewrite and thus staying on a default naive lowering path.

Agree, but that's something that is likely to be more problematic on GPUs if we end up with unfused consumer <- dequantize and/or quantize <- producer, since we'll have to copy the full blown tensor (the i32 version instead of the i8 version) around.

@bjacob
Copy link
Contributor

bjacob commented Nov 24, 2023

Would anything prevent the fusion (of the consumer rescaling into the producer matmul) on GPU?

The only thing preventing that fusion here on CPU is the usage of ukernel for the matmul. Even then, we have a plan for eventually getting the fusion -- get the ukernel bitcode to fully inline (that's already the case on x86, just not yet on Arm, superficial issue), and tile outer loops to size 1 so they vanish, and ensure the consumer rescaling goes in the same dispatch -- at that point, there should be nothing preventing late (post linking) LLVM IR passes from coalescing the stores of i32 loads to memory at the end of the ukernel with their reloads at the start of the rescaling.

@qcolombet
Copy link
Contributor

qcolombet commented Nov 24, 2023

Would anything prevent the fusion (of the consumer rescaling into the producer matmul) on GPU?

Not that I can think of, but I don't know how IREE does fusion :).

@qcolombet
Copy link
Contributor

Yep, it seems like a nice generally useful PR to write to let QuantizedMatmulToMatmul.cpp handle QuantizedBatchMatmul. I'll let your team handle it and return to other things, but let me know if there's any question. There's a comment at the top linking to the paper explaining the math that's going on here and the idea of that rewrite.

@LLITCHEV, when you're done with your softmax improvements or if the softmax improvements take too long, could you look at this?

@LLITCHEV
Copy link
Contributor

@qcolombet I'll follow up once the softmax is done. Thanks!

@hanhanW
Copy link
Contributor

hanhanW commented Dec 2, 2023

The quantized_batch_matmul is going through CPUDefault, and it's not vectorized. We need to decompose it.

side note:
I think we can have an option to signal an error when

  1. A dispatch is going through CPUDefault, and
  2. It is not a LinalgExt op.

@mariecwhite
Copy link
Contributor Author

The i8mm s8s8s32 ukernel doesn't seem to be effective here. I compiled ViT INT8 with only the +i8mm or +dotprod CPU feature and compared the e2e latencies. dotprod is slightly faster.

variant latency 1 thread (ms) latency 5 threads (ms)
dotprod 752 243
i8mm 807 250

This does not correlate with the microbenchmarks:

-----------------------------------------------------------------------------------------------------------
Benchmark                                                 Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------------------
BM_mmt4d_s8s8s32_tile_1x8x4_dotprod/real_time         0.129 us        0.128 us      8388607 items_per_second=127.338G/s
BM_mmt4d_s8s8s32_tile_2x8x4_dotprod/real_time         0.164 us        0.163 us      8388607 items_per_second=200.269G/s
BM_mmt4d_s8s8s32_tile_4x8x4_dotprod/real_time         0.253 us        0.253 us      4194303 items_per_second=258.707G/s
BM_mmt4d_s8s8s32_tile_8x8x4_dotprod/real_time         0.363 us        0.362 us      2097151 items_per_second=361.478G/s
BM_mmt4d_s8s8s32_tile_1x8x8_i8mm/real_time            0.191 us        0.191 us      4194303 items_per_second=171.266G/s
BM_mmt4d_s8s8s32_tile_2x8x8_i8mm/real_time            0.182 us        0.182 us      4194303 items_per_second=359.516G/s
BM_mmt4d_s8s8s32_tile_4x8x8_i8mm/real_time            0.274 us        0.273 us      4194303 items_per_second=479.116G/s
BM_mmt4d_s8s8s32_tile_8x8x8_i8mm/real_time            0.452 us        0.451 us      2097151 items_per_second=579.566G/s

TOSA MLIR definitely contains i8*i8=i32 matmuls and it is hitting the expected tile size paths in CPUMaterializeEncoding.

I tried running with --iree-llvmcpu-disable-distribution but i8mm exceeds the stack allocation limit and disabling that check leads to a segmentation fault when benchmarking.

@mariecwhite
Copy link
Contributor Author

@mariecwhite
Copy link
Contributor Author

To reproduce:

  1. Download https://storage.googleapis.com/iree-model-artifacts/tflite/tflite_models_1706739936/VIT_CLASSIFICATION_INT8_TFLITE_3X224X224XINT8/tosa.mlirbc

  2. Build

#CPU_FEATURES=+i8mm
CPU_FEATURES=+dotprod

iree-compile --iree-input-type=tosa \
    --iree-hal-target-backends=llvm-cpu \
    --iree-llvmcpu-target-triple=aarch64-none-linux-android34 \
    --iree-llvmcpu-target-cpu-features=${CPU_FEATURES} \
    --iree-opt-data-tiling=true \
    --iree-llvmcpu-enable-ukernels=all \
    ${INPUT_MLIR_PATH} -o ${OUTPUT_VMFB_PATH}

adb push ${OUTPUT_VMFB_PATH} /data/local/tmp
  1. Benchmark on device
iree-benchmark-module --module=vit.vmfb \
    --device=local-task \
    --function=main \
    --input=1x3x224x224xf32=0 \
    --task_topology_cpu_ids=0,1,2,3,4

@mariecwhite mariecwhite changed the title [CPU] Understand why IREE is 2x slower than TFLite on ViT on ARM64 [CPU] Understand why IREE is 2x slower than TFLite on ViT INT8 on ARM64 Feb 27, 2024
@mariecwhite
Copy link
Contributor Author

mariecwhite commented Feb 27, 2024

simpleperf profile attached with L1/L2 cache counters.
vit_i8mm_profile.zip
vit_dotprod_profile.zip

There looks to be 2 issues:

  • quantized_batch_matmul is not using any of the ukernels (the sdot or smmla don't appear in the dissassembly). Note however that quantized_batch_matmuls don't make up the top dispatches (dispatches that take the most cpu cycles). The mmt4d with ukernels make up most of the top dispatches.
  • if the i8mm variant correlated with the microbenchmarks, we should expect the quantized_batch_matmuls to bubble up and appear more often in the top dispatches but this is not the case. The distribution is very similar to the dotprod version. Dissassembly shows that smmla is being used.

L1 cache misses are mostly due to quantized_batch_matmul. Until this is fixed, it's hard to debug the second issue.

@bjacob
Copy link
Contributor

bjacob commented Feb 27, 2024

Yes, the quantized_batch_matmul issue was identified in the earlier round of comments above already: #15399 (comment)

it really isn't hard to fix, one existing quantized_matmul -> matmul rewrite pattern needs to be slightly generalized to handle the batch_ variant, that's all.

@hanhanW
Copy link
Contributor

hanhanW commented Feb 27, 2024

what Benoit said. We are not data-tiling the quantized_batch_matmul. It should be converted to batch_matmul plus somthing in frontend/preprocessing.

@bjacob
Copy link
Contributor

bjacob commented Feb 28, 2024

Filed #16599 for that specifically. Assuming you're not already working on this @mariecwhite , so we might be staffing this on our end -- sync there.

@dcaballe
Copy link
Contributor

I think @LLITCHEV is not currently looking into it but I'm not sure if he made any progress. Hopefully we can update.

@bjacob
Copy link
Contributor

bjacob commented Feb 28, 2024

Ah - well on our side, @pashu123 has started looking into it. Whoever gets there first :-P

@LLITCHEV
Copy link
Contributor

@dcaballe The priority of this was reduced and I'm looking at something else at the moment.

@mariecwhite
Copy link
Contributor Author

Looks like it's over to you @pashu123. Thanks!

@pashu123
Copy link
Contributor

Cool, I have added the support here #16615 . Meanwhile, I will be adding regression tests. Thanks.

@hanhanW
Copy link
Contributor

hanhanW commented Jul 30, 2024

Closing the issue because the performance concern is addressed. According to the benchmark report, IREE got 2x faster which should be competitive with TFLite.

@hanhanW hanhanW closed this as completed Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working p2
Projects
None yet
Development

No branches or pull requests

10 participants