Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flow] Change the definition of "dequantization" recognizer. #17711

Conversation

MaheshRavishankar
Copy link
Contributor

The dequantization operation today is trying to enforce that the input indexing map is an identity. This is overly conservative for newer quantization schemes. This changes the logic to just look at operand ranks to check if the operation is a dequantization operation.

@MaheshRavishankar MaheshRavishankar added benchmarks:cuda Run default CUDA benchmarks benchmarks:x86_64 Run default x86_64 benchmarks benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU labels Jun 20, 2024
Copy link

Abbreviated Benchmark Summary

@ commit 88f32fab4c2503c4b8cf1d43f1480ae1467e5788 (vs. base 90f29a66d5bbd58167d84b2011d27c7ffb9a1ee1)

Data-Tiling Comparison Table

Click to show
Name No-DT (baseline) DT-Only DT-UK
BertForMaskedLMTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 218.814 (1.0X) N/A 106.242 (2.1X)
BertLargeTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 780.636 (1.0X) N/A 222.751 (3.5X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 6.905 (1.0X) N/A 8.525 (0.8X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 32.027 (1.0X) N/A 30.118 (1.1X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 35.762 (1.0X) N/A 34.370 (1.0X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 276.144 (1.0X) N/A 236.681 (1.2X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.804 (1.0X) N/A 5.033 (1.2X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 26.879 (1.0X) N/A 13.231 (2.0X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 9.085 (1.0X) N/A 8.749 (1.0X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 70.114 (1.0X) N/A 37.586 (1.9X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 10.986 (1.0X) N/A 8.569 (1.3X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 87.114 (1.0X) N/A 39.165 (2.2X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 12.215 (1.0X) N/A 13.035 (0.9X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 80.271 (1.0X) N/A 56.198 (1.4X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.045 (1.0X) N/A 61.171 (0.6X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 181.271 (1.0X) N/A 185.715 (1.0X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.229 (1.0X) N/A 61.197 (0.6X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 181.289 (1.0X) N/A 189.493 (1.0X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 68.484 (1.0X) N/A 63.960 (1.1X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 517.199 (1.0X) N/A 240.747 (2.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.830 (1.0X) N/A 4.569 (1.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 24.879 (1.0X) N/A 17.823 (1.4X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 3.759 (1.0X) N/A 4.880 (0.8X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.550 (1.0X) N/A 11.545 (1.0X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.832 (1.0X) N/A 5.419 (1.1X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 21.571 (1.0X) N/A 11.802 (1.8X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 2.848 (1.0X) N/A 2.815 (1.0X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 2.843 (1.0X) N/A 2.724 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.462 (1.0X) N/A 9.823 (0.9X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.376 (1.0X) N/A 31.267 (1.1X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 0.767 (1.0X) N/A 0.631 (1.2X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.699 (1.0X) N/A 0.568 (1.2X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.128 (1.0X) N/A 5.156 (0.8X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 17.644 (1.0X) N/A 19.645 (0.9X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 7.556 (1.0X) N/A 7.571 (1.0X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 49.011 (1.0X) N/A 44.412 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 50.220 (1.0X) N/A 44.950 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 30.156 (1.0X) N/A 27.873 (1.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 92.587 (1.0X) N/A 21.878 (4.2X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 91.902 (1.0X) N/A 22.881 (4.0X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 52.074 (1.0X) N/A 22.052 (2.4X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 138.227 (1.0X) N/A 27.606 (5.0X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 124.005 (1.0X) N/A 29.448 (4.2X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 76.970 (1.0X) N/A 26.520 (2.9X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 708.956 (1.0X) N/A 356.168 (2.0X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 704.307 (1.0X) N/A 361.714 (1.9X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 397.287 (1.0X) N/A 216.870 (1.8X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 1059.008 (1.0X) N/A 266.730 (4.0X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1058.553 (1.0X) N/A 265.371 (4.0X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 552.999 (1.0X) N/A 157.294 (3.5X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 2104.567 (1.0X) N/A 305.475 (6.9X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 2105.228 (1.0X) N/A 308.456 (6.8X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1134.590 (1.0X) N/A 185.222 (6.1X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 12.274 (1.0X) N/A 1.323 (9.3X)

Regressed Latencies 🚩

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
MobileBertSquad\_int8(tflite) [arm-valhall-vulkan\_android31-vulkan\_spirv][experimental-flags,fuse-padding,max-concurrency] vulkan(none)[full-inference,default-flags] with default @ pixel-6-pro[gpu] 103.830 (vs. 75.807, 36.97%↑) 104.044 0.845
MobileBertSquad\_int8(tflite) [arm-valhall-vulkan\_android31-vulkan\_spirv][default-flags] vulkan(none)[full-inference,default-flags] with default @ pixel-6-pro[gpu] 113.657 (vs. 87.071, 30.53%↑) 113.557 0.731
MobileBertSquad\_fp16(tflite) [arm-valhall-vulkan\_android31-vulkan\_spirv][experimental-flags,fuse-padding,max-concurrency,demote-f32-to-f16] vulkan(none)[full-inference,default-flags] with default @ pixel-6-pro[gpu] 108.362 (vs. 93.941, 15.35%↑) 110.493 5.122

[Top 3 out of 6 results showed]

Improved Latencies 🎉

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
MobileBertSquad\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 157.294 (vs. 191.253, 17.76%↓) 158.182 4.144
MobileBertSquad\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_sync(embedded\_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 266.730 (vs. 322.032, 17.17%↓) 267.317 2.030
MobileBertSquad\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 265.371 (vs. 313.399, 15.32%↓) 266.121 2.363

[Top 3 out of 12 results showed]

Regressed Total Dispatch Sizes 🚩

Benchmark Name Total Dispatch Size (bytes)
MobileBertSquad\_int8(tflite) [riscv\_32-generic-linux\_gnu-llvm\_cpu][default-flags,compile-stats] 3810668 (vs. 2116860, 80.02%↑)
MobileBertSquad\_int8(tflite) [riscv\_64-generic-linux\_gnu-llvm\_cpu][default-flags,compile-stats] 3602552 (vs. 2108296, 70.88%↑)
MobileBertSquad\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk,compile-stats] 1807936 (vs. 1372976, 31.68%↑)

[Top 3 out of 8 results showed]

Regressed Total Artifact Sizes 🚩

Benchmark Name Total Artifact Size (bytes)
MobileBertSquad\_int8(tflite) [riscv\_32-generic-linux\_gnu-llvm\_cpu][default-flags,compile-stats] 30105159 (vs. 28415879, 5.94%↑)
MobileBertSquad\_int8(tflite) [riscv\_64-generic-linux\_gnu-llvm\_cpu][default-flags,compile-stats] 29897031 (vs. 28407303, 5.24%↑)

Improved Stream IR Dispatch Count (# of cmd.dispatch ops) 🎉

Benchmark Name Stream IR Dispatch Count (# of cmd.dispatch ops)
MobileBertSquad\_int8(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][experimental-flags,no-dt,compile-stats] 1078 (vs. 1102, 2.18%↓)
MobileBertSquad\_int8(tflite) [riscv\_64-generic-linux\_gnu-llvm\_cpu][default-flags,compile-stats] 1078 (vs. 1102, 2.18%↓)
MobileBertSquad\_int8(tflite) [riscv\_32-generic-linux\_gnu-llvm\_cpu][default-flags,compile-stats] 1078 (vs. 1102, 2.18%↓)

[Top 3 out of 11 results showed]

For more information:

Source Workflow Run

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @pashu123 you can check if the PR addresses your fusion issue, FYI.

@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_fix_dequant_op_def branch from 1e4b977 to abd591b Compare June 25, 2024 05:24
The dequantization operation today is trying to enforce that the input
indexing map is an identity. This is overly conservative for newer
quantization schemes. This changes the logic to just look at operand
ranks to check if the operation is a dequantization operation.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_fix_dequant_op_def branch from 27533a3 to d8a3fe5 Compare June 25, 2024 17:53
@MaheshRavishankar MaheshRavishankar enabled auto-merge (squash) June 25, 2024 17:54
@MaheshRavishankar MaheshRavishankar merged commit 22cf0b0 into iree-org:main Jun 25, 2024
58 checks passed
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
…g#17711)

The dequantization operation today is trying to enforce that the input
indexing map is an identity. This is overly conservative for newer
quantization schemes. This changes the logic to just look at operand
ranks to check if the operation is a dequantization operation.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: Lubo Litchev <lubol@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:cuda Run default CUDA benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU benchmarks:x86_64 Run default x86_64 benchmarks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants