From b4321ea20b0e45c4179f800a1b65df44bef9ee48 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Thu, 13 Jun 2024 09:39:37 -0700 Subject: [PATCH] Update CI to include benchmarking changes in Test Suite. (#17655) This commit updates the SHARK-Test ref, config files, and yaml files to have the most up to date flags and benchmarking support. I will also concentrate on a python implementation for pulling in configs in Test-Suite, so we don't have to rely on all these config files. Checked the golden values over 15 times in the CI so should be good. This commit also adds support to the CI, so that it generates a job summary of the benchmark mean times for e2e and all the sub models. This can be seen by developers in the summary tab of the PckgCI testing. Example: https://github.com/iree-org/iree/actions/runs/9501523985 image Side note: The build_test_all_bazel was failing the first couple times and then passed. Seems to be unstable --------- Signed-off-by: saienduri --- .github/workflows/pkgci_regression_test.yml | 63 +- .../attention_and_matmul_spec.mlir | 621 +++++++++++++++++- .../onnx_cpu_llvm_sync.json | 2 + .../external_test_suite/onnx_gpu_cuda.json | 2 + .../onnx_gpu_rocm_rdna3.json | 2 + .../external_test_suite/onnx_gpu_vulkan.json | 2 + .../pytorch_models_cpu_llvm_task.json | 4 +- .../pytorch_models_gpu_rocm_gfx90a.json | 4 +- ...dels_gpu_rocm_gfx90a_additional_flags.json | 4 +- .../sdxl_prompt_encoder_cpu_llvm_task.json | 22 + .../sdxl_prompt_encoder_gpu_rocm_gfx90a.json | 34 + .../sdxl_scheduled_unet_gpu_rocm_gfx90a.json | 7 +- .../sdxl_vae_decode_cpu_llvm_task.json | 20 + .../sdxl_vae_decode_gpu_rocm_gfx90a.json | 27 + 14 files changed, 797 insertions(+), 17 deletions(-) create mode 100644 build_tools/pkgci/external_test_suite/sdxl_prompt_encoder_cpu_llvm_task.json create mode 100644 build_tools/pkgci/external_test_suite/sdxl_prompt_encoder_gpu_rocm_gfx90a.json create mode 100644 build_tools/pkgci/external_test_suite/sdxl_vae_decode_cpu_llvm_task.json create mode 100644 build_tools/pkgci/external_test_suite/sdxl_vae_decode_gpu_rocm_gfx90a.json diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml index 9921cf0ecda8..cbb0fa56702a 100644 --- a/.github/workflows/pkgci_regression_test.yml +++ b/.github/workflows/pkgci_regression_test.yml @@ -90,7 +90,7 @@ jobs: uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 with: repository: nod-ai/SHARK-TestSuite - ref: c9b3337e1f754c83d178568be1339aaef5f08045 + ref: ab932cc54f1e460ccd9b4a4f1efa07d0ee069eb5 path: SHARK-TestSuite submodules: false lfs: false @@ -138,15 +138,19 @@ jobs: # CPU - name: cpu_llvm_task models-config-file: pytorch_models_cpu_llvm_task.json - sdxl-config-file: sdxl_scheduled_unet_cpu_llvm_task.json + sdxl-unet-config-file: sdxl_scheduled_unet_cpu_llvm_task.json + sdxl-vae-config-file: sdxl_vae_decode_cpu_llvm_task.json + sdxl-clip-config-file: sdxl_prompt_encoder_cpu_llvm_task.json runs-on: nodai-amdgpu-w7900-x86-64 # AMD GPU - name: amdgpu_rocm_gfx90a models-config-file: pytorch_models_gpu_rocm_gfx90a.json models-extra-flags-config-file: pytorch_models_gpu_rocm_gfx90a_additional_flags.json - sdxl-config-file: sdxl_scheduled_unet_gpu_rocm_gfx90a.json - runs-on: nodai-amdgpu-mi250-x86-64 + sdxl-unet-config-file: sdxl_scheduled_unet_gpu_rocm_gfx90a.json + sdxl-vae-config-file: sdxl_vae_decode_gpu_rocm_gfx90a.json + sdxl-clip-config-file: sdxl_prompt_encoder_gpu_rocm_gfx90a.json + runs-on: nodai-amdgpu-mi210-x86-64 - name: amdgpu_vulkan models-config-file: pytorch_models_gpu_vulkan.json runs-on: nodai-amdgpu-w7900-x86-64 @@ -166,7 +170,9 @@ jobs: IREE_TEST_PATH_EXTENSION: ${{ github.workspace }}/build_tools/pkgci/external_test_suite MODELS_CONFIG_FILE_PATH: build_tools/pkgci/external_test_suite/${{ matrix.models-config-file }} MODELS_EXTRA_FLAGS_CONFIG_FILE_PATH: build_tools/pkgci/external_test_suite/${{ matrix.models-extra-flags-config-file }} - SDXL_CONFIG_FILE_PATH: build_tools/pkgci/external_test_suite/${{ matrix.sdxl-config-file }} + SDXL_UNET_CONFIG_FILE_PATH: build_tools/pkgci/external_test_suite/${{ matrix.sdxl-unet-config-file }} + SDXL_CLIP_CONFIG_FILE_PATH: build_tools/pkgci/external_test_suite/${{ matrix.sdxl-clip-config-file }} + SDXL_VAE_CONFIG_FILE_PATH: build_tools/pkgci/external_test_suite/${{ matrix.sdxl-vae-config-file }} VENV_DIR: ${{ github.workspace }}/venv steps: - name: Checking out IREE repository @@ -201,7 +207,7 @@ jobs: uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 with: repository: nod-ai/SHARK-TestSuite - ref: c9b3337e1f754c83d178568be1339aaef5f08045 + ref: ab932cc54f1e460ccd9b4a4f1efa07d0ee069eb5 path: SHARK-TestSuite submodules: false lfs: true @@ -243,7 +249,7 @@ jobs: --config-files=${MODELS_EXTRA_FLAGS_CONFIG_FILE_PATH} - name: "Run external tests - SDXL scheduled unet" - if: "matrix.sdxl-config-file != '' && !cancelled()" + if: "matrix.sdxl-unet-config-file != '' && !cancelled()" run: | source ${VENV_DIR}/bin/activate pytest SHARK-TestSuite/iree_tests/pytorch/models/sdxl-scheduled-unet-3-tank \ @@ -254,10 +260,49 @@ jobs: --log-cli-level=info \ --timeout=1200 \ --durations=0 \ - --config-files=${SDXL_CONFIG_FILE_PATH} + --config-files=${SDXL_UNET_CONFIG_FILE_PATH} + + - name: "Run external tests - SDXL prompt encoder" + if: "matrix.sdxl-clip-config-file != '' && !cancelled()" + run: | + source ${VENV_DIR}/bin/activate + pytest SHARK-TestSuite/iree_tests/pytorch/models/sdxl-prompt-encoder-tank \ + -rpfE \ + -k real_weights \ + --no-skip-tests-missing-files \ + --capture=no \ + --log-cli-level=info \ + --timeout=1200 \ + --durations=0 \ + --config-files=${SDXL_CLIP_CONFIG_FILE_PATH} + + - name: "Run external tests - SDXL vae decode" + if: "matrix.sdxl-vae-config-file != '' && !cancelled()" + run: | + source ${VENV_DIR}/bin/activate + pytest SHARK-TestSuite/iree_tests/pytorch/models/sdxl-vae-decode-tank \ + -rpfE \ + -k real_weights \ + --no-skip-tests-missing-files \ + --capture=no \ + --log-cli-level=info \ + --timeout=1200 \ + --durations=0 \ + --config-files=${SDXL_VAE_CONFIG_FILE_PATH} - name: "Running SDXL rocm pipeline benchmark" if: contains(matrix.name, 'rocm') run: | source ${VENV_DIR}/bin/activate - bash SHARK-TestSuite/iree_tests/benchmarks/benchmark_sdxl_rocm.sh + pytest SHARK-TestSuite/iree_tests/benchmarks/benchmark_sdxl_rocm.py \ + --goldentime-rocm-e2e-ms 1636 \ + --goldentime-rocm-unet-ms 442 \ + --goldentime-rocm-clip-ms 16.5 \ + --goldentime-rocm-vae-ms 285 \ + --gpu-number 3 \ + --rocm-chip gfx90a \ + --log-cli-level=info \ + --retries 7 + echo "### SDXL Benchmark Summary:" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY # this is a blank line + echo "$(> $GITHUB_STEP_SUMMARY diff --git a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir index 4a3b309b3841..ff5878ffb55c 100644 --- a/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir +++ b/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir @@ -432,7 +432,7 @@ module attributes { transform.with_named_sequence } { // Send it down a custom transform dialect pipeline. transform.named_sequence @custom_attention_len_512(%attention: !transform.any_op {transform.readonly}) { %func = transform.get_parent_op %attention {op_name = "func.func"} : (!transform.any_op) -> !transform.any_op - %attn = transform.param.constant #iree_codegen.translation_info -> !transform.any_param + %attn = transform.param.constant #iree_codegen.translation_info -> !transform.any_param transform.annotate %func "translation_info" = %attn : !transform.any_op, !transform.any_param transform.yield } @@ -447,7 +447,7 @@ module attributes { transform.with_named_sequence } { // Send it down a custom transform dialect pipeline. transform.named_sequence @custom_attention(%attention: !transform.any_op {transform.readonly}) { %func = transform.get_parent_op %attention {op_name = "func.func"} : (!transform.any_op) -> !transform.any_op - %attn = transform.param.constant #iree_codegen.translation_info -> !transform.any_param + %attn = transform.param.constant #iree_codegen.translation_info -> !transform.any_param transform.annotate %func "translation_info" = %attn : !transform.any_op, !transform.any_param transform.yield } @@ -460,6 +460,591 @@ module attributes { transform.with_named_sequence } { transform.yield %attention : !transform.any_op } +//===----------------------------------------------------------------------===// +// Matmul tuning +//===----------------------------------------------------------------------===// + + transform.named_sequence @match_mmt_f16_f16_f32(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.match.operation_name %root ["linalg.generic"] : !transform.any_op + // transform.print %root {name = "Generic"} : !transform.any_op + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root { + ^bb0(%lhs: tensor, %rhs: tensor, %out: tensor): + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%out : tensor) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %8 = arith.extf %in : f16 to f32 + %9 = arith.extf %in_0 : f16 to f32 + %10 = arith.mulf %8, %9 : f32 + %11 = arith.addf %acc, %10 : f32 + linalg.yield %11 : f32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + transform.yield %root : !transform.any_op + } + + transform.named_sequence @match_mmt_f16_f16_f16(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.match.operation_name %root ["linalg.generic"] : !transform.any_op + // transform.print %root {name = "Generic"} : !transform.any_op + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root { + ^bb0(%lhs: tensor, %rhs: tensor, %out: tensor): + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%out : tensor) { + ^bb0(%in: f16, %in_0: f16, %acc: f16): + %10 = arith.mulf %in, %in_0 : f16 + %11 = arith.addf %acc, %10 : f16 + linalg.yield %11 : f16 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + transform.yield %root : !transform.any_op + } + + transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly}, %config: !transform.any_param {transform.readonly}) { + transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param + // transform.print %op {name = "Applied"} : !transform.any_op + transform.yield + } + + transform.named_sequence @match_mmt_2048x10240x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<10240x1280xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 2> + , no_reorder_workgroups, llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_2048x1280x5120(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<2048x5120xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x5120xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 1> + , no_reorder_workgroups, llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_2048x1280x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x1280xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 1> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_8192x5120x640(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<8192x640xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<5120x640xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 1> + }> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_8192x640x2560(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<8192x2560xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x2560xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_8192x640x640(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<8192x640xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x640xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 4, subgroup_n_count = 1> + , no_reorder_workgroups, llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + +//===----------------------------------------------------------------------===// +// Convolution tuning +//===----------------------------------------------------------------------===// + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x640(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x640xf16>, %rhs: tensor<3x3x640x1280xf16>, %out: tensor<2x32x32x1280xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x640xf16>, tensor<3x3x640x1280xf16>) + outs(%out : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 5> + }> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x1280xf16>, %rhs: tensor<3x3x1280x1280xf16>, %out: tensor<2x32x32x1280xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x1280xf16>, tensor<3x3x1280x1280xf16>) + outs(%out : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 4> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1920(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x1920xf16>, %rhs: tensor<3x3x1920x1280xf16>, %out: tensor<2x32x32x1280xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x1920xf16>, tensor<3x3x1920x1280xf16>) + outs(%out : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 5> + , no_reorder_workgroups}> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x2560(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x2560xf16>, %rhs: tensor<3x3x2560x1280xf16>, %out: tensor<2x32x32x1280xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x2560xf16>, tensor<3x3x2560x1280xf16>) + outs(%out : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 5> + }> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x320(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x320xf16>, %rhs: tensor<3x3x320x320xf16>, %out: tensor<2x128x128x320xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x320xf16>, tensor<3x3x320x320xf16>) + outs(%out : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 4, subgroup_n_count = 2> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x640(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x640xf16>, %rhs: tensor<3x3x640x640xf16>, %out: tensor<2x64x64x640xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x640xf16>, tensor<3x3x640x640xf16>) + outs(%out : tensor<2x64x64x640xf32>) -> tensor<2x64x64x640xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 5> + , no_reorder_workgroups}> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + +//===----------------------------------------------------------------------===// +// Batch matmul tuning +//===----------------------------------------------------------------------===// + + transform.named_sequence @match_batch_matmul_64x968x320x640(%batch_matmul: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul { + ^bb0(%lhs: tensor<64x968x640xf16>, %rhs: tensor<64x640x320xf16>, %out: tensor<64x968x320xf32>): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : tensor<64x968x640xf16>, tensor<64x640x320xf16>) + outs(%out : tensor<64x968x320xf32>) -> tensor<64x968x320xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 4, subgroup_n_count = 1> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_batch_matmul_64x968x640x640(%batch_matmul: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul { + ^bb0(%lhs: tensor<64x968x640xf16>, %rhs: tensor<64x640x640xf16>, %out: tensor<64x968x640xf32>): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : tensor<64x968x640xf16>, tensor<64x640x640xf16>) + outs(%out : tensor<64x968x640xf32>) -> tensor<64x968x640xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 4, subgroup_n_count = 1> + , no_reorder_workgroups}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_batch_matmul_64x968x320x960(%batch_matmul: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul { + ^bb0(%lhs: tensor<64x968x960xf16>, %rhs: tensor<64x960x320xf16>, %out: tensor<64x968x320xf32>): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : tensor<64x968x960xf16>, tensor<64x960x320xf16>) + outs(%out : tensor<64x968x320xf32>) -> tensor<64x968x320xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 4, subgroup_n_count = 1> + , no_reorder_workgroups, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_batch_matmul_64x242x640x960(%batch_matmul: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul { + ^bb0(%lhs: tensor<64x242x960xf16>, %rhs: tensor<64x960x640xf16>, %out: tensor<64x242x640xf32>): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : tensor<64x242x960xf16>, tensor<64x960x640xf16>) + outs(%out : tensor<64x242x640xf32>) -> tensor<64x242x640xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_batch_matmul_64x242x1280x1280(%batch_matmul: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul { + ^bb0(%lhs: tensor<64x242x1280xf16>, %rhs: tensor<64x1280x1280xf16>, %out: tensor<64x242x1280xf32>): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : tensor<64x242x1280xf16>, tensor<64x1280x1280xf16>) + outs(%out : tensor<64x242x1280xf32>) -> tensor<64x242x1280xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2> + , no_reorder_workgroups}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_batch_matmul_64x242x640x1280(%batch_matmul: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul { + ^bb0(%lhs: tensor<64x242x1280xf16>, %rhs: tensor<64x1280x640xf16>, %out: tensor<64x242x640xf32>): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : tensor<64x242x1280xf16>, tensor<64x1280x640xf16>) + outs(%out : tensor<64x242x640xf32>) -> tensor<64x242x640xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2> + , no_reorder_workgroups}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_batch_matmul_64x242x640x1920(%batch_matmul: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul { + ^bb0(%lhs: tensor<64x242x1920xf16>, %rhs: tensor<64x1920x640xf16>, %out: tensor<64x242x640xf32>): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : tensor<64x242x1920xf16>, tensor<64x1920x640xf16>) + outs(%out : tensor<64x242x640xf32>) -> tensor<64x242x640xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + } + +//===----------------------------------------------------------------------===// +// Contraction tuning +//===----------------------------------------------------------------------===// + + transform.named_sequence @match_contract_3x2x20x1024x64x1280(%contract: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { + ^bb0(%lhs: tensor<2x1024x1280xf16>, %rhs: tensor<3x20x64x1280xf16>, %out: tensor<3x2x20x1024x64xf32>): + %20 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] + } ins(%lhs, %rhs : tensor<2x1024x1280xf16>, tensor<3x20x64x1280xf16>) + outs(%out : tensor<3x2x20x1024x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %22 = arith.extf %in : f16 to f32 + %23 = arith.extf %in_0 : f16 to f32 + %24 = arith.mulf %22, %23 : f32 + %25 = arith.addf %acc, %24 : f32 + linalg.yield %25 : f32 + } -> tensor<3x2x20x1024x64xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 4, subgroup_n_count = 1> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %contract, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_contract_3x2x10x4096x64x640(%contract: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { + ^bb0(%lhs: tensor<2x4096x640xf16>, %rhs: tensor<3x10x64x640xf16>, %out: tensor<3x2x10x4096x64xf32>): + %20 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] + } ins(%lhs, %rhs : tensor<2x4096x640xf16>, tensor<3x10x64x640xf16>) + outs(%out : tensor<3x2x10x4096x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %22 = arith.extf %in : f16 to f32 + %23 = arith.extf %in_0 : f16 to f32 + %24 = arith.mulf %22, %23 : f32 + %25 = arith.addf %acc, %24 : f32 + linalg.yield %25 : f32 + } -> tensor<3x2x10x4096x64xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 4, subgroup_n_count = 1> + , llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}}> + > -> !transform.any_param + transform.yield %contract, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_contract_2x10x64x64x2048(%contract: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { + ^bb0(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<10x64x2048xf16>, %out: tensor<2x10x64x64xf32>): + %14 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"] + } ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<10x64x2048xf16>) + outs(%out : tensor<2x10x64x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %16 = arith.extf %in : f16 to f32 + %17 = arith.extf %in_0 : f16 to f32 + %18 = arith.mulf %16, %17 : f32 + %19 = arith.addf %acc, %18 : f32 + linalg.yield %19 : f32 + } -> tensor<2x10x64x64xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2> + }> + > -> !transform.any_param + transform.yield %contract, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_contract_2x20x64x64x2048(%contract: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { + ^bb0(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<20x64x2048xf16>, %out: tensor<2x20x64x64xf32>): + %14 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"] + } ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<20x64x2048xf16>) + outs(%out : tensor<2x20x64x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %16 = arith.extf %in : f16 to f32 + %17 = arith.extf %in_0 : f16 to f32 + %18 = arith.mulf %16, %17 : f32 + %19 = arith.addf %acc, %18 : f32 + linalg.yield %19 : f32 + } -> tensor<2x20x64x64xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2> + }> + > -> !transform.any_param + transform.yield %contract, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_contract_2x20x1024x64x1280(%contract: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { + ^bb0(%lhs: tensor<2x1024x1280xf16>, %rhs: tensor<20x64x1280xf16>, %out: tensor<2x20x1024x64xf32>): + %20 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"] + } ins(%lhs, %rhs : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) + outs(%out : tensor<2x20x1024x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %22 = arith.extf %in : f16 to f32 + %23 = arith.extf %in_0 : f16 to f32 + %24 = arith.mulf %22, %23 : f32 + %25 = arith.addf %acc, %24 : f32 + linalg.yield %25 : f32 + } -> tensor<2x20x1024x64xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 4, subgroup_n_count = 1> + }> + > -> !transform.any_param + transform.yield %contract, %config : !transform.any_op, !transform.any_param + } + //===----------------------------------------------------------------------===// // Entry point //===----------------------------------------------------------------------===// @@ -469,6 +1054,38 @@ module attributes { transform.with_named_sequence } { // Attention. @match_attention_len_512 -> @custom_attention_len_512, @match_attention -> @custom_attention + + // Matmul. + , @match_mmt_2048x10240x1280 -> @apply_op_config + , @match_mmt_2048x1280x5120 -> @apply_op_config + , @match_mmt_2048x1280x1280 -> @apply_op_config + , @match_mmt_8192x5120x640 -> @apply_op_config + , @match_mmt_8192x640x2560 -> @apply_op_config + , @match_mmt_8192x640x640 -> @apply_op_config + + // Convolution. + , @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x640 -> @apply_op_config + , @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280 -> @apply_op_config + , @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1920 -> @apply_op_config + , @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x2560 -> @apply_op_config + , @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x640 -> @apply_op_config + , @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x320 -> @apply_op_config + + // Batch matmul. + , @match_batch_matmul_64x968x320x640 -> @apply_op_config + , @match_batch_matmul_64x968x640x640 -> @apply_op_config + , @match_batch_matmul_64x968x320x960 -> @apply_op_config + , @match_batch_matmul_64x242x1280x1280 -> @apply_op_config + , @match_batch_matmul_64x242x640x960 -> @apply_op_config + , @match_batch_matmul_64x242x640x1280 -> @apply_op_config + , @match_batch_matmul_64x242x640x1920 -> @apply_op_config + + // Contration. + , @match_contract_3x2x20x1024x64x1280 -> @apply_op_config + , @match_contract_3x2x10x4096x64x640 -> @apply_op_config + , @match_contract_2x10x64x64x2048 -> @apply_op_config + , @match_contract_2x20x64x64x2048 -> @apply_op_config + , @match_contract_2x20x1024x64x1280 -> @apply_op_config : (!transform.any_op) -> (!transform.any_op) transform.yield } diff --git a/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json b/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json index 8f9615459ab5..258ca1c13c55 100644 --- a/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json +++ b/build_tools/pkgci/external_test_suite/onnx_cpu_llvm_sync.json @@ -131,6 +131,7 @@ "test_dequantizelinear_axis", "test_dequantizelinear_blocked", "test_dequantizelinear_e4m3fn", + "test_dequantizelinear_e4m3fn_float16", "test_dequantizelinear_e4m3fn_zero_point", "test_dequantizelinear_e5m2", "test_dequantizelinear_int16", @@ -583,6 +584,7 @@ "test_hardsigmoid_example", "test_hardswish_expanded", "test_max_float64", + "test_maxpool_2d_ceil_output_size_reduce_by_one", "test_min_float64", "test_mod_mixed_sign_int16", "test_mod_mixed_sign_int32", diff --git a/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json b/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json index 5f1d3b518d9c..bead5fe62684 100644 --- a/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json +++ b/build_tools/pkgci/external_test_suite/onnx_gpu_cuda.json @@ -137,6 +137,7 @@ "test_dequantizelinear_axis", "test_dequantizelinear_blocked", "test_dequantizelinear_e4m3fn", + "test_dequantizelinear_e4m3fn_float16", "test_dequantizelinear_e4m3fn_zero_point", "test_dequantizelinear_e5m2", "test_dequantizelinear_int16", @@ -585,6 +586,7 @@ "test_hardsigmoid_example", "test_hardswish_expanded", "test_max_float64", + "test_maxpool_2d_ceil_output_size_reduce_by_one", "test_min_float64", "test_mod_mixed_sign_float64", "test_mod_mixed_sign_int16", diff --git a/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json b/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json index bac1369fd991..89715d74d95e 100644 --- a/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json +++ b/build_tools/pkgci/external_test_suite/onnx_gpu_rocm_rdna3.json @@ -132,6 +132,7 @@ "test_dequantizelinear_axis", "test_dequantizelinear_blocked", "test_dequantizelinear_e4m3fn", + "test_dequantizelinear_e4m3fn_float16", "test_dequantizelinear_e4m3fn_zero_point", "test_dequantizelinear_e5m2", "test_dequantizelinear_int16", @@ -590,6 +591,7 @@ "test_hardsigmoid_example", "test_hardswish_expanded", "test_max_float64", + "test_maxpool_2d_ceil_output_size_reduce_by_one", "test_min_float64", "test_mod_mixed_sign_float64", "test_mod_mixed_sign_int16", diff --git a/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json b/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json index 21c894bb774e..db27108dc5d9 100644 --- a/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json +++ b/build_tools/pkgci/external_test_suite/onnx_gpu_vulkan.json @@ -147,6 +147,7 @@ "test_dequantizelinear_axis", "test_dequantizelinear_blocked", "test_dequantizelinear_e4m3fn", + "test_dequantizelinear_e4m3fn_float16", "test_dequantizelinear_e4m3fn_zero_point", "test_dequantizelinear_e5m2", "test_dequantizelinear_int16", @@ -647,6 +648,7 @@ "test_max_float64", "test_max_int16", "test_max_int8", + "test_maxpool_2d_ceil_output_size_reduce_by_one", "test_min_float16", "test_min_float64", "test_min_int16", diff --git a/build_tools/pkgci/external_test_suite/pytorch_models_cpu_llvm_task.json b/build_tools/pkgci/external_test_suite/pytorch_models_cpu_llvm_task.json index 422f9ab19621..0e9793064250 100644 --- a/build_tools/pkgci/external_test_suite/pytorch_models_cpu_llvm_task.json +++ b/build_tools/pkgci/external_test_suite/pytorch_models_cpu_llvm_task.json @@ -8,7 +8,9 @@ "--device=local-task" ], "skip_compile_tests": [ - "sdxl-scheduled-unet-3-tank" + "sdxl-scheduled-unet-3-tank", + "sdxl-prompt-encoder-tank", + "sdxl-vae-decode-tank" ], "skip_run_tests": [], "expected_compile_failures": [ diff --git a/build_tools/pkgci/external_test_suite/pytorch_models_gpu_rocm_gfx90a.json b/build_tools/pkgci/external_test_suite/pytorch_models_gpu_rocm_gfx90a.json index a478db3dc07d..c3ba3502518f 100644 --- a/build_tools/pkgci/external_test_suite/pytorch_models_gpu_rocm_gfx90a.json +++ b/build_tools/pkgci/external_test_suite/pytorch_models_gpu_rocm_gfx90a.json @@ -9,7 +9,9 @@ "--device=hip" ], "skip_compile_tests": [ - "sdxl-scheduled-unet-3-tank" + "sdxl-scheduled-unet-3-tank", + "sdxl-prompt-encoder-tank", + "sdxl-vae-decode-tank" ], "skip_run_tests": [], "expected_compile_failures": [ diff --git a/build_tools/pkgci/external_test_suite/pytorch_models_gpu_rocm_gfx90a_additional_flags.json b/build_tools/pkgci/external_test_suite/pytorch_models_gpu_rocm_gfx90a_additional_flags.json index 7939b899f291..59abf04d3165 100644 --- a/build_tools/pkgci/external_test_suite/pytorch_models_gpu_rocm_gfx90a_additional_flags.json +++ b/build_tools/pkgci/external_test_suite/pytorch_models_gpu_rocm_gfx90a_additional_flags.json @@ -11,7 +11,9 @@ "--device=hip" ], "skip_compile_tests": [ - "sdxl-scheduled-unet-3-tank" + "sdxl-scheduled-unet-3-tank", + "sdxl-prompt-encoder-tank", + "sdxl-vae-decode-tank" ], "skip_run_tests": [], "expected_compile_failures": [ diff --git a/build_tools/pkgci/external_test_suite/sdxl_prompt_encoder_cpu_llvm_task.json b/build_tools/pkgci/external_test_suite/sdxl_prompt_encoder_cpu_llvm_task.json new file mode 100644 index 000000000000..cc39c2d53d9e --- /dev/null +++ b/build_tools/pkgci/external_test_suite/sdxl_prompt_encoder_cpu_llvm_task.json @@ -0,0 +1,22 @@ +{ + "config_name": "cpu_llvm_task", + "iree_compile_flags" : [ + "--iree-hal-target-backends=llvm-cpu", + "--iree-llvmcpu-target-cpu-features=host" + ], + "iree_run_module_flags": [ + "--device=local-task", + "--parameters=model=real_weights.irpa", + "--input=1x64xi64=@inference_input.0.bin", + "--input=1x64xi64=@inference_input.1.bin", + "--input=1x64xi64=@inference_input.2.bin", + "--input=1x64xi64=@inference_input.3.bin", + "--expected_output=2x64x2048xf16=@inference_output.0.bin", + "--expected_output=2x1280xf16=@inference_output.1.bin", + "--expected_f16_threshold=1.0f" + ], + "skip_compile_tests": [], + "skip_run_tests": [], + "expected_compile_failures": [], + "expected_run_failures": [] +} diff --git a/build_tools/pkgci/external_test_suite/sdxl_prompt_encoder_gpu_rocm_gfx90a.json b/build_tools/pkgci/external_test_suite/sdxl_prompt_encoder_gpu_rocm_gfx90a.json new file mode 100644 index 000000000000..186a05488402 --- /dev/null +++ b/build_tools/pkgci/external_test_suite/sdxl_prompt_encoder_gpu_rocm_gfx90a.json @@ -0,0 +1,34 @@ +{ + "config_name": "gpu_rocm", + "iree_compile_flags": [ + "--iree-hal-target-backends=rocm", + "--iree-rocm-target-chip=gfx90a", + "--iree-input-type=torch", + "--iree-opt-const-eval=false", + "--iree-global-opt-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-rocm-waves-per-eu=2", + "--iree-llvmgpu-enable-prefetch", + "--iree-flow-enable-aggressive-fusion", + "--iree-global-opt-enable-fuse-horizontal-contractions=true", + "--iree-opt-aggressively-propagate-transposes=true", + "--iree-codegen-llvmgpu-use-vector-distribution=true", + "--iree-execution-model=async-external", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))" + ], + "iree_run_module_flags": [ + "--device=hip", + "--parameters=model=real_weights.irpa", + "--input=1x64xi64=@inference_input.0.bin", + "--input=1x64xi64=@inference_input.1.bin", + "--input=1x64xi64=@inference_input.2.bin", + "--input=1x64xi64=@inference_input.3.bin", + "--expected_output=2x64x2048xf16=@inference_output.0.bin", + "--expected_output=2x1280xf16=@inference_output.1.bin", + "--expected_f16_threshold=1.0f" + ], + "skip_compile_tests": [], + "skip_run_tests": [], + "expected_compile_failures": [], + "expected_run_failures": [] +} diff --git a/build_tools/pkgci/external_test_suite/sdxl_scheduled_unet_gpu_rocm_gfx90a.json b/build_tools/pkgci/external_test_suite/sdxl_scheduled_unet_gpu_rocm_gfx90a.json index 743b20bd2857..5731caef95ee 100644 --- a/build_tools/pkgci/external_test_suite/sdxl_scheduled_unet_gpu_rocm_gfx90a.json +++ b/build_tools/pkgci/external_test_suite/sdxl_scheduled_unet_gpu_rocm_gfx90a.json @@ -1,9 +1,8 @@ { "config_name": "gpu_rocm", - "iree_compile_flags": [ + "iree_compile_flags" : [ "--iree-hal-target-backends=rocm", "--iree-rocm-target-chip=gfx90a", - "--iree-input-demote-f64-to-f32", "--iree-opt-const-eval=false", "--iree-codegen-transform-dialect-library=${IREE_TEST_PATH_EXTENSION}/attention_and_matmul_spec.mlir", "--iree-global-opt-propagate-transposes=true", @@ -16,6 +15,8 @@ "--iree-opt-data-tiling=false", "--iree-codegen-gpu-native-math-precision=true", "--iree-codegen-llvmgpu-use-vector-distribution", + "--iree-rocm-waves-per-eu=2", + "--iree-execution-model=async-external", "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))" ], "iree_run_module_flags": [ @@ -27,7 +28,7 @@ "--input=2x1280xf16=@inference_input.2.bin", "--input=1xf16=@inference_input.3.bin", "--expected_output=1x4x128x128xf16=@inference_output.0.bin", - "--expected_f16_threshold=0.8f" + "--expected_f16_threshold=0.7f" ], "skip_compile_tests": [], "skip_run_tests": [], diff --git a/build_tools/pkgci/external_test_suite/sdxl_vae_decode_cpu_llvm_task.json b/build_tools/pkgci/external_test_suite/sdxl_vae_decode_cpu_llvm_task.json new file mode 100644 index 000000000000..a6f517f8b805 --- /dev/null +++ b/build_tools/pkgci/external_test_suite/sdxl_vae_decode_cpu_llvm_task.json @@ -0,0 +1,20 @@ +{ + "config_name": "cpu_llvm_task", + "iree_compile_flags" : [ + "--iree-hal-target-backends=llvm-cpu", + "--iree-llvmcpu-target-cpu-features=host" + ], + "iree_run_module_flags": [ + "--device=local-task", + "--parameters=model=real_weights.irpa", + "--input=1x4x128x128xf16=@inference_input.0.bin", + "--expected_output=1x3x1024x1024xf16=@inference_output.0.bin", + "--expected_f16_threshold=0.02f" + ], + "skip_compile_tests": [], + "skip_run_tests": [], + "expected_compile_failures": [], + "expected_run_failures": [ + "sdxl-vae-decode-tank" + ] +} diff --git a/build_tools/pkgci/external_test_suite/sdxl_vae_decode_gpu_rocm_gfx90a.json b/build_tools/pkgci/external_test_suite/sdxl_vae_decode_gpu_rocm_gfx90a.json new file mode 100644 index 000000000000..57a82e98a03e --- /dev/null +++ b/build_tools/pkgci/external_test_suite/sdxl_vae_decode_gpu_rocm_gfx90a.json @@ -0,0 +1,27 @@ +{ + "config_name": "gpu_rocm", + "iree_compile_flags" : [ + "--iree-hal-target-backends=rocm", + "--iree-rocm-target-chip=gfx90a", + "--iree-opt-const-eval=false", + "--iree-global-opt-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-rocm-waves-per-eu=2", + "--iree-flow-enable-aggressive-fusion", + "--iree-codegen-llvmgpu-use-vector-distribution=true", + "--iree-execution-model=async-external", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))" + ], + "iree_run_module_flags": [ + "--device=hip", + "--parameters=model=real_weights.irpa", + "--input=1x4x128x128xf16=@inference_input.0.bin", + "--expected_output=1x3x1024x1024xf16=@inference_output.0.bin", + "--expected_f16_threshold=0.4f" + ], + "skip_compile_tests": [], + "skip_run_tests": [], + "expected_compile_failures": [], + "expected_run_failures": [] +}