Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flow] Make sink reshapes changes less conservative. #17706

Conversation

MaheshRavishankar
Copy link
Contributor

While deciding if a reshape needs "sinking", for a tensor.expand_shape -> linalg.*, first check was to check that the linalg.* operation could already fuse with one of its existing producers. That check was broadly aggressive. The fusion only kicks in when the iteration domains match. Eventually the actual dispatch formation logic needs to be commoned to a single place to do this better, but kicking that to a follow up.

@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_sink_reshape_fixes branch from 3a50abd to e3066d6 Compare June 19, 2024 20:32
@MaheshRavishankar MaheshRavishankar added benchmarks:cuda Run default CUDA benchmarks benchmarks:x86_64 Run default x86_64 benchmarks benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU labels Jun 19, 2024
Copy link

github-actions bot commented Jun 19, 2024

Abbreviated Benchmark Summary

@ commit 319360d130b43a73ace1fba52c56f4b90db2b5f0 (vs. base 7090f64b7bd60a597ee6c61b8ff0d624153cb7f0)

Data-Tiling Comparison Table

Click to show
Name No-DT (baseline) DT-Only DT-UK
BertLargeTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 730.322 (1.0X) N/A 223.605 (3.3X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 7.005 (1.0X) N/A 8.493 (0.8X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 35.863 (1.0X) N/A 34.345 (1.0X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.815 (1.0X) N/A 5.019 (1.2X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 9.126 (1.0X) N/A 8.438 (1.1X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.013 (1.0X) N/A 8.887 (1.2X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.713 (1.0X) N/A 13.952 (0.8X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.390 (1.0X) N/A 61.378 (0.6X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.865 (1.0X) N/A 61.736 (0.6X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 68.659 (1.0X) N/A 64.516 (1.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.641 (1.0X) N/A 4.569 (1.0X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 3.715 (1.0X) N/A 4.960 (0.7X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.837 (1.0X) N/A 5.429 (1.1X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 2.920 (1.0X) N/A 2.838 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.467 (1.0X) N/A 9.945 (0.9X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 0.762 (1.0X) N/A 0.585 (1.3X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.107 (1.0X) N/A 5.281 (0.8X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 7.581 (1.0X) N/A 7.605 (1.0X)
matmul_256x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 6.647 (1.0X) N/A 1.811 (3.7X)
BertForMaskedLMTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 216.242 (1.0X) N/A 106.568 (2.0X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 32.105 (1.0X) N/A 29.782 (1.1X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 273.629 (1.0X) N/A 230.326 (1.2X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 26.973 (1.0X) N/A 13.123 (2.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 69.650 (1.0X) N/A 39.601 (1.8X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 88.173 (1.0X) N/A 41.863 (2.1X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 79.414 (1.0X) N/A 56.655 (1.4X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 180.533 (1.0X) N/A 185.795 (1.0X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 180.533 (1.0X) N/A 189.816 (1.0X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 516.670 (1.0X) N/A 240.772 (2.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 25.273 (1.0X) N/A 17.780 (1.4X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.943 (1.0X) N/A 12.290 (1.0X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 21.588 (1.0X) N/A 11.806 (1.8X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 2.794 (1.0X) N/A 2.663 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.349 (1.0X) N/A 31.471 (1.1X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.698 (1.0X) N/A 0.522 (1.3X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 17.841 (1.0X) N/A 19.751 (0.9X)
matmul_1x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.054 (1.0X) N/A 0.054 (1.0X)
matmul_1x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.042 (1.0X) N/A 0.021 (2.0X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 48.850 (1.0X) N/A 43.231 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 49.915 (1.0X) N/A 43.708 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 30.041 (1.0X) N/A 27.063 (1.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 92.877 (1.0X) N/A 22.147 (4.2X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 93.378 (1.0X) N/A 22.607 (4.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 52.476 (1.0X) N/A 22.091 (2.4X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 126.762 (1.0X) N/A 27.145 (4.7X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 128.063 (1.0X) N/A 28.541 (4.5X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 70.865 (1.0X) N/A 26.528 (2.7X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 703.885 (1.0X) N/A 352.091 (2.0X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 701.687 (1.0X) N/A 355.356 (2.0X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 392.740 (1.0X) N/A 213.765 (1.8X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 1059.705 (1.0X) N/A 276.894 (3.8X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1060.070 (1.0X) N/A 274.291 (3.9X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 548.528 (1.0X) N/A 158.360 (3.5X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 2060.557 (1.0X) N/A 297.914 (6.9X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 2061.551 (1.0X) N/A 296.252 (7.0X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1097.600 (1.0X) N/A 176.258 (6.2X)
matmul_1x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 0.080 (1.0X) N/A 0.016 (5.0X)
matmul_1x256x2048_i8_i8_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 0.071 (1.0X) N/A 0.016 (4.3X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 11.985 (1.0X) N/A 1.309 (9.2X)
matmul_256x256x2048_i8_i8_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 16.541 (1.0X) N/A 1.088 (15.2X)

Regressed Latencies 🚩

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
MobileBertSquad\_fp16(tflite) [arm-valhall-vulkan\_android31-vulkan\_spirv][experimental-flags,fuse-padding,max-concurrency,demote-f32-to-f16] vulkan(none)[full-inference,default-flags] with default @ pixel-6-pro[gpu] 109.247 (vs. 94.876, 15.15%↑) 110.828 3.575
BertForMaskedLMTF(stablehlo) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][experimental-flags,no-dt] local\_task(embedded\_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 216.242 (vs. 191.491, 12.93%↑) 211.222 15.908

Improved Latencies 🎉

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
matmul\_1x256x2048\_i8\_i8\_i32\_tile\_config\_default(linalg) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_sync(embedded\_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 0.016 (vs. 0.019, 12.02%↓) 0.017 0.000
GPT2\_117M\_TF\_1X4XI32(stablehlo) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 28.541 (vs. 31.259, 8.70%↓) 28.533 0.660
MobileBertSquad\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 158.360 (vs. 173.127, 8.53%↓) 159.559 3.329

[Top 3 out of 13 results showed]

Regressed Stream IR Dispatch Count (# of cmd.dispatch ops) 🚩

Benchmark Name Stream IR Dispatch Count (# of cmd.dispatch ops)
Vit\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][experimental-flags,no-dt,compile-stats] 255 (vs. 243, 4.94%↑)

For more information:

Source Workflow Run

@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_sink_reshape_fixes branch 2 times, most recently from 6cae1e9 to 5c9dac6 Compare June 22, 2024 21:27
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving but dispatch regression of 330 -> 380 needs investigation before landing.

@MaheshRavishankar
Copy link
Contributor Author

Have been trying to repro it. Can't repro locally.

@IanWood1
Copy link
Contributor

IanWood1 commented Jun 27, 2024

I think the regression can be seen here:

After

full after change
ir before commit

  %expanded_181 = tensor.expand_shape %43 [[0, 1], [2]] output_shape [1, 4, 768] : tensor<4x768xf32> into tensor<1x4x768xf32>
  %expanded_182 = tensor.expand_shape %46 [[0, 1]] output_shape [1, 4] : tensor<4xf32> into tensor<1x4xf32>
  %expanded_183 = tensor.expand_shape %45 [[0, 1]] output_shape [1, 4] : tensor<4xf32> into tensor<1x4xf32>
  %47 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>,
	  affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, 
      affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 
    iterator_types = ["parallel", "parallel", "parallel"]}
    ins(%expanded_181, %expanded_182, %cst_60, %cst_59, %expanded_183 
      : tensor<1x4x768xf32>, tensor<1x4xf32>, tensor<768xf32>, tensor<768xf32>, tensor<1x4xf32>)
    outs(%7 : tensor<1x4x768xf32>) {
  ^bb0(%in: f32, %in_584: f32, %in_585: f32, %in_586: f32, %in_587: f32, %out: f32):
    %355 = arith.divf %in_584, %cst_9 : f32
    %356 = arith.addf %355, %cst_8 : f32
    %357 = math.rsqrt %356 : f32
    %358 = arith.mulf %357, %in_585 : f32
    %359 = arith.mulf %in_587, %358 : f32
    %360 = arith.subf %in_586, %359 : f32
    %361 = arith.mulf %in, %358 : f32
    %362 = arith.addf %361, %360 : f32
    linalg.yield %362 : f32
  } -> tensor<1x4x768xf32>
  %collapsed_184 = tensor.collapse_shape %47 [[0, 1], [2]] : tensor<1x4x768xf32> into tensor<4x768xf32>

It should be possible to sink the expanded dims and fold with the collapse shape, but since each input indexing map is a projection the consumer op isn't considered fusable with tensor.expand_shape's source's defining op. https://github.com/MaheshRavishankar/iree/blob/5c9dac66000f312854a016228633b8b3b52a1282/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp#L73C1-L77C2)

@MaheshRavishankar
Copy link
Contributor Author

Thanks Ian... let me look into this a little bit more.

@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_sink_reshape_fixes branch 2 times, most recently from 41ed0e1 to 9bb15c6 Compare June 28, 2024 07:02
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems mostly okay to me, but I'm not sure about some of the logic. It seems like there could be some edge cases that won't be handled right.

While deciding if a reshape needs "sinking", for a
`tensor.expand_shape` -> `linalg.*`, first check was to check that the
`linalg.*` operation could already fuse with one of its existing
producers. That check was broadly aggressive. The fusion only kicks in
when the iteration domains match. Eventually the actual dispatch
formation logic needs to be commoned to a single place to do this
better, but kicking that to a follow up.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_sink_reshape_fixes branch from 9bb15c6 to 253aecc Compare June 28, 2024 23:26
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@MaheshRavishankar MaheshRavishankar force-pushed the sdxl_quantized_sink_reshape_fixes branch from b808b9b to f8163b2 Compare June 29, 2024 01:53
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@MaheshRavishankar
Copy link
Contributor Author

Ok, folks. I am landing this with the one regression in number of dispatches. I am checking some out of tree models and this seems to do better.

@MaheshRavishankar MaheshRavishankar merged commit 4ad00ef into iree-org:main Jun 29, 2024
60 checks passed
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
While deciding if a reshape needs "sinking", for a `tensor.expand_shape`
-> `linalg.*`, first check was to check that the `linalg.*` operation
could already fuse with one of its existing producers. That check was
broadly aggressive. The fusion only kicks in when the iteration domains
match. Eventually the actual dispatch formation logic needs to be
commoned to a single place to do this better, but kicking that to a
follow up.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: Lubo Litchev <lubol@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:cuda Run default CUDA benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU benchmarks:x86_64 Run default x86_64 benchmarks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants